summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/features
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-22 08:28:02 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-22 10:15:54 +0000
commit9a66f21d375e4fa07a96069644615581e55129d5 (patch)
tree06891b9c7f91fd0cf673cd1c540cb0dcafd5e950 /searchlib/src/tests/features
parent804e8057c2eca61ec9bc8985430613e0731922a2 (diff)
handle onnx model config for inputs and outputs
Diffstat (limited to 'searchlib/src/tests/features')
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp26
1 files changed, 21 insertions, 5 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index b49d9c365de..6a1e4ef9fa1 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -58,8 +58,8 @@ struct OnnxFeatureTest : ::testing::Test {
vespalib::string expr_name = feature_name + ".rankingScript";
indexEnv.getProperties().add(expr_name, expr);
}
- void add_onnx(const vespalib::string &name, const vespalib::string &file) {
- indexEnv.addOnnxModel(name, file);
+ void add_onnx(const OnnxModel &model) {
+ indexEnv.addOnnxModel(model);
}
void compile(const vespalib::string &seed) {
resolver->addSeed(seed);
@@ -89,7 +89,7 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]");
- add_onnx("simple", simple_model);
+ add_onnx(OnnxModel("simple", simple_model));
compile(onnx_feature("simple"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
@@ -101,7 +101,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]");
- add_onnx("dynamic", dynamic_model);
+ add_onnx(OnnxModel("dynamic", dynamic_model));
compile(onnx_feature("dynamic"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
@@ -112,7 +112,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
add_expr("input_0", "tensor<float>(a[2]):[10,20]");
add_expr("input_1", "tensor<float>(a[2]):[5,10]");
- add_onnx("strange_names", strange_names_model);
+ add_onnx(OnnxModel("strange_names", strange_names_model));
compile(onnx_feature("strange_names"));
auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
@@ -121,4 +121,20 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub);
}
+TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) {
+ add_expr("my_first_input", "tensor<float>(a[2]):[10,20]");
+ add_expr("my_second_input", "tensor<float>(a[2]):[5,10]");
+ add_onnx(OnnxModel("custom_names", strange_names_model)
+ .input_feature("input:0", "rankingExpression(my_first_input)")
+ .input_feature("input/1", "rankingExpression(my_second_input)")
+ .output_name("foo/bar", "my_first_output")
+ .output_name("-baz:0", "my_second_output"));
+ compile(onnx_feature("custom_names"));
+ auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
+ auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
+ EXPECT_EQ(get(1), expect_add);
+ EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add);
+ EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()