diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-03-10 12:40:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-10 12:40:34 +0100 |
commit | 285c792c0ae5a22ccaf6e4a5cf852bf8d5326b07 (patch) | |
tree | 8dabed8237c2b061b1579618dcad15c296aec718 | |
parent | ce7bcde58f5a8b80d631a11f6b19c13c36c72450 (diff) | |
parent | 286e1b632d316d9eadcefcce0d804fa4f8016fae (diff) |
Merge pull request #26396 from vespa-engine/arnej/unify-cell-type-conversion
Arnej/unify cell type conversion
3 files changed, 44 insertions, 21 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 9c79961eddf..eef75a32c0a 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -180,12 +180,25 @@ class TensorConverter { } static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { + // NOTE: + // should match best_cell_type in onnx_wrapper.cpp switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; - } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return TensorType.Value.INT8; + + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return TensorType.Value.BFLOAT16; + + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return TensorType.Value.FLOAT; + + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return TensorType.Value.DOUBLE; + } return TensorType.Value.DOUBLE; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 35ec1d8c54a..2c008dbb922 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -56,21 +56,27 @@ class TypeConverter { tensor.getDimsList()); } - private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; - case UINT64: return TensorType.Value.FLOAT; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + private static TensorType.Value toValueType(Onnx.TensorProto.DataType onnxType) { + // NOTE: + // should match best_cell_type in onnx_wrapper.cpp + switch (onnxType) { + case BOOL: // Imperfect conversion fallthrough + case INT8: + return TensorType.Value.INT8; + case BFLOAT16: + return TensorType.Value.BFLOAT16; + case UINT8: // Imperfect conversion fallthrough + case INT16: // Imperfect conversion fallthrough + case UINT16: // Imperfect conversion fallthrough + case FLOAT: + return TensorType.Value.FLOAT; + case INT32: // Imperfect conversion fallthrough + case INT64: // Imperfect conversion fallthrough + case UINT32: // Imperfect conversion fallthrough + case UINT64: // Imperfect conversion fallthrough + case DOUBLE: + return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + onnxType + " cannot be converted to a Vespa tensor type"); } } diff --git a/model-integration/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto index dc6542867e0..27f1fdef4b3 100644 --- a/model-integration/src/main/protobuf/onnx.proto +++ b/model-integration/src/main/protobuf/onnx.proto @@ -298,6 +298,10 @@ message TensorProto { UINT64 = 13; COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; // Future extensions go here. } @@ -461,4 +465,4 @@ message OperatorSetIdProto { // The version of the operator set being identified. // This field MUST be present in this version of the IR. optional int64 version = 2; -}
\ No newline at end of file +} |