diff options
Diffstat (limited to 'model-integration/src/main/java')
2 files changed, 37 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java new file mode 100644 index 00000000000..8db237d15e4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -0,0 +1,35 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Softmax extends IntermediateOperation { + + public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + String dimension = inputType.dimensions().get(0).name(); + if (inputType.rank() == 2) { + dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension + } + + TensorFunction inputFunction = inputs.get(0).function().get(); + return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 1abbd0063a1..357794faee2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -2,6 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; +import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Sum; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; @@ -112,6 +113,7 @@ class GraphImporter { case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "softmax": return new Softmax(modelName, nodeName, inputs); // state ops case "variable": return new Constant(modelName, nodeName, nodeType); |