aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java9
1 files changed, 8 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 83086926316..e2b83246bfc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
@@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
+
+ // input is referenced twice due to avoidance of overflow. so make this it's own function.
+ inputs.get(0).exportAsRankingFunction = true;
+
return inputs.get(0).type().get();
}
@@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation {
}
TensorFunction input = inputs.get(0).function().get();
- TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction exp = new Map(cap, ScalarFunctions.exp());
TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());