diff options
author | Lester Solbakken <lesters@oath.com> | 2019-08-22 11:16:39 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-08-22 11:16:39 +0200 |
commit | 9eb33e8ced84849db9b4c75c20caf11bb053abb2 (patch) | |
tree | e2ea21f92d54b294fa1c48ae5762d2eee92479aa /model-integration | |
parent | a3689d2d45e426bbfb4b924ae4994ed8aad46361 (diff) |
Add varhandleop and readvariableop to supported TF import operations
Diffstat (limited to 'model-integration')
2 files changed, 13 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 0e9c98b2b56..4f656d86929 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -34,7 +34,6 @@ import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; -import org.tensorflow.op.core.DecodeRaw; import java.io.IOException; import java.util.List; @@ -119,6 +118,8 @@ class GraphImporter { // state ops case "variable": return new Constant(modelName, nodeName, nodeType); case "variablev2": return new Constant(modelName, nodeName, nodeType); + case "varhandleop": return new Constant(modelName, nodeName, nodeType); + case "readvariableop":return new Identity(modelName, nodeName, inputs); // evaluation no-ops case "stopgradient":return new Identity(modelName, nodeName, inputs); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 08c0564ed8a..d9bb5c2fe45 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -62,15 +62,22 @@ class TypeConverter { } private static TensorShapeProto tensorFlowShape(NodeDef node) { - AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) + // Use specific shape if available... + AttrValue attrShape = node.getAttrMap().get("shape"); + if (attrShape != null && attrShape.getValueCase() == AttrValue.ValueCase.SHAPE) { + return attrShape.getShape(); + } + + // ... else use inferred shape + AttrValue attrOutputShapes = node.getAttrMap().get("_output_shapes"); + if (attrOutputShapes == null) throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "does not exist"); - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) + if (attrOutputShapes.getValueCase() != AttrValue.ValueCase.LIST) throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + "is not of expected type"); - return attrValueList.getList().getShape(0); // support multiple outputs? + return attrOutputShapes.getList().getShape(0); // support multiple outputs? } private static DataType tensorFlowValueType(NodeDef node) { |