summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java4
1 files changed, 2 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index 4f656d86929..0d2ba0cc714 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -64,7 +64,7 @@ class GraphImporter {
case "identity": return new Identity(modelName, nodeName, inputs);
case "placeholder": return new Argument(modelName, nodeName, nodeType);
case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
case "shape": return new Shape(modelName, nodeName, inputs);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
@@ -113,7 +113,7 @@ class GraphImporter {
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "softmax": return new Softmax(modelName, nodeName, inputs);
+ case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
// state ops
case "variable": return new Constant(modelName, nodeName, nodeType);