diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
commit | aad5c7184f37e1441c928efa77b434620742ff88 (patch) | |
tree | 34a92e7f954aa92e21d48816335771ff607fe404 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java | |
parent | 6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff) |
Update model-integration for supporting BERT-type models
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 83086926316..d03827f4c72 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -28,6 +28,10 @@ public class Softmax extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; + + // input is referenced twice due to overflow avoidance, so make this it's own function. + inputs.get(0).exportAsRankingFunction = true; + return inputs.get(0).type().get(); } @@ -50,7 +54,9 @@ public class Softmax extends IntermediateOperation { } TensorFunction input = inputs.get(0).function().get(); - TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction exp = new Map(cap, ScalarFunctions.exp()); TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); |