diff options
-rw-r--r-- | eval/src/vespa/eval/onnx/onnx_wrapper.cpp | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp index 2891b37ebe8..e9758f2ddc8 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "onnx_wrapper.h" +#include <vespa/eval/eval/cell_type.h> #include <vespa/eval/eval/dense_cells_value.h> #include <vespa/eval/eval/value_type.h> #include <vespa/vespalib/util/arrayref.h> @@ -21,6 +22,14 @@ using vespalib::ConstArrayRef; using vespalib::make_string_short::fmt; +// as documented in onnxruntime_cxx_api.h : +namespace Ort { +template <> +struct TypeToTensorType<vespalib::BFloat16> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; }; +template <> +struct TypeToTensorType<vespalib::eval::Int8Float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; +} + namespace vespalib::eval { namespace { |