diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-06-23 13:53:28 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-06-23 14:57:27 +0000 |
commit | 6558fa641b2b762f710c02448c887c40e60b1d18 (patch) | |
tree | be7c65746de5bfb5dfc9d752843fc5d3dc6d74db /searchlib/src/tests | |
parent | 16a9339a6cfb78bb5177a80fc7463a2bcd994c9a (diff) |
dry run onnx models on setup
Diffstat (limited to 'searchlib/src/tests')
4 files changed, 78 insertions, 3 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/fragile.onnx b/searchlib/src/tests/features/onnx_feature/fragile.onnx new file mode 100644 index 00000000000..2a05500e95b --- /dev/null +++ b/searchlib/src/tests/features/onnx_feature/fragile.onnx @@ -0,0 +1,15 @@ + +fragile.py:b + +in1 +in2out"AddfragileZ +in1 + + +Z +in2 +
+batchb +out +
+batchB
\ No newline at end of file diff --git a/searchlib/src/tests/features/onnx_feature/fragile.py b/searchlib/src/tests/features/onnx_feature/fragile.py new file mode 100755 index 00000000000..e4eaf168e14 --- /dev/null +++ b/searchlib/src/tests/features/onnx_feature/fragile.py @@ -0,0 +1,30 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +INPUT1 = helper.make_tensor_value_info('in1', TensorProto.FLOAT, [2]) +INPUT2 = helper.make_tensor_value_info('in2', TensorProto.FLOAT, ['batch']) + +OUTPUT = helper.make_tensor_value_info('out', TensorProto.FLOAT, ['batch']) + +nodes = [ + helper.make_node( + 'Add', + ['in1', 'in2'], + ['out'], + ), +] +graph_def = helper.make_graph( + nodes, + 'fragile', + [ + INPUT1, + INPUT2, + ], + [ + OUTPUT, + ], +) +model_def = helper.make_model(graph_def, producer_name='fragile.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'fragile.onnx') diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index 6a1e4ef9fa1..c07ebc48604 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -28,6 +28,7 @@ std::string vespa_dir = source_dir + "/" + "../../../../.."; std::string simple_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/simple.onnx"; std::string dynamic_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/dynamic.onnx"; std::string strange_names_model = source_dir + "/" + "strange_names.onnx"; +std::string fragile_model = source_dir + "/" + "fragile.onnx"; uint32_t default_docid = 1; @@ -61,13 +62,19 @@ struct OnnxFeatureTest : ::testing::Test { void add_onnx(const OnnxModel &model) { indexEnv.addOnnxModel(model); } - void compile(const vespalib::string &seed) { + bool try_compile(const vespalib::string &seed) { resolver->addSeed(seed); - ASSERT_TRUE(resolver->compile()); + if (!resolver->compile()) { + return false; + } MatchDataLayout mdl; QueryEnvironment queryEnv(&indexEnv); match_data = mdl.createMatchData(); program.setup(*match_data, queryEnv, overrides); + return true; + } + void compile(const vespalib::string &seed) { + ASSERT_TRUE(try_compile(seed)); } TensorSpec get(const vespalib::string &feature, uint32_t docid) { auto result = program.get_all_features(false); @@ -137,4 +144,27 @@ TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) { EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub); } +TEST_F(OnnxFeatureTest, fragile_model_can_be_evaluated) { + add_expr("in1", "tensor<float>(x[2]):[docid,5]"); + add_expr("in2", "tensor<float>(x[2]):[docid,10]"); + add_onnx(OnnxModel("fragile", fragile_model)); + EXPECT_TRUE(try_compile(onnx_feature("fragile"))); + EXPECT_EQ(get(1), TensorSpec::from_expr("tensor<float>(d0[2]):[2,15]")); + EXPECT_EQ(get(3), TensorSpec::from_expr("tensor<float>(d0[2]):[6,15]")); +} + +TEST_F(OnnxFeatureTest, runtime_broken_model_can_be_set_up_without_dry_run) { + add_expr("in1", "tensor<float>(x[2]):[docid,5]"); + add_expr("in2", "tensor<float>(x[3]):[docid,10,31515]"); + add_onnx(OnnxModel("fragile", fragile_model).dry_run_on_setup(false)); + EXPECT_TRUE(try_compile(onnx_feature("fragile"))); +} + +TEST_F(OnnxFeatureTest, runtime_broken_model_fails_with_dry_run) { + add_expr("in1", "tensor<float>(x[2]):[docid,5]"); + add_expr("in2", "tensor<float>(x[3]):[docid,10,31515]"); + add_onnx(OnnxModel("fragile", fragile_model).dry_run_on_setup(true)); + EXPECT_FALSE(try_compile(onnx_feature("fragile"))); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/features/onnx_feature/strange_names.py b/searchlib/src/tests/features/onnx_feature/strange_names.py index 681da641264..2e3d3fe4dd1 100755 --- a/searchlib/src/tests/features/onnx_feature/strange_names.py +++ b/searchlib/src/tests/features/onnx_feature/strange_names.py @@ -32,5 +32,5 @@ graph_def = helper.make_graph( OUTPUT2, ], ) -model_def = helper.make_model(graph_def, producer_name='strange_names.py') +model_def = helper.make_model(graph_def, producer_name='strange_names.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) onnx.save(model_def, 'strange_names.onnx') |