aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-06-23 10:01:11 +0000
committerArne Juul <arnej@yahooinc.com>2023-06-23 10:01:15 +0000
commit83ba3d8b6226b8b109a07b470111a9c7581bcdb8 (patch)
tree6f7e26ea754d04a50ec934feb9813ee1f01841fe /model-integration/src/main/java/ai/vespa/rankingexpression
parent97b15016085dd6f2b515b7051803f92e34b29ab9 (diff)
update onnx.proto
* use latest version from https://github.com/onnx/onnx/blob/main/onnx/onnx.proto * track API changes (enum -> int32)
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java7
2 files changed, 7 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index f12f60dcc8e..f690b8e8c8a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -32,8 +32,9 @@ class TensorConverter {
}
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType());
if (tensorProto.hasRawData()) {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case BOOL: return new RawBoolValues(tensorProto);
case FLOAT: return new RawFloatValues(tensorProto);
case DOUBLE: return new RawDoubleValues(tensorProto);
@@ -41,7 +42,7 @@ class TensorConverter {
case INT64: return new RawLongValues(tensorProto);
}
} else {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case FLOAT: return new FloatValues(tensorProto);
case DOUBLE: return new DoubleValues(tensorProto);
case INT32: return new IntValues(tensorProto);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 35ec1d8c54a..deac950d324 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -37,7 +37,8 @@ class TypeConverter {
static OrderedTensorType typeFrom(Onnx.TypeProto type) {
String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType()));
+ var elemType = Onnx.TensorProto.DataType.forNumber(type.getTensorType().getElemType());
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(elemType));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
@@ -52,8 +53,8 @@ class TypeConverter {
}
static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
- return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()),
- tensor.getDimsList());
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensor.getDataType());
+ return OrderedTensorType.fromDimensionList(toValueType(elemType), tensor.getDimsList());
}
private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {