aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp76
1 files changed, 56 insertions, 20 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index db2415e9969..23c41167266 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -11,6 +11,7 @@ using namespace vespalib::tensor;
using vespalib::make_string_short::fmt;
using TensorInfo = Onnx::TensorInfo;
+using ElementType = Onnx::ElementType;
using DZ = Onnx::DimSize;
std::string get_source_dir() {
@@ -20,6 +21,7 @@ std::string get_source_dir() {
std::string source_dir = get_source_dir();
std::string simple_model = source_dir + "/simple.onnx";
std::string dynamic_model = source_dir + "/dynamic.onnx";
+std::string int_types_model = source_dir + "/int_types.onnx";
void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
@@ -28,24 +30,12 @@ void dump_info(const char *ctx, const std::vector<TensorInfo> &info) {
}
}
-TEST(WirePlannerTest, element_types_must_match) {
- Onnx::WirePlanner planner;
- ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
- ValueType type2 = ValueType::from_spec("tensor<double>(a[5])");
- TensorInfo info1 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::FLOAT};
- TensorInfo info2 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::DOUBLE};
- EXPECT_TRUE(planner.bind_input_type(type1, info1));
- EXPECT_FALSE(planner.bind_input_type(type2, info1));
- EXPECT_FALSE(planner.bind_input_type(type1, info2));
- EXPECT_TRUE(planner.bind_input_type(type2, info2));
-}
-
TEST(WirePlannerTest, known_dimension_sizes_must_match) {
Onnx::WirePlanner planner;
ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[5])");
ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])");
- TensorInfo info = TensorInfo{"info", {DZ(5),DZ(5)}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ(5),DZ(5)}, ElementType::FLOAT};
EXPECT_FALSE(planner.bind_input_type(type1, info));
EXPECT_FALSE(planner.bind_input_type(type2, info));
EXPECT_TRUE(planner.bind_input_type(type3, info));
@@ -55,7 +45,7 @@ TEST(WirePlannerTest, symbolic_dimension_sizes_must_match) {
Onnx::WirePlanner planner;
ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10])");
- TensorInfo info = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ("dim")}, ElementType::FLOAT};
EXPECT_TRUE(planner.bind_input_type(type1, info)); // binds 'dim' to 5
EXPECT_FALSE(planner.bind_input_type(type2, info));
EXPECT_TRUE(planner.bind_input_type(type1, info));
@@ -65,7 +55,7 @@ TEST(WirePlannerTest, unknown_dimension_sizes_match_anything) {
Onnx::WirePlanner planner;
ValueType type1 = ValueType::from_spec("tensor<float>(a[5])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10])");
- TensorInfo info = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ()}, ElementType::FLOAT};
EXPECT_TRUE(planner.bind_input_type(type1, info));
EXPECT_TRUE(planner.bind_input_type(type2, info));
}
@@ -73,9 +63,9 @@ TEST(WirePlannerTest, unknown_dimension_sizes_match_anything) {
TEST(WirePlannerTest, all_output_dimensions_must_be_bound) {
Onnx::WirePlanner planner;
ValueType type = ValueType::from_spec("tensor<float>(a[5],b[10])");
- TensorInfo info1 = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT};
- TensorInfo info2 = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT};
- TensorInfo info3 = TensorInfo{"info", {DZ("dim"),DZ()}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info1 = TensorInfo{"info", {DZ()}, ElementType::FLOAT};
+ TensorInfo info2 = TensorInfo{"info", {DZ("dim")}, ElementType::FLOAT};
+ TensorInfo info3 = TensorInfo{"info", {DZ("dim"),DZ()}, ElementType::FLOAT};
EXPECT_TRUE(planner.make_output_type(info1).is_error());
EXPECT_TRUE(planner.make_output_type(info2).is_error());
EXPECT_TRUE(planner.make_output_type(info3).is_error());
@@ -90,7 +80,7 @@ TEST(WirePlannerTest, dimensions_resolve_left_to_right) {
ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])");
ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[10])");
ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])");
- TensorInfo info = TensorInfo{"info", {DZ("dim"),DZ("dim")}, TensorInfo::ElementType::FLOAT};
+ TensorInfo info = TensorInfo{"info", {DZ("dim"),DZ("dim")}, ElementType::FLOAT};
EXPECT_FALSE(planner.bind_input_type(type1, info)); // binds 'dim' to 5, then fails (5 != 10)
EXPECT_FALSE(planner.bind_input_type(type2, info));
EXPECT_TRUE(planner.bind_input_type(type3, info));
@@ -180,7 +170,7 @@ TEST(OnnxTest, simple_onnx_model_can_be_evaluated)
DenseTensorView new_bias(bias_type, TypedCells(new_bias_values));
ctx.bind_param(2, new_bias);
ctx.eval();
- EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
+ EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
//-------------------------------------------------------------------------
}
@@ -230,4 +220,50 @@ TEST(OnnxTest, dynamic_onnx_model_can_be_evaluated)
//-------------------------------------------------------------------------
}
+TEST(OnnxTest, int_types_onnx_model_can_be_evaluated)
+{
+ Onnx model(int_types_model, Onnx::Optimize::ENABLE);
+ Onnx::WirePlanner planner;
+
+ ValueType query_type = ValueType::from_spec("tensor<float>(a[1],b[4])");
+ std::vector<float> query_values({1.0, 2.0, 3.0, 4.0});
+ DenseTensorView query(query_type, TypedCells(query_values));
+ EXPECT_TRUE(planner.bind_input_type(query_type, model.inputs()[0]));
+
+ ValueType attribute_type = ValueType::from_spec("tensor<double>(a[4],b[1])");
+ std::vector<double> attribute_values({5.0, 6.0, 7.0, 8.0});
+ DenseTensorView attribute(attribute_type, TypedCells(attribute_values));
+ EXPECT_TRUE(planner.bind_input_type(attribute_type, model.inputs()[1]));
+
+ ValueType bias_type = ValueType::from_spec("tensor<double>(a[1],b[1])");
+ std::vector<double> bias_values({9.0});
+ DenseTensorView bias(bias_type, TypedCells(bias_values));
+ EXPECT_TRUE(planner.bind_input_type(bias_type, model.inputs()[2]));
+
+ EXPECT_EQ(planner.make_output_type(model.outputs()[0]),
+ ValueType::from_spec("tensor<double>(d0[1],d1[1])"));
+
+ Onnx::WireInfo wire_info = planner.get_wire_info(model);
+ Onnx::EvalContext ctx(model, wire_info);
+
+ const Value &output = ctx.get_result(0);
+ EXPECT_EQ(output.type(), ValueType::from_spec("tensor<double>(d0[1],d1[1])"));
+ //-------------------------------------------------------------------------
+ ctx.bind_param(0, query);
+ ctx.bind_param(1, attribute);
+ ctx.bind_param(2, bias);
+ ctx.eval();
+ auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
+ EXPECT_EQ(cells.type, ValueType::CellType::DOUBLE);
+ EXPECT_EQ(cells.size, 1);
+ EXPECT_EQ(cells.get(0), 79.0);
+ //-------------------------------------------------------------------------
+ std::vector<double> new_bias_values({10.0});
+ DenseTensorView new_bias(bias_type, TypedCells(new_bias_values));
+ ctx.bind_param(2, new_bias);
+ ctx.eval();
+ EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0);
+ //-------------------------------------------------------------------------
+}
+
GTEST_MAIN_RUN_ALL_TESTS()