From 5361472f495df07e9b1ae2af91c2f780ed8966df Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Tue, 19 Nov 2019 12:37:16 +0000 Subject: Add skeleton for NearestNeighborTerm in C++. --- .../src/tests/query/customtypevisitor_test.cpp | 3 + searchlib/src/tests/query/query_visitor_test.cpp | 2 + searchlib/src/tests/query/querybuilder_test.cpp | 119 +++++++++++---------- searchlib/src/vespa/searchlib/parsequery/parse.h | 3 +- .../searchlib/parsequery/stackdumpiterator.cpp | 11 ++ .../vespa/searchlib/query/tree/customtypevisitor.h | 3 + .../src/vespa/searchlib/query/tree/querybuilder.h | 12 +++ .../vespa/searchlib/query/tree/queryreplicator.h | 5 + .../src/vespa/searchlib/query/tree/queryvisitor.h | 2 + .../src/vespa/searchlib/query/tree/simplequery.h | 51 +++++---- .../searchlib/query/tree/stackdumpcreator.cpp | 14 ++- .../searchlib/query/tree/stackdumpquerycreator.h | 7 ++ .../searchlib/query/tree/templatetermvisitor.h | 1 + .../src/vespa/searchlib/query/tree/termnodes.h | 28 +++++ .../queryeval/create_blueprint_visitor_helper.cpp | 8 ++ .../queryeval/create_blueprint_visitor_helper.h | 2 + .../src/vespa/searchlib/queryeval/termasstring.cpp | 10 +- 17 files changed, 195 insertions(+), 86 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/tests/query/customtypevisitor_test.cpp b/searchlib/src/tests/query/customtypevisitor_test.cpp index c5eeac8543d..3f7d57b7aa4 100644 --- a/searchlib/src/tests/query/customtypevisitor_test.cpp +++ b/searchlib/src/tests/query/customtypevisitor_test.cpp @@ -54,6 +54,7 @@ struct MyDotProduct : DotProduct { MyDotProduct() : DotProduct("view", 0, Weight struct MyWandTerm : WandTerm { MyWandTerm() : WandTerm("view", 0, Weight(42), 57, 67, 77.7) {} }; struct MyPredicateQuery : InitTerm {}; struct MyRegExpTerm : InitTerm {}; +struct MyNearestNeighborTerm : NearestNeighborTerm {}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -78,6 +79,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; class MyCustomVisitor : public CustomTypeVisitor @@ -113,6 +115,7 @@ public: void visit(MyWandTerm &) override { setVisited(); } void visit(MyPredicateQuery &) override { setVisited(); } void visit(MyRegExpTerm &) override { setVisited(); } + void visit(MyNearestNeighborTerm &) override { setVisited(); } }; template diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index f8922c54a4e..39e381c0942 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -65,6 +65,7 @@ public: void visit(WandTerm &) override { isVisited() = true; } void visit(PredicateQuery &) override { isVisited() = true; } void visit(RegExpTerm &) override { isVisited() = true; } + void visit(NearestNeighborTerm &) override { isVisited() = true; } }; template @@ -98,6 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit(new SimpleRegExpTerm("t", "field", 0, Weight(0))); + checkVisit(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 6673a107d44..d47fb071d81 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -48,7 +48,7 @@ PredicateQueryTerm::UP getPredicateQueryTerm() { template Node::UP createQueryTree() { QueryBuilder builder; - builder.addAnd(10); + builder.addAnd(11); { builder.addRank(2); { @@ -111,6 +111,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -140,6 +141,15 @@ bool checkTerm(const Term *term, const typename Term::Type &t, const string &f, EXPECT_EQUAL(use_position_data, term->usePositionData())); } +template +NodeType* +as_node(Node* node) +{ + auto* result = dynamic_cast(node); + ASSERT_TRUE(result != nullptr); + return result; +} + template void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::And And; @@ -166,126 +176,114 @@ void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::RegExpTerm RegExpTerm; ASSERT_TRUE(node); - And *and_node = dynamic_cast(node); - ASSERT_TRUE(and_node); - EXPECT_EQUAL(10u, and_node->getChildren().size()); + auto* and_node = as_node(node); + EXPECT_EQUAL(11u, and_node->getChildren().size()); - - Rank *rank = dynamic_cast(and_node->getChildren()[0]); - ASSERT_TRUE(rank); + auto* rank = as_node(and_node->getChildren()[0]); EXPECT_EQUAL(2u, rank->getChildren().size()); - Near *near = dynamic_cast(rank->getChildren()[0]); - ASSERT_TRUE(near); + auto* near = as_node(rank->getChildren()[0]); EXPECT_EQUAL(2u, near->getChildren().size()); EXPECT_EQUAL(distance, near->getDistance()); - StringTerm *string_term = dynamic_cast(near->getChildren()[0]); + auto* string_term = as_node(near->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[0], view[0], id[0], weight[0])); - SubstringTerm *substring_term = dynamic_cast(near->getChildren()[1]); + auto* substring_term = as_node(near->getChildren()[1]); EXPECT_TRUE(checkTerm(substring_term, str[1], view[1], id[1], weight[1])); - ONear *onear = dynamic_cast(rank->getChildren()[1]); - ASSERT_TRUE(onear); + auto* onear = as_node(rank->getChildren()[1]); EXPECT_EQUAL(2u, onear->getChildren().size()); EXPECT_EQUAL(distance, onear->getDistance()); - SuffixTerm *suffix_term = dynamic_cast(onear->getChildren()[0]); + auto* suffix_term = as_node(onear->getChildren()[0]); EXPECT_TRUE(checkTerm(suffix_term, str[2], view[2], id[2], weight[2])); - PrefixTerm *prefix_term = dynamic_cast(onear->getChildren()[1]); + auto* prefix_term = as_node(onear->getChildren()[1]); EXPECT_TRUE(checkTerm(prefix_term, str[3], view[3], id[3], weight[3])); - - Or *or_node = dynamic_cast(and_node->getChildren()[1]); - ASSERT_TRUE(or_node); + auto* or_node = as_node(and_node->getChildren()[1]); EXPECT_EQUAL(3u, or_node->getChildren().size()); - Phrase *phrase = dynamic_cast(or_node->getChildren()[0]); - ASSERT_TRUE(phrase); + auto* phrase = as_node(or_node->getChildren()[0]); EXPECT_TRUE(phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(3u, phrase->getChildren().size()); - string_term = dynamic_cast(phrase->getChildren()[0]); + string_term = as_node(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast(phrase->getChildren()[1]); + string_term = as_node(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - string_term = dynamic_cast(phrase->getChildren()[2]); + string_term = as_node(phrase->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[4])); - phrase = dynamic_cast(or_node->getChildren()[1]); - ASSERT_TRUE(phrase); + phrase = as_node(or_node->getChildren()[1]); EXPECT_TRUE(!phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(2u, phrase->getChildren().size()); - string_term = dynamic_cast(phrase->getChildren()[0]); + string_term = as_node(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast(phrase->getChildren()[1]); + string_term = as_node(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - AndNot *and_not = dynamic_cast(or_node->getChildren()[2]); - ASSERT_TRUE(and_not); + auto* and_not = as_node(or_node->getChildren()[2]); EXPECT_EQUAL(2u, and_not->getChildren().size()); - NumberTerm *integer_term = dynamic_cast(and_not->getChildren()[0]); + auto* integer_term = as_node(and_not->getChildren()[0]); EXPECT_TRUE(checkTerm(integer_term, int1, view[7], id[7], weight[7])); - NumberTerm *float_term = dynamic_cast(and_not->getChildren()[1]); + auto* float_term = as_node(and_not->getChildren()[1]); EXPECT_TRUE(checkTerm(float_term, float1, view[8], id[8], weight[8], false)); - - RangeTerm *range_term = dynamic_cast(and_node->getChildren()[2]); - ASSERT_TRUE(range_term); + auto* range_term = as_node(and_node->getChildren()[2]); EXPECT_TRUE(checkTerm(range_term, range, view[9], id[9], weight[9])); - LocationTerm *loc_term = dynamic_cast(and_node->getChildren()[3]); - ASSERT_TRUE(loc_term); + auto* loc_term = as_node(and_node->getChildren()[3]); EXPECT_TRUE(checkTerm(loc_term, location, view[10], id[10], weight[10])); - WeakAnd *wand = dynamic_cast(and_node->getChildren()[4]); - ASSERT_TRUE(wand != 0); + auto* wand = as_node(and_node->getChildren()[4]); EXPECT_EQUAL(123u, wand->getMinHits()); EXPECT_EQUAL(2u, wand->getChildren().size()); - string_term = dynamic_cast(wand->getChildren()[0]); + string_term = as_node(wand->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast(wand->getChildren()[1]); + string_term = as_node(wand->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - PredicateQuery *predicateQuery = dynamic_cast(and_node->getChildren()[5]); - ASSERT_TRUE(predicateQuery); + auto* predicateQuery = as_node(and_node->getChildren()[5]); PredicateQueryTerm::UP pqt(new PredicateQueryTerm); EXPECT_TRUE(checkTerm(predicateQuery, getPredicateQueryTerm(), view[3], id[3], weight[3])); - DotProduct *dotProduct = dynamic_cast(and_node->getChildren()[6]); - ASSERT_TRUE(dotProduct); + auto* dotProduct = as_node(and_node->getChildren()[6]); EXPECT_EQUAL(3u, dotProduct->getChildren().size()); - string_term = dynamic_cast(dotProduct->getChildren()[0]); + string_term = as_node(dotProduct->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[3], view[3], id[3], weight[3])); - string_term = dynamic_cast(dotProduct->getChildren()[1]); + string_term = as_node(dotProduct->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast(dotProduct->getChildren()[2]); + string_term = as_node(dotProduct->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - WandTerm *wandTerm = dynamic_cast(and_node->getChildren()[7]); - ASSERT_TRUE(wandTerm); + auto* wandTerm = as_node(and_node->getChildren()[7]); EXPECT_EQUAL(57u, wandTerm->getTargetNumHits()); EXPECT_EQUAL(67, wandTerm->getScoreThreshold()); EXPECT_EQUAL(77.7, wandTerm->getThresholdBoostFactor()); EXPECT_EQUAL(2u, wandTerm->getChildren().size()); - string_term = dynamic_cast(wandTerm->getChildren()[0]); + string_term = as_node(wandTerm->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[1], view[1], id[1], weight[1])); - string_term = dynamic_cast(wandTerm->getChildren()[1]); + string_term = as_node(wandTerm->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[2], view[2], id[2], weight[2])); - RegExpTerm *regexp_term = dynamic_cast(and_node->getChildren()[8]); + auto* regexp_term = as_node(and_node->getChildren()[8]); EXPECT_TRUE(checkTerm(regexp_term, str[5], view[5], id[5], weight[5])); - SameElement *same = dynamic_cast(and_node->getChildren()[9]); - ASSERT_TRUE(same != nullptr); + auto* same = as_node(and_node->getChildren()[9]); EXPECT_EQUAL(view[4], same->getView()); EXPECT_EQUAL(3u, same->getChildren().size()); - string_term = dynamic_cast(same->getChildren()[0]); + string_term = as_node(same->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[5])); - string_term = dynamic_cast(same->getChildren()[1]); + string_term = as_node(same->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[6])); - string_term = dynamic_cast(same->getChildren()[2]); + string_term = as_node(same->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[7])); + auto* nearest_neighbor = as_node(and_node->getChildren()[10]); + EXPECT_EQUAL("query_tensor", nearest_neighbor->get_query_tensor_name()); + EXPECT_EQUAL("doc_tensor", nearest_neighbor->getView()); + EXPECT_EQUAL(id[3], nearest_neighbor->getId()); + EXPECT_EQUAL(weight[5].percent(), nearest_neighbor->getWeight().percent()); + EXPECT_EQUAL(7u, nearest_neighbor->get_target_num_hits()); } struct AbstractTypes { @@ -395,6 +393,12 @@ struct MyRegExpTerm : RegExpTerm { : RegExpTerm(t, f, i, w) { } }; +struct MyNearestNeighborTerm : NearestNeighborTerm { + MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t i, Weight w, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits) + {} +}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -419,6 +423,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; TEST("require that Custom Query Trees Can Be Built") { diff --git a/searchlib/src/vespa/searchlib/parsequery/parse.h b/searchlib/src/vespa/searchlib/parsequery/parse.h index 9c0e76d2441..83352b571c8 100644 --- a/searchlib/src/vespa/searchlib/parsequery/parse.h +++ b/searchlib/src/vespa/searchlib/parsequery/parse.h @@ -60,7 +60,8 @@ public: ITEM_PREDICATE_QUERY = 23, ITEM_REGEXP = 24, ITEM_WORD_ALTERNATIVES = 25, - ITEM_MAX = 26, // Indicates how long tables must be. + ITEM_NEAREST_NEIGHBOR = 26, + ITEM_MAX = 27, // Indicates how long tables must be. ITEM_UNDEF = 31, }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index a70fe07cf81..70a3097ae05 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -270,6 +270,17 @@ SimpleQueryStackDumpIterator::next() } break; + case ParseItem::ITEM_NEAREST_NEIGHBOR: + try { + _curr_index_name = read_stringref(p); + _curr_term = read_stringref(p); // query_tensor_name + _currArg1 = readCompressedPositiveInt(p); // target_num_hits; + _currArity = 0; + } catch (...) { + return false; + } + break; + default: // Unknown item, so report that no more are available return false; diff --git a/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h b/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h index cdeebcaf9e5..3882bc41b2b 100644 --- a/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h @@ -49,6 +49,7 @@ public: virtual void visit(typename NodeTypes::WandTerm &) = 0; virtual void visit(typename NodeTypes::PredicateQuery &) = 0; virtual void visit(typename NodeTypes::RegExpTerm &) = 0; + virtual void visit(typename NodeTypes::NearestNeighborTerm &) = 0; private: // Route QueryVisit requests to the correct custom type. @@ -75,6 +76,7 @@ private: typedef typename NodeTypes::WandTerm TWandTerm; typedef typename NodeTypes::PredicateQuery TPredicateQuery; typedef typename NodeTypes::RegExpTerm TRegExpTerm; + typedef typename NodeTypes::NearestNeighborTerm TNearestNeighborTerm; void visit(And &n) override { visit(static_cast(n)); } void visit(AndNot &n) override { visit(static_cast(n)); } @@ -98,6 +100,7 @@ private: void visit(WandTerm &n) override { visit(static_cast(n)); } void visit(PredicateQuery &n) override { visit(static_cast(n)); } void visit(RegExpTerm &n) override { visit(static_cast(n)); } + void visit(NearestNeighborTerm &n) override { visit(static_cast(n)); } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index a2ad8eae84b..797defc39f5 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -202,6 +202,13 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id, return new typename NodeTypes::RegExpTerm(term, view, id, weight); } +template +typename NodeTypes::NearestNeighborTerm * +create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) { + return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits); +} + template class QueryBuilder : public QueryBuilderBase { template @@ -309,6 +316,11 @@ public: adjustWeight(weight); return addTerm(createRegExpTerm(term, view, id, weight)); } + typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) { + adjustWeight(weight); + return addTerm(create_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits)); + } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index e7c3fd8c73b..d2249a53f18 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -163,6 +163,11 @@ private: node.getTerm(), node.getView(), node.getId(), node.getWeight())); } + + void visit(NearestNeighborTerm &node) override { + replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), + node.getId(), node.getWeight(), node.get_target_num_hits())); + } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h b/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h index 0cb56f9127a..533e240e088 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h @@ -26,6 +26,7 @@ class WandTerm; class PredicateQuery; class RegExpTerm; class SameElement; +class NearestNeighborTerm; struct QueryVisitor { virtual ~QueryVisitor() {} @@ -52,6 +53,7 @@ struct QueryVisitor { virtual void visit(WandTerm &) = 0; virtual void visit(PredicateQuery &) = 0; virtual void visit(RegExpTerm &) = 0; + virtual void visit(NearestNeighborTerm &) = 0; }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 707ed2aa0db..8663bede4d6 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -103,31 +103,38 @@ struct SimpleRegExpTerm : RegExpTerm { : RegExpTerm(term, view, id, weight) { } }; +struct SimpleNearestNeighborTerm : NearestNeighborTerm { + SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits) + {} +}; struct SimpleQueryNodeTypes { - typedef SimpleAnd And; - typedef SimpleAndNot AndNot; - typedef SimpleEquiv Equiv; - typedef SimpleNumberTerm NumberTerm; - typedef SimpleLocationTerm LocationTerm; - typedef SimpleNear Near; - typedef SimpleONear ONear; - typedef SimpleOr Or; - typedef SimplePhrase Phrase; - typedef SimpleSameElement SameElement; - typedef SimplePrefixTerm PrefixTerm; - typedef SimpleRangeTerm RangeTerm; - typedef SimpleRank Rank; - typedef SimpleStringTerm StringTerm; - typedef SimpleSubstringTerm SubstringTerm; - typedef SimpleSuffixTerm SuffixTerm; - typedef SimpleWeakAnd WeakAnd; - typedef SimpleWeightedSetTerm WeightedSetTerm; - typedef SimpleDotProduct DotProduct; - typedef SimpleWandTerm WandTerm; - typedef SimplePredicateQuery PredicateQuery; - typedef SimpleRegExpTerm RegExpTerm; + using And = SimpleAnd; + using AndNot = SimpleAndNot; + using Equiv = SimpleEquiv; + using NumberTerm = SimpleNumberTerm; + using LocationTerm = SimpleLocationTerm; + using Near = SimpleNear; + using ONear = SimpleONear; + using Or = SimpleOr; + using Phrase = SimplePhrase; + using SameElement = SimpleSameElement; + using PrefixTerm = SimplePrefixTerm; + using RangeTerm = SimpleRangeTerm; + using Rank = SimpleRank; + using StringTerm = SimpleStringTerm; + using SubstringTerm = SimpleSubstringTerm; + using SuffixTerm = SimpleSuffixTerm; + using WeakAnd = SimpleWeakAnd; + using WeightedSetTerm = SimpleWeightedSetTerm; + using DotProduct = SimpleDotProduct; + using WandTerm = SimpleWandTerm; + using PredicateQuery = SimplePredicateQuery; + using RegExpTerm = SimpleRegExpTerm; + using NearestNeighborTerm = SimpleNearestNeighborTerm; }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 645750b8576..63acf532144 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -196,8 +196,7 @@ class QueryNodeConverter : public QueryVisitor { template void appendTerm(const TermBase &node); - template - void createTerm(const Term &node, size_t type) { + void createTermNode(const TermNode &node, size_t type) { uint8_t typefield = type | ParseItem::IF_WEIGHT | ParseItem::IF_UNIQUEID; uint8_t flags = 0; if (!node.isRanked()) { @@ -216,6 +215,11 @@ class QueryNodeConverter : public QueryVisitor { appendByte(flags); } appendString(node.getView()); + } + + template + void createTerm(const Term &node, size_t type) { + createTermNode(node, type); appendTerm(node); } @@ -255,6 +259,12 @@ class QueryNodeConverter : public QueryVisitor { createTerm(node, ParseItem::ITEM_REGEXP); } + void visit(NearestNeighborTerm &node) override { + createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR); + appendString(node.get_query_tensor_name()); + appendCompressedPositiveNumber(node.get_target_num_hits()); + } + public: QueryNodeConverter() : _buf(4096) diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index dfb0c75a695..a5f25d81400 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -109,6 +109,13 @@ private: pureTermView = vespalib::stringref(); } else if (type == ParseItem::ITEM_NOT) { builder.addAndNot(arity); + } else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) { + vespalib::stringref query_tensor_name = queryStack.getTerm(); + vespalib::stringref field_name = queryStack.getIndexName(); + uint32_t target_num_hits = queryStack.getArg1(); + int32_t id = queryStack.getUniqueId(); + Weight weight = queryStack.GetWeight(); + builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h b/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h index 0cdaca82572..d1abc816838 100644 --- a/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h @@ -31,6 +31,7 @@ class TemplateTermVisitor : public CustomTypeTermVisitor { void visit(typename NodeTypes::SuffixTerm &n) override { myVisit(n); } void visit(typename NodeTypes::PredicateQuery &n) override { myVisit(n); } void visit(typename NodeTypes::RegExpTerm &n) override { myVisit(n); } + void visit(typename NodeTypes::NearestNeighborTerm &n) override { myVisit(n); } // Phrases are terms with children. This visitor will not visit // the phrase's children, unless this member function is diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index 35c23dde985..a82b1e14d76 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -113,5 +113,33 @@ public: virtual ~RegExpTerm() = 0; }; +/** + * Term matching the K nearest neighbors in a multi-dimensional vector space. + * + * The query point is specified as a dense tensor of order 1. + * This is found in fef::IQueryEnvironment using the query tensor name as key. + * The field name is the name of a dense document tensor of order 1. + * Both tensors are validated to have the same tensor type before the query is sent to the backend. + * + * Target num hits (K) is a hint to how many neighbors to return. + * The actual returned number might be higher (or lower if the query returns fewer hits). + */ +class NearestNeighborTerm : public QueryNodeMixin { +private: + vespalib::string _query_tensor_name; + uint32_t _target_num_hits; + +public: + NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) + : QueryNodeMixinType(field_name, id, weight), + _query_tensor_name(query_tensor_name), + _target_num_hits(target_num_hits) + {} + virtual ~NearestNeighborTerm() {} + const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } + uint32_t get_target_num_hits() const { return _target_num_hits; } +}; + } diff --git a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp index 3731b2ff6a8..22a1c2166ea 100644 --- a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp @@ -90,4 +90,12 @@ CreateBlueprintVisitorHelper::visitWandTerm(query::WandTerm &n) { n); } +void +CreateBlueprintVisitorHelper::visitNearestNeighborTerm(query::NearestNeighborTerm &n) +{ + (void) n; + // TODO (geirst): implement + setResult(std::make_unique()); +} + } diff --git a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h index 84830111fde..ae1f8938a27 100644 --- a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h +++ b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h @@ -41,6 +41,7 @@ public: void visitWeightedSetTerm(query::WeightedSetTerm &n); void visitDotProduct(query::DotProduct &n); void visitWandTerm(query::WandTerm &n); + void visitNearestNeighborTerm(query::NearestNeighborTerm &n); void handleNumberTermAsText(query::NumberTerm &n); @@ -62,6 +63,7 @@ public: void visit(query::WeightedSetTerm &n) override { visitWeightedSetTerm(n); } void visit(query::DotProduct &n) override { visitDotProduct(n); } void visit(query::WandTerm &n) override { visitWandTerm(n); } + void visit(query::NearestNeighborTerm &n) override { visitNearestNeighborTerm(n); } void visit(query::NumberTerm &n) override = 0; void visit(query::LocationTerm &n) override = 0; diff --git a/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp b/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp index 3829ea45e2b..7a97110713d 100644 --- a/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp @@ -14,28 +14,29 @@ LOG_SETUP(".termasstring"); using search::query::And; using search::query::AndNot; +using search::query::DotProduct; using search::query::Equiv; -using search::query::NumberTerm; using search::query::LocationTerm; using search::query::Near; +using search::query::NearestNeighborTerm; using search::query::Node; +using search::query::NumberTerm; using search::query::ONear; using search::query::Or; using search::query::Phrase; -using search::query::SameElement; using search::query::PredicateQuery; using search::query::PrefixTerm; using search::query::QueryVisitor; using search::query::RangeTerm; using search::query::Rank; using search::query::RegExpTerm; +using search::query::SameElement; using search::query::StringTerm; using search::query::SubstringTerm; using search::query::SuffixTerm; +using search::query::WandTerm; using search::query::WeakAnd; using search::query::WeightedSetTerm; -using search::query::DotProduct; -using search::query::WandTerm; using vespalib::string; namespace search::queryeval { @@ -101,6 +102,7 @@ struct TermAsStringVisitor : public QueryVisitor { void visit(SuffixTerm &n) override {visitTerm(n); } void visit(RegExpTerm &n) override {visitTerm(n); } void visit(PredicateQuery &) override {illegalVisit(); } + void visit(NearestNeighborTerm &) override { illegalVisit(); } }; } // namespace -- cgit v1.2.3