aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-03 11:29:43 +0200
committerLester Solbakken <lesters@oath.com>2020-04-03 11:29:43 +0200
commit3789127189224d6cbd6f109b9a95f848869ea6cc (patch)
tree79cef74e6c61da059ed0eae79632fa001433ddc2 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
parent706cb2d3b2d623318ba9c0a8db0e4355448af65a (diff)
for testing onlylesters/bert-testing
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java9
1 files changed, 8 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 83086926316..e2b83246bfc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
@@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
+
+ // input is referenced twice due to avoidance of overflow. so make this it's own function.
+ inputs.get(0).exportAsRankingFunction = true;
+
return inputs.get(0).type().get();
}
@@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation {
}
TensorFunction input = inputs.get(0).function().get();
- TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction exp = new Map(cap, ScalarFunctions.exp());
TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());