diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-03 11:29:43 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-03 11:29:43 +0200 |
commit | 3789127189224d6cbd6f109b9a95f848869ea6cc (patch) | |
tree | 79cef74e6c61da059ed0eae79632fa001433ddc2 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java | |
parent | 706cb2d3b2d623318ba9c0a8db0e4355448af65a (diff) |
for testing onlylesters/bert-testing
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 83086926316..e2b83246bfc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Map; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; @@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; + + // input is referenced twice due to avoidance of overflow. so make this it's own function. + inputs.get(0).exportAsRankingFunction = true; + return inputs.get(0).type().get(); } @@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation { } TensorFunction input = inputs.get(0).function().get(); - TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction exp = new Map(cap, ScalarFunctions.exp()); TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); |