summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java36
1 files changed, 15 insertions, 21 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 2c008dbb922..35ec1d8c54a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -56,27 +56,21 @@ class TypeConverter {
tensor.getDimsList());
}
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType onnxType) {
- // NOTE:
- // should match best_cell_type in onnx_wrapper.cpp
- switch (onnxType) {
- case BOOL: // Imperfect conversion fallthrough
- case INT8:
- return TensorType.Value.INT8;
- case BFLOAT16:
- return TensorType.Value.BFLOAT16;
- case UINT8: // Imperfect conversion fallthrough
- case INT16: // Imperfect conversion fallthrough
- case UINT16: // Imperfect conversion fallthrough
- case FLOAT:
- return TensorType.Value.FLOAT;
- case INT32: // Imperfect conversion fallthrough
- case INT64: // Imperfect conversion fallthrough
- case UINT32: // Imperfect conversion fallthrough
- case UINT64: // Imperfect conversion fallthrough
- case DOUBLE:
- return TensorType.Value.DOUBLE;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + onnxType +
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.FLOAT;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
" cannot be converted to a Vespa tensor type");
}
}