diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 4f656d86929..0d2ba0cc714 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -64,7 +64,7 @@ class GraphImporter { case "identity": return new Identity(modelName, nodeName, inputs); case "placeholder": return new Argument(modelName, nodeName, nodeType); case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "shape": return new Shape(modelName, nodeName, inputs); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); @@ -113,7 +113,7 @@ class GraphImporter { case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); // state ops case "variable": return new Constant(modelName, nodeName, nodeType); |