diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index f12f60dcc8e..f690b8e8c8a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -32,8 +32,9 @@ class TensorConverter { } private static Values readValuesOf(Onnx.TensorProto tensorProto) { + var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType()); if (tensorProto.hasRawData()) { - switch (tensorProto.getDataType()) { + switch (elemType) { case BOOL: return new RawBoolValues(tensorProto); case FLOAT: return new RawFloatValues(tensorProto); case DOUBLE: return new RawDoubleValues(tensorProto); @@ -41,7 +42,7 @@ class TensorConverter { case INT64: return new RawLongValues(tensorProto); } } else { - switch (tensorProto.getDataType()) { + switch (elemType) { case FLOAT: return new FloatValues(tensorProto); case DOUBLE: return new DoubleValues(tensorProto); case INT32: return new IntValues(tensorProto); |