diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java new file mode 100644 index 00000000000..8db237d15e4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -0,0 +1,35 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Softmax extends IntermediateOperation { + + public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + String dimension = inputType.dimensions().get(0).name(); + if (inputType.rank() == 2) { + dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension + } + + TensorFunction inputFunction = inputs.get(0).function().get(); + return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); + } + +} |