summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-11-02 09:49:03 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-11-02 09:49:03 +0000
commitea3b0172daf1538863aa5cc5eec701cad18dd6db (patch)
tree49a1e41856304012bd09466a91397083ff67e399 /searchlib
parentbcf0896ee427cb584fc4af466e3cdb94eb06c073 (diff)
added new rank feature "onnx(model_name)"
Works the same as "onnxModel(model_name)", but is not treated as the exact same feature (features are currently not allowed to have multiple base names). If both variants are used at the same time, the model may be calculated twice, but the model cache will still make sure that the model itself is only loaded once. The plan is to deprecate and possibly remove the "onnxModel(model_name)" variant at some point in the future.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp29
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp5
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.h4
-rw-r--r--searchlib/src/vespa/searchlib/features/setup.cpp3
4 files changed, 30 insertions, 11 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index 5017df69192..54b574abbf3 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -39,6 +39,10 @@ vespalib::string expr_feature(const vespalib::string &name) {
}
vespalib::string onnx_feature(const vespalib::string &name) {
+ return fmt("onnx(%s)", name.c_str());
+}
+
+vespalib::string onnx_feature_old(const vespalib::string &name) {
return fmt("onnxModel(%s)", name.c_str());
}
@@ -54,7 +58,8 @@ struct OnnxFeatureTest : ::testing::Test {
{
factory.addPrototype(std::make_shared<DocidBlueprint>());
factory.addPrototype(std::make_shared<RankingExpressionBlueprint>());
- factory.addPrototype(std::make_shared<OnnxBlueprint>());
+ factory.addPrototype(std::make_shared<OnnxBlueprint>("onnx"));
+ factory.addPrototype(std::make_shared<OnnxBlueprint>("onnxModel"));
}
~OnnxFeatureTest();
void add_expr(const vespalib::string &name, const vespalib::string &expr) {
@@ -104,6 +109,18 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
add_onnx(OnnxModel("simple", simple_model));
compile(onnx_feature("simple"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get("onnx(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0));
+ EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
+}
+
+TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated_with_old_name) {
+ add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
+ add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
+ add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]");
+ add_onnx(OnnxModel("simple", simple_model));
+ compile(onnx_feature_old("simple"));
+ EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0));
EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
@@ -116,7 +133,7 @@ TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
add_onnx(OnnxModel("dynamic", dynamic_model));
compile(onnx_feature("dynamic"));
EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
- EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get("onnx(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0));
EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
}
@@ -129,8 +146,8 @@ TEST_F(OnnxFeatureTest, strange_input_and_output_names_are_normalized) {
auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
EXPECT_EQ(get(1), expect_add);
- EXPECT_EQ(get("onnxModel(strange_names).foo_bar", 1), expect_add);
- EXPECT_EQ(get("onnxModel(strange_names)._baz_0", 1), expect_sub);
+ EXPECT_EQ(get("onnx(strange_names).foo_bar", 1), expect_add);
+ EXPECT_EQ(get("onnx(strange_names)._baz_0", 1), expect_sub);
}
TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) {
@@ -145,8 +162,8 @@ TEST_F(OnnxFeatureTest, input_features_and_output_names_can_be_specified) {
auto expect_add = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},15).add({{"d0",1}},30);
auto expect_sub = TensorSpec("tensor<float>(d0[2])").add({{"d0",0}},5).add({{"d0",1}},10);
EXPECT_EQ(get(1), expect_add);
- EXPECT_EQ(get("onnxModel(custom_names).my_first_output", 1), expect_add);
- EXPECT_EQ(get("onnxModel(custom_names).my_second_output", 1), expect_sub);
+ EXPECT_EQ(get("onnx(custom_names).my_first_output", 1), expect_add);
+ EXPECT_EQ(get("onnx(custom_names).my_second_output", 1), expect_sub);
}
TEST_F(OnnxFeatureTest, fragile_model_can_be_evaluated) {
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index dd0215e1d53..c28abe10b4a 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -100,13 +100,14 @@ public:
}
};
-OnnxBlueprint::OnnxBlueprint()
- : Blueprint("onnxModel"),
+OnnxBlueprint::OnnxBlueprint(vespalib::stringref baseName)
+ : Blueprint(baseName),
_cache_token(),
_debug_model(),
_model(nullptr),
_wire_info()
{
+ assert((baseName == "onnx") || (baseName == "onnxModel"));
}
OnnxBlueprint::~OnnxBlueprint() = default;
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h
index 63a260b8427..ebbc22d4eb2 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.h
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h
@@ -20,11 +20,11 @@ private:
const Onnx *_model;
Onnx::WireInfo _wire_info;
public:
- OnnxBlueprint();
+ OnnxBlueprint(vespalib::stringref baseName);
~OnnxBlueprint() override;
void visitDumpFeatures(const fef::IIndexEnvironment &, fef::IDumpFeatureVisitor &) const override {}
fef::Blueprint::UP createInstance() const override {
- return std::make_unique<OnnxBlueprint>();
+ return std::make_unique<OnnxBlueprint>(getBaseName());
}
fef::ParameterDescriptions getDescriptions() const override {
return fef::ParameterDescriptions().desc().string();
diff --git a/searchlib/src/vespa/searchlib/features/setup.cpp b/searchlib/src/vespa/searchlib/features/setup.cpp
index af232238366..f2d5bd745ac 100644
--- a/searchlib/src/vespa/searchlib/features/setup.cpp
+++ b/searchlib/src/vespa/searchlib/features/setup.cpp
@@ -124,7 +124,8 @@ void setup_search_features(fef::IBlueprintRegistry & registry)
registry.addPrototype(std::make_shared<TermFieldMdBlueprint>());
registry.addPrototype(std::make_shared<ConstantBlueprint>());
registry.addPrototype(std::make_shared<GlobalSequenceBlueprint>());
- registry.addPrototype(std::make_shared<OnnxBlueprint>());
+ registry.addPrototype(std::make_shared<OnnxBlueprint>("onnx"));
+ registry.addPrototype(std::make_shared<OnnxBlueprint>("onnxModel"));
// Ranking Expression
auto replacers = std::make_unique<ListExpressionReplacer>();