diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a /model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx')
2 files changed, 24 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index dd2add973e4..5cc1defc010 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -16,8 +16,10 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import onnx.Onnx.TensorProto.DataType; import java.util.List; import java.util.stream.Collectors; @@ -114,7 +116,8 @@ class GraphImporter { } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); - OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); + OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()), + tensorProto.getDimsList()); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); @@ -133,6 +136,25 @@ class GraphImporter { return operation; } + private static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index f251a14213b..79b399f2c6f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -36,7 +36,7 @@ class TypeConverter { private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); |