aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
commit5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch)
tree2b65d4f48b92bf7ec846b3efd5d5259244bc234a /model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx
parent6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff)
Add tensor value type
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java24
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java2
2 files changed, 24 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index dd2add973e4..5cc1defc010 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -16,8 +16,10 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import onnx.Onnx.TensorProto.DataType;
import java.util.List;
import java.util.stream.Collectors;
@@ -114,7 +116,8 @@ class GraphImporter {
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()),
+ tensorProto.getDimsList());
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
@@ -133,6 +136,25 @@ class GraphImporter {
return operation;
}
+ private static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
Onnx.TensorProto tensor = getConstantTensor(name, graph);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index f251a14213b..79b399f2c6f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -36,7 +36,7 @@ class TypeConverter {
private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE);
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);