diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-11-19 18:12:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-19 18:12:52 +0100 |
commit | a8c482c3c85f55f45f805781522fe70dd7a6a532 (patch) | |
tree | 04c28080034a78a2f1ee0b63736dc7ccbf45ff8a | |
parent | 8cac1747efdcb3a76b27c13985486b08c40e95d8 (diff) | |
parent | 9f2c0a443e61d0790b1b87b7126fcfe6c7d6e951 (diff) |
Merge pull request #11352 from vespa-engine/geirst/nearest-neighbor-term-skeleton-cpp
Add skeleton for NearestNeighborTerm in C++.
28 files changed, 214 insertions, 90 deletions
diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp index 03f20876adc..a80b08badf8 100644 --- a/searchcore/src/tests/proton/matching/query_test.cpp +++ b/searchcore/src/tests/proton/matching/query_test.cpp @@ -272,6 +272,7 @@ public: void visit(ProtonWandTerm &) override {} void visit(ProtonPredicateQuery &) override {} void visit(ProtonRegExpTerm &) override {} + void visit(ProtonNearestNeighborTerm &) override {} }; void Test::requireThatTermsAreLookedUp() { @@ -423,6 +424,7 @@ public: void visit(ProtonWandTerm &) override {} void visit(ProtonPredicateQuery &) override {} void visit(ProtonRegExpTerm &) override {} + void visit(ProtonNearestNeighborTerm &) override {} }; void Test::requireThatTermDataIsFilledIn() { diff --git a/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp b/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp index c0418d82359..9ecbd532389 100644 --- a/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp +++ b/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp @@ -65,6 +65,7 @@ struct DumpQuery : QueryVisitor { void visit(WandTerm &) override {} void visit(PredicateQuery &) override {} void visit(RegExpTerm &) override {} + void visit(NearestNeighborTerm &) override {} }; std::string dump_query(Node &root) { diff --git a/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp b/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp index 7e55c8f778c..c8c5a3a427b 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp @@ -159,6 +159,7 @@ protected: void visit(ProtonSuffixTerm &n) override { buildTerm(n); } void visit(ProtonPredicateQuery &n) override { buildTerm(n); } void visit(ProtonRegExpTerm &n) override { buildTerm(n); } + void visit(ProtonNearestNeighborTerm &n) override { buildTerm(n); } public: BlueprintBuilderVisitor(const IRequestContext & requestContext, ISearchContext &context) : diff --git a/searchcore/src/vespa/searchcore/proton/matching/querynodes.h b/searchcore/src/vespa/searchcore/proton/matching/querynodes.h index 6454845b247..d7ec24edb8f 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/querynodes.h +++ b/searchcore/src/vespa/searchcore/proton/matching/querynodes.h @@ -137,6 +137,7 @@ typedef ProtonTerm<search::query::DotProduct> ProtonDotProduct; typedef ProtonTerm<search::query::WandTerm> ProtonWandTerm; typedef ProtonTerm<search::query::PredicateQuery> ProtonPredicateQuery; typedef ProtonTerm<search::query::RegExpTerm> ProtonRegExpTerm; +typedef ProtonTerm<search::query::NearestNeighborTerm> ProtonNearestNeighborTerm; struct ProtonNodeTypes { typedef ProtonAnd And; @@ -161,6 +162,7 @@ struct ProtonNodeTypes { typedef ProtonWandTerm WandTerm; typedef ProtonPredicateQuery PredicateQuery; typedef ProtonRegExpTerm RegExpTerm; + typedef ProtonNearestNeighborTerm NearestNeighborTerm; }; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp b/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp index 241ab53874f..1e5df97a659 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp @@ -67,6 +67,7 @@ public: void visit(ProtonSuffixTerm &n) override { visitTerm(n); } void visit(ProtonPredicateQuery &) override {} void visit(ProtonRegExpTerm &n) override { visitTerm(n); } + void visit(ProtonNearestNeighborTerm &) override {} }; } // namespace proton::matching::<unnamed> diff --git a/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp b/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp index 3fd4000bf9f..3184b5cc061 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp @@ -42,6 +42,7 @@ struct TermDataFromTermVisitor void visit(ProtonSuffixTerm &n) override { visitTerm(n); } void visit(ProtonPredicateQuery &) override { } void visit(ProtonRegExpTerm &n) override { visitTerm(n); } + void visit(ProtonNearestNeighborTerm &n) override { visitTerm(n); } }; } // namespace diff --git a/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp b/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp index af355452c73..eada88010dd 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp @@ -56,6 +56,7 @@ struct TermExpander : QueryVisitor { void visit(WandTerm &) override {} void visit(PredicateQuery &) override {} void visit(RegExpTerm &) override {} + void visit(NearestNeighborTerm &) override {} void flush(Intermediate &parent) { for (Node::UP &term: terms) { parent.append(std::move(term)); diff --git a/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp b/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp index b71927c714f..12128a3df18 100644 --- a/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp +++ b/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp @@ -189,7 +189,6 @@ private: void visit(ONear &) override { } void visit(SameElement &) override { } - void visit(WeightedSetTerm &n) override { visitTerm(n); } void visit(DotProduct &n) override { visitTerm(n); } void visit(WandTerm &n) override { visitTerm(n); } @@ -203,6 +202,7 @@ private: void visit(SuffixTerm &n) override { visitTerm(n); } void visit(PredicateQuery &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } + void visit(NearestNeighborTerm &n) override { visitTerm(n); } public: CreateBlueprintVisitor(const IIndexCollection &indexes, diff --git a/searchlib/src/tests/query/customtypevisitor_test.cpp b/searchlib/src/tests/query/customtypevisitor_test.cpp index c5eeac8543d..3f7d57b7aa4 100644 --- a/searchlib/src/tests/query/customtypevisitor_test.cpp +++ b/searchlib/src/tests/query/customtypevisitor_test.cpp @@ -54,6 +54,7 @@ struct MyDotProduct : DotProduct { MyDotProduct() : DotProduct("view", 0, Weight struct MyWandTerm : WandTerm { MyWandTerm() : WandTerm("view", 0, Weight(42), 57, 67, 77.7) {} }; struct MyPredicateQuery : InitTerm<PredicateQuery> {}; struct MyRegExpTerm : InitTerm<RegExpTerm> {}; +struct MyNearestNeighborTerm : NearestNeighborTerm {}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -78,6 +79,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; class MyCustomVisitor : public CustomTypeVisitor<MyQueryNodeTypes> @@ -113,6 +115,7 @@ public: void visit(MyWandTerm &) override { setVisited<MyWandTerm>(); } void visit(MyPredicateQuery &) override { setVisited<MyPredicateQuery>(); } void visit(MyRegExpTerm &) override { setVisited<MyRegExpTerm>(); } + void visit(MyNearestNeighborTerm &) override { setVisited<MyNearestNeighborTerm>(); } }; template <class T> diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index f8922c54a4e..39e381c0942 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -65,6 +65,7 @@ public: void visit(WandTerm &) override { isVisited<WandTerm>() = true; } void visit(PredicateQuery &) override { isVisited<PredicateQuery>() = true; } void visit(RegExpTerm &) override { isVisited<RegExpTerm>() = true; } + void visit(NearestNeighborTerm &) override { isVisited<NearestNeighborTerm>() = true; } }; template <class T> @@ -98,6 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 6673a107d44..d47fb071d81 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -48,7 +48,7 @@ PredicateQueryTerm::UP getPredicateQueryTerm() { template <class NodeTypes> Node::UP createQueryTree() { QueryBuilder<NodeTypes> builder; - builder.addAnd(10); + builder.addAnd(11); { builder.addRank(2); { @@ -111,6 +111,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -140,6 +141,15 @@ bool checkTerm(const Term *term, const typename Term::Type &t, const string &f, EXPECT_EQUAL(use_position_data, term->usePositionData())); } +template <class NodeType> +NodeType* +as_node(Node* node) +{ + auto* result = dynamic_cast<NodeType*>(node); + ASSERT_TRUE(result != nullptr); + return result; +} + template <class NodeTypes> void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::And And; @@ -166,126 +176,114 @@ void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::RegExpTerm RegExpTerm; ASSERT_TRUE(node); - And *and_node = dynamic_cast<And *>(node); - ASSERT_TRUE(and_node); - EXPECT_EQUAL(10u, and_node->getChildren().size()); + auto* and_node = as_node<And>(node); + EXPECT_EQUAL(11u, and_node->getChildren().size()); - - Rank *rank = dynamic_cast<Rank *>(and_node->getChildren()[0]); - ASSERT_TRUE(rank); + auto* rank = as_node<Rank>(and_node->getChildren()[0]); EXPECT_EQUAL(2u, rank->getChildren().size()); - Near *near = dynamic_cast<Near *>(rank->getChildren()[0]); - ASSERT_TRUE(near); + auto* near = as_node<Near>(rank->getChildren()[0]); EXPECT_EQUAL(2u, near->getChildren().size()); EXPECT_EQUAL(distance, near->getDistance()); - StringTerm *string_term = dynamic_cast<StringTerm *>(near->getChildren()[0]); + auto* string_term = as_node<StringTerm>(near->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[0], view[0], id[0], weight[0])); - SubstringTerm *substring_term = dynamic_cast<SubstringTerm *>(near->getChildren()[1]); + auto* substring_term = as_node<SubstringTerm>(near->getChildren()[1]); EXPECT_TRUE(checkTerm(substring_term, str[1], view[1], id[1], weight[1])); - ONear *onear = dynamic_cast<ONear *>(rank->getChildren()[1]); - ASSERT_TRUE(onear); + auto* onear = as_node<ONear>(rank->getChildren()[1]); EXPECT_EQUAL(2u, onear->getChildren().size()); EXPECT_EQUAL(distance, onear->getDistance()); - SuffixTerm *suffix_term = dynamic_cast<SuffixTerm *>(onear->getChildren()[0]); + auto* suffix_term = as_node<SuffixTerm>(onear->getChildren()[0]); EXPECT_TRUE(checkTerm(suffix_term, str[2], view[2], id[2], weight[2])); - PrefixTerm *prefix_term = dynamic_cast<PrefixTerm *>(onear->getChildren()[1]); + auto* prefix_term = as_node<PrefixTerm>(onear->getChildren()[1]); EXPECT_TRUE(checkTerm(prefix_term, str[3], view[3], id[3], weight[3])); - - Or *or_node = dynamic_cast<Or *>(and_node->getChildren()[1]); - ASSERT_TRUE(or_node); + auto* or_node = as_node<Or>(and_node->getChildren()[1]); EXPECT_EQUAL(3u, or_node->getChildren().size()); - Phrase *phrase = dynamic_cast<Phrase *>(or_node->getChildren()[0]); - ASSERT_TRUE(phrase); + auto* phrase = as_node<Phrase>(or_node->getChildren()[0]); EXPECT_TRUE(phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(3u, phrase->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]); + string_term = as_node<StringTerm>(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]); + string_term = as_node<StringTerm>(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[2]); + string_term = as_node<StringTerm>(phrase->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[4])); - phrase = dynamic_cast<Phrase *>(or_node->getChildren()[1]); - ASSERT_TRUE(phrase); + phrase = as_node<Phrase>(or_node->getChildren()[1]); EXPECT_TRUE(!phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(2u, phrase->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]); + string_term = as_node<StringTerm>(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]); + string_term = as_node<StringTerm>(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - AndNot *and_not = dynamic_cast<AndNot *>(or_node->getChildren()[2]); - ASSERT_TRUE(and_not); + auto* and_not = as_node<AndNot>(or_node->getChildren()[2]); EXPECT_EQUAL(2u, and_not->getChildren().size()); - NumberTerm *integer_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[0]); + auto* integer_term = as_node<NumberTerm>(and_not->getChildren()[0]); EXPECT_TRUE(checkTerm(integer_term, int1, view[7], id[7], weight[7])); - NumberTerm *float_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[1]); + auto* float_term = as_node<NumberTerm>(and_not->getChildren()[1]); EXPECT_TRUE(checkTerm(float_term, float1, view[8], id[8], weight[8], false)); - - RangeTerm *range_term = dynamic_cast<RangeTerm *>(and_node->getChildren()[2]); - ASSERT_TRUE(range_term); + auto* range_term = as_node<RangeTerm>(and_node->getChildren()[2]); EXPECT_TRUE(checkTerm(range_term, range, view[9], id[9], weight[9])); - LocationTerm *loc_term = dynamic_cast<LocationTerm *>(and_node->getChildren()[3]); - ASSERT_TRUE(loc_term); + auto* loc_term = as_node<LocationTerm>(and_node->getChildren()[3]); EXPECT_TRUE(checkTerm(loc_term, location, view[10], id[10], weight[10])); - WeakAnd *wand = dynamic_cast<WeakAnd *>(and_node->getChildren()[4]); - ASSERT_TRUE(wand != 0); + auto* wand = as_node<WeakAnd>(and_node->getChildren()[4]); EXPECT_EQUAL(123u, wand->getMinHits()); EXPECT_EQUAL(2u, wand->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(wand->getChildren()[0]); + string_term = as_node<StringTerm>(wand->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(wand->getChildren()[1]); + string_term = as_node<StringTerm>(wand->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - PredicateQuery *predicateQuery = dynamic_cast<PredicateQuery *>(and_node->getChildren()[5]); - ASSERT_TRUE(predicateQuery); + auto* predicateQuery = as_node<PredicateQuery>(and_node->getChildren()[5]); PredicateQueryTerm::UP pqt(new PredicateQueryTerm); EXPECT_TRUE(checkTerm(predicateQuery, getPredicateQueryTerm(), view[3], id[3], weight[3])); - DotProduct *dotProduct = dynamic_cast<DotProduct *>(and_node->getChildren()[6]); - ASSERT_TRUE(dotProduct); + auto* dotProduct = as_node<DotProduct>(and_node->getChildren()[6]); EXPECT_EQUAL(3u, dotProduct->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[0]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[3], view[3], id[3], weight[3])); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[1]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[2]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - WandTerm *wandTerm = dynamic_cast<WandTerm *>(and_node->getChildren()[7]); - ASSERT_TRUE(wandTerm); + auto* wandTerm = as_node<WandTerm>(and_node->getChildren()[7]); EXPECT_EQUAL(57u, wandTerm->getTargetNumHits()); EXPECT_EQUAL(67, wandTerm->getScoreThreshold()); EXPECT_EQUAL(77.7, wandTerm->getThresholdBoostFactor()); EXPECT_EQUAL(2u, wandTerm->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[0]); + string_term = as_node<StringTerm>(wandTerm->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[1], view[1], id[1], weight[1])); - string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[1]); + string_term = as_node<StringTerm>(wandTerm->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[2], view[2], id[2], weight[2])); - RegExpTerm *regexp_term = dynamic_cast<RegExpTerm *>(and_node->getChildren()[8]); + auto* regexp_term = as_node<RegExpTerm>(and_node->getChildren()[8]); EXPECT_TRUE(checkTerm(regexp_term, str[5], view[5], id[5], weight[5])); - SameElement *same = dynamic_cast<SameElement *>(and_node->getChildren()[9]); - ASSERT_TRUE(same != nullptr); + auto* same = as_node<SameElement>(and_node->getChildren()[9]); EXPECT_EQUAL(view[4], same->getView()); EXPECT_EQUAL(3u, same->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[0]); + string_term = as_node<StringTerm>(same->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[5])); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[1]); + string_term = as_node<StringTerm>(same->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[6])); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[2]); + string_term = as_node<StringTerm>(same->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[7])); + auto* nearest_neighbor = as_node<NearestNeighborTerm>(and_node->getChildren()[10]); + EXPECT_EQUAL("query_tensor", nearest_neighbor->get_query_tensor_name()); + EXPECT_EQUAL("doc_tensor", nearest_neighbor->getView()); + EXPECT_EQUAL(id[3], nearest_neighbor->getId()); + EXPECT_EQUAL(weight[5].percent(), nearest_neighbor->getWeight().percent()); + EXPECT_EQUAL(7u, nearest_neighbor->get_target_num_hits()); } struct AbstractTypes { @@ -395,6 +393,12 @@ struct MyRegExpTerm : RegExpTerm { : RegExpTerm(t, f, i, w) { } }; +struct MyNearestNeighborTerm : NearestNeighborTerm { + MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t i, Weight w, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits) + {} +}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -419,6 +423,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; TEST("require that Custom Query Trees Can Be Built") { diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 5261f568673..9c132527abc 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -592,6 +592,11 @@ public: createShallowWeightedSet(bp, n, _field, _attr.isIntegerType()); } } + void visit(query::NearestNeighborTerm &n) override { + (void) n; + // TODO (geirst): implement + setResult(std::make_unique<queryeval::EmptyBlueprint>(_field)); + } }; } // namespace diff --git a/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp b/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp index 7d8bcf032ba..ddcee50c219 100644 --- a/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp +++ b/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp @@ -398,6 +398,8 @@ public: handleNumberTermAsText(n); } + void not_supported(Node &) {} + void visit(LocationTerm &n) override { visitTerm(n); } void visit(PrefixTerm &n) override { visitTerm(n); } void visit(RangeTerm &n) override { visitTerm(n); } @@ -405,7 +407,8 @@ public: void visit(SubstringTerm &n) override { visitTerm(n); } void visit(SuffixTerm &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } - void visit(PredicateQuery &) override { } + void visit(PredicateQuery &n) override { not_supported(n); } + void visit(NearestNeighborTerm &n) override { not_supported(n); } }; Blueprint::UP diff --git a/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp b/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp index d3d3004100c..d8e48e84fb7 100644 --- a/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp +++ b/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp @@ -28,6 +28,7 @@ using index::IndexBuilder; using index::Schema; using index::SchemaUtil; using query::LocationTerm; +using query::NearestNeighborTerm; using query::Node; using query::NumberTerm; using query::PredicateQuery; @@ -163,6 +164,8 @@ public: setResult(fieldIndex->make_term_blueprint(termStr, _field, _fieldId)); } + void not_supported(Node &) {} + void visit(LocationTerm &n) override { visitTerm(n); } void visit(PrefixTerm &n) override { visitTerm(n); } void visit(RangeTerm &n) override { visitTerm(n); } @@ -170,7 +173,8 @@ public: void visit(SubstringTerm &n) override { visitTerm(n); } void visit(SuffixTerm &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } - void visit(PredicateQuery &) override { } + void visit(PredicateQuery &n) override { not_supported(n); } + void visit(NearestNeighborTerm &n) override { not_supported(n); } void visit(NumberTerm &n) override { handleNumberTermAsText(n); diff --git a/searchlib/src/vespa/searchlib/parsequery/parse.h b/searchlib/src/vespa/searchlib/parsequery/parse.h index 9c0e76d2441..83352b571c8 100644 --- a/searchlib/src/vespa/searchlib/parsequery/parse.h +++ b/searchlib/src/vespa/searchlib/parsequery/parse.h @@ -60,7 +60,8 @@ public: ITEM_PREDICATE_QUERY = 23, ITEM_REGEXP = 24, ITEM_WORD_ALTERNATIVES = 25, - ITEM_MAX = 26, // Indicates how long tables must be. + ITEM_NEAREST_NEIGHBOR = 26, + ITEM_MAX = 27, // Indicates how long tables must be. ITEM_UNDEF = 31, }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index a70fe07cf81..70a3097ae05 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -270,6 +270,17 @@ SimpleQueryStackDumpIterator::next() } break; + case ParseItem::ITEM_NEAREST_NEIGHBOR: + try { + _curr_index_name = read_stringref(p); + _curr_term = read_stringref(p); // query_tensor_name + _currArg1 = readCompressedPositiveInt(p); // target_num_hits; + _currArity = 0; + } catch (...) { + return false; + } + break; + default: // Unknown item, so report that no more are available return false; diff --git a/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h b/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h index cdeebcaf9e5..3882bc41b2b 100644 --- a/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h @@ -49,6 +49,7 @@ public: virtual void visit(typename NodeTypes::WandTerm &) = 0; virtual void visit(typename NodeTypes::PredicateQuery &) = 0; virtual void visit(typename NodeTypes::RegExpTerm &) = 0; + virtual void visit(typename NodeTypes::NearestNeighborTerm &) = 0; private: // Route QueryVisit requests to the correct custom type. @@ -75,6 +76,7 @@ private: typedef typename NodeTypes::WandTerm TWandTerm; typedef typename NodeTypes::PredicateQuery TPredicateQuery; typedef typename NodeTypes::RegExpTerm TRegExpTerm; + typedef typename NodeTypes::NearestNeighborTerm TNearestNeighborTerm; void visit(And &n) override { visit(static_cast<TAnd&>(n)); } void visit(AndNot &n) override { visit(static_cast<TAndNot&>(n)); } @@ -98,6 +100,7 @@ private: void visit(WandTerm &n) override { visit(static_cast<TWandTerm&>(n)); } void visit(PredicateQuery &n) override { visit(static_cast<TPredicateQuery&>(n)); } void visit(RegExpTerm &n) override { visit(static_cast<TRegExpTerm&>(n)); } + void visit(NearestNeighborTerm &n) override { visit(static_cast<TNearestNeighborTerm&>(n)); } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index a2ad8eae84b..797defc39f5 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -203,6 +203,13 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id, } template <class NodeTypes> +typename NodeTypes::NearestNeighborTerm * +create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) { + return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits); +} + +template <class NodeTypes> class QueryBuilder : public QueryBuilderBase { template <class T> T &addIntermediate(T *node, int child_count) { @@ -309,6 +316,11 @@ public: adjustWeight(weight); return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight)); } + typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) { + adjustWeight(weight); + return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits)); + } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index e7c3fd8c73b..d2249a53f18 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -163,6 +163,11 @@ private: node.getTerm(), node.getView(), node.getId(), node.getWeight())); } + + void visit(NearestNeighborTerm &node) override { + replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), + node.getId(), node.getWeight(), node.get_target_num_hits())); + } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h b/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h index 0cb56f9127a..533e240e088 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h @@ -26,6 +26,7 @@ class WandTerm; class PredicateQuery; class RegExpTerm; class SameElement; +class NearestNeighborTerm; struct QueryVisitor { virtual ~QueryVisitor() {} @@ -52,6 +53,7 @@ struct QueryVisitor { virtual void visit(WandTerm &) = 0; virtual void visit(PredicateQuery &) = 0; virtual void visit(RegExpTerm &) = 0; + virtual void visit(NearestNeighborTerm &) = 0; }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 707ed2aa0db..8663bede4d6 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -103,31 +103,38 @@ struct SimpleRegExpTerm : RegExpTerm { : RegExpTerm(term, view, id, weight) { } }; +struct SimpleNearestNeighborTerm : NearestNeighborTerm { + SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits) + {} +}; struct SimpleQueryNodeTypes { - typedef SimpleAnd And; - typedef SimpleAndNot AndNot; - typedef SimpleEquiv Equiv; - typedef SimpleNumberTerm NumberTerm; - typedef SimpleLocationTerm LocationTerm; - typedef SimpleNear Near; - typedef SimpleONear ONear; - typedef SimpleOr Or; - typedef SimplePhrase Phrase; - typedef SimpleSameElement SameElement; - typedef SimplePrefixTerm PrefixTerm; - typedef SimpleRangeTerm RangeTerm; - typedef SimpleRank Rank; - typedef SimpleStringTerm StringTerm; - typedef SimpleSubstringTerm SubstringTerm; - typedef SimpleSuffixTerm SuffixTerm; - typedef SimpleWeakAnd WeakAnd; - typedef SimpleWeightedSetTerm WeightedSetTerm; - typedef SimpleDotProduct DotProduct; - typedef SimpleWandTerm WandTerm; - typedef SimplePredicateQuery PredicateQuery; - typedef SimpleRegExpTerm RegExpTerm; + using And = SimpleAnd; + using AndNot = SimpleAndNot; + using Equiv = SimpleEquiv; + using NumberTerm = SimpleNumberTerm; + using LocationTerm = SimpleLocationTerm; + using Near = SimpleNear; + using ONear = SimpleONear; + using Or = SimpleOr; + using Phrase = SimplePhrase; + using SameElement = SimpleSameElement; + using PrefixTerm = SimplePrefixTerm; + using RangeTerm = SimpleRangeTerm; + using Rank = SimpleRank; + using StringTerm = SimpleStringTerm; + using SubstringTerm = SimpleSubstringTerm; + using SuffixTerm = SimpleSuffixTerm; + using WeakAnd = SimpleWeakAnd; + using WeightedSetTerm = SimpleWeightedSetTerm; + using DotProduct = SimpleDotProduct; + using WandTerm = SimpleWandTerm; + using PredicateQuery = SimplePredicateQuery; + using RegExpTerm = SimpleRegExpTerm; + using NearestNeighborTerm = SimpleNearestNeighborTerm; }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 645750b8576..63acf532144 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -196,8 +196,7 @@ class QueryNodeConverter : public QueryVisitor { template <typename T> void appendTerm(const TermBase<T> &node); - template <class Term> - void createTerm(const Term &node, size_t type) { + void createTermNode(const TermNode &node, size_t type) { uint8_t typefield = type | ParseItem::IF_WEIGHT | ParseItem::IF_UNIQUEID; uint8_t flags = 0; if (!node.isRanked()) { @@ -216,6 +215,11 @@ class QueryNodeConverter : public QueryVisitor { appendByte(flags); } appendString(node.getView()); + } + + template <class Term> + void createTerm(const Term &node, size_t type) { + createTermNode(node, type); appendTerm(node); } @@ -255,6 +259,12 @@ class QueryNodeConverter : public QueryVisitor { createTerm(node, ParseItem::ITEM_REGEXP); } + void visit(NearestNeighborTerm &node) override { + createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR); + appendString(node.get_query_tensor_name()); + appendCompressedPositiveNumber(node.get_target_num_hits()); + } + public: QueryNodeConverter() : _buf(4096) diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index dfb0c75a695..a5f25d81400 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -109,6 +109,13 @@ private: pureTermView = vespalib::stringref(); } else if (type == ParseItem::ITEM_NOT) { builder.addAndNot(arity); + } else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) { + vespalib::stringref query_tensor_name = queryStack.getTerm(); + vespalib::stringref field_name = queryStack.getIndexName(); + uint32_t target_num_hits = queryStack.getArg1(); + int32_t id = queryStack.getUniqueId(); + Weight weight = queryStack.GetWeight(); + builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h b/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h index 0cdaca82572..d1abc816838 100644 --- a/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h @@ -31,6 +31,7 @@ class TemplateTermVisitor : public CustomTypeTermVisitor<NodeTypes> { void visit(typename NodeTypes::SuffixTerm &n) override { myVisit(n); } void visit(typename NodeTypes::PredicateQuery &n) override { myVisit(n); } void visit(typename NodeTypes::RegExpTerm &n) override { myVisit(n); } + void visit(typename NodeTypes::NearestNeighborTerm &n) override { myVisit(n); } // Phrases are terms with children. This visitor will not visit // the phrase's children, unless this member function is diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index 35c23dde985..a82b1e14d76 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -113,5 +113,33 @@ public: virtual ~RegExpTerm() = 0; }; +/** + * Term matching the K nearest neighbors in a multi-dimensional vector space. + * + * The query point is specified as a dense tensor of order 1. + * This is found in fef::IQueryEnvironment using the query tensor name as key. + * The field name is the name of a dense document tensor of order 1. + * Both tensors are validated to have the same tensor type before the query is sent to the backend. + * + * Target num hits (K) is a hint to how many neighbors to return. + * The actual returned number might be higher (or lower if the query returns fewer hits). + */ +class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode> { +private: + vespalib::string _query_tensor_name; + uint32_t _target_num_hits; + +public: + NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) + : QueryNodeMixinType(field_name, id, weight), + _query_tensor_name(query_tensor_name), + _target_num_hits(target_num_hits) + {} + virtual ~NearestNeighborTerm() {} + const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } + uint32_t get_target_num_hits() const { return _target_num_hits; } +}; + } diff --git a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h index 84830111fde..4fd8f64cc99 100644 --- a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h +++ b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h @@ -41,6 +41,7 @@ public: void visitWeightedSetTerm(query::WeightedSetTerm &n); void visitDotProduct(query::DotProduct &n); void visitWandTerm(query::WandTerm &n); + void visitNearestNeighborTerm(query::NearestNeighborTerm &n); void handleNumberTermAsText(query::NumberTerm &n); @@ -71,6 +72,7 @@ public: void visit(query::SubstringTerm &n) override = 0; void visit(query::SuffixTerm &n) override = 0; void visit(query::RegExpTerm &n) override = 0; + void visit(query::NearestNeighborTerm &n) override = 0; }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp b/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp index 4c678a9902f..fc3a6399e00 100644 --- a/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp @@ -6,9 +6,10 @@ #include "create_blueprint_visitor_helper.h" #include <vespa/vespalib/objects/visit.h> -using search::query::NumberTerm; using search::query::LocationTerm; +using search::query::NearestNeighborTerm; using search::query::Node; +using search::query::NumberTerm; using search::query::PredicateQuery; using search::query::PrefixTerm; using search::query::RangeTerm; @@ -64,6 +65,7 @@ public: void visit(SuffixTerm &n) override { visitTerm(n); } void visit(PredicateQuery &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } + void visit(NearestNeighborTerm &n) override { visitTerm(n); } }; template <class Map> diff --git a/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp b/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp index 3829ea45e2b..7a97110713d 100644 --- a/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp @@ -14,28 +14,29 @@ LOG_SETUP(".termasstring"); using search::query::And; using search::query::AndNot; +using search::query::DotProduct; using search::query::Equiv; -using search::query::NumberTerm; using search::query::LocationTerm; using search::query::Near; +using search::query::NearestNeighborTerm; using search::query::Node; +using search::query::NumberTerm; using search::query::ONear; using search::query::Or; using search::query::Phrase; -using search::query::SameElement; using search::query::PredicateQuery; using search::query::PrefixTerm; using search::query::QueryVisitor; using search::query::RangeTerm; using search::query::Rank; using search::query::RegExpTerm; +using search::query::SameElement; using search::query::StringTerm; using search::query::SubstringTerm; using search::query::SuffixTerm; +using search::query::WandTerm; using search::query::WeakAnd; using search::query::WeightedSetTerm; -using search::query::DotProduct; -using search::query::WandTerm; using vespalib::string; namespace search::queryeval { @@ -101,6 +102,7 @@ struct TermAsStringVisitor : public QueryVisitor { void visit(SuffixTerm &n) override {visitTerm(n); } void visit(RegExpTerm &n) override {visitTerm(n); } void visit(PredicateQuery &) override {illegalVisit(); } + void visit(NearestNeighborTerm &) override { illegalVisit(); } }; } // namespace |