aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-03-12 23:19:21 +0100
committerGitHub <noreply@github.com>2023-03-12 23:19:21 +0100
commit9553cd2709c0791aa530f5388cf156116e857795 (patch)
tree31c351bc0402817fbd7037f37b1fa92bd168bfcb
parentb06d77bb7433d750fbc02446bab00af8c6ce7fcc (diff)
parentce5d8913e957151c7cd2c0e184ae8e310e31e06e (diff)
Merge pull request #26405 from vespa-engine/revert-26396-arnej/unify-cell-type-conversionv8.136.37
Revert "Arnej/unify cell type conversion"
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java36
-rw-r--r--model-integration/src/main/protobuf/onnx.proto6
3 files changed, 21 insertions, 44 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index eef75a32c0a..9c79961eddf 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -180,25 +180,12 @@ class TensorConverter {
}
static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) {
- // NOTE:
- // should match best_cell_type in onnx_wrapper.cpp
switch (onnxType) {
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
- return TensorType.Value.INT8;
-
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16:
- return TensorType.Value.BFLOAT16;
-
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
- return TensorType.Value.FLOAT;
-
- case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
- return TensorType.Value.DOUBLE;
- }
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE;
+ }
return TensorType.Value.DOUBLE;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 2c008dbb922..35ec1d8c54a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -56,27 +56,21 @@ class TypeConverter {
tensor.getDimsList());
}
- private static TensorType.Value toValueType(Onnx.TensorProto.DataType onnxType) {
- // NOTE:
- // should match best_cell_type in onnx_wrapper.cpp
- switch (onnxType) {
- case BOOL: // Imperfect conversion fallthrough
- case INT8:
- return TensorType.Value.INT8;
- case BFLOAT16:
- return TensorType.Value.BFLOAT16;
- case UINT8: // Imperfect conversion fallthrough
- case INT16: // Imperfect conversion fallthrough
- case UINT16: // Imperfect conversion fallthrough
- case FLOAT:
- return TensorType.Value.FLOAT;
- case INT32: // Imperfect conversion fallthrough
- case INT64: // Imperfect conversion fallthrough
- case UINT32: // Imperfect conversion fallthrough
- case UINT64: // Imperfect conversion fallthrough
- case DOUBLE:
- return TensorType.Value.DOUBLE;
- default: throw new IllegalArgumentException("A ONNX tensor with data type " + onnxType +
+ private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.FLOAT;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
" cannot be converted to a Vespa tensor type");
}
}
diff --git a/model-integration/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto
index 27f1fdef4b3..dc6542867e0 100644
--- a/model-integration/src/main/protobuf/onnx.proto
+++ b/model-integration/src/main/protobuf/onnx.proto
@@ -298,10 +298,6 @@ message TensorProto {
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
- // Non-IEEE floating-point format based on IEEE754 single-precision
- // floating-point number truncated to 16 bits.
- // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
- BFLOAT16 = 16;
// Future extensions go here.
}
@@ -465,4 +461,4 @@ message OperatorSetIdProto {
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
-}
+} \ No newline at end of file