summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java5
1 files changed, 3 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index f12f60dcc8e..f690b8e8c8a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -32,8 +32,9 @@ class TensorConverter {
}
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType());
if (tensorProto.hasRawData()) {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case BOOL: return new RawBoolValues(tensorProto);
case FLOAT: return new RawFloatValues(tensorProto);
case DOUBLE: return new RawDoubleValues(tensorProto);
@@ -41,7 +42,7 @@ class TensorConverter {
case INT64: return new RawLongValues(tensorProto);
}
} else {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case FLOAT: return new FloatValues(tensorProto);
case DOUBLE: return new DoubleValues(tensorProto);
case INT32: return new IntValues(tensorProto);