summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2020-03-04 08:54:10 +0100
committerGitHub <noreply@github.com>2020-03-04 08:54:10 +0100
commit1847f9936b76d78f6bd6b8f83cf4756b0ef855d7 (patch)
tree12efce15586858687022c68e64d60d95e53bbf60 /searchlib/src
parenta347208cb7d5713d9b8035a26d2a04d9ba7b140b (diff)
parentf8a67179013e3ed6b373b227a3baf8db95057b27 (diff)
Merge pull request #12401 from vespa-engine/arnej/extend-nns-item
Arnej/extend nns item.
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp2
-rw-r--r--searchlib/src/tests/query/query_visitor_test.cpp2
-rw-r--r--searchlib/src/tests/query/querybuilder_test.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp29
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h24
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/querybuilder.h12
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/queryreplicator.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/simplequery.h6
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h25
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/termnodes.h11
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp11
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h4
15 files changed, 87 insertions, 57 deletions
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
index 24be21f65ec..47728c9785c 100644
--- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
+++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
@@ -295,7 +295,7 @@ public:
request_ctx.set_query_tensor("query_tensor", tensor_spec);
}
Blueprint::UP create_blueprint() {
- query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7);
+ query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33);
return source.createBlueprint(request_ctx, FieldSpec(attr_name, 0, 0), term);
}
};
diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp
index 39e381c0942..edbc29be784 100644
--- a/searchlib/src/tests/query/query_visitor_test.cpp
+++ b/searchlib/src/tests/query/query_visitor_test.cpp
@@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() {
checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0)));
checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0)));
checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0)));
- checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123));
+ checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321));
}
} // namespace
diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp
index 7f496b3493c..8560cb0e091 100644
--- a/searchlib/src/tests/query/querybuilder_test.cpp
+++ b/searchlib/src/tests/query/querybuilder_test.cpp
@@ -111,7 +111,7 @@ Node::UP createQueryTree() {
builder.addStringTerm(str[5], view[5], id[5], weight[6]);
builder.addStringTerm(str[6], view[6], id[6], weight[7]);
}
- builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7);
+ builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33);
}
Node::UP node = builder.build();
ASSERT_TRUE(node.get());
@@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm {
};
struct MyNearestNeighborTerm : NearestNeighborTerm {
MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t i, Weight w, uint32_t target_num_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits)
+ int32_t i, Weight w, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits)
{}
};
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index 8595b0eff7f..9af05059bef 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -646,7 +646,9 @@ public:
query_tensor.release();
setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *dense_attr_tensor,
std::move(dense_query_tensor_up),
- n.get_target_num_hits()));
+ n.get_target_num_hits(),
+ n.get_allow_approximate(),
+ n.get_explore_additional_hits()));
}
};
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
index 70a3097ae05..c42cf8fc370 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
@@ -21,9 +21,11 @@ SimpleQueryStackDumpIterator::SimpleQueryStackDumpIterator(vespalib::stringref b
_currUniqueId(0),
_currFlags(0),
_currArity(0),
- _currArg1(0),
- _currArg2(0),
- _currArg3(0),
+ _extraIntArg1(0),
+ _extraIntArg2(0),
+ _extraIntArg3(0),
+ _extraDoubleArg4(0),
+ _extraDoubleArg5(0),
_predicate_query_term(),
_curr_index_name(),
_curr_term(),
@@ -138,7 +140,6 @@ SimpleQueryStackDumpIterator::next()
case ParseItem::ITEM_ANY:
try {
_currArity = readCompressedPositiveInt(p);
- _currArg1 = 0;
_curr_index_name = vespalib::stringref();
_curr_term = vespalib::stringref();
} catch (...) {
@@ -150,7 +151,7 @@ SimpleQueryStackDumpIterator::next()
case ParseItem::ITEM_ONEAR:
try {
_currArity = readCompressedPositiveInt(p);
- _currArg1 = readCompressedPositiveInt(p);
+ _extraIntArg1 = readCompressedPositiveInt(p);
_curr_index_name = vespalib::stringref();
_curr_term = vespalib::stringref();
} catch (...) {
@@ -161,7 +162,7 @@ SimpleQueryStackDumpIterator::next()
case ParseItem::ITEM_WEAK_AND:
try {
_currArity = readCompressedPositiveInt(p);
- _currArg1 = readCompressedPositiveInt(p);
+ _extraIntArg1 = readCompressedPositiveInt(p); // targetNumHits
_curr_index_name = read_stringref(p);
_curr_term = vespalib::stringref();
} catch (...) {
@@ -171,7 +172,6 @@ SimpleQueryStackDumpIterator::next()
case ParseItem::ITEM_SAME_ELEMENT:
try {
_currArity = readCompressedPositiveInt(p);
- _currArg1 = 0;
_curr_index_name = read_stringref(p);
_curr_term = vespalib::stringref();
} catch (...) {
@@ -182,7 +182,6 @@ SimpleQueryStackDumpIterator::next()
case ParseItem::ITEM_PURE_WEIGHTED_STRING:
try {
_curr_term = read_stringref(p);
- _currArg1 = 0;
_currArity = 0;
} catch (...) {
return false;
@@ -196,7 +195,6 @@ SimpleQueryStackDumpIterator::next()
p += sizeof(int64_t);
if (p > _bufEnd) return false;
- _currArg1 = 0;
_currArity = 0;
break;
case ParseItem::ITEM_WORD_ALTERNATIVES:
@@ -218,7 +216,6 @@ SimpleQueryStackDumpIterator::next()
try {
_curr_index_name = read_stringref(p);
_curr_term = read_stringref(p);
- _currArg1 = 0;
_currArity = 0;
} catch (...) {
return false;
@@ -258,11 +255,9 @@ SimpleQueryStackDumpIterator::next()
_currArity = readCompressedPositiveInt(p);
_curr_index_name = read_stringref(p);
if (_currType == ParseItem::ITEM_WAND) {
- _currArg1 = readCompressedPositiveInt(p); // targetNumHits
- _currArg2 = read_double(p); // scoreThreshold
- _currArg3 = read_double(p); // thresholdBoostFactor
- } else {
- _currArg1 = 0;
+ _extraIntArg1 = readCompressedPositiveInt(p); // targetNumHits
+ _extraDoubleArg4 = read_double(p); // scoreThreshold
+ _extraDoubleArg5 = read_double(p); // thresholdBoostFactor
}
_curr_term = vespalib::stringref();
} catch (...) {
@@ -274,7 +269,9 @@ SimpleQueryStackDumpIterator::next()
try {
_curr_index_name = read_stringref(p);
_curr_term = read_stringref(p); // query_tensor_name
- _currArg1 = readCompressedPositiveInt(p); // target_num_hits;
+ _extraIntArg1 = readCompressedPositiveInt(p); // targetNumHits
+ _extraIntArg2 = readCompressedPositiveInt(p); // allow_approximate
+ _extraIntArg3 = readCompressedPositiveInt(p); // explore_additional_hits
_currArity = 0;
} catch (...) {
return false;
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
index 6eb3fb7777d..73c97bb5fb3 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
@@ -42,12 +42,13 @@ private:
/** The arity of the current item */
uint32_t _currArity;
- /** The first argument of the current item (length of NEAR/ONEAR area for example) */
- uint32_t _currArg1;
- /** The second argument of the current item (score threshold of WAND for example) */
- double _currArg2;
- /** The third argument of the current item (threshold boost factor of WAND for example) */
- double _currArg3;
+
+ /* extra arguments */
+ uint32_t _extraIntArg1;
+ uint32_t _extraIntArg2;
+ uint32_t _extraIntArg3;
+ double _extraDoubleArg4;
+ double _extraDoubleArg5;
/** The predicate query specification */
query::PredicateQueryTerm::UP _predicate_query_term;
/** The index name (field name) in the current item */
@@ -118,11 +119,12 @@ public:
uint32_t getArity() const { return _currArity; }
- uint32_t getArg1() const { return _currArg1; }
-
- double getArg2() const { return _currArg2; }
-
- double getArg3() const { return _currArg3; }
+ uint32_t getNearDistance() const { return _extraIntArg1; }
+ uint32_t getTargetNumHits() const { return _extraIntArg1; }
+ double getScoreThreshold() const { return _extraDoubleArg4; }
+ double getThresholdBoostFactor() const { return _extraDoubleArg5; }
+ bool getAllowApproximate() const { return (_extraIntArg2 != 0); }
+ uint32_t getExploreAdditionalHits() const { return _extraIntArg3; }
query::PredicateQueryTerm::UP getPredicateQueryTerm()
{ return std::move(_predicate_query_term); }
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index 24c458c7e32..3db6c8e68c8 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -41,7 +41,7 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
QueryConnector * qc = dynamic_cast<QueryConnector *> (qn.get());
NearQueryNode * nqn = dynamic_cast<NearQueryNode *> (qc);
if (nqn) {
- nqn->distance(queryRep.getArg1());
+ nqn->distance(queryRep.getNearDistance());
}
if ((type == ParseItem::ITEM_WEAK_AND) ||
(type == ParseItem::ITEM_WEIGHTED_SET) ||
diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
index 797defc39f5..8e6f2944ec9 100644
--- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
+++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
@@ -205,8 +205,11 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id,
template <class NodeTypes>
typename NodeTypes::NearestNeighborTerm *
create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits) {
- return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits);
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+{
+ return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits);
}
template <class NodeTypes>
@@ -317,9 +320,10 @@ public:
return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight));
}
typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits) {
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits) {
adjustWeight(weight);
- return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits));
+ return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
index 0bf923960b9..9289df7cbe9 100644
--- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
@@ -165,7 +165,8 @@ private:
void visit(NearestNeighborTerm &node) override {
replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(),
- node.getId(), node.getWeight(), node.get_target_num_hits()));
+ node.getId(), node.getWeight(), node.get_target_num_hits(),
+ node.get_allow_approximate(), node.get_explore_additional_hits()));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
index 8663bede4d6..4953f1a5b7c 100644
--- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h
+++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
@@ -105,8 +105,10 @@ struct SimpleRegExpTerm : RegExpTerm {
};
struct SimpleNearestNeighborTerm : NearestNeighborTerm {
SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits)
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits)
{}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
index 63acf532144..aafeaa46a22 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
@@ -263,6 +263,8 @@ class QueryNodeConverter : public QueryVisitor {
createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR);
appendString(node.get_query_tensor_name());
appendCompressedPositiveNumber(node.get_target_num_hits());
+ appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0);
+ appendCompressedPositiveNumber(node.get_explore_additional_hits());
}
public:
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
index 791da010720..65d6abeeaad 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
@@ -48,9 +48,6 @@ public:
private:
static Term * createQueryTerm(search::SimpleQueryStackDumpIterator &queryStack, QueryBuilder<NodeTypes> & builder, vespalib::stringref & pureTermView) {
uint32_t arity = queryStack.getArity();
- uint32_t arg1 = queryStack.getArg1();
- double arg2 = queryStack.getArg2();
- double arg3 = queryStack.getArg3();
ParseItem::ItemType type = queryStack.getType();
Node::UP node;
Term *t = 0;
@@ -68,16 +65,19 @@ private:
pureTermView = view;
} else if (type == ParseItem::ITEM_WEAK_AND) {
vespalib::stringref view = queryStack.getIndexName();
- builder.addWeakAnd(arity, arg1, view);
+ uint32_t targetNumHits = queryStack.getTargetNumHits();
+ builder.addWeakAnd(arity, targetNumHits, view);
pureTermView = view;
} else if (type == ParseItem::ITEM_EQUIV) {
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
builder.addEquiv(arity, id, weight);
} else if (type == ParseItem::ITEM_NEAR) {
- builder.addNear(arity, arg1);
+ uint32_t nearDistance = queryStack.getNearDistance();
+ builder.addNear(arity, nearDistance);
} else if (type == ParseItem::ITEM_ONEAR) {
- builder.addONear(arity, arg1);
+ uint32_t nearDistance = queryStack.getNearDistance();
+ builder.addONear(arity, nearDistance);
} else if (type == ParseItem::ITEM_PHRASE) {
vespalib::stringref view = queryStack.getIndexName();
int32_t id = queryStack.getUniqueId();
@@ -104,17 +104,24 @@ private:
vespalib::stringref view = queryStack.getIndexName();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
- t = &builder.addWandTerm(arity, view, id, weight, arg1, arg2, arg3);
+ uint32_t targetNumHits = queryStack.getTargetNumHits();
+ double scoreThreshold = queryStack.getScoreThreshold();
+ double thresholdBoostFactor = queryStack.getThresholdBoostFactor();
+ t = &builder.addWandTerm(arity, view, id, weight,
+ targetNumHits, scoreThreshold, thresholdBoostFactor);
pureTermView = vespalib::stringref();
} else if (type == ParseItem::ITEM_NOT) {
builder.addAndNot(arity);
} else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) {
vespalib::stringref query_tensor_name = queryStack.getTerm();
vespalib::stringref field_name = queryStack.getIndexName();
- uint32_t target_num_hits = queryStack.getArg1();
+ uint32_t target_num_hits = queryStack.getTargetNumHits();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
- builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits);
+ bool allow_approximate = queryStack.getAllowApproximate();
+ uint32_t explore_additional_hits = queryStack.getExploreAdditionalHits();
+ builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits);
} else {
vespalib::stringref term = queryStack.getTerm();
vespalib::stringref view = queryStack.getIndexName();
diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
index a82b1e14d76..9af424716fb 100644
--- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h
+++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
@@ -128,17 +128,24 @@ class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode>
private:
vespalib::string _query_tensor_name;
uint32_t _target_num_hits;
+ bool _allow_approximate;
+ uint32_t _explore_additional_hits;
public:
NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits)
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
: QueryNodeMixinType(field_name, id, weight),
_query_tensor_name(query_tensor_name),
- _target_num_hits(target_num_hits)
+ _target_num_hits(target_num_hits),
+ _allow_approximate(allow_approximate),
+ _explore_additional_hits(explore_additional_hits)
{}
virtual ~NearestNeighborTerm() {}
const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
+ bool get_allow_approximate() const { return _allow_approximate; }
+ uint32_t get_explore_additional_hits() const { return _explore_additional_hits; }
};
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index d4aa2aaa1d7..d3b2925e075 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -13,11 +13,13 @@ namespace search::queryeval {
NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
- uint32_t target_num_hits)
+ uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits)
: ComplexLeafBlueprint(field),
_attr_tensor(attr_tensor),
_query_tensor(std::move(query_tensor)),
_target_num_hits(target_num_hits),
+ _approximate(approximate),
+ _explore_additional_hits(explore_additional_hits),
_distance_heap(target_num_hits),
_found_hits()
{
@@ -34,15 +36,14 @@ void
NearestNeighborBlueprint::perform_top_k()
{
auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (nns_index) {
+ if (_approximate && nns_index) {
auto lhs_type = _query_tensor->fast_type();
auto rhs_type = _attr_tensor.getTensorType();
// XXX deal with different cell types later
if (lhs_type == rhs_type) {
auto lhs = _query_tensor->cellsRef();
uint32_t k = _target_num_hits;
- uint32_t explore_k = k + 100; // XXX hardcoded for now
- _found_hits = nns_index->find_top_k(k, lhs, explore_k);
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits);
}
}
}
@@ -73,6 +74,8 @@ NearestNeighborBlueprint::visitMembers(vespalib::ObjectVisitor& visitor) const
visitor.visitString("attribute_tensor", _attr_tensor.getTensorType().to_spec());
visitor.visitString("query_tensor", _query_tensor->type().to_spec());
visitor.visitInt("target_num_hits", _target_num_hits);
+ visitor.visitBool("approximate", _approximate);
+ visitor.visitInt("explore_additional_hits", _explore_additional_hits);
}
bool
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index ab4413c487a..39165b066be 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -21,6 +21,8 @@ private:
const tensor::DenseTensorAttribute& _attr_tensor;
std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor;
uint32_t _target_num_hits;
+ bool _approximate;
+ uint32_t _explore_additional_hits;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
@@ -29,7 +31,7 @@ public:
NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
- uint32_t target_num_hits);
+ uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits);
NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete;
NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete;
~NearestNeighborBlueprint();