summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-04 11:00:38 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-04 11:00:38 +0200
commit831d8f16f10d04c64d54f6d353e3e200c31dc703 (patch)
treec9678f67fce271ef8438a1f17d01c8255d4fb551 /model-integration
parent8ebca621899e388b48cef80a649d826088e6e64c (diff)
Use value type of ONNX tensor arguments
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java28
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java33
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java27
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java64
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java4
6 files changed, 75 insertions, 83 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 5cc1defc010..a469e666d93 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -16,10 +16,8 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
-import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
-import onnx.Onnx.TensorProto.DataType;
import java.util.List;
import java.util.stream.Collectors;
@@ -107,8 +105,8 @@ class GraphImporter {
if (isArgumentTensor(name, onnxGraph)) {
Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
if (valueInfoProto == null)
- throw new IllegalArgumentException("Could not find argument tensor: " + name);
- OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
+ throw new IllegalArgumentException("Could not find argument tensor '" + name + "'");
+ OrderedTensorType type = TypeConverter.typeFrom(valueInfoProto.getType());
operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
intermediateGraph.inputs(intermediateGraph.defaultSignature())
@@ -116,8 +114,7 @@ class GraphImporter {
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()),
- tensorProto.getDimsList());
+ OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
@@ -136,25 +133,6 @@ class GraphImporter {
return operation;
}
- private static TensorType.Value toValueType(DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.DOUBLE;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.DOUBLE;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
- }
- }
-
private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
Onnx.TensorProto tensor = getConstantTensor(name, graph);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 79b399f2c6f..29d600fa7c6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -30,13 +30,10 @@ class TypeConverter {
}
}
- static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
- return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
- }
-
- private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ static OrderedTensorType typeFrom(Onnx.TypeProto type) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType()));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
@@ -49,4 +46,28 @@ class TypeConverter {
return builder.build();
}
+ static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
+ return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()),
+ tensor.getDimsList());
+ }
+
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index cb838cd67b1..a07c0fdf4dc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -51,7 +51,7 @@ class GraphImporter {
String nodeName = node.getName();
String modelName = graph.name();
int nodePort = IntermediateOperation.indexPartOf(nodeName);
- OrderedTensorType nodeType = TypeConverter.fromTensorFlowType(node);
+ OrderedTensorType nodeType = TypeConverter.typeFrom(node);
AttributeConverter attributes = AttributeConverter.convert(node);
switch (node.getOp().toLowerCase()) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index a4fe38cce95..9cba388d00e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -28,7 +28,7 @@ public class TensorConverter {
}
private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = toVespaTensorType(tfTensor, dimensionPrefix);
+ TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
@@ -54,16 +54,6 @@ public class TensorConverter {
return builder.build();
}
- private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
- int dimensionIndex = 0;
- for (long dimensionSize : tfTensor.shape()) {
- if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
- b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
- }
- return b.build();
- }
-
public static Long tensorSize(TensorType type) {
Long size = 1L;
for (TensorType.Dimension dimension : type.dimensions()) {
@@ -108,21 +98,6 @@ public class TensorConverter {
throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
}
- /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */
- static TensorType.Value toValueType(DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.DOUBLE;
- default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
- }
- }
-
/** Allows reading values from buffers of various numeric types as bytes */
private static abstract class Values {
abstract double get(int i);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 3e825026b0e..d8ddb01b650 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -9,8 +9,6 @@ import org.tensorflow.framework.DataType;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
-import java.util.List;
-
/**
* Converts and verifies TensorFlow tensor types into Vespa tensor types.
*
@@ -37,6 +35,32 @@ class TypeConverter {
}
}
+ static OrderedTensorType typeFrom(NodeDef node) {
+ String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
+ TensorShapeProto shape = tensorFlowShape(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node)));
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ static TensorType typeFrom(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
+ int dimensionIndex = 0;
+ for (long dimensionSize : tfTensor.shape()) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
+ }
+ return b.build();
+ }
+
private static TensorShapeProto tensorFlowShape(NodeDef node) {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
if (attrValueList == null)
@@ -59,27 +83,7 @@ class TypeConverter {
return attrValueList.getList().getType(0); // support multiple outputs?
}
- static OrderedTensorType fromTensorFlowType(NodeDef node) {
- return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
- }
-
- private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- TensorShapeProto shape = tensorFlowShape(node);
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node)));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- /** TensorFlow has two different DataType classes. This must be kept in sync with TensorConverter.toValueType */
- static TensorType.Value toValueType(DataType dataType) {
+ private static TensorType.Value toValueType(DataType dataType) {
switch (dataType) {
case DT_FLOAT: return TensorType.Value.FLOAT;
case DT_DOUBLE: return TensorType.Value.DOUBLE;
@@ -100,4 +104,18 @@ class TypeConverter {
}
}
+ private static TensorType.Value toValueType(org.tensorflow.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 424e4d6c57c..07814687dc6 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -43,14 +43,14 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check inputs
assertEquals(1, model.inputs().size());
assertTrue(model.inputs().containsKey("Placeholder"));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get("Placeholder"));
// Check signature
ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add");
assertNotNull(output);
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
output.expression());
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}