diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-11-19 12:37:16 +0000 |
---|---|---|
committer | Geir Storli <geirst@verizonmedia.com> | 2019-11-19 12:38:27 +0000 |
commit | 5361472f495df07e9b1ae2af91c2f780ed8966df (patch) | |
tree | 78f044702c6dbf490c73fcc675ed0625478148a5 /searchlib/src/tests/query | |
parent | e8c2faeb2c1feac0a3712592f4a55ce276d2fc60 (diff) |
Add skeleton for NearestNeighborTerm in C++.
Diffstat (limited to 'searchlib/src/tests/query')
-rw-r--r-- | searchlib/src/tests/query/customtypevisitor_test.cpp | 3 | ||||
-rw-r--r-- | searchlib/src/tests/query/query_visitor_test.cpp | 2 | ||||
-rw-r--r-- | searchlib/src/tests/query/querybuilder_test.cpp | 119 |
3 files changed, 67 insertions, 57 deletions
diff --git a/searchlib/src/tests/query/customtypevisitor_test.cpp b/searchlib/src/tests/query/customtypevisitor_test.cpp index c5eeac8543d..3f7d57b7aa4 100644 --- a/searchlib/src/tests/query/customtypevisitor_test.cpp +++ b/searchlib/src/tests/query/customtypevisitor_test.cpp @@ -54,6 +54,7 @@ struct MyDotProduct : DotProduct { MyDotProduct() : DotProduct("view", 0, Weight struct MyWandTerm : WandTerm { MyWandTerm() : WandTerm("view", 0, Weight(42), 57, 67, 77.7) {} }; struct MyPredicateQuery : InitTerm<PredicateQuery> {}; struct MyRegExpTerm : InitTerm<RegExpTerm> {}; +struct MyNearestNeighborTerm : NearestNeighborTerm {}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -78,6 +79,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; class MyCustomVisitor : public CustomTypeVisitor<MyQueryNodeTypes> @@ -113,6 +115,7 @@ public: void visit(MyWandTerm &) override { setVisited<MyWandTerm>(); } void visit(MyPredicateQuery &) override { setVisited<MyPredicateQuery>(); } void visit(MyRegExpTerm &) override { setVisited<MyRegExpTerm>(); } + void visit(MyNearestNeighborTerm &) override { setVisited<MyNearestNeighborTerm>(); } }; template <class T> diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index f8922c54a4e..39e381c0942 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -65,6 +65,7 @@ public: void visit(WandTerm &) override { isVisited<WandTerm>() = true; } void visit(PredicateQuery &) override { isVisited<PredicateQuery>() = true; } void visit(RegExpTerm &) override { isVisited<RegExpTerm>() = true; } + void visit(NearestNeighborTerm &) override { isVisited<NearestNeighborTerm>() = true; } }; template <class T> @@ -98,6 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 6673a107d44..d47fb071d81 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -48,7 +48,7 @@ PredicateQueryTerm::UP getPredicateQueryTerm() { template <class NodeTypes> Node::UP createQueryTree() { QueryBuilder<NodeTypes> builder; - builder.addAnd(10); + builder.addAnd(11); { builder.addRank(2); { @@ -111,6 +111,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -140,6 +141,15 @@ bool checkTerm(const Term *term, const typename Term::Type &t, const string &f, EXPECT_EQUAL(use_position_data, term->usePositionData())); } +template <class NodeType> +NodeType* +as_node(Node* node) +{ + auto* result = dynamic_cast<NodeType*>(node); + ASSERT_TRUE(result != nullptr); + return result; +} + template <class NodeTypes> void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::And And; @@ -166,126 +176,114 @@ void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::RegExpTerm RegExpTerm; ASSERT_TRUE(node); - And *and_node = dynamic_cast<And *>(node); - ASSERT_TRUE(and_node); - EXPECT_EQUAL(10u, and_node->getChildren().size()); + auto* and_node = as_node<And>(node); + EXPECT_EQUAL(11u, and_node->getChildren().size()); - - Rank *rank = dynamic_cast<Rank *>(and_node->getChildren()[0]); - ASSERT_TRUE(rank); + auto* rank = as_node<Rank>(and_node->getChildren()[0]); EXPECT_EQUAL(2u, rank->getChildren().size()); - Near *near = dynamic_cast<Near *>(rank->getChildren()[0]); - ASSERT_TRUE(near); + auto* near = as_node<Near>(rank->getChildren()[0]); EXPECT_EQUAL(2u, near->getChildren().size()); EXPECT_EQUAL(distance, near->getDistance()); - StringTerm *string_term = dynamic_cast<StringTerm *>(near->getChildren()[0]); + auto* string_term = as_node<StringTerm>(near->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[0], view[0], id[0], weight[0])); - SubstringTerm *substring_term = dynamic_cast<SubstringTerm *>(near->getChildren()[1]); + auto* substring_term = as_node<SubstringTerm>(near->getChildren()[1]); EXPECT_TRUE(checkTerm(substring_term, str[1], view[1], id[1], weight[1])); - ONear *onear = dynamic_cast<ONear *>(rank->getChildren()[1]); - ASSERT_TRUE(onear); + auto* onear = as_node<ONear>(rank->getChildren()[1]); EXPECT_EQUAL(2u, onear->getChildren().size()); EXPECT_EQUAL(distance, onear->getDistance()); - SuffixTerm *suffix_term = dynamic_cast<SuffixTerm *>(onear->getChildren()[0]); + auto* suffix_term = as_node<SuffixTerm>(onear->getChildren()[0]); EXPECT_TRUE(checkTerm(suffix_term, str[2], view[2], id[2], weight[2])); - PrefixTerm *prefix_term = dynamic_cast<PrefixTerm *>(onear->getChildren()[1]); + auto* prefix_term = as_node<PrefixTerm>(onear->getChildren()[1]); EXPECT_TRUE(checkTerm(prefix_term, str[3], view[3], id[3], weight[3])); - - Or *or_node = dynamic_cast<Or *>(and_node->getChildren()[1]); - ASSERT_TRUE(or_node); + auto* or_node = as_node<Or>(and_node->getChildren()[1]); EXPECT_EQUAL(3u, or_node->getChildren().size()); - Phrase *phrase = dynamic_cast<Phrase *>(or_node->getChildren()[0]); - ASSERT_TRUE(phrase); + auto* phrase = as_node<Phrase>(or_node->getChildren()[0]); EXPECT_TRUE(phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(3u, phrase->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]); + string_term = as_node<StringTerm>(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]); + string_term = as_node<StringTerm>(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[2]); + string_term = as_node<StringTerm>(phrase->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[4])); - phrase = dynamic_cast<Phrase *>(or_node->getChildren()[1]); - ASSERT_TRUE(phrase); + phrase = as_node<Phrase>(or_node->getChildren()[1]); EXPECT_TRUE(!phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(2u, phrase->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]); + string_term = as_node<StringTerm>(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]); + string_term = as_node<StringTerm>(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - AndNot *and_not = dynamic_cast<AndNot *>(or_node->getChildren()[2]); - ASSERT_TRUE(and_not); + auto* and_not = as_node<AndNot>(or_node->getChildren()[2]); EXPECT_EQUAL(2u, and_not->getChildren().size()); - NumberTerm *integer_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[0]); + auto* integer_term = as_node<NumberTerm>(and_not->getChildren()[0]); EXPECT_TRUE(checkTerm(integer_term, int1, view[7], id[7], weight[7])); - NumberTerm *float_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[1]); + auto* float_term = as_node<NumberTerm>(and_not->getChildren()[1]); EXPECT_TRUE(checkTerm(float_term, float1, view[8], id[8], weight[8], false)); - - RangeTerm *range_term = dynamic_cast<RangeTerm *>(and_node->getChildren()[2]); - ASSERT_TRUE(range_term); + auto* range_term = as_node<RangeTerm>(and_node->getChildren()[2]); EXPECT_TRUE(checkTerm(range_term, range, view[9], id[9], weight[9])); - LocationTerm *loc_term = dynamic_cast<LocationTerm *>(and_node->getChildren()[3]); - ASSERT_TRUE(loc_term); + auto* loc_term = as_node<LocationTerm>(and_node->getChildren()[3]); EXPECT_TRUE(checkTerm(loc_term, location, view[10], id[10], weight[10])); - WeakAnd *wand = dynamic_cast<WeakAnd *>(and_node->getChildren()[4]); - ASSERT_TRUE(wand != 0); + auto* wand = as_node<WeakAnd>(and_node->getChildren()[4]); EXPECT_EQUAL(123u, wand->getMinHits()); EXPECT_EQUAL(2u, wand->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(wand->getChildren()[0]); + string_term = as_node<StringTerm>(wand->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(wand->getChildren()[1]); + string_term = as_node<StringTerm>(wand->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - PredicateQuery *predicateQuery = dynamic_cast<PredicateQuery *>(and_node->getChildren()[5]); - ASSERT_TRUE(predicateQuery); + auto* predicateQuery = as_node<PredicateQuery>(and_node->getChildren()[5]); PredicateQueryTerm::UP pqt(new PredicateQueryTerm); EXPECT_TRUE(checkTerm(predicateQuery, getPredicateQueryTerm(), view[3], id[3], weight[3])); - DotProduct *dotProduct = dynamic_cast<DotProduct *>(and_node->getChildren()[6]); - ASSERT_TRUE(dotProduct); + auto* dotProduct = as_node<DotProduct>(and_node->getChildren()[6]); EXPECT_EQUAL(3u, dotProduct->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[0]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[3], view[3], id[3], weight[3])); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[1]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[2]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - WandTerm *wandTerm = dynamic_cast<WandTerm *>(and_node->getChildren()[7]); - ASSERT_TRUE(wandTerm); + auto* wandTerm = as_node<WandTerm>(and_node->getChildren()[7]); EXPECT_EQUAL(57u, wandTerm->getTargetNumHits()); EXPECT_EQUAL(67, wandTerm->getScoreThreshold()); EXPECT_EQUAL(77.7, wandTerm->getThresholdBoostFactor()); EXPECT_EQUAL(2u, wandTerm->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[0]); + string_term = as_node<StringTerm>(wandTerm->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[1], view[1], id[1], weight[1])); - string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[1]); + string_term = as_node<StringTerm>(wandTerm->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[2], view[2], id[2], weight[2])); - RegExpTerm *regexp_term = dynamic_cast<RegExpTerm *>(and_node->getChildren()[8]); + auto* regexp_term = as_node<RegExpTerm>(and_node->getChildren()[8]); EXPECT_TRUE(checkTerm(regexp_term, str[5], view[5], id[5], weight[5])); - SameElement *same = dynamic_cast<SameElement *>(and_node->getChildren()[9]); - ASSERT_TRUE(same != nullptr); + auto* same = as_node<SameElement>(and_node->getChildren()[9]); EXPECT_EQUAL(view[4], same->getView()); EXPECT_EQUAL(3u, same->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[0]); + string_term = as_node<StringTerm>(same->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[5])); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[1]); + string_term = as_node<StringTerm>(same->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[6])); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[2]); + string_term = as_node<StringTerm>(same->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[7])); + auto* nearest_neighbor = as_node<NearestNeighborTerm>(and_node->getChildren()[10]); + EXPECT_EQUAL("query_tensor", nearest_neighbor->get_query_tensor_name()); + EXPECT_EQUAL("doc_tensor", nearest_neighbor->getView()); + EXPECT_EQUAL(id[3], nearest_neighbor->getId()); + EXPECT_EQUAL(weight[5].percent(), nearest_neighbor->getWeight().percent()); + EXPECT_EQUAL(7u, nearest_neighbor->get_target_num_hits()); } struct AbstractTypes { @@ -395,6 +393,12 @@ struct MyRegExpTerm : RegExpTerm { : RegExpTerm(t, f, i, w) { } }; +struct MyNearestNeighborTerm : NearestNeighborTerm { + MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t i, Weight w, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits) + {} +}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -419,6 +423,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; TEST("require that Custom Query Trees Can Be Built") { |