summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-10 11:04:48 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-10 11:04:48 +0000
commit286e1b632d316d9eadcefcce0d804fa4f8016fae (patch)
tree504b1dae716485aa76fe20a5083eac9226362b48
parent11dc6cb2309ea5c9981f1ccd50f3cf417986eb0e (diff)
cell-type conversions should match
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java36
2 files changed, 39 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index 9c79961eddf..eef75a32c0a 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -180,12 +180,25 @@ class TensorConverter {
}
static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
+ // NOTE:
+ // should match best_cell_type in onnx_wrapper.cpp
switch (onnxType) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
- }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
+ return TensorType.Value.INT8;
+
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
+ return TensorType.Value.BFLOAT16;
+
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
+ return TensorType.Value.FLOAT;
+
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
+ return TensorType.Value.DOUBLE;
+ }
return TensorType.Value.DOUBLE;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 35ec1d8c54a..2c008dbb922 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -56,21 +56,27 @@ class TypeConverter {
tensor.getDimsList());
}
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT8: return TensorType.Value.FLOAT;
- case INT16: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.FLOAT;
- case UINT8: return TensorType.Value.FLOAT;
- case UINT16: return TensorType.Value.FLOAT;
- case UINT32: return TensorType.Value.FLOAT;
- case UINT64: return TensorType.Value.FLOAT;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType onnxType) {
+ // NOTE:
+ // should match best_cell_type in onnx_wrapper.cpp
+ switch (onnxType) {
+ case BOOL: // Imperfect conversion fallthrough
+ case INT8:
+ return TensorType.Value.INT8;
+ case BFLOAT16:
+ return TensorType.Value.BFLOAT16;
+ case UINT8: // Imperfect conversion fallthrough
+ case INT16: // Imperfect conversion fallthrough
+ case UINT16: // Imperfect conversion fallthrough
+ case FLOAT:
+ return TensorType.Value.FLOAT;
+ case INT32: // Imperfect conversion fallthrough
+ case INT64: // Imperfect conversion fallthrough
+ case UINT32: // Imperfect conversion fallthrough
+ case UINT64: // Imperfect conversion fallthrough
+ case DOUBLE:
+ return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + onnxType +
" cannot be converted to a Vespa tensor type");
}
}