diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-11-02 09:49:03 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-11-02 09:49:03 +0000 |
commit | ea3b0172daf1538863aa5cc5eec701cad18dd6db (patch) | |
tree | 49a1e41856304012bd09466a91397083ff67e399 /searchlib | |
parent | bcf0896ee427cb584fc4af466e3cdb94eb06c073 (diff) |
added new rank feature "onnx(model_name)"
Works the same as "onnxModel(model_name)", but is not treated as the
exact same feature (features are currently not allowed to have
multiple base names). If both variants are used at the same time, the
model may be calculated twice, but the model cache will still make
sure that the model itself is only loaded once.
The plan is to deprecate and possibly remove the
"onnxModel(model_name)" variant at some point in the future.
Diffstat (limited to 'searchlib')
4 files changed, 30 insertions, 11 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index 5017df69192..54b574abbf3 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -39,6 +39,10 @@ vespalib::string expr_feature(const vespalib::string &name) { } vespalib::string onnx_feature(const vespalib::string &name) { + return fmt("onnx(%s)", name.c_str()); +} + +vespalib::string onnx_feature_old(const vespalib::string &name) { return fmt("onnxModel(%s)", name.c_str()); } @@ -54,7 +58,8 @@ struct OnnxFeatureTest : ::testing::Test { { factory.addPrototype(std::make_shared<DocidBlueprint>()); factory.addPrototype(std::make_shared<RankingExpressionBlueprint>()); - factory.addPrototype(std::make_shared<OnnxBlueprint>()); + factory.addPrototype(std::make_shared<OnnxBlueprint>("onnx")); + factory.addPrototype(std::make_shared<OnnxBlueprint>("onnxModel")); } ~OnnxFeatureTest(); void add_expr(const vespalib::string &name, const vespalib::string &expr) { @@ -104,6 +109,18 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { add_onnx(OnnxModel("simple", simple_model)); compile(onnx_feature("simple")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get("onnx(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0)); + EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); +} + +TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated_with_old_name) { + add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); + add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); + add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]"); + add_onnx(OnnxModel("simple", simple_model)); + compile(onnx_feature_old("simple")); + EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0)); EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); @@ -116,7 +133,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { add_onnx(OnnxModel("dynamic", dynamic_model)); compile(onnx_feature("dynamic")); EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); - EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get("onnx(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0)); EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); } @@ -129,8 +146,8 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) { auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); EXPECT_EQ(get(1), expect_add); - EXPECT_EQ(get("onnxModel(strange_names).foo_bar", 1), expect_add); - EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub); + EXPECT_EQ(get("onnx(strange_names).foo_bar", 1), expect_add); + EXPECT_EQ(get("onnx(strange_names)._baz_0", 1), expect_sub); } TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) { @@ -145,8 +162,8 @@ TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) { auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30); auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10); EXPECT_EQ(get(1), expect_add); - EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add); - EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub); + EXPECT_EQ(get("onnx(custom_names).my_first_output", 1), expect_add); + EXPECT_EQ(get("onnx(custom_names).my_second_output", 1), expect_sub); } TEST_F(OnnxFeatureTest, fragile_model_can_be_evaluated) { diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index dd0215e1d53..c28abe10b4a 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -100,13 +100,14 @@ public: } }; -OnnxBlueprint::OnnxBlueprint() - : Blueprint("onnxModel"), +OnnxBlueprint::OnnxBlueprint(vespalib::stringref baseName) + : Blueprint(baseName), _cache_token(), _debug_model(), _model(nullptr), _wire_info() { + assert((baseName == "onnx") || (baseName == "onnxModel")); } OnnxBlueprint::~OnnxBlueprint() = default; diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h index 63a260b8427..ebbc22d4eb2 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.h +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h @@ -20,11 +20,11 @@ private: const Onnx *_model; Onnx::WireInfo _wire_info; public: - OnnxBlueprint(); + OnnxBlueprint(vespalib::stringref baseName); ~OnnxBlueprint() override; void visitDumpFeatures(const fef::IIndexEnvironment &, fef::IDumpFeatureVisitor &) const override {} fef::Blueprint::UP createInstance() const override { - return std::make_unique<OnnxBlueprint>(); + return std::make_unique<OnnxBlueprint>(getBaseName()); } fef::ParameterDescriptions getDescriptions() const override { return fef::ParameterDescriptions().desc().string(); diff --git a/searchlib/src/vespa/searchlib/features/setup.cpp b/searchlib/src/vespa/searchlib/features/setup.cpp index af232238366..f2d5bd745ac 100644 --- a/searchlib/src/vespa/searchlib/features/setup.cpp +++ b/searchlib/src/vespa/searchlib/features/setup.cpp @@ -124,7 +124,8 @@ void setup_search_features(fef::IBlueprintRegistry & registry) registry.addPrototype(std::make_shared<TermFieldMdBlueprint>()); registry.addPrototype(std::make_shared<ConstantBlueprint>()); registry.addPrototype(std::make_shared<GlobalSequenceBlueprint>()); - registry.addPrototype(std::make_shared<OnnxBlueprint>()); + registry.addPrototype(std::make_shared<OnnxBlueprint>("onnx")); + registry.addPrototype(std::make_shared<OnnxBlueprint>("onnxModel")); // Ranking Expression auto replacers = std::make_unique<ListExpressionReplacer>(); |