diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2023-03-12 23:10:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-12 23:10:54 +0100 |
commit | ce5d8913e957151c7cd2c0e184ae8e310e31e06e (patch) | |
tree | 31c351bc0402817fbd7037f37b1fa92bd168bfcb /model-integration/src/main/java/ai | |
parent | b06d77bb7433d750fbc02446bab00af8c6ce7fcc (diff) |
Revert "Arnej/unify cell type conversion"
Diffstat (limited to 'model-integration/src/main/java/ai')
2 files changed, 20 insertions, 39 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index eef75a32c0a..9c79961eddf 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -180,25 +180,12 @@ class TensorConverter { } static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { - // NOTE: - // should match best_cell_type in onnx_wrapper.cpp switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: - return TensorType.Value.INT8; - - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: - return TensorType.Value.BFLOAT16; - - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: - return TensorType.Value.FLOAT; - - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: - return TensorType.Value.DOUBLE; - } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; + } return TensorType.Value.DOUBLE; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 2c008dbb922..35ec1d8c54a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -56,27 +56,21 @@ class TypeConverter { tensor.getDimsList()); } - private static TensorType.Value toValueType(Onnx.TensorProto.DataType onnxType) { - // NOTE: - // should match best_cell_type in onnx_wrapper.cpp - switch (onnxType) { - case BOOL: // Imperfect conversion fallthrough - case INT8: - return TensorType.Value.INT8; - case BFLOAT16: - return TensorType.Value.BFLOAT16; - case UINT8: // Imperfect conversion fallthrough - case INT16: // Imperfect conversion fallthrough - case UINT16: // Imperfect conversion fallthrough - case FLOAT: - return TensorType.Value.FLOAT; - case INT32: // Imperfect conversion fallthrough - case INT64: // Imperfect conversion fallthrough - case UINT32: // Imperfect conversion fallthrough - case UINT64: // Imperfect conversion fallthrough - case DOUBLE: - return TensorType.Value.DOUBLE; - default: throw new IllegalArgumentException("A ONNX tensor with data type " + onnxType + + private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.FLOAT; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type"); } } |