aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-23 12:42:54 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-23 16:57:41 +0000
commita851a849bab624a7a1fdc221332b5bb4c69aadd0 (patch)
tree8f7414b320b91995d187adfede58f6dbb306a373 /eval
parent19711977575c14de515ff1dca768ba6eb11c6be6 (diff)
handle new cell types (per doc in onnxruntime_cxx_api.h)
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.cpp9
1 files changed, 9 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
index 2891b37ebe8..e9758f2ddc8 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "onnx_wrapper.h"
+#include <vespa/eval/eval/cell_type.h>
#include <vespa/eval/eval/dense_cells_value.h>
#include <vespa/eval/eval/value_type.h>
#include <vespa/vespalib/util/arrayref.h>
@@ -21,6 +22,14 @@ using vespalib::ConstArrayRef;
using vespalib::make_string_short::fmt;
+// as documented in onnxruntime_cxx_api.h :
+namespace Ort {
+template <>
+struct TypeToTensorType<vespalib::BFloat16> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; };
+template <>
+struct TypeToTensorType<vespalib::eval::Int8Float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; };
+}
+
namespace vespalib::eval {
namespace {