diff options
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp')
-rw-r--r-- | eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp index b474d2458b9..4fd527a7dfb 100644 --- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp +++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp @@ -1,12 +1,16 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/eval/eval/tensor_spec.h> +#include <vespa/eval/eval/int8float.h> #include <vespa/eval/onnx/onnx_wrapper.h> +#include <vespa/vespalib/util/bfloat16.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/gtest/gtest.h> using namespace vespalib::eval; +using vespalib::BFloat16; + using vespalib::make_string_short::fmt; using TensorInfo = Onnx::TensorInfo; using ElementType = Onnx::ElementType; @@ -21,6 +25,7 @@ std::string simple_model = source_dir + "/simple.onnx"; std::string dynamic_model = source_dir + "/dynamic.onnx"; std::string int_types_model = source_dir + "/int_types.onnx"; std::string guess_batch_model = source_dir + "/guess_batch.onnx"; +std::string unstable_types_model = source_dir + "/unstable_types.onnx"; void dump_info(const char *ctx, const std::vector<TensorInfo> &info) { fprintf(stderr, "%s:\n", ctx); @@ -306,4 +311,94 @@ TEST(OnnxTest, we_guess_batch_dimension_size_when_inference_fails) { //------------------------------------------------------------------------- } +TEST(OnnxTest, zero_copy_unstable_types) { + Onnx model(unstable_types_model, Onnx::Optimize::ENABLE); + ASSERT_EQ(model.inputs().size(), 2); + ASSERT_EQ(model.outputs().size(), 2); + + ValueType in8_type = ValueType::from_spec("tensor<int8>(a[3])"); + std::vector<Int8Float> in8_values({1.0, 2.0, 3.0}); + DenseValueView in8(in8_type, TypedCells(in8_values)); + + ValueType in16_type = ValueType::from_spec("tensor<bfloat16>(a[3])"); + std::vector<BFloat16> in16_values({4.0, 5.0, 6.0}); + DenseValueView in16(in16_type, TypedCells(in16_values)); + + Onnx::WirePlanner planner; + EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0])); + EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1])); + EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[3])"); + EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor<bfloat16>(d0[3])"); + + auto wire_info = planner.get_wire_info(model); + Onnx::EvalContext ctx(model, wire_info); + + const Value &out8 = ctx.get_result(0); + const Value &out16 = ctx.get_result(1); + EXPECT_EQ(out8.type().to_spec(), "tensor<int8>(d0[3])"); + EXPECT_EQ(out16.type().to_spec(), "tensor<bfloat16>(d0[3])"); + //------------------------------------------------------------------------- + ctx.bind_param(0, in8); + ctx.bind_param(1, in16); + ctx.eval(); + auto cells8 = out8.cells(); + auto cells16 = out16.cells(); + ASSERT_EQ(cells8.type, CellType::INT8); + ASSERT_EQ(cells16.type, CellType::BFLOAT16); + ASSERT_EQ(cells8.size, 3); + ASSERT_EQ(cells16.size, 3); + EXPECT_EQ(cells8.typify<Int8Float>()[0], 4.0); + EXPECT_EQ(cells8.typify<Int8Float>()[1], 5.0); + EXPECT_EQ(cells8.typify<Int8Float>()[2], 6.0); + EXPECT_EQ(cells16.typify<BFloat16>()[0], 1.0); + EXPECT_EQ(cells16.typify<BFloat16>()[1], 2.0); + EXPECT_EQ(cells16.typify<BFloat16>()[2], 3.0); + //------------------------------------------------------------------------- +} + +TEST(OnnxTest, converted_unstable_types) { + Onnx model(unstable_types_model, Onnx::Optimize::ENABLE); + ASSERT_EQ(model.inputs().size(), 2); + ASSERT_EQ(model.outputs().size(), 2); + + ValueType in8_type = ValueType::from_spec("tensor<float>(a[3])"); + std::vector<float> in8_values({1.0, 2.0, 3.0}); + DenseValueView in8(in8_type, TypedCells(in8_values)); + + ValueType in16_type = ValueType::from_spec("tensor<float>(a[3])"); + std::vector<float> in16_values({4.0, 5.0, 6.0}); + DenseValueView in16(in16_type, TypedCells(in16_values)); + + Onnx::WirePlanner planner; + EXPECT_TRUE(planner.bind_input_type(in8_type, model.inputs()[0])); + EXPECT_TRUE(planner.bind_input_type(in16_type, model.inputs()[1])); + EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), "tensor<int8>(d0[3])"); + EXPECT_EQ(planner.make_output_type(model.outputs()[1]).to_spec(), "tensor<bfloat16>(d0[3])"); + + auto wire_info = planner.get_wire_info(model); + Onnx::EvalContext ctx(model, wire_info); + + const Value &out8 = ctx.get_result(0); + const Value &out16 = ctx.get_result(1); + EXPECT_EQ(out8.type().to_spec(), "tensor<int8>(d0[3])"); + EXPECT_EQ(out16.type().to_spec(), "tensor<bfloat16>(d0[3])"); + //------------------------------------------------------------------------- + ctx.bind_param(0, in8); + ctx.bind_param(1, in16); + ctx.eval(); + auto cells8 = out8.cells(); + auto cells16 = out16.cells(); + ASSERT_EQ(cells8.type, CellType::INT8); + ASSERT_EQ(cells16.type, CellType::BFLOAT16); + ASSERT_EQ(cells8.size, 3); + ASSERT_EQ(cells16.size, 3); + EXPECT_EQ(cells8.typify<Int8Float>()[0], 4.0); + EXPECT_EQ(cells8.typify<Int8Float>()[1], 5.0); + EXPECT_EQ(cells8.typify<Int8Float>()[2], 6.0); + EXPECT_EQ(cells16.typify<BFloat16>()[0], 1.0); + EXPECT_EQ(cells16.typify<BFloat16>()[1], 2.0); + EXPECT_EQ(cells16.typify<BFloat16>()[2], 3.0); + //------------------------------------------------------------------------- +} + GTEST_MAIN_RUN_ALL_TESTS() |