summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/query
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-11-19 12:37:16 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-11-19 12:38:27 +0000
commit5361472f495df07e9b1ae2af91c2f780ed8966df (patch)
tree78f044702c6dbf490c73fcc675ed0625478148a5 /searchlib/src/tests/query
parente8c2faeb2c1feac0a3712592f4a55ce276d2fc60 (diff)
Add skeleton for NearestNeighborTerm in C++.
Diffstat (limited to 'searchlib/src/tests/query')
-rw-r--r--searchlib/src/tests/query/customtypevisitor_test.cpp3
-rw-r--r--searchlib/src/tests/query/query_visitor_test.cpp2
-rw-r--r--searchlib/src/tests/query/querybuilder_test.cpp119
3 files changed, 67 insertions, 57 deletions
diff --git a/searchlib/src/tests/query/customtypevisitor_test.cpp b/searchlib/src/tests/query/customtypevisitor_test.cpp
index c5eeac8543d..3f7d57b7aa4 100644
--- a/searchlib/src/tests/query/customtypevisitor_test.cpp
+++ b/searchlib/src/tests/query/customtypevisitor_test.cpp
@@ -54,6 +54,7 @@ struct MyDotProduct : DotProduct { MyDotProduct() : DotProduct("view", 0, Weight
struct MyWandTerm : WandTerm { MyWandTerm() : WandTerm("view", 0, Weight(42), 57, 67, 77.7) {} };
struct MyPredicateQuery : InitTerm<PredicateQuery> {};
struct MyRegExpTerm : InitTerm<RegExpTerm> {};
+struct MyNearestNeighborTerm : NearestNeighborTerm {};
struct MyQueryNodeTypes {
typedef MyAnd And;
@@ -78,6 +79,7 @@ struct MyQueryNodeTypes {
typedef MyWandTerm WandTerm;
typedef MyPredicateQuery PredicateQuery;
typedef MyRegExpTerm RegExpTerm;
+ typedef MyNearestNeighborTerm NearestNeighborTerm;
};
class MyCustomVisitor : public CustomTypeVisitor<MyQueryNodeTypes>
@@ -113,6 +115,7 @@ public:
void visit(MyWandTerm &) override { setVisited<MyWandTerm>(); }
void visit(MyPredicateQuery &) override { setVisited<MyPredicateQuery>(); }
void visit(MyRegExpTerm &) override { setVisited<MyRegExpTerm>(); }
+ void visit(MyNearestNeighborTerm &) override { setVisited<MyNearestNeighborTerm>(); }
};
template <class T>
diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp
index f8922c54a4e..39e381c0942 100644
--- a/searchlib/src/tests/query/query_visitor_test.cpp
+++ b/searchlib/src/tests/query/query_visitor_test.cpp
@@ -65,6 +65,7 @@ public:
void visit(WandTerm &) override { isVisited<WandTerm>() = true; }
void visit(PredicateQuery &) override { isVisited<PredicateQuery>() = true; }
void visit(RegExpTerm &) override { isVisited<RegExpTerm>() = true; }
+ void visit(NearestNeighborTerm &) override { isVisited<NearestNeighborTerm>() = true; }
};
template <class T>
@@ -98,6 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() {
checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0)));
checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0)));
checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0)));
+ checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123));
}
} // namespace
diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp
index 6673a107d44..d47fb071d81 100644
--- a/searchlib/src/tests/query/querybuilder_test.cpp
+++ b/searchlib/src/tests/query/querybuilder_test.cpp
@@ -48,7 +48,7 @@ PredicateQueryTerm::UP getPredicateQueryTerm() {
template <class NodeTypes>
Node::UP createQueryTree() {
QueryBuilder<NodeTypes> builder;
- builder.addAnd(10);
+ builder.addAnd(11);
{
builder.addRank(2);
{
@@ -111,6 +111,7 @@ Node::UP createQueryTree() {
builder.addStringTerm(str[5], view[5], id[5], weight[6]);
builder.addStringTerm(str[6], view[6], id[6], weight[7]);
}
+ builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7);
}
Node::UP node = builder.build();
ASSERT_TRUE(node.get());
@@ -140,6 +141,15 @@ bool checkTerm(const Term *term, const typename Term::Type &t, const string &f,
EXPECT_EQUAL(use_position_data, term->usePositionData()));
}
+template <class NodeType>
+NodeType*
+as_node(Node* node)
+{
+ auto* result = dynamic_cast<NodeType*>(node);
+ ASSERT_TRUE(result != nullptr);
+ return result;
+}
+
template <class NodeTypes>
void checkQueryTreeTypes(Node *node) {
typedef typename NodeTypes::And And;
@@ -166,126 +176,114 @@ void checkQueryTreeTypes(Node *node) {
typedef typename NodeTypes::RegExpTerm RegExpTerm;
ASSERT_TRUE(node);
- And *and_node = dynamic_cast<And *>(node);
- ASSERT_TRUE(and_node);
- EXPECT_EQUAL(10u, and_node->getChildren().size());
+ auto* and_node = as_node<And>(node);
+ EXPECT_EQUAL(11u, and_node->getChildren().size());
-
- Rank *rank = dynamic_cast<Rank *>(and_node->getChildren()[0]);
- ASSERT_TRUE(rank);
+ auto* rank = as_node<Rank>(and_node->getChildren()[0]);
EXPECT_EQUAL(2u, rank->getChildren().size());
- Near *near = dynamic_cast<Near *>(rank->getChildren()[0]);
- ASSERT_TRUE(near);
+ auto* near = as_node<Near>(rank->getChildren()[0]);
EXPECT_EQUAL(2u, near->getChildren().size());
EXPECT_EQUAL(distance, near->getDistance());
- StringTerm *string_term = dynamic_cast<StringTerm *>(near->getChildren()[0]);
+ auto* string_term = as_node<StringTerm>(near->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[0], view[0], id[0], weight[0]));
- SubstringTerm *substring_term = dynamic_cast<SubstringTerm *>(near->getChildren()[1]);
+ auto* substring_term = as_node<SubstringTerm>(near->getChildren()[1]);
EXPECT_TRUE(checkTerm(substring_term, str[1], view[1], id[1], weight[1]));
- ONear *onear = dynamic_cast<ONear *>(rank->getChildren()[1]);
- ASSERT_TRUE(onear);
+ auto* onear = as_node<ONear>(rank->getChildren()[1]);
EXPECT_EQUAL(2u, onear->getChildren().size());
EXPECT_EQUAL(distance, onear->getDistance());
- SuffixTerm *suffix_term = dynamic_cast<SuffixTerm *>(onear->getChildren()[0]);
+ auto* suffix_term = as_node<SuffixTerm>(onear->getChildren()[0]);
EXPECT_TRUE(checkTerm(suffix_term, str[2], view[2], id[2], weight[2]));
- PrefixTerm *prefix_term = dynamic_cast<PrefixTerm *>(onear->getChildren()[1]);
+ auto* prefix_term = as_node<PrefixTerm>(onear->getChildren()[1]);
EXPECT_TRUE(checkTerm(prefix_term, str[3], view[3], id[3], weight[3]));
-
- Or *or_node = dynamic_cast<Or *>(and_node->getChildren()[1]);
- ASSERT_TRUE(or_node);
+ auto* or_node = as_node<Or>(and_node->getChildren()[1]);
EXPECT_EQUAL(3u, or_node->getChildren().size());
- Phrase *phrase = dynamic_cast<Phrase *>(or_node->getChildren()[0]);
- ASSERT_TRUE(phrase);
+ auto* phrase = as_node<Phrase>(or_node->getChildren()[0]);
EXPECT_TRUE(phrase->isRanked());
EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent());
EXPECT_EQUAL(3u, phrase->getChildren().size());
- string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]);
+ string_term = as_node<StringTerm>(phrase->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4]));
- string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]);
+ string_term = as_node<StringTerm>(phrase->getChildren()[1]);
EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4]));
- string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[2]);
+ string_term = as_node<StringTerm>(phrase->getChildren()[2]);
EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[4]));
- phrase = dynamic_cast<Phrase *>(or_node->getChildren()[1]);
- ASSERT_TRUE(phrase);
+ phrase = as_node<Phrase>(or_node->getChildren()[1]);
EXPECT_TRUE(!phrase->isRanked());
EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent());
EXPECT_EQUAL(2u, phrase->getChildren().size());
- string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]);
+ string_term = as_node<StringTerm>(phrase->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4]));
- string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]);
+ string_term = as_node<StringTerm>(phrase->getChildren()[1]);
EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4]));
- AndNot *and_not = dynamic_cast<AndNot *>(or_node->getChildren()[2]);
- ASSERT_TRUE(and_not);
+ auto* and_not = as_node<AndNot>(or_node->getChildren()[2]);
EXPECT_EQUAL(2u, and_not->getChildren().size());
- NumberTerm *integer_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[0]);
+ auto* integer_term = as_node<NumberTerm>(and_not->getChildren()[0]);
EXPECT_TRUE(checkTerm(integer_term, int1, view[7], id[7], weight[7]));
- NumberTerm *float_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[1]);
+ auto* float_term = as_node<NumberTerm>(and_not->getChildren()[1]);
EXPECT_TRUE(checkTerm(float_term, float1, view[8], id[8], weight[8], false));
-
- RangeTerm *range_term = dynamic_cast<RangeTerm *>(and_node->getChildren()[2]);
- ASSERT_TRUE(range_term);
+ auto* range_term = as_node<RangeTerm>(and_node->getChildren()[2]);
EXPECT_TRUE(checkTerm(range_term, range, view[9], id[9], weight[9]));
- LocationTerm *loc_term = dynamic_cast<LocationTerm *>(and_node->getChildren()[3]);
- ASSERT_TRUE(loc_term);
+ auto* loc_term = as_node<LocationTerm>(and_node->getChildren()[3]);
EXPECT_TRUE(checkTerm(loc_term, location, view[10], id[10], weight[10]));
- WeakAnd *wand = dynamic_cast<WeakAnd *>(and_node->getChildren()[4]);
- ASSERT_TRUE(wand != 0);
+ auto* wand = as_node<WeakAnd>(and_node->getChildren()[4]);
EXPECT_EQUAL(123u, wand->getMinHits());
EXPECT_EQUAL(2u, wand->getChildren().size());
- string_term = dynamic_cast<StringTerm *>(wand->getChildren()[0]);
+ string_term = as_node<StringTerm>(wand->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4]));
- string_term = dynamic_cast<StringTerm *>(wand->getChildren()[1]);
+ string_term = as_node<StringTerm>(wand->getChildren()[1]);
EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5]));
- PredicateQuery *predicateQuery = dynamic_cast<PredicateQuery *>(and_node->getChildren()[5]);
- ASSERT_TRUE(predicateQuery);
+ auto* predicateQuery = as_node<PredicateQuery>(and_node->getChildren()[5]);
PredicateQueryTerm::UP pqt(new PredicateQueryTerm);
EXPECT_TRUE(checkTerm(predicateQuery, getPredicateQueryTerm(), view[3], id[3], weight[3]));
- DotProduct *dotProduct = dynamic_cast<DotProduct *>(and_node->getChildren()[6]);
- ASSERT_TRUE(dotProduct);
+ auto* dotProduct = as_node<DotProduct>(and_node->getChildren()[6]);
EXPECT_EQUAL(3u, dotProduct->getChildren().size());
- string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[0]);
+ string_term = as_node<StringTerm>(dotProduct->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[3], view[3], id[3], weight[3]));
- string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[1]);
+ string_term = as_node<StringTerm>(dotProduct->getChildren()[1]);
EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4]));
- string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[2]);
+ string_term = as_node<StringTerm>(dotProduct->getChildren()[2]);
EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5]));
- WandTerm *wandTerm = dynamic_cast<WandTerm *>(and_node->getChildren()[7]);
- ASSERT_TRUE(wandTerm);
+ auto* wandTerm = as_node<WandTerm>(and_node->getChildren()[7]);
EXPECT_EQUAL(57u, wandTerm->getTargetNumHits());
EXPECT_EQUAL(67, wandTerm->getScoreThreshold());
EXPECT_EQUAL(77.7, wandTerm->getThresholdBoostFactor());
EXPECT_EQUAL(2u, wandTerm->getChildren().size());
- string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[0]);
+ string_term = as_node<StringTerm>(wandTerm->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[1], view[1], id[1], weight[1]));
- string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[1]);
+ string_term = as_node<StringTerm>(wandTerm->getChildren()[1]);
EXPECT_TRUE(checkTerm(string_term, str[2], view[2], id[2], weight[2]));
- RegExpTerm *regexp_term = dynamic_cast<RegExpTerm *>(and_node->getChildren()[8]);
+ auto* regexp_term = as_node<RegExpTerm>(and_node->getChildren()[8]);
EXPECT_TRUE(checkTerm(regexp_term, str[5], view[5], id[5], weight[5]));
- SameElement *same = dynamic_cast<SameElement *>(and_node->getChildren()[9]);
- ASSERT_TRUE(same != nullptr);
+ auto* same = as_node<SameElement>(and_node->getChildren()[9]);
EXPECT_EQUAL(view[4], same->getView());
EXPECT_EQUAL(3u, same->getChildren().size());
- string_term = dynamic_cast<StringTerm *>(same->getChildren()[0]);
+ string_term = as_node<StringTerm>(same->getChildren()[0]);
EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[5]));
- string_term = dynamic_cast<StringTerm *>(same->getChildren()[1]);
+ string_term = as_node<StringTerm>(same->getChildren()[1]);
EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[6]));
- string_term = dynamic_cast<StringTerm *>(same->getChildren()[2]);
+ string_term = as_node<StringTerm>(same->getChildren()[2]);
EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[7]));
+ auto* nearest_neighbor = as_node<NearestNeighborTerm>(and_node->getChildren()[10]);
+ EXPECT_EQUAL("query_tensor", nearest_neighbor->get_query_tensor_name());
+ EXPECT_EQUAL("doc_tensor", nearest_neighbor->getView());
+ EXPECT_EQUAL(id[3], nearest_neighbor->getId());
+ EXPECT_EQUAL(weight[5].percent(), nearest_neighbor->getWeight().percent());
+ EXPECT_EQUAL(7u, nearest_neighbor->get_target_num_hits());
}
struct AbstractTypes {
@@ -395,6 +393,12 @@ struct MyRegExpTerm : RegExpTerm {
: RegExpTerm(t, f, i, w) {
}
};
+struct MyNearestNeighborTerm : NearestNeighborTerm {
+ MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
+ int32_t i, Weight w, uint32_t target_num_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits)
+ {}
+};
struct MyQueryNodeTypes {
typedef MyAnd And;
@@ -419,6 +423,7 @@ struct MyQueryNodeTypes {
typedef MyWandTerm WandTerm;
typedef MyPredicateQuery PredicateQuery;
typedef MyRegExpTerm RegExpTerm;
+ typedef MyNearestNeighborTerm NearestNeighborTerm;
};
TEST("require that Custom Query Trees Can Be Built") {