summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-09-06 08:19:59 +0200
committerGitHub <noreply@github.com>2020-09-06 08:19:59 +0200
commit95a44feac17b0ccc1d41693e6ad9de99227d3d8b (patch)
treef589e6fc4b288677c69ccbbe000e59f2a6d820ee
parentec0e8716abbf29d958e768d40b57b23002221c4a (diff)
parentb5b4b47229990035260806ed644c0f46d0c8a282 (diff)
Merge pull request #14301 from vespa-engine/havardpe/onnx-element-type-adapters-and-converters
adapt and convert between vespa and onnx types
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/int_types.onnx23
-rwxr-xr-xeval/src/tests/tensor/onnx_wrapper/int_types.py33
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp76
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp373
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.h50
-rw-r--r--searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp2
-rw-r--r--vespalib/src/vespa/vespalib/util/classname.h8
7 files changed, 431 insertions, 134 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/int_types.onnx b/eval/src/tests/tensor/onnx_wrapper/int_types.onnx
new file mode 100644
index 00000000000..65c0765a31b
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/int_types.onnx
@@ -0,0 +1,23 @@
+ int_types.py:æ
+0
+ query_tensor
+attribute_tensormatmul"MatMul
+"
+matmul
+ bias_tensoroutput"Addint_types_scoringZ
+ query_tensor
+ 
+
+Z"
+attribute_tensor
+ 
+
+Z
+ bias_tensor
+ 
+
+b
+output
+ 
+
+B \ No newline at end of file
diff --git a/eval/src/tests/tensor/onnx_wrapper/int_types.py b/eval/src/tests/tensor/onnx_wrapper/int_types.py
new file mode 100755
index 00000000000..cd82bfd44b5
--- /dev/null
+++ b/eval/src/tests/tensor/onnx_wrapper/int_types.py
@@ -0,0 +1,33 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.INT32, [1, 4])
+ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.INT32, [4, 1])
+BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.INT32, [1, 1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.INT32, [1, 1])
+
+nodes = [
+ helper.make_node(
+ 'MatMul',
+ ['query_tensor', 'attribute_tensor'],
+ ['matmul'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['matmul', 'bias_tensor'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'int_types_scoring',
+ [
+ QUERY_TENSOR,
+ ATTRIBUTE_TENSOR,
+ BIAS_TENSOR,
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='int_types.py')
+onnx.save(model_def, 'int_types.onnx')
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index db2415e9969..23c41167266 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -11,6 +11,7 @@ using namespace vespalib::tensor;
using vespalib::make_string_short::fmt;
using TensorInfo = Onnx::TensorInfo;
+using ElementType = Onnx::ElementType;
using DZ = Onnx::DimSize;
std::string get_source_dir() {
@@ -20,6 +21,7 @@ std::string get_source_dir() {
std::string source_dir = get_source_dir();
std::string simple_model = source_dir + "/simple.onnx";
std::string dynamic_model = source_dir + "/dynamic.onnx";
+std::string int_types_model = source_dir + "/int_types.onnx";
void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
@@ -28,24 +30,12 @@ void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
}
}
-TEST(WirePlannerTest, element_types_must_match) {
- Onnx::WirePlanner planner;
- ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
- ValueType type2 = ValueType::from_spec("tensor<double>(a[5])");
- TensorInfo info1 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::FLOAT};
- TensorInfo info2 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::DOUBLE};
- EXPECT_TRUE(planner.bind_input_type(type1, info1));
- EXPECT_FALSE(planner.bind_input_type(type2, info1));
- EXPECT_FALSE(planner.bind_input_type(type1, info2));
- EXPECT_TRUE(planner.bind_input_type(type2, info2));
-}
-
TEST(WirePlannerTest, known_dimension_sizes_must_match) {
Onnx::WirePlanner planner;
ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[5])");
ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])");
- TensorInfo info = TensorInfo{"info", {DZ(5),DZ(5)}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ(5),DZ(5)}, ElementType::FLOAT};
EXPECT_FALSE(planner.bind_input_type(type1, info));
EXPECT_FALSE(planner.bind_input_type(type2, info));
EXPECT_TRUE(planner.bind_input_type(type3, info));
@@ -55,7 +45,7 @@ TEST(WirePlannerTest, symbolic_dimension_sizes_must_match) {
Onnx::WirePlanner planner;
ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10])");
- TensorInfo info = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ("dim")}, ElementType::FLOAT};
EXPECT_TRUE(planner.bind_input_type(type1, info)); // binds 'dim' to 5
EXPECT_FALSE(planner.bind_input_type(type2, info));
EXPECT_TRUE(planner.bind_input_type(type1, info));
@@ -65,7 +55,7 @@ TEST(WirePlannerTest, unknown_dimension_sizes_match_anything) {
Onnx::WirePlanner planner;
ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10])");
- TensorInfo info = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ()}, ElementType::FLOAT};
EXPECT_TRUE(planner.bind_input_type(type1, info));
EXPECT_TRUE(planner.bind_input_type(type2, info));
}
@@ -73,9 +63,9 @@ TEST(WirePlannerTest, unknown_dimension_sizes_match_anything) {
TEST(WirePlannerTest, all_output_dimensions_must_be_bound) {
Onnx::WirePlanner planner;
ValueType type = ValueType::from_spec("tensor<float>(a[5],b[10])");
- TensorInfo info1 = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT};
- TensorInfo info2 = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT};
- TensorInfo info3 = TensorInfo{"info", {DZ("dim"),DZ()}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info1 = TensorInfo{"info", {DZ()}, ElementType::FLOAT};
+ TensorInfo info2 = TensorInfo{"info", {DZ("dim")}, ElementType::FLOAT};
+ TensorInfo info3 = TensorInfo{"info", {DZ("dim"),DZ()}, ElementType::FLOAT};
EXPECT_TRUE(planner.make_output_type(info1).is_error());
EXPECT_TRUE(planner.make_output_type(info2).is_error());
EXPECT_TRUE(planner.make_output_type(info3).is_error());
@@ -90,7 +80,7 @@ TEST(WirePlannerTest, dimensions_resolve_left_to_right) {
ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[10])");
ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])");
- TensorInfo info = TensorInfo{"info", {DZ("dim"),DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ("dim"),DZ("dim")}, ElementType::FLOAT};
EXPECT_FALSE(planner.bind_input_type(type1, info)); // binds 'dim' to 5, then fails (5 != 10)
EXPECT_FALSE(planner.bind_input_type(type2, info));
EXPECT_TRUE(planner.bind_input_type(type3, info));
@@ -180,7 +170,7 @@ TEST(OnnxTest, simple_onnx_model_can_be_evaluated)
DenseTensorView new_bias(bias_type, TypedCells(new_bias_values));
ctx.bind_param(2, new_bias);
ctx.eval();
- EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
+ EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
//-------------------------------------------------------------------------
}
@@ -230,4 +220,50 @@ TEST(OnnxTest, dynamic_onnx_model_can_be_evaluated)
//-------------------------------------------------------------------------
}
+TEST(OnnxTest, int_types_onnx_model_can_be_evaluated)
+{
+ Onnx model(int_types_model, Onnx::Optimize::ENABLE);
+ Onnx::WirePlanner planner;
+
+ ValueType query_type = ValueType::from_spec("tensor<float>(a[1],b[4])");
+ std::vector<float> query_values({1.0, 2.0, 3.0, 4.0});
+ DenseTensorView query(query_type, TypedCells(query_values));
+ EXPECT_TRUE(planner.bind_input_type(query_type, model.inputs()[0]));
+
+ ValueType attribute_type = ValueType::from_spec("tensor<double>(a[4],b[1])");
+ std::vector<double> attribute_values({5.0, 6.0, 7.0, 8.0});
+ DenseTensorView attribute(attribute_type, TypedCells(attribute_values));
+ EXPECT_TRUE(planner.bind_input_type(attribute_type, model.inputs()[1]));
+
+ ValueType bias_type = ValueType::from_spec("tensor<double>(a[1],b[1])");
+ std::vector<double> bias_values({9.0});
+ DenseTensorView bias(bias_type, TypedCells(bias_values));
+ EXPECT_TRUE(planner.bind_input_type(bias_type, model.inputs()[2]));
+
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]),
+ ValueType::from_spec("tensor<double>(d0[1],d1[1])"));
+
+ Onnx::WireInfo wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &output = ctx.get_result(0);
+ EXPECT_EQ(output.type(), ValueType::from_spec("tensor<double>(d0[1],d1[1])"));
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, query);
+ ctx.bind_param(1, attribute);
+ ctx.bind_param(2, bias);
+ ctx.eval();
+ auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
+ EXPECT_EQ(cells.type, ValueType::CellType::DOUBLE);
+ EXPECT_EQ(cells.size, 1);
+ EXPECT_EQ(cells.get(0), 79.0);
+ //-------------------------------------------------------------------------
+ std::vector<double> new_bias_values({10.0});
+ DenseTensorView new_bias(bias_type, TypedCells(new_bias_values));
+ ctx.bind_param(2, new_bias);
+ ctx.eval();
+ EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
+ //-------------------------------------------------------------------------
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
index 88346213901..7c29b20f2f4 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
@@ -3,50 +3,228 @@
#include "onnx_wrapper.h"
#include <vespa/eval/eval/value_type.h>
#include "dense_tensor_view.h"
-#include "mutable_dense_tensor_view.h"
+#include "dense_tensor.h"
#include <vespa/vespalib/util/arrayref.h>
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/typify.h>
+#include <vespa/vespalib/util/classname.h>
#include <assert.h>
#include <cmath>
#include <stdlib.h>
#include <stdio.h>
+#include <type_traits>
+#include <vespa/log/log.h>
+LOG_SETUP(".eval.onnx_wrapper");
+
+using vespalib::ArrayRef;
+using vespalib::ConstArrayRef;
using vespalib::eval::ValueType;
+using vespalib::eval::TypifyCellType;
+
using vespalib::make_string_short::fmt;
namespace vespalib::tensor {
+using ParamBinder = Onnx::EvalContext::ParamBinder;
+using EvalHook = Onnx::EvalContext::EvalHook;
+
namespace {
-vespalib::string to_str(Onnx::TensorInfo::ElementType element_type) {
- if (element_type == Onnx::TensorInfo::ElementType::FLOAT) {
- return "float";
+struct TypifyOnnxElementType {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(Onnx::ElementType value, F &&f) {
+ switch(value) {
+ case Onnx::ElementType::INT8: return f(Result<int8_t>());
+ case Onnx::ElementType::INT16: return f(Result<int16_t>());
+ case Onnx::ElementType::INT32: return f(Result<int32_t>());
+ case Onnx::ElementType::INT64: return f(Result<int64_t>());
+ case Onnx::ElementType::UINT8: return f(Result<uint8_t>());
+ case Onnx::ElementType::UINT16: return f(Result<uint16_t>());
+ case Onnx::ElementType::UINT32: return f(Result<uint32_t>());
+ case Onnx::ElementType::UINT64: return f(Result<uint64_t>());
+ case Onnx::ElementType::FLOAT: return f(Result<float>());
+ case Onnx::ElementType::DOUBLE: return f(Result<double>());
+ }
+ abort();
}
- if (element_type == Onnx::TensorInfo::ElementType::DOUBLE) {
- return "double";
+};
+
+using MyTypify = TypifyValue<TypifyCellType,TypifyOnnxElementType>;
+
+//-----------------------------------------------------------------------------
+
+struct TypeToString {
+ template <typename T> static vespalib::string invoke() { return getClassName<T>(); }
+};
+
+struct IsSameType {
+ template <typename A, typename B> static bool invoke() { return std::is_same<A,B>(); }
+};
+
+struct CreateOnnxTensor {
+ template <typename T> static Ort::Value invoke(const std::vector<int64_t> &sizes, OrtAllocator *alloc) {
+ return Ort::Value::CreateTensor<T>(alloc, sizes.data(), sizes.size());
+ }
+};
+
+struct CreateVespaTensorRef {
+ template <typename T> static eval::Value::UP invoke(const eval::ValueType &type_ref, Ort::Value &value) {
+ size_t num_cells = type_ref.dense_subspace_size();
+ ConstArrayRef<T> cells(value.GetTensorMutableData<T>(), num_cells);
+ return std::make_unique<DenseTensorView>(type_ref, TypedCells(cells));
}
- return "???";
+};
+
+struct CreateVespaTensor {
+ template <typename T> static eval::Value::UP invoke(const eval::ValueType &type) {
+ size_t num_cells = type.dense_subspace_size();
+ std::vector<T> cells(num_cells, T{});
+ return std::make_unique<DenseTensor<T>>(type, std::move(cells));
+ }
+};
+
+template <typename T>
+struct ParamAdapter : ParamBinder {
+ const Onnx::TensorType &type;
+ const Ort::MemoryInfo &memory;
+ ParamAdapter(const Onnx::TensorType &type_in, const Ort::MemoryInfo &memory_in)
+ : type(type_in), memory(memory_in) {}
+ void bind(const eval::Value &vespa, Ort::Value &onnx) override {
+ const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef();
+ auto cells = unconstify(cells_ref.typify<T>());
+ onnx = Ort::Value::CreateTensor<T>(memory, cells.begin(), cells.size(), type.dimensions.data(), type.dimensions.size());
+ }
+};
+
+struct CreateParamAdapter {
+ template <typename T> static ParamBinder::UP invoke(const Onnx::TensorType &type, const Ort::MemoryInfo &memory) {
+ return std::make_unique<ParamAdapter<T>>(type, memory);
+ }
+};
+
+template <typename SRC, typename DST>
+struct ParamConverter : ParamBinder {
+ void bind(const eval::Value &vespa, Ort::Value &onnx) override {
+ auto cells = static_cast<const DenseTensorView &>(vespa).cellsRef().typify<SRC>();
+ size_t n = cells.size();
+ const SRC *src = cells.begin();
+ DST *dst = onnx.GetTensorMutableData<DST>();
+ for (size_t i = 0; i < n; ++i) {
+ dst[i] = DST(src[i]);
+ }
+ }
+};
+
+struct CreateParamConverter {
+ template <typename SRC, typename DST> static ParamBinder::UP invoke() {
+ return std::make_unique<ParamConverter<SRC,DST>>();
+ }
+};
+
+template <typename SRC, typename DST>
+struct ResultConverter : EvalHook {
+ Ort::Value &onnx;
+ const eval::Value &vespa;
+ ResultConverter(Ort::Value &onnx_in, const eval::Value &vespa_in)
+ : onnx(onnx_in), vespa(vespa_in) {}
+ void invoke() override {
+ const auto &cells_ref = static_cast<const DenseTensorView &>(vespa).cellsRef();
+ auto cells = unconstify(cells_ref.typify<DST>());
+ size_t n = cells.size();
+ DST *dst = cells.begin();
+ const SRC *src = onnx.GetTensorMutableData<SRC>();
+ for (size_t i = 0; i < n; ++i) {
+ dst[i] = DST(src[i]);
+ }
+ }
+};
+
+struct CreateResultConverter {
+ template <typename SRC, typename DST> static EvalHook::UP invoke(Ort::Value &onnx, const eval::Value &vespa) {
+ return std::make_unique<ResultConverter<SRC,DST>>(onnx, vespa);
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+template <typename E> vespalib::string type_name(E enum_value) {
+ return typify_invoke<1,MyTypify,TypeToString>(enum_value);
+}
+
+template <typename E1, typename E2> bool is_same_type(E1 e1, E2 e2) {
+ return typify_invoke<2,MyTypify,IsSameType>(e1, e2);
}
-ValueType::CellType as_cell_type(Onnx::TensorInfo::ElementType type) {
- if (type == Onnx::TensorInfo::ElementType::FLOAT) {
- return ValueType::CellType::FLOAT;
+Ort::Value create_onnx_tensor(const Onnx::TensorType &type, OrtAllocator *alloc) {
+ return typify_invoke<1,MyTypify,CreateOnnxTensor>(type.elements, type.dimensions, alloc);
+}
+
+eval::Value::UP create_vespa_tensor_ref(const eval::ValueType &type_ref, Ort::Value &value) {
+ return typify_invoke<1,MyTypify,CreateVespaTensorRef>(type_ref.cell_type(), type_ref, value);
+}
+
+eval::Value::UP create_vespa_tensor(const eval::ValueType &type) {
+ return typify_invoke<1,MyTypify,CreateVespaTensor>(type.cell_type(), type);
+}
+
+ParamBinder::UP create_param_adapter(eval::ValueType::CellType ct, const Onnx::TensorType &type, const Ort::MemoryInfo &memory) {
+ return typify_invoke<1,MyTypify,CreateParamAdapter>(ct, type, memory);
+}
+
+ParamBinder::UP create_param_converter(eval::ValueType::CellType ct, Onnx::ElementType et) {
+ return typify_invoke<2,MyTypify,CreateParamConverter>(ct, et);
+}
+
+EvalHook::UP create_result_converter(Onnx::ElementType et, Ort::Value &onnx, const eval::Value &vespa) {
+ return typify_invoke<2,MyTypify,CreateResultConverter>(et, vespa.type().cell_type(), onnx, vespa);
+}
+
+//-----------------------------------------------------------------------------
+
+auto convert_optimize(Onnx::Optimize optimize) {
+ switch (optimize) {
+ case Onnx::Optimize::ENABLE: return ORT_ENABLE_ALL;
+ case Onnx::Optimize::DISABLE: return ORT_DISABLE_ALL;
}
- if (type == Onnx::TensorInfo::ElementType::DOUBLE) {
- return ValueType::CellType::DOUBLE;
+ abort();
+}
+
+ValueType::CellType to_cell_type(Onnx::ElementType type) {
+ switch (type) {
+ case Onnx::ElementType::INT8: [[fallthrough]];
+ case Onnx::ElementType::INT16: [[fallthrough]];
+ case Onnx::ElementType::UINT8: [[fallthrough]];
+ case Onnx::ElementType::UINT16: [[fallthrough]];
+ case Onnx::ElementType::FLOAT: return ValueType::CellType::FLOAT;
+ case Onnx::ElementType::INT32: [[fallthrough]];
+ case Onnx::ElementType::INT64: [[fallthrough]];
+ case Onnx::ElementType::UINT32: [[fallthrough]];
+ case Onnx::ElementType::UINT64: [[fallthrough]];
+ case Onnx::ElementType::DOUBLE: return ValueType::CellType::DOUBLE;
}
abort();
}
-auto convert_optimize(Onnx::Optimize optimize) {
- if (optimize == Onnx::Optimize::ENABLE) {
- return ORT_ENABLE_ALL;
- } else {
- assert(optimize == Onnx::Optimize::DISABLE);
- return ORT_DISABLE_ALL;
+Onnx::ElementType make_element_type(ONNXTensorElementDataType element_type) {
+ switch (element_type) {
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: return Onnx::ElementType::INT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: return Onnx::ElementType::INT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: return Onnx::ElementType::INT32;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: return Onnx::ElementType::INT64;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: return Onnx::ElementType::UINT8;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: return Onnx::ElementType::UINT16;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: return Onnx::ElementType::UINT32;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: return Onnx::ElementType::UINT64;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: return Onnx::ElementType::FLOAT;
+ case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: return Onnx::ElementType::DOUBLE;
+ default:
+ throw Ort::Exception(fmt("[onnx wrapper] unsupported element type: %d", element_type), ORT_FAIL);
}
}
+//-----------------------------------------------------------------------------
+
class OnnxString {
private:
static Ort::AllocatorWithDefaultOptions _alloc;
@@ -98,22 +276,20 @@ std::vector<Onnx::DimSize> make_dimensions(const Ort::TensorTypeAndShapeInfo &te
return result;
}
-Onnx::TensorInfo::ElementType make_element_type(ONNXTensorElementDataType element_type) {
- if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
- return Onnx::TensorInfo::ElementType::FLOAT;
- } else if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
- return Onnx::TensorInfo::ElementType::DOUBLE;
- } else {
- return Onnx::TensorInfo::ElementType::UNKNOWN;
- }
-}
-
Onnx::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) {
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
auto element_type = tensor_info.GetElementType();
return Onnx::TensorInfo{vespalib::string(name.get()), make_dimensions(tensor_info), make_element_type(element_type)};
}
+std::vector<int64_t> extract_sizes(const eval::ValueType &type) {
+ std::vector<int64_t> sizes;
+ for (const auto &dim: type.dimensions()) {
+ sizes.push_back(dim.size);
+ }
+ return sizes;
+}
+
}
vespalib::string
@@ -131,7 +307,7 @@ Onnx::DimSize::as_string() const
vespalib::string
Onnx::TensorInfo::type_as_string() const
{
- vespalib::string res = to_str(elements);
+ vespalib::string res = type_name(elements);
for (const auto &dim: dimensions) {
res += dim.as_string();
}
@@ -142,6 +318,8 @@ Onnx::TensorInfo::~TensorInfo() = default;
//-----------------------------------------------------------------------------
+Onnx::WireInfo::~WireInfo() = default;
+
Onnx::WirePlanner::~WirePlanner() = default;
bool
@@ -150,13 +328,6 @@ Onnx::WirePlanner::bind_input_type(const eval::ValueType &vespa_in, const Tensor
const auto &type = vespa_in;
const auto &name = onnx_in.name;
const auto &dimensions = onnx_in.dimensions;
- const auto &elements = onnx_in.elements;
- if ((elements == TensorInfo::ElementType::UNKNOWN) || dimensions.empty()) {
- return false;
- }
- if (type.cell_type() != as_cell_type(elements)) {
- return false;
- }
if (type.dimensions().size() != dimensions.size()) {
return false;
}
@@ -172,10 +343,9 @@ Onnx::WirePlanner::bind_input_type(const eval::ValueType &vespa_in, const Tensor
} else if (bound_size != type.dimensions()[i].size) {
return false;
}
- } else {
- _unknown_sizes[std::make_pair(name,i)] = type.dimensions()[i].size;
}
}
+ _input_types.emplace(name, type);
return true;
}
@@ -184,9 +354,6 @@ Onnx::WirePlanner::make_output_type(const TensorInfo &onnx_out) const
{
const auto &dimensions = onnx_out.dimensions;
const auto &elements = onnx_out.elements;
- if ((elements == TensorInfo::ElementType::UNKNOWN) || dimensions.empty()) {
- return ValueType::error_type();
- }
std::vector<ValueType::Dimension> dim_list;
for (const auto &dim: dimensions) {
size_t dim_size = dim.value;
@@ -201,7 +368,7 @@ Onnx::WirePlanner::make_output_type(const TensorInfo &onnx_out) const
}
dim_list.emplace_back(fmt("d%zu", dim_list.size()), dim_size);
}
- return ValueType::tensor_type(std::move(dim_list), as_cell_type(elements));
+ return ValueType::tensor_type(std::move(dim_list), to_cell_type(elements));
}
Onnx::WireInfo
@@ -209,26 +376,32 @@ Onnx::WirePlanner::get_wire_info(const Onnx &model) const
{
WireInfo info;
for (const auto &input: model.inputs()) {
- size_t input_idx = 0;
- std::vector<int64_t> sizes;
- for (const auto &dim: input.dimensions) {
- if (dim.is_known()) {
- sizes.push_back(dim.value);
- } else if (dim.is_symbolic()) {
- const auto &pos = _symbolic_sizes.find(dim.name);
- assert(pos != _symbolic_sizes.end());
- sizes.push_back(pos->second);
- } else {
- const auto &pos = _unknown_sizes.find(std::make_pair(input.name, input_idx));
- assert(pos != _unknown_sizes.end());
- sizes.push_back(pos->second);
- }
- ++input_idx;
+ const auto &pos = _input_types.find(input.name);
+ assert(pos != _input_types.end());
+ auto vespa_type = pos->second;
+ info.onnx_inputs.emplace_back(input.elements, extract_sizes(vespa_type));
+ info.vespa_inputs.push_back(std::move(vespa_type));
+ if (!is_same_type(info.vespa_inputs.back().cell_type(),
+ info.onnx_inputs.back().elements))
+ {
+ LOG(warning, "input '%s' with element type '%s' is bound to vespa value with cell type '%s'; "
+ "adding explicit conversion step (this conversion might be lossy)",
+ input.name.c_str(), type_name(info.onnx_inputs.back().elements).c_str(),
+ type_name(info.vespa_inputs.back().cell_type()).c_str());
}
- info.input_sizes.push_back(sizes);
}
for (const auto &output: model.outputs()) {
- info.output_types.push_back(make_output_type(output));
+ auto vespa_type = make_output_type(output);
+ info.onnx_outputs.emplace_back(output.elements, extract_sizes(vespa_type));
+ info.vespa_outputs.push_back(std::move(vespa_type));
+ if (!is_same_type(info.vespa_outputs.back().cell_type(),
+ info.onnx_outputs.back().elements))
+ {
+ LOG(warning, "output '%s' with element type '%s' is bound to vespa value with cell type '%s'; "
+ "adding explicit conversion step (this conversion might be lossy)",
+ output.name.c_str(), type_name(info.onnx_outputs.back().elements).c_str(),
+ type_name(info.vespa_outputs.back().cell_type()).c_str());
+ }
}
return info;
}
@@ -243,36 +416,42 @@ Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info)
_cpu_memory(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)),
_param_values(),
_result_values(),
- _result_views()
+ _results(),
+ _param_binders(),
+ _eval_hooks()
{
- assert(_wire_info.input_sizes.size() == _model.inputs().size());
- assert(_wire_info.output_types.size() == _model.outputs().size());
- for (const auto &input: _wire_info.input_sizes) {
- (void) input;
- _param_values.push_back(Ort::Value(nullptr));
- }
- std::vector<int64_t> dim_sizes;
- size_t num_cells;
- dim_sizes.reserve(16);
- // NB: output type must be reference inside vector since the view does not copy it
- for (const auto &output: _wire_info.output_types) {
- num_cells = 1;
- dim_sizes.clear();
- for (const auto &dim: output.dimensions()) {
- dim_sizes.push_back(dim.size);
- num_cells *= dim.size;
+ assert(_wire_info.vespa_inputs.size() == _model.inputs().size());
+ assert(_wire_info.onnx_inputs.size() == _model.inputs().size());
+ assert(_wire_info.onnx_outputs.size() == _model.outputs().size());
+ assert(_wire_info.vespa_outputs.size() == _model.outputs().size());
+ _param_values.reserve(_model.inputs().size());
+ _result_values.reserve(_model.outputs().size());
+ _results.reserve(_model.outputs().size());
+ auto result_guard = _result_values.begin();
+ for (size_t i = 0; i < _model.inputs().size(); ++i) {
+ const auto &vespa = _wire_info.vespa_inputs[i];
+ const auto &onnx = _wire_info.onnx_inputs[i];
+ if (is_same_type(vespa.cell_type(), onnx.elements)) {
+ _param_values.push_back(Ort::Value(nullptr));
+ _param_binders.push_back(create_param_adapter(vespa.cell_type(), onnx, _cpu_memory));
+ } else {
+ _param_values.push_back(create_onnx_tensor(onnx, _alloc));
+ _param_binders.push_back(create_param_converter(vespa.cell_type(), onnx.elements));
}
- if (output.cell_type() == ValueType::CellType::FLOAT) {
- _result_values.push_back(Ort::Value::CreateTensor<float>(_alloc, dim_sizes.data(), dim_sizes.size()));
- ConstArrayRef<float> cells(_result_values.back().GetTensorMutableData<float>(), num_cells);
- _result_views.emplace_back(output, TypedCells(cells));
+ }
+ for (size_t i = 0; i < _model.outputs().size(); ++i) {
+ const auto &vespa = _wire_info.vespa_outputs[i];
+ const auto &onnx = _wire_info.onnx_outputs[i];
+ _result_values.push_back(create_onnx_tensor(onnx, _alloc));
+ if (is_same_type(vespa.cell_type(), onnx.elements)) {
+ _results.push_back(create_vespa_tensor_ref(vespa, _result_values.back()));
} else {
- assert(output.cell_type() == ValueType::CellType::DOUBLE);
- _result_values.push_back(Ort::Value::CreateTensor<double>(_alloc, dim_sizes.data(), dim_sizes.size()));
- ConstArrayRef<double> cells(_result_values.back().GetTensorMutableData<double>(), num_cells);
- _result_views.emplace_back(output, TypedCells(cells));
+ _results.push_back(create_vespa_tensor(vespa));
+ _eval_hooks.push_back(create_result_converter(onnx.elements, _result_values.back(), *_results.back().get()));
}
}
+ // make sure references to Ort::Value inside _result_values are safe
+ assert(result_guard == _result_values.begin());
}
Onnx::EvalContext::~EvalContext() = default;
@@ -280,36 +459,26 @@ Onnx::EvalContext::~EvalContext() = default;
void
Onnx::EvalContext::bind_param(size_t i, const eval::Value &param)
{
- // NB: dense tensors are always (sub)classes of DenseTensorView
- const auto &cells_ref = static_cast<const DenseTensorView &>(param).cellsRef();
- const auto &input_sizes = _wire_info.input_sizes;
- if (cells_ref.type == ValueType::CellType::FLOAT) {
- // NB: create requires non-const input
- auto cells = unconstify(cells_ref.typify<float>());
- _param_values[i] = Ort::Value::CreateTensor<float>(_cpu_memory, cells.begin(), cells.size(), input_sizes[i].data(), input_sizes[i].size());
- } else {
- assert(cells_ref.type == ValueType::CellType::DOUBLE);
- // NB: create requires non-const input
- auto cells = unconstify(cells_ref.typify<double>());
- _param_values[i] = Ort::Value::CreateTensor<double>(_cpu_memory, cells.begin(), cells.size(), input_sizes[i].data(), input_sizes[i].size());
- }
+ _param_binders[i]->bind(param, _param_values[i]);
}
void
Onnx::EvalContext::eval()
{
- // NB: Run requires non-const session
Ort::Session &session = const_cast<Ort::Session&>(_model._session);
Ort::RunOptions run_opts(nullptr);
session.Run(run_opts,
_model._input_name_refs.data(), _param_values.data(), _param_values.size(),
_model._output_name_refs.data(), _result_values.data(), _result_values.size());
+ for (const auto &hook: _eval_hooks) {
+ hook->invoke();
+ }
}
const eval::Value &
Onnx::EvalContext::get_result(size_t i) const
{
- return _result_views[i];
+ return *_results[i];
}
//-----------------------------------------------------------------------------
@@ -334,10 +503,16 @@ Onnx::extract_meta_data()
size_t num_inputs = _session.GetInputCount();
for (size_t i = 0; i < num_inputs; ++i) {
_inputs.push_back(make_tensor_info(OnnxString::get_input_name(_session, i), _session.GetInputTypeInfo(i)));
+ if (_inputs.back().dimensions.empty()) {
+ throw Ort::Exception(fmt("[onnx wrapper] input '%s' has unspecified type, this is not supported", _inputs.back().name.c_str()), ORT_FAIL);
+ }
}
size_t num_outputs = _session.GetOutputCount();
for (size_t i = 0; i < num_outputs; ++i) {
_outputs.push_back(make_tensor_info(OnnxString::get_output_name(_session, i), _session.GetOutputTypeInfo(i)));
+ if (_outputs.back().dimensions.empty()) {
+ throw Ort::Exception(fmt("[onnx wrapper] output '%s' has unspecified type, this is not supported", _outputs.back().name.c_str()), ORT_FAIL);
+ }
}
for (const auto &input: _inputs) {
_input_name_refs.push_back(input.name.c_str());
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
index 857098d0e3e..4d2ef6ba50d 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
@@ -49,9 +49,11 @@ public:
vespalib::string as_string() const;
};
+ // supported onnx element types
+ enum class ElementType { INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64, FLOAT, DOUBLE };
+
// information about a single input or output tensor
struct TensorInfo {
- enum class ElementType { FLOAT, DOUBLE, UNKNOWN };
vespalib::string name;
std::vector<DimSize> dimensions;
ElementType elements;
@@ -59,20 +61,30 @@ public:
~TensorInfo();
};
+ // concrete tensor type with known dimension sizes
+ struct TensorType {
+ ElementType elements;
+ std::vector<int64_t> dimensions;
+ TensorType(ElementType elements_in, std::vector<int64_t> dimensions_in)
+ : elements(elements_in), dimensions(std::move(dimensions_in)) {}
+ };
+
// how the model should be wired with inputs/outputs
struct WireInfo {
- std::vector<std::vector<int64_t>> input_sizes;
- std::vector<eval::ValueType> output_types;
- WireInfo() : input_sizes(), output_types() {}
+ std::vector<eval::ValueType> vespa_inputs;
+ std::vector<Onnx::TensorType> onnx_inputs;
+ std::vector<Onnx::TensorType> onnx_outputs;
+ std::vector<eval::ValueType> vespa_outputs;
+ ~WireInfo();
};
// planning how we should wire the model based on input types
class WirePlanner {
private:
+ std::map<vespalib::string,eval::ValueType> _input_types;
std::map<vespalib::string,size_t> _symbolic_sizes;
- std::map<std::pair<vespalib::string,size_t>,size_t> _unknown_sizes;
public:
- WirePlanner() : _symbolic_sizes(), _unknown_sizes() {}
+ WirePlanner() : _input_types(), _symbolic_sizes() {}
~WirePlanner();
bool bind_input_type(const eval::ValueType &vespa_in, const TensorInfo &onnx_in);
eval::ValueType make_output_type(const TensorInfo &onnx_out) const;
@@ -83,15 +95,29 @@ public:
// all parameter values are expected to be bound per evaluation
// output values are pre-allocated and will not change
class EvalContext {
+ public:
+ struct ParamBinder {
+ using UP = std::unique_ptr<ParamBinder>;
+ virtual void bind(const eval::Value &vespa, Ort::Value &onnx) = 0;
+ virtual ~ParamBinder() {}
+ };
+ struct EvalHook {
+ using UP = std::unique_ptr<EvalHook>;
+ virtual void invoke() = 0;
+ virtual ~EvalHook() {}
+ };
+
private:
static Ort::AllocatorWithDefaultOptions _alloc;
- const Onnx &_model;
- const WireInfo &_wire_info;
- Ort::MemoryInfo _cpu_memory;
- std::vector<Ort::Value> _param_values;
- std::vector<Ort::Value> _result_values;
- std::vector<DenseTensorView> _result_views;
+ const Onnx &_model;
+ const WireInfo &_wire_info;
+ Ort::MemoryInfo _cpu_memory;
+ std::vector<Ort::Value> _param_values;
+ std::vector<Ort::Value> _result_values;
+ std::vector<eval::Value::UP> _results;
+ std::vector<ParamBinder::UP> _param_binders;
+ std::vector<EvalHook::UP> _eval_hooks;
public:
EvalContext(const Onnx &model, const WireInfo &wire_info);
diff --git a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
index 21f66735748..1cc8d0280f6 100644
--- a/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
+++ b/searchcore/src/tests/proton/verify_ranksetup/verify_ranksetup_test.cpp
@@ -344,7 +344,7 @@ TEST_F("require that onnx model can be verified", OnnxSetup()) {
}
TEST_F("require that input type mismatch makes onnx model fail verification", OnnxSetup()) {
- f.rank_expr("query_tensor", "tensor<double>(a[1],b[4]):[[1,2,3,4]]"); // <- double vs float
+ f.rank_expr("query_tensor", "tensor<float>(a[1],b[3]):[[1,2,3]]"); // <- 3 vs 4
f.rank_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
f.rank_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]");
f.verify_invalid({"onnxModel(simple)"});
diff --git a/vespalib/src/vespa/vespalib/util/classname.h b/vespalib/src/vespa/vespalib/util/classname.h
index aea371b995d..23dc3659fdd 100644
--- a/vespalib/src/vespa/vespalib/util/classname.h
+++ b/vespalib/src/vespa/vespalib/util/classname.h
@@ -8,9 +8,13 @@ namespace vespalib {
string demangle(const char * native);
template <typename T>
-string
-getClassName(const T & obj) {
+string getClassName(const T & obj) {
return demangle(typeid(obj).name());
}
+template <typename T>
+string getClassName() {
+ return demangle(typeid(T).name());
+}
+
}