summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-06-23 13:53:28 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-06-23 14:57:27 +0000
commit6558fa641b2b762f710c02448c887c40e60b1d18 (patch)
treebe7c65746de5bfb5dfc9d752843fc5d3dc6d74db /searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
parent16a9339a6cfb78bb5177a80fc7463a2bcd994c9a (diff)
dry run onnx models on setup
Diffstat (limited to 'searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp')
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp34
1 files changed, 32 insertions, 2 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index 6a1e4ef9fa1..c07ebc48604 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -28,6 +28,7 @@ std::string vespa_dir = source_dir + "/" + "../../../../..";
std::string simple_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/simple.onnx";
std::string dynamic_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/dynamic.onnx";
std::string strange_names_model = source_dir + "/" + "strange_names.onnx";
+std::string fragile_model = source_dir + "/" + "fragile.onnx";
uint32_t default_docid = 1;
@@ -61,13 +62,19 @@ struct OnnxFeatureTest : ::testing::Test {
void add_onnx(const OnnxModel &model) {
indexEnv.addOnnxModel(model);
}
- void compile(const vespalib::string &seed) {
+ bool try_compile(const vespalib::string &seed) {
resolver->addSeed(seed);
- ASSERT_TRUE(resolver->compile());
+ if (!resolver->compile()) {
+ return false;
+ }
MatchDataLayout mdl;
QueryEnvironment queryEnv(&indexEnv);
match_data = mdl.createMatchData();
program.setup(*match_data, queryEnv, overrides);
+ return true;
+ }
+ void compile(const vespalib::string &seed) {
+ ASSERT_TRUE(try_compile(seed));
}
TensorSpec get(const vespalib::string &feature, uint32_t docid) {
auto result = program.get_all_features(false);
@@ -137,4 +144,27 @@ TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) {
EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub);
}
+TEST_F(OnnxFeatureTest, fragile_model_can_be_evaluated) {
+ add_expr("in1", "tensor<float>(x[2]):[docid,5]");
+ add_expr("in2", "tensor<float>(x[2]):[docid,10]");
+ add_onnx(OnnxModel("fragile", fragile_model));
+ EXPECT_TRUE(try_compile(onnx_feature("fragile")));
+ EXPECT_EQ(get(1), TensorSpec::from_expr("tensor<float>(d0[2]):[2,15]"));
+ EXPECT_EQ(get(3), TensorSpec::from_expr("tensor<float>(d0[2]):[6,15]"));
+}
+
+TEST_F(OnnxFeatureTest, runtime_broken_model_can_be_set_up_without_dry_run) {
+ add_expr("in1", "tensor<float>(x[2]):[docid,5]");
+ add_expr("in2", "tensor<float>(x[3]):[docid,10,31515]");
+ add_onnx(OnnxModel("fragile", fragile_model).dry_run_on_setup(false));
+ EXPECT_TRUE(try_compile(onnx_feature("fragile")));
+}
+
+TEST_F(OnnxFeatureTest, runtime_broken_model_fails_with_dry_run) {
+ add_expr("in1", "tensor<float>(x[2]):[docid,5]");
+ add_expr("in2", "tensor<float>(x[3]):[docid,10,31515]");
+ add_onnx(OnnxModel("fragile", fragile_model).dry_run_on_setup(true));
+ EXPECT_FALSE(try_compile(onnx_feature("fragile")));
+}
+
GTEST_MAIN_RUN_ALL_TESTS()