diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-09-22 08:28:02 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-09-22 10:15:54 +0000 |
commit | 9a66f21d375e4fa07a96069644615581e55129d5 (patch) | |
tree | 06891b9c7f91fd0cf673cd1c540cb0dcafd5e950 /searchlib/src/tests/features | |
parent | 804e8057c2eca61ec9bc8985430613e0731922a2 (diff) |
handle onnx model config for inputs and outputs
Diffstat (limited to 'searchlib/src/tests/features')
-rw-r--r-- | searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp | 26 |
1 files changed, 21 insertions, 5 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index b49d9c365de..6a1e4ef9fa1 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -58,8 +58,8 @@ struct OnnxFeatureTest : ::testing::Test { vespalib::string expr_name = feature_name + ".rankingScript"; indexEnv.getProperties().add(expr_name, expr); } - void add_onnx(const vespalib::string &name, const vespalib::string &file) { - indexEnv.addOnnxModel(name, file); + void add_onnx(const OnnxModel &model) { + indexEnv.addOnnxModel(model); } void compile(const vespalib::string &seed) { resolver->addSeed(seed); @@ -89,7 +89,7 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]"); - add_onnx("simple", simple_model); + add_onnx(OnnxModel("simple", simple_model)); compile(onnx_feature("simple")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); @@ -101,7 +101,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]"); - add_onnx("dynamic", dynamic_model); + add_onnx(OnnxModel("dynamic", dynamic_model)); compile(onnx_feature("dynamic")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); @@ -112,7 +112,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { add_expr("input_0", "tensor<float>(a[2]):[10,20]"); add_expr("input_1", "tensor<float>(a[2]):[5,10]"); - add_onnx("strange_names", strange_names_model); + add_onnx(OnnxModel("strange_names", strange_names_model)); compile(onnx_feature("strange_names")); auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); @@ -121,4 +121,20 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub); } +TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) { + add_expr("my_first_input", "tensor<float>(a[2]):[10,20]"); + add_expr("my_second_input", "tensor<float>(a[2]):[5,10]"); + add_onnx(OnnxModel("custom_names", strange_names_model) + .input_feature("input:0", "rankingExpression(my_first_input)") + .input_feature("input/1", "rankingExpression(my_second_input)") + .output_name("foo/bar", "my_first_output") + .output_name("-baz:0", "my_second_output")); + compile(onnx_feature("custom_names")); + auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); + auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); + EXPECT_EQ(get(1), expect_add); + EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add); + EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub); +} + GTEST_MAIN_RUN_ALL_TESTS() |