diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-06-23 10:01:11 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-06-23 10:01:15 +0000 |
commit | 83ba3d8b6226b8b109a07b470111a9c7581bcdb8 (patch) | |
tree | 6f7e26ea754d04a50ec934feb9813ee1f01841fe /model-integration/src/main/java/ai | |
parent | 97b15016085dd6f2b515b7051803f92e34b29ab9 (diff) |
update onnx.proto
* use latest version from https://github.com/onnx/onnx/blob/main/onnx/onnx.proto
* track API changes (enum -> int32)
Diffstat (limited to 'model-integration/src/main/java/ai')
2 files changed, 7 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index f12f60dcc8e..f690b8e8c8a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -32,8 +32,9 @@ class TensorConverter { } private static Values readValuesOf(Onnx.TensorProto tensorProto) { + var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType()); if (tensorProto.hasRawData()) { - switch (tensorProto.getDataType()) { + switch (elemType) { case BOOL: return new RawBoolValues(tensorProto); case FLOAT: return new RawFloatValues(tensorProto); case DOUBLE: return new RawDoubleValues(tensorProto); @@ -41,7 +42,7 @@ class TensorConverter { case INT64: return new RawLongValues(tensorProto); } } else { - switch (tensorProto.getDataType()) { + switch (elemType) { case FLOAT: return new FloatValues(tensorProto); case DOUBLE: return new DoubleValues(tensorProto); case INT32: return new IntValues(tensorProto); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 35ec1d8c54a..deac950d324 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -37,7 +37,8 @@ class TypeConverter { static OrderedTensorType typeFrom(Onnx.TypeProto type) { String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType())); + var elemType = Onnx.TensorProto.DataType.forNumber(type.getTensorType().getElemType()); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(elemType)); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); @@ -52,8 +53,8 @@ class TypeConverter { } static OrderedTensorType typeFrom(Onnx.TensorProto tensor) { - return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()), - tensor.getDimsList()); + var elemType = Onnx.TensorProto.DataType.forNumber(tensor.getDataType()); + return OrderedTensorType.fromDimensionList(toValueType(elemType), tensor.getDimsList()); } private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { |