aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-03-02 09:24:02 +0000
committerArne Juul <arnej@verizonmedia.com>2020-03-02 11:44:22 +0000
commitfeeb478f356b0c2d6c3b7e0d80ef15620dd019b1 (patch)
treee097f5e897bed98452ff9bb8a26fb37bab59c5c5
parent25707b0248f895e17058c09a782e1c88914a542f (diff)
extend NearestNeighborItem
-rw-r--r--container-search/abi-spec.json4
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java18
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java5
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/YqlParser.java10
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java30
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java6
-rw-r--r--searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp2
-rw-r--r--searchlib/src/tests/query/query_visitor_test.cpp2
-rw-r--r--searchlib/src/tests/query/querybuilder_test.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/querybuilder.h12
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/queryreplicator.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/simplequery.h6
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h5
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/termnodes.h11
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h4
19 files changed, 114 insertions, 33 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 82d3223c8fe..51fee99a743 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -858,8 +858,12 @@
"public void <init>(java.lang.String, java.lang.String)",
"public int getTargetNumHits()",
"public java.lang.String getIndexName()",
+ "public int getHnswExploreAdditionalHits()",
+ "public boolean getAllowApproximate()",
"public java.lang.String getQueryTensorName()",
"public void setTargetNumHits(int)",
+ "public void setHnswExploreAdditionalHits(int)",
+ "public void setAllowApproximate(boolean)",
"public void setIndexName(java.lang.String)",
"public com.yahoo.prelude.query.Item$ItemType getItemType()",
"public java.lang.String getName()",
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
index 35b87ec0190..836107138d0 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
@@ -20,6 +20,8 @@ import java.nio.ByteBuffer;
public class NearestNeighborItem extends SimpleTaggableItem {
private int targetNumHits = 0;
+ private int hnswExploreAdditionalHits = 0;
+ private boolean approximate = true;
private String field;
private String queryTensorName;
@@ -34,12 +36,24 @@ public class NearestNeighborItem extends SimpleTaggableItem {
/** Returns the field name */
public String getIndexName() { return field; }
+ /** Returns the number of extra hits to explore in HNSW algorithm */
+ public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; }
+
+ /** Returns whether approximation is allowed */
+ public boolean getAllowApproximate() { return approximate; }
+
/** Returns the name of the query tensor */
public String getQueryTensorName() { return queryTensorName; }
/** Set the K number of hits to produce */
public void setTargetNumHits(int target) { this.targetNumHits = target; }
+ /** Set the number of extra hits to explore in HNSW algorithm */
+ public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; }
+
+ /** Set whether approximation is allowed */
+ public void setAllowApproximate(boolean value) { this.approximate = value; }
+
@Override
public void setIndexName(String index) { this.field = index; }
@@ -58,6 +72,8 @@ public class NearestNeighborItem extends SimpleTaggableItem {
putString(field, buffer);
putString(queryTensorName, buffer);
IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer);
+ IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer);
+ IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer);
return 1; // number of encoded stack dump items
}
@@ -65,6 +81,8 @@ public class NearestNeighborItem extends SimpleTaggableItem {
protected void appendBodyString(StringBuilder buffer) {
buffer.append("{field=").append(field);
buffer.append(",queryTensorName=").append(queryTensorName);
+ buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
+ buffer.append(",approximate=").append(String.valueOf(approximate));
buffer.append(",targetNumHits=").append(targetNumHits).append("}");
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
index 6eef1252998..38b207cc7eb 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
@@ -702,6 +702,11 @@ public class VespaSerializer {
comma(destination, initLen);
int targetNumHits = item.getTargetNumHits();
destination.append("\"targetNumHits\": ").append(targetNumHits);
+ int explore = item.getHnswExploreAdditionalHits();
+ if (explore != 0) {
+ destination.append(",\"hnsw.exploreAdditionalHits\": ").append(explore);
+ }
+ destination.append(",\"approximate\": ").append(item.getAllowApproximate());
destination.append("}]");
destination.append(NEAREST_NEIGHBOR).append('(');
destination.append(item.getIndexName()).append(", ");
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
index 8d013e501e8..f4560806dd2 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
@@ -137,6 +137,7 @@ public class YqlParser implements Parser {
static final String ACCENT_DROP = "accentDrop";
static final String ALTERNATIVES = "alternatives";
static final String AND_SEGMENTING = "andSegmenting";
+ static final String APPROXIMATE = "approximate";
static final String BOUNDS = "bounds";
static final String BOUNDS_LEFT_OPEN = "leftOpen";
static final String BOUNDS_OPEN = "open";
@@ -149,6 +150,7 @@ public class YqlParser implements Parser {
static final String EQUIV = "equiv";
static final String FILTER = "filter";
static final String HIT_LIMIT = "hitLimit";
+ static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits";
static final String IMPLICIT_TRANSFORMS = "implicitTransforms";
static final String LABEL = "label";
static final String NEAR = "near";
@@ -421,6 +423,14 @@ public class YqlParser implements Parser {
if (targetNumHits != null) {
item.setTargetNumHits(targetNumHits);
}
+ Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS,
+ Integer.class, null, "number of extra hits to explore for HNSW algorithm");
+ if (hnswExploreAdditionalHits != null) {
+ item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits);
+ }
+ Boolean allowApproximate = getAnnotation(ast, APPROXIMATE,
+ Boolean.class, Boolean.TRUE, "allow approximate nearest neighbor search");
+ item.setAllowApproximate(allowApproximate);
String label = getAnnotation(ast, LABEL, String.class, null, "item label");
if (label != null) {
item.setLabel(label);
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 0cbf3a6f92c..c6233ffa50b 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -139,12 +139,24 @@ public class ValidateNearestNeighborTestCase {
assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError());
}
+ static String desc(String field, String qt, int th, String errmsg) {
+ StringBuilder r = new StringBuilder();
+ r.append("NEAREST_NEIGHBOR {");
+ r.append("field=").append(field);
+ r.append(",queryTensorName=").append(qt);
+ r.append(",hnsw.exploreAdditionalHits=0");
+ r.append(",approximate=true");
+ r.append(",targetNumHits=").append(th);
+ r.append("} ").append(errmsg);
+ return r.toString();
+ }
+
@Test
public void testMissingTargetNumHits() {
String q = "select * from sources * where nearestNeighbor(dvector,qvector);";
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=0} has invalid targetNumHits", r);
+ assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits"), r);
}
@Test
@@ -152,16 +164,16 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("dvector", "foo");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r);
+ assertErrMsg(desc("dvector", "foo", 1, "query tensor not found"), r);
}
@Test
public void testQueryTensorWrongType() {
String q = makeQuery("dvector", "qvector");
Result r = doSearch(searcher, q, "tensor string");
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r);
+ assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: class java.lang.String"), r);
r = doSearch(searcher, q, null);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: null", r);
+ assertErrMsg(desc("dvector", "qvector", 1, "query tensor should be a tensor, was: null"), r);
}
@Test
@@ -169,7 +181,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("dvector", "qvector");
Tensor t = makeTensor(tt_dense_dvector_2, 2);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r);
+ assertErrMsg(desc("dvector", "qvector", 1, "field type tensor(x[3]) does not match query tensor type tensor(x[2])"), r);
}
@Test
@@ -177,7 +189,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("foo", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r);
+ assertErrMsg(desc("foo", "qvector", 1, "field is not an attribute"), r);
}
@Test
@@ -185,7 +197,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("simple", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r);
+ assertErrMsg(desc("simple", "qvector", 1, "field is not a tensor"), r);
}
@Test
@@ -193,7 +205,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("sparse", "qvector");
Tensor t = makeTensor(tt_sparse_vector_x);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r);
+ assertErrMsg(desc("sparse", "qvector", 1, "tensor type tensor(x{}) is not a dense vector"), r);
}
@Test
@@ -201,7 +213,7 @@ public class ValidateNearestNeighborTestCase {
String q = makeQuery("matrix", "qvector");
Tensor t = makeMatrix(tt_dense_matrix_xy);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r);
+ assertErrMsg(desc("matrix", "qvector", 1, "tensor type tensor(x[3],y[1]) is not a dense vector"), r);
}
private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) {
diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
index 5eb1f3e3de1..e43dbd4e266 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
@@ -550,9 +550,11 @@ public class YqlParserTestCase {
@Test
public void testNearestNeighbor() {
assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=0}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=0}");
assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,targetNumHits=37}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=37}");
+ assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetNumHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetNumHits=3}");
}
@Test
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
index 24be21f65ec..47728c9785c 100644
--- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
+++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
@@ -295,7 +295,7 @@ public:
request_ctx.set_query_tensor("query_tensor", tensor_spec);
}
Blueprint::UP create_blueprint() {
- query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7);
+ query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33);
return source.createBlueprint(request_ctx, FieldSpec(attr_name, 0, 0), term);
}
};
diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp
index 39e381c0942..edbc29be784 100644
--- a/searchlib/src/tests/query/query_visitor_test.cpp
+++ b/searchlib/src/tests/query/query_visitor_test.cpp
@@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() {
checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0)));
checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0)));
checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0)));
- checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123));
+ checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321));
}
} // namespace
diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp
index 7f496b3493c..8560cb0e091 100644
--- a/searchlib/src/tests/query/querybuilder_test.cpp
+++ b/searchlib/src/tests/query/querybuilder_test.cpp
@@ -111,7 +111,7 @@ Node::UP createQueryTree() {
builder.addStringTerm(str[5], view[5], id[5], weight[6]);
builder.addStringTerm(str[6], view[6], id[6], weight[7]);
}
- builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7);
+ builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33);
}
Node::UP node = builder.build();
ASSERT_TRUE(node.get());
@@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm {
};
struct MyNearestNeighborTerm : NearestNeighborTerm {
MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t i, Weight w, uint32_t target_num_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits)
+ int32_t i, Weight w, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits)
{}
};
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index 8595b0eff7f..9af05059bef 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -646,7 +646,9 @@ public:
query_tensor.release();
setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *dense_attr_tensor,
std::move(dense_query_tensor_up),
- n.get_target_num_hits()));
+ n.get_target_num_hits(),
+ n.get_allow_approximate(),
+ n.get_explore_additional_hits()));
}
};
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
index 70a3097ae05..f0fb53a5640 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp
@@ -274,7 +274,9 @@ SimpleQueryStackDumpIterator::next()
try {
_curr_index_name = read_stringref(p);
_curr_term = read_stringref(p); // query_tensor_name
- _currArg1 = readCompressedPositiveInt(p); // target_num_hits;
+ _currArg1 = readCompressedPositiveInt(p); // target_num_hits
+ _currArg2 = readCompressedPositiveInt(p); // allow_approximate
+ _currArg3 = readCompressedPositiveInt(p); // explore_additional_hits
_currArity = 0;
} catch (...) {
return false;
diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
index 797defc39f5..8e6f2944ec9 100644
--- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
+++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h
@@ -205,8 +205,11 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id,
template <class NodeTypes>
typename NodeTypes::NearestNeighborTerm *
create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits) {
- return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits);
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+{
+ return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits);
}
template <class NodeTypes>
@@ -317,9 +320,10 @@ public:
return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight));
}
typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits) {
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits) {
adjustWeight(weight);
- return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits));
+ return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
index 0bf923960b9..9289df7cbe9 100644
--- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h
@@ -165,7 +165,8 @@ private:
void visit(NearestNeighborTerm &node) override {
replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(),
- node.getId(), node.getWeight(), node.get_target_num_hits()));
+ node.getId(), node.getWeight(), node.get_target_num_hits(),
+ node.get_allow_approximate(), node.get_explore_additional_hits()));
}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
index 8663bede4d6..4953f1a5b7c 100644
--- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h
+++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h
@@ -105,8 +105,10 @@ struct SimpleRegExpTerm : RegExpTerm {
};
struct SimpleNearestNeighborTerm : NearestNeighborTerm {
SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits)
- : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits)
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
+ : NearestNeighborTerm(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits)
{}
};
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
index 63acf532144..aafeaa46a22 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp
@@ -263,6 +263,8 @@ class QueryNodeConverter : public QueryVisitor {
createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR);
appendString(node.get_query_tensor_name());
appendCompressedPositiveNumber(node.get_target_num_hits());
+ appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0);
+ appendCompressedPositiveNumber(node.get_explore_additional_hits());
}
public:
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
index 791da010720..a57c24584cc 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
@@ -114,7 +114,10 @@ private:
uint32_t target_num_hits = queryStack.getArg1();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
- builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits);
+ uint32_t allow_approximate = (queryStack.getArg2() != 0);
+ uint32_t explore_additional_hits = queryStack.getArg3();
+ builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight,
+ target_num_hits, allow_approximate, explore_additional_hits);
} else {
vespalib::stringref term = queryStack.getTerm();
vespalib::stringref view = queryStack.getIndexName();
diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
index a82b1e14d76..9af424716fb 100644
--- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h
+++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h
@@ -128,17 +128,24 @@ class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode>
private:
vespalib::string _query_tensor_name;
uint32_t _target_num_hits;
+ bool _allow_approximate;
+ uint32_t _explore_additional_hits;
public:
NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name,
- int32_t id, Weight weight, uint32_t target_num_hits)
+ int32_t id, Weight weight, uint32_t target_num_hits,
+ bool allow_approximate, uint32_t explore_additional_hits)
: QueryNodeMixinType(field_name, id, weight),
_query_tensor_name(query_tensor_name),
- _target_num_hits(target_num_hits)
+ _target_num_hits(target_num_hits),
+ _allow_approximate(allow_approximate),
+ _explore_additional_hits(explore_additional_hits)
{}
virtual ~NearestNeighborTerm() {}
const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
+ bool get_allow_approximate() const { return _allow_approximate; }
+ uint32_t get_explore_additional_hits() const { return _explore_additional_hits; }
};
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index d4aa2aaa1d7..c160f8d5485 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -13,17 +13,22 @@ namespace search::queryeval {
NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
- uint32_t target_num_hits)
+ uint32_t target_num_hits, bool approximate, uint32_t explore_k)
: ComplexLeafBlueprint(field),
_attr_tensor(attr_tensor),
_query_tensor(std::move(query_tensor)),
_target_num_hits(target_num_hits),
+ _approximate(approximate),
+ _explore_k(explore_k),
_distance_heap(target_num_hits),
_found_hits()
{
uint32_t est_hits = _attr_tensor.getNumDocs();
if (_attr_tensor.nearest_neighbor_index()) {
est_hits = std::min(target_num_hits, est_hits);
+ if (_explore_k == 0) {
+ _explore_k = 100;
+ }
}
setEstimate(HitEstimate(est_hits, false));
}
@@ -34,15 +39,14 @@ void
NearestNeighborBlueprint::perform_top_k()
{
auto nns_index = _attr_tensor.nearest_neighbor_index();
- if (nns_index) {
+ if (_approximate && nns_index) {
auto lhs_type = _query_tensor->fast_type();
auto rhs_type = _attr_tensor.getTensorType();
// XXX deal with different cell types later
if (lhs_type == rhs_type) {
auto lhs = _query_tensor->cellsRef();
uint32_t k = _target_num_hits;
- uint32_t explore_k = k + 100; // XXX hardcoded for now
- _found_hits = nns_index->find_top_k(k, lhs, explore_k);
+ _found_hits = nns_index->find_top_k(k, lhs, k + _explore_k);
}
}
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index ab4413c487a..a782633ccc3 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -21,6 +21,8 @@ private:
const tensor::DenseTensorAttribute& _attr_tensor;
std::unique_ptr<vespalib::tensor::DenseTensorView> _query_tensor;
uint32_t _target_num_hits;
+ bool _approximate;
+ uint32_t _explore_k;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
@@ -29,7 +31,7 @@ public:
NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
- uint32_t target_num_hits);
+ uint32_t target_num_hits, bool approximate, uint32_t explore_k);
NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete;
NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete;
~NearestNeighborBlueprint();