diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-01-09 14:52:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-09 14:52:02 +0100 |
commit | 3504451650259601f0f1c2dfdf8c406253e02aab (patch) | |
tree | 7fea3bbee4370b6e776a6f1971d0cecc8acad484 /container-search | |
parent | c584a25c487525b4325df7760e2d685e2832de4b (diff) | |
parent | feef42f47a71d1bc9028d2151d6725e4917db1bd (diff) |
Merge pull request #15950 from vespa-engine/arnej/add-distance-threshold
Arnej/add distance threshold
Diffstat (limited to 'container-search')
9 files changed, 51 insertions, 8 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 6f48ae5b41a..e2285a10672 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -884,10 +884,12 @@ "public void <init>(java.lang.String, java.lang.String)", "public int getTargetNumHits()", "public java.lang.String getIndexName()", + "public double getDistanceThreshold()", "public int getHnswExploreAdditionalHits()", "public boolean getAllowApproximate()", "public java.lang.String getQueryTensorName()", "public void setTargetNumHits(int)", + "public void setDistanceThreshold(double)", "public void setHnswExploreAdditionalHits(int)", "public void setAllowApproximate(boolean)", "public void setIndexName(java.lang.String)", diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java index e237463582f..bb95cbad178 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java @@ -22,6 +22,7 @@ public class NearestNeighborItem extends SimpleTaggableItem { private int targetNumHits = 0; private int hnswExploreAdditionalHits = 0; + private double distanceThreshold = Double.POSITIVE_INFINITY; private boolean approximate = true; private String field; private final String queryTensorName; @@ -37,6 +38,9 @@ public class NearestNeighborItem extends SimpleTaggableItem { /** Returns the field name */ public String getIndexName() { return field; } + /** Returns the distance threshold for nearest-neighbor hits */ + public double getDistanceThreshold () { return this.distanceThreshold ; } + /** Returns the number of extra hits to explore in HNSW algorithm */ public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; } @@ -49,6 +53,9 @@ public class NearestNeighborItem extends SimpleTaggableItem { /** Set the K number of hits to produce */ public void setTargetNumHits(int target) { this.targetNumHits = target; } + /** Set the distance threshold for nearest-neighbor hits */ + public void setDistanceThreshold(double threshold) { this.distanceThreshold = threshold; } + /** Set the number of extra hits to explore in HNSW algorithm */ public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; } @@ -72,9 +79,18 @@ public class NearestNeighborItem extends SimpleTaggableItem { super.encodeThis(buffer); putString(field, buffer); putString(queryTensorName, buffer); + int approxNum = (approximate ? 1 : 0); + // should become always-true later: + boolean sendDistanceThreshold = (distanceThreshold < Double.POSITIVE_INFINITY); + if (sendDistanceThreshold) { + approxNum |= 0x40; // temporary flag bit + } IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer); - IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer); + IntegerCompressor.putCompressedPositiveNumber(approxNum, buffer); IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer); + if (sendDistanceThreshold) { + buffer.putDouble(distanceThreshold); + } return 1; // number of encoded stack dump items } @@ -83,6 +99,7 @@ public class NearestNeighborItem extends SimpleTaggableItem { buffer.append("{field=").append(field); buffer.append(",queryTensorName=").append(queryTensorName); buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits); + buffer.append(",distanceThreshold=").append(distanceThreshold); buffer.append(",approximate=").append(approximate); buffer.append(",targetHits=").append(targetNumHits).append("}"); } @@ -93,6 +110,7 @@ public class NearestNeighborItem extends SimpleTaggableItem { discloser.addProperty("field", field); discloser.addProperty("queryTensorName", queryTensorName); discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits); + discloser.addProperty("distanceThreshold", distanceThreshold); discloser.addProperty("approximate", approximate); discloser.addProperty("targetHits", targetNumHits); } diff --git a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java index 30d741f465c..5f1f26b77e9 100644 --- a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java +++ b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java @@ -78,6 +78,7 @@ import static com.yahoo.search.yql.YqlParser.CONNECTIVITY; import static com.yahoo.search.yql.YqlParser.DEFAULT_TARGET_NUM_HITS; import static com.yahoo.search.yql.YqlParser.DESCENDING_HITS_ORDER; import static com.yahoo.search.yql.YqlParser.DISTANCE; +import static com.yahoo.search.yql.YqlParser.DISTANCE_THRESHOLD; import static com.yahoo.search.yql.YqlParser.DOT_PRODUCT; import static com.yahoo.search.yql.YqlParser.EQUIV; import static com.yahoo.search.yql.YqlParser.FILTER; @@ -481,6 +482,10 @@ public class SelectParser implements Parser { if (TARGET_NUM_HITS.equals(annotation_name)){ item.setTargetNumHits((int)(annotation_value.asDouble())); } + if (DISTANCE_THRESHOLD.equals(annotation_name)) { + double distanceThreshold = annotation_value.asDouble(); + item.setDistanceThreshold(distanceThreshold); + } if (HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) { int hnswExploreAdditionalHits = (int)(annotation_value.asDouble()); item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits); diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java index a38e48fd89d..f4a36ea51ab 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java +++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java @@ -723,7 +723,13 @@ public class VespaSerializer { destination.append(leafAnnotations(item)); comma(destination, initLen); int targetNumHits = item.getTargetNumHits(); - annotationKey(destination, "targetNumHits").append(targetNumHits); + annotationKey(destination, YqlParser.TARGET_NUM_HITS).append(targetNumHits); + double distanceThreshold = item.getDistanceThreshold(); + if (distanceThreshold < Double.POSITIVE_INFINITY) { + comma(destination, initLen); + String key = YqlParser.DISTANCE_THRESHOLD; + annotationKey(destination, key).append(distanceThreshold); + } int explore = item.getHnswExploreAdditionalHits(); if (explore != 0) { comma(destination, initLen); diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index 739aae0e277..f37aeb4c1e0 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -156,6 +156,7 @@ public class YqlParser implements Parser { public static final String FILTER = "filter"; public static final String GEO_LOCATION = "geoLocation"; public static final String HIT_LIMIT = "hitLimit"; + public static final String DISTANCE_THRESHOLD = "distanceThreshold"; public static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits"; public static final String IMPLICIT_TRANSFORMS = "implicitTransforms"; public static final String LABEL = "label"; @@ -459,6 +460,11 @@ public class YqlParser implements Parser { if (targetNumHits != null) { item.setTargetNumHits(targetNumHits); } + Double distanceThreshold = getAnnotation(ast, DISTANCE_THRESHOLD, + Double.class, null, "maximum distance allowed from query point"); + if (distanceThreshold != null) { + item.setDistanceThreshold(distanceThreshold); + } Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS, Integer.class, null, "number of extra hits to explore for HNSW algorithm"); if (hnswExploreAdditionalHits != null) { diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index c49603737a6..72956b5b6eb 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -138,6 +138,7 @@ public class ValidateNearestNeighborTestCase { r.append("field=").append(field); r.append(",queryTensorName=").append(qt); r.append(",hnsw.exploreAdditionalHits=0"); + r.append(",distanceThreshold=Infinity"); r.append(",approximate=true"); r.append(",targetHits=").append(th); r.append("} ").append(errmsg); diff --git a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java index 63840b0f5ec..a44a9f25b62 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java @@ -139,6 +139,7 @@ public class VespaSerializerTestCase { parseAndConfirm("[{\"targetNumHits\": 1, \"hnsw.exploreAdditionalHits\": 76}]nearestNeighbor(semantic_embedding, my_property)"); parseAndConfirm("[{\"targetNumHits\": 2, \"approximate\": false}]nearestNeighbor(semantic_embedding, my_property)"); parseAndConfirm("[{\"targetNumHits\": 3, \"hnsw.exploreAdditionalHits\": 67, \"approximate\": false}]nearestNeighbor(semantic_embedding, my_property)"); + parseAndConfirm("[{\"targetNumHits\": 4, \"distanceThreshold\": 100100.25}]nearestNeighbor(semantic_embedding, my_property)"); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java index f5e22e30f45..2d88351f9ea 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java @@ -568,11 +568,15 @@ public class YqlParserTestCase { @Test public void testNearestNeighbor() { assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=0}"); assertParse("select foo from bar where [{\"targetHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=37}"); assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetHits=3}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,distanceThreshold=Infinity,approximate=false,targetHits=3}"); + + assertParse("select foo from bar where [{\"targetHits\": 7, \"distanceThreshold\": 100100.25}]nearestNeighbor(semantic_embedding, my_vector);", + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=100100.25,approximate=true,targetHits=7}"); + } @Test diff --git a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java index c802eb18c0f..3239a97a094 100644 --- a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java +++ b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java @@ -537,10 +537,10 @@ public class SelectTestCase { @Test public void testNearestNeighbor() { assertParse("{ \"nearestNeighbor\": [ \"f1field\", \"q2prop\" ] }", - "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}"); + "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=0}"); - assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37} }}", - "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}"); + assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37, \"hnsw.exploreAdditionalHits\": 42, \"distanceThreshold\": 100100.25 } }}", + "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=42,distanceThreshold=100100.25,approximate=true,targetHits=37}"); } @Test |