summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp95
1 files changed, 95 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index b474d2458b9..4fd527a7dfb 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -1,12 +1,16 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/eval/eval/tensor_spec.h>
+#include <vespa/eval/eval/int8float.h>
#include <vespa/eval/onnx/onnx_wrapper.h>
+#include <vespa/vespalib/util/bfloat16.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
using namespace vespalib::eval;
+using vespalib::BFloat16;
+
using vespalib::make_string_short::fmt;
using TensorInfo = Onnx::TensorInfo;
using ElementType = Onnx::ElementType;
@@ -21,6 +25,7 @@ std::string simple_model = source_dir + "/simple.onnx";
std::string dynamic_model = source_dir + "/dynamic.onnx";
std::string int_types_model = source_dir + "/int_types.onnx";
std::string guess_batch_model = source_dir + "/guess_batch.onnx";
+std::string unstable_types_model = source_dir + "/unstable_types.onnx";
void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
@@ -306,4 +311,94 @@ TEST(OnnxTest, we_guess_batch_dimension_size_when_inference_fails) {
//-------------------------------------------------------------------------
}
+TEST(OnnxTest, zero_copy_unstable_types) {
+ Onnx model(unstable_types_model, Onnx::Optimize::ENABLE);
+ ASSERT_EQ(model.inputs().size(), 2);
+ ASSERT_EQ(model.outputs().size(), 2);
+
+ ValueType in8_type = ValueType::from_spec("tensor<int8>(a[3])");
+ std::vector<Int8Float> in8_values({1.0, 2.0, 3.0});
+ DenseValueView in8(in8_type, TypedCells(in8_values));
+
+ ValueType in16_type = ValueType::from_spec("tensor<bfloat16>(a[3])");
+ std::vector<BFloat16> in16_values({4.0, 5.0, 6.0});
+ DenseValueView in16(in16_type, TypedCells(in16_values));
+
+ Onnx::WirePlanner planner;
+ EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0]));
+ EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1]));
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor<bfloat16>(d0[3])");
+
+ auto wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &out8 = ctx.get_result(0);
+ const Value &out16 = ctx.get_result(1);
+ EXPECT_EQ(out8.type().to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(out16.type().to_spec(), "tensor<bfloat16>(d0[3])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, in8);
+ ctx.bind_param(1, in16);
+ ctx.eval();
+ auto cells8 = out8.cells();
+ auto cells16 = out16.cells();
+ ASSERT_EQ(cells8.type, CellType::INT8);
+ ASSERT_EQ(cells16.type, CellType::BFLOAT16);
+ ASSERT_EQ(cells8.size, 3);
+ ASSERT_EQ(cells16.size, 3);
+ EXPECT_EQ(cells8.typify<Int8Float>()[0], 4.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[1], 5.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[2], 6.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[0], 1.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[1], 2.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[2], 3.0);
+ //-------------------------------------------------------------------------
+}
+
+TEST(OnnxTest, converted_unstable_types) {
+ Onnx model(unstable_types_model, Onnx::Optimize::ENABLE);
+ ASSERT_EQ(model.inputs().size(), 2);
+ ASSERT_EQ(model.outputs().size(), 2);
+
+ ValueType in8_type = ValueType::from_spec("tensor<float>(a[3])");
+ std::vector<float> in8_values({1.0, 2.0, 3.0});
+ DenseValueView in8(in8_type, TypedCells(in8_values));
+
+ ValueType in16_type = ValueType::from_spec("tensor<float>(a[3])");
+ std::vector<float> in16_values({4.0, 5.0, 6.0});
+ DenseValueView in16(in16_type, TypedCells(in16_values));
+
+ Onnx::WirePlanner planner;
+ EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0]));
+ EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1]));
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor<bfloat16>(d0[3])");
+
+ auto wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &out8 = ctx.get_result(0);
+ const Value &out16 = ctx.get_result(1);
+ EXPECT_EQ(out8.type().to_spec(), "tensor<int8>(d0[3])");
+ EXPECT_EQ(out16.type().to_spec(), "tensor<bfloat16>(d0[3])");
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, in8);
+ ctx.bind_param(1, in16);
+ ctx.eval();
+ auto cells8 = out8.cells();
+ auto cells16 = out16.cells();
+ ASSERT_EQ(cells8.type, CellType::INT8);
+ ASSERT_EQ(cells16.type, CellType::BFLOAT16);
+ ASSERT_EQ(cells8.size, 3);
+ ASSERT_EQ(cells16.size, 3);
+ EXPECT_EQ(cells8.typify<Int8Float>()[0], 4.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[1], 5.0);
+ EXPECT_EQ(cells8.typify<Int8Float>()[2], 6.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[0], 1.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[1], 2.0);
+ EXPECT_EQ(cells16.typify<BFloat16>()[2], 3.0);
+ //-------------------------------------------------------------------------
+}
+
GTEST_MAIN_RUN_ALL_TESTS()