summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-08-22 11:16:39 +0200
committerLester Solbakken <lesters@oath.com>2019-08-22 11:16:39 +0200
commit9eb33e8ced84849db9b4c75c20caf11bb053abb2 (patch)
treee2ea21f92d54b294fa1c48ae5762d2eee92479aa /model-integration
parenta3689d2d45e426bbfb4b924ae4994ed8aad46361 (diff)
Add varhandleop and readvariableop to supported TF import operations
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java15
2 files changed, 13 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index 0e9c98b2b56..4f656d86929 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -34,7 +34,6 @@ import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
-import org.tensorflow.op.core.DecodeRaw;
import java.io.IOException;
import java.util.List;
@@ -119,6 +118,8 @@ class GraphImporter {
// state ops
case "variable": return new Constant(modelName, nodeName, nodeType);
case "variablev2": return new Constant(modelName, nodeName, nodeType);
+ case "varhandleop": return new Constant(modelName, nodeName, nodeType);
+ case "readvariableop":return new Identity(modelName, nodeName, inputs);
// evaluation no-ops
case "stopgradient":return new Identity(modelName, nodeName, inputs);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 08c0564ed8a..d9bb5c2fe45 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -62,15 +62,22 @@ class TypeConverter {
}
private static TensorShapeProto tensorFlowShape(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null)
+ // Use specific shape if available...
+ AttrValue attrShape = node.getAttrMap().get("shape");
+ if (attrShape != null && attrShape.getValueCase() == AttrValue.ValueCase.SHAPE) {
+ return attrShape.getShape();
+ }
+
+ // ... else use inferred shape
+ AttrValue attrOutputShapes = node.getAttrMap().get("_output_shapes");
+ if (attrOutputShapes == null)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
"does not exist");
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
+ if (attrOutputShapes.getValueCase() != AttrValue.ValueCase.LIST)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
"is not of expected type");
- return attrValueList.getList().getShape(0); // support multiple outputs?
+ return attrOutputShapes.getList().getShape(0); // support multiple outputs?
}
private static DataType tensorFlowValueType(NodeDef node) {