aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java35
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java2
2 files changed, 37 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
new file mode 100644
index 00000000000..8db237d15e4
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -0,0 +1,35 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Softmax extends IntermediateOperation {
+
+ public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+ return inputs.get(0).type().get();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ String dimension = inputType.dimensions().get(0).name();
+ if (inputType.rank() == 2) {
+ dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension
+ }
+
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index 1abbd0063a1..357794faee2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -2,6 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
+import ai.vespa.rankingexpression.importer.operations.Softmax;
import ai.vespa.rankingexpression.importer.operations.Sum;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
@@ -112,6 +113,7 @@ class GraphImporter {
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "softmax": return new Softmax(modelName, nodeName, inputs);
// state ops
case "variable": return new Constant(modelName, nodeName, nodeType);