diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-04 11:00:38 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-04 11:00:38 +0200 |
commit | 831d8f16f10d04c64d54f6d353e3e200c31dc703 (patch) | |
tree | c9678f67fce271ef8438a1f17d01c8255d4fb551 /model-integration | |
parent | 8ebca621899e388b48cef80a649d826088e6e64c (diff) |
Use value type of ONNX tensor arguments
Diffstat (limited to 'model-integration')
6 files changed, 75 insertions, 83 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 5cc1defc010..a469e666d93 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -16,10 +16,8 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; -import onnx.Onnx.TensorProto.DataType; import java.util.List; import java.util.stream.Collectors; @@ -107,8 +105,8 @@ class GraphImporter { if (isArgumentTensor(name, onnxGraph)) { Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); if (valueInfoProto == null) - throw new IllegalArgumentException("Could not find argument tensor: " + name); - OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType()); + throw new IllegalArgumentException("Could not find argument tensor '" + name + "'"); + OrderedTensorType type = TypeConverter.typeFrom(valueInfoProto.getType()); operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); intermediateGraph.inputs(intermediateGraph.defaultSignature()) @@ -116,8 +114,7 @@ class GraphImporter { } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); - OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()), - tensorProto.getDimsList()); + OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); @@ -136,25 +133,6 @@ class GraphImporter { return operation; } - private static TensorType.Value toValueType(DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.DOUBLE; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; - case UINT64: return TensorType.Value.DOUBLE; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); - } - } - private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 79b399f2c6f..29d600fa7c6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -30,13 +30,10 @@ class TypeConverter { } } - static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { - return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... - } - - private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + static OrderedTensorType typeFrom(Onnx.TypeProto type) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType())); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); @@ -49,4 +46,28 @@ class TypeConverter { return builder.build(); } + static OrderedTensorType typeFrom(Onnx.TensorProto tensor) { + return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()), + tensor.getDimsList()); + } + + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index cb838cd67b1..a07c0fdf4dc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -51,7 +51,7 @@ class GraphImporter { String nodeName = node.getName(); String modelName = graph.name(); int nodePort = IntermediateOperation.indexPartOf(nodeName); - OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node); + OrderedTensorType nodeType = TypeConverter.typeFrom(node); AttributeConverter attributes = AttributeConverter.convert(node); switch (node.getOp().toLowerCase()) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index a4fe38cce95..9cba388d00e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -28,7 +28,7 @@ public class TensorConverter { } private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = toVespaTensorType(tfTensor, dimensionPrefix); + TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) @@ -54,16 +54,6 @@ public class TensorConverter { return builder.build(); } - private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); - int dimensionIndex = 0; - for (long dimensionSize : tfTensor.shape()) { - if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); - } - return b.build(); - } - public static Long tensorSize(TensorType type) { Long size = 1L; for (TensorType.Dimension dimension : type.dimensions()) { @@ -108,21 +98,6 @@ public class TensorConverter { throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); } - /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */ - static TensorType.Value toValueType(DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.DOUBLE; - default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); - } - } - /** Allows reading values from buffers of various numeric types as bytes */ private static abstract class Values { abstract double get(int i); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 3e825026b0e..d8ddb01b650 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -9,8 +9,6 @@ import org.tensorflow.framework.DataType; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.TensorShapeProto; -import java.util.List; - /** * Converts and verifies TensorFlow tensor types into Vespa tensor types. * @@ -37,6 +35,32 @@ class TypeConverter { } } + static OrderedTensorType typeFrom(NodeDef node) { + String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... + TensorShapeProto shape = tensorFlowShape(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node))); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); + if (tensorFlowDimension.getSize() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + static TensorType typeFrom(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); + int dimensionIndex = 0; + for (long dimensionSize : tfTensor.shape()) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); + } + return b.build(); + } + private static TensorShapeProto tensorFlowShape(NodeDef node) { AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); if (attrValueList == null) @@ -59,27 +83,7 @@ class TypeConverter { return attrValueList.getList().getType(0); // support multiple outputs? } - static OrderedTensorType fromTensorFlowType(NodeDef node) { - return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... - } - - private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - TensorShapeProto shape = tensorFlowShape(node); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node))); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); - if (tensorFlowDimension.getSize() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - /** TensorFlow has two different DataType classes. This must be kept in sync with TensorConverter.toValueType */ - static TensorType.Value toValueType(DataType dataType) { + private static TensorType.Value toValueType(DataType dataType) { switch (dataType) { case DT_FLOAT: return TensorType.Value.FLOAT; case DT_DOUBLE: return TensorType.Value.DOUBLE; @@ -100,4 +104,18 @@ class TypeConverter { } } + private static TensorType.Value toValueType(org.tensorflow.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 424e4d6c57c..07814687dc6 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -43,14 +43,14 @@ public class OnnxMnistSoftmaxImportTestCase { // Check inputs assertEquals(1, model.inputs().size()); assertTrue(model.inputs().containsKey("Placeholder")); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); + assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get("Placeholder")); // Check signature ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add"); assertNotNull(output); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", output.expression()); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), + assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } |