summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-01-09 14:52:02 +0100
committerGitHub <noreply@github.com>2021-01-09 14:52:02 +0100
commit3504451650259601f0f1c2dfdf8c406253e02aab (patch)
tree7fea3bbee4370b6e776a6f1971d0cecc8acad484 /container-search
parentc584a25c487525b4325df7760e2d685e2832de4b (diff)
parentfeef42f47a71d1bc9028d2151d6725e4917db1bd (diff)
Merge pull request #15950 from vespa-engine/arnej/add-distance-threshold
Arnej/add distance threshold
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json2
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java20
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/SelectParser.java5
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java8
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/YqlParser.java6
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java1
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java1
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java10
-rw-r--r--container-search/src/test/java/com/yahoo/select/SelectTestCase.java6
9 files changed, 51 insertions, 8 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 6f48ae5b41a..e2285a10672 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -884,10 +884,12 @@
"public void <init>(java.lang.String, java.lang.String)",
"public int getTargetNumHits()",
"public java.lang.String getIndexName()",
+ "public double getDistanceThreshold()",
"public int getHnswExploreAdditionalHits()",
"public boolean getAllowApproximate()",
"public java.lang.String getQueryTensorName()",
"public void setTargetNumHits(int)",
+ "public void setDistanceThreshold(double)",
"public void setHnswExploreAdditionalHits(int)",
"public void setAllowApproximate(boolean)",
"public void setIndexName(java.lang.String)",
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
index e237463582f..bb95cbad178 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
@@ -22,6 +22,7 @@ public class NearestNeighborItem extends SimpleTaggableItem {
private int targetNumHits = 0;
private int hnswExploreAdditionalHits = 0;
+ private double distanceThreshold = Double.POSITIVE_INFINITY;
private boolean approximate = true;
private String field;
private final String queryTensorName;
@@ -37,6 +38,9 @@ public class NearestNeighborItem extends SimpleTaggableItem {
/** Returns the field name */
public String getIndexName() { return field; }
+ /** Returns the distance threshold for nearest-neighbor hits */
+ public double getDistanceThreshold () { return this.distanceThreshold ; }
+
/** Returns the number of extra hits to explore in HNSW algorithm */
public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; }
@@ -49,6 +53,9 @@ public class NearestNeighborItem extends SimpleTaggableItem {
/** Set the K number of hits to produce */
public void setTargetNumHits(int target) { this.targetNumHits = target; }
+ /** Set the distance threshold for nearest-neighbor hits */
+ public void setDistanceThreshold(double threshold) { this.distanceThreshold = threshold; }
+
/** Set the number of extra hits to explore in HNSW algorithm */
public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; }
@@ -72,9 +79,18 @@ public class NearestNeighborItem extends SimpleTaggableItem {
super.encodeThis(buffer);
putString(field, buffer);
putString(queryTensorName, buffer);
+ int approxNum = (approximate ? 1 : 0);
+ // should become always-true later:
+ boolean sendDistanceThreshold = (distanceThreshold < Double.POSITIVE_INFINITY);
+ if (sendDistanceThreshold) {
+ approxNum |= 0x40; // temporary flag bit
+ }
IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer);
- IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer);
+ IntegerCompressor.putCompressedPositiveNumber(approxNum, buffer);
IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer);
+ if (sendDistanceThreshold) {
+ buffer.putDouble(distanceThreshold);
+ }
return 1; // number of encoded stack dump items
}
@@ -83,6 +99,7 @@ public class NearestNeighborItem extends SimpleTaggableItem {
buffer.append("{field=").append(field);
buffer.append(",queryTensorName=").append(queryTensorName);
buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
+ buffer.append(",distanceThreshold=").append(distanceThreshold);
buffer.append(",approximate=").append(approximate);
buffer.append(",targetHits=").append(targetNumHits).append("}");
}
@@ -93,6 +110,7 @@ public class NearestNeighborItem extends SimpleTaggableItem {
discloser.addProperty("field", field);
discloser.addProperty("queryTensorName", queryTensorName);
discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits);
+ discloser.addProperty("distanceThreshold", distanceThreshold);
discloser.addProperty("approximate", approximate);
discloser.addProperty("targetHits", targetNumHits);
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java
index 30d741f465c..5f1f26b77e9 100644
--- a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java
+++ b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java
@@ -78,6 +78,7 @@ import static com.yahoo.search.yql.YqlParser.CONNECTIVITY;
import static com.yahoo.search.yql.YqlParser.DEFAULT_TARGET_NUM_HITS;
import static com.yahoo.search.yql.YqlParser.DESCENDING_HITS_ORDER;
import static com.yahoo.search.yql.YqlParser.DISTANCE;
+import static com.yahoo.search.yql.YqlParser.DISTANCE_THRESHOLD;
import static com.yahoo.search.yql.YqlParser.DOT_PRODUCT;
import static com.yahoo.search.yql.YqlParser.EQUIV;
import static com.yahoo.search.yql.YqlParser.FILTER;
@@ -481,6 +482,10 @@ public class SelectParser implements Parser {
if (TARGET_NUM_HITS.equals(annotation_name)){
item.setTargetNumHits((int)(annotation_value.asDouble()));
}
+ if (DISTANCE_THRESHOLD.equals(annotation_name)) {
+ double distanceThreshold = annotation_value.asDouble();
+ item.setDistanceThreshold(distanceThreshold);
+ }
if (HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) {
int hnswExploreAdditionalHits = (int)(annotation_value.asDouble());
item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits);
diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
index a38e48fd89d..f4a36ea51ab 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java
@@ -723,7 +723,13 @@ public class VespaSerializer {
destination.append(leafAnnotations(item));
comma(destination, initLen);
int targetNumHits = item.getTargetNumHits();
- annotationKey(destination, "targetNumHits").append(targetNumHits);
+ annotationKey(destination, YqlParser.TARGET_NUM_HITS).append(targetNumHits);
+ double distanceThreshold = item.getDistanceThreshold();
+ if (distanceThreshold < Double.POSITIVE_INFINITY) {
+ comma(destination, initLen);
+ String key = YqlParser.DISTANCE_THRESHOLD;
+ annotationKey(destination, key).append(distanceThreshold);
+ }
int explore = item.getHnswExploreAdditionalHits();
if (explore != 0) {
comma(destination, initLen);
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
index 739aae0e277..f37aeb4c1e0 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
@@ -156,6 +156,7 @@ public class YqlParser implements Parser {
public static final String FILTER = "filter";
public static final String GEO_LOCATION = "geoLocation";
public static final String HIT_LIMIT = "hitLimit";
+ public static final String DISTANCE_THRESHOLD = "distanceThreshold";
public static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits";
public static final String IMPLICIT_TRANSFORMS = "implicitTransforms";
public static final String LABEL = "label";
@@ -459,6 +460,11 @@ public class YqlParser implements Parser {
if (targetNumHits != null) {
item.setTargetNumHits(targetNumHits);
}
+ Double distanceThreshold = getAnnotation(ast, DISTANCE_THRESHOLD,
+ Double.class, null, "maximum distance allowed from query point");
+ if (distanceThreshold != null) {
+ item.setDistanceThreshold(distanceThreshold);
+ }
Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS,
Integer.class, null, "number of extra hits to explore for HNSW algorithm");
if (hnswExploreAdditionalHits != null) {
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index c49603737a6..72956b5b6eb 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -138,6 +138,7 @@ public class ValidateNearestNeighborTestCase {
r.append("field=").append(field);
r.append(",queryTensorName=").append(qt);
r.append(",hnsw.exploreAdditionalHits=0");
+ r.append(",distanceThreshold=Infinity");
r.append(",approximate=true");
r.append(",targetHits=").append(th);
r.append("} ").append(errmsg);
diff --git a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java
index 63840b0f5ec..a44a9f25b62 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java
@@ -139,6 +139,7 @@ public class VespaSerializerTestCase {
parseAndConfirm("[{\"targetNumHits\": 1, \"hnsw.exploreAdditionalHits\": 76}]nearestNeighbor(semantic_embedding, my_property)");
parseAndConfirm("[{\"targetNumHits\": 2, \"approximate\": false}]nearestNeighbor(semantic_embedding, my_property)");
parseAndConfirm("[{\"targetNumHits\": 3, \"hnsw.exploreAdditionalHits\": 67, \"approximate\": false}]nearestNeighbor(semantic_embedding, my_property)");
+ parseAndConfirm("[{\"targetNumHits\": 4, \"distanceThreshold\": 100100.25}]nearestNeighbor(semantic_embedding, my_property)");
}
@Test
diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
index f5e22e30f45..2d88351f9ea 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
@@ -568,11 +568,15 @@ public class YqlParserTestCase {
@Test
public void testNearestNeighbor() {
assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=0}");
assertParse("select foo from bar where [{\"targetHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=37}");
assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetHits=3}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,distanceThreshold=Infinity,approximate=false,targetHits=3}");
+
+ assertParse("select foo from bar where [{\"targetHits\": 7, \"distanceThreshold\": 100100.25}]nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=100100.25,approximate=true,targetHits=7}");
+
}
@Test
diff --git a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java
index c802eb18c0f..3239a97a094 100644
--- a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java
+++ b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java
@@ -537,10 +537,10 @@ public class SelectTestCase {
@Test
public void testNearestNeighbor() {
assertParse("{ \"nearestNeighbor\": [ \"f1field\", \"q2prop\" ] }",
- "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}");
+ "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=0}");
- assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37} }}",
- "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}");
+ assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37, \"hnsw.exploreAdditionalHits\": 42, \"distanceThreshold\": 100100.25 } }}",
+ "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=42,distanceThreshold=100100.25,approximate=true,targetHits=37}");
}
@Test