summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp2
-rw-r--r--searchlib/src/tests/query/query_visitor_test.cpp2
-rw-r--r--searchlib/src/tests/query/querybuilder_test.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/querybuilder.h12
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/queryreplicator.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/simplequery.h6
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h5
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/termnodes.h11
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h4
13 files changed, 52 insertions, 22 deletions
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
index 24be21f65ec..47728c9785c 100644
--- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
+++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
@@ -295,7 +295,7 @@ public:
request_ctx.set_query_tensor("query_tensor", tensor_spec);
}
Blueprint::UP create_blueprint() {
- query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7);
+ query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33);
return source.createBlueprint(request_ctx, FieldSpec(attr_name, 0, 0), term);
}
};
diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp
index 39e381c0942..edbc29be784 100644
--- a/searchlib/src/tests/query/query_visitor_test.cpp
+++ b/searchlib/src/tests/query/query_visitor_test.cpp
@@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() {
checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0)));
checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0)));
checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0)));
- checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123));
+ checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321));
}
} // namespace
diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp
index 7f496b3493c..8560cb0e091 100644
--- a/searchlib/src/tests/query/querybuilder_test.cpp
+++ b/searchlib/src/tests/query/querybuilder_test.cpp
@@ -111,7 +111,7 @@ Node::UP createQueryTree() {
builder.addStringTerm(str[5], view[5], id[5], weight[6]);
builder.addStringTerm(str[6], view[6], id[6], weight[7]);
}
- builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7);
+ builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33);
}
Node::UP node = builder.build();
ASSERT_TRUE(node.get());
@@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm {
};
struct MyNearestNeighborTerm : NearestNeighborTerm {
MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t i, Weight w, uint32_t target_num_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits)
+ int32_t i, Weight w, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits)
{}
};
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index 8595b0eff7f..9af05059bef 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -646,7 +646,9 @@ public:
query_tensor.release();
setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *dense_attr_tensor,
std::move(dense_query_tensor_up),
- n.get_target_num_hits()));
+ n.get_target_num_hits(),
+ n.get_allow_approximate(),
+ n.get_explore_additional_hits()));
}
};
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
index 70a3097ae05..f0fb53a5640 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
@@ -274,7 +274,9 @@ SimpleQueryStackDumpIterator::next()
try {
_curr_index_name = read_stringref(p);
_curr_term = read_stringref(p); // query_tensor_name
- _currArg1 = readCompressedPositiveInt(p); // target_num_hits;
+ _currArg1 = readCompressedPositiveInt(p); // target_num_hits
+ _currArg2 = readCompressedPositiveInt(p); // allow_approximate
+ _currArg3 = readCompressedPositiveInt(p); // explore_additional_hits
_currArity = 0;
} catch (...) {
return false;
diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
index 797defc39f5..8e6f2944ec9 100644
--- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
+++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
@@ -205,8 +205,11 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id,
template <class NodeTypes>
typename NodeTypes::NearestNeighborTerm *
create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits) {
- return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits);
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+{
+ return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits);
}
template <class NodeTypes>
@@ -317,9 +320,10 @@ public:
return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight));
}
typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits) {
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits) {
adjustWeight(weight);
- return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits));
+ return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
index 0bf923960b9..9289df7cbe9 100644
--- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
@@ -165,7 +165,8 @@ private:
void visit(NearestNeighborTerm &node) override {
replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(),
- node.getId(), node.getWeight(), node.get_target_num_hits()));
+ node.getId(), node.getWeight(), node.get_target_num_hits(),
+ node.get_allow_approximate(), node.get_explore_additional_hits()));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
index 8663bede4d6..4953f1a5b7c 100644
--- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h
+++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
@@ -105,8 +105,10 @@ struct SimpleRegExpTerm : RegExpTerm {
};
struct SimpleNearestNeighborTerm : NearestNeighborTerm {
SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits)
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits)
{}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
index 63acf532144..aafeaa46a22 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
@@ -263,6 +263,8 @@ class QueryNodeConverter : public QueryVisitor {
createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR);
appendString(node.get_query_tensor_name());
appendCompressedPositiveNumber(node.get_target_num_hits());
+ appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0);
+ appendCompressedPositiveNumber(node.get_explore_additional_hits());
}
public:
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
index 791da010720..a57c24584cc 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
@@ -114,7 +114,10 @@ private:
uint32_t target_num_hits = queryStack.getArg1();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
- builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits);
+ uint32_t allow_approximate = (queryStack.getArg2() != 0);
+ uint32_t explore_additional_hits = queryStack.getArg3();
+ builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits);
} else {
vespalib::stringref term = queryStack.getTerm();
vespalib::stringref view = queryStack.getIndexName();
diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
index a82b1e14d76..9af424716fb 100644
--- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h
+++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
@@ -128,17 +128,24 @@ class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode>
private:
vespalib::string _query_tensor_name;
uint32_t _target_num_hits;
+ bool _allow_approximate;
+ uint32_t _explore_additional_hits;
public:
NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits)
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
: QueryNodeMixinType(field_name, id, weight),
_query_tensor_name(query_tensor_name),
- _target_num_hits(target_num_hits)
+ _target_num_hits(target_num_hits),
+ _allow_approximate(allow_approximate),
+ _explore_additional_hits(explore_additional_hits)
{}
virtual ~NearestNeighborTerm() {}
const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
+ bool get_allow_approximate() const { return _allow_approximate; }
+ uint32_t get_explore_additional_hits() const { return _explore_additional_hits; }
};
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index d4aa2aaa1d7..c160f8d5485 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -13,17 +13,22 @@ namespace search::queryeval {
NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
- uint32_t target_num_hits)
+ uint32_t target_num_hits, bool approximate, uint32_t explore_k)
: ComplexLeafBlueprint(field),
_attr_tensor(attr_tensor),
_query_tensor(std::move(query_tensor)),
_target_num_hits(target_num_hits),
+ _approximate(approximate),
+ _explore_k(explore_k),
_distance_heap(target_num_hits),
_found_hits()
{
uint32_t est_hits = _attr_tensor.getNumDocs();
if (_attr_tensor.nearest_neighbor_index()) {
est_hits = std::min(target_num_hits, est_hits);
+ if (_explore_k == 0) {
+ _explore_k = 100;
+ }
}
setEstimate(HitEstimate(est_hits, false));
}
@@ -34,15 +39,14 @@ void
NearestNeighborBlueprint::perform_top_k()
{
auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (nns_index) {
+ if (_approximate && nns_index) {
auto lhs_type = _query_tensor->fast_type();
auto rhs_type = _attr_tensor.getTensorType();
// XXX deal with different cell types later
if (lhs_type == rhs_type) {
auto lhs = _query_tensor->cellsRef();
uint32_t k = _target_num_hits;
- uint32_t explore_k = k + 100; // XXX hardcoded for now
- _found_hits = nns_index->find_top_k(k, lhs, explore_k);
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_k);
}
}
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index ab4413c487a..a782633ccc3 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -21,6 +21,8 @@ private:
const tensor::DenseTensorAttribute& _attr_tensor;
std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor;
uint32_t _target_num_hits;
+ bool _approximate;
+ uint32_t _explore_k;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
@@ -29,7 +31,7 @@ public:
NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
- uint32_t target_num_hits);
+ uint32_t target_num_hits, bool approximate, uint32_t explore_k);
NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete;
NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete;
~NearestNeighborBlueprint();