diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-10 11:04:48 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-10 11:04:48 +0000 |
commit | 286e1b632d316d9eadcefcce0d804fa4f8016fae (patch) | |
tree | 504b1dae716485aa76fe20a5083eac9226362b48 /model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java | |
parent | 11dc6cb2309ea5c9981f1ccd50f3cf417986eb0e (diff) |
cell-type conversions should match
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java | 23 |
1 files changed, 18 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 9c79961eddf..eef75a32c0a 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -180,12 +180,25 @@ class TensorConverter { } static private TensorType.Value toVespaValueType(TensorInfo.OnnxTensorType onnxType) { + // NOTE: + // should match best_cell_type in onnx_wrapper.cpp switch (onnxType) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return TensorType.Value.INT8; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: return TensorType.Value.BFLOAT16; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return TensorType.Value.FLOAT; - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return TensorType.Value.DOUBLE; - } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + return TensorType.Value.INT8; + + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16: + return TensorType.Value.BFLOAT16; + + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + return TensorType.Value.FLOAT; + + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + return TensorType.Value.DOUBLE; + } return TensorType.Value.DOUBLE; } |