diff options
author | Jon Bratseth <bratseth@oath.com> | 2021-09-13 23:21:16 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-13 23:21:16 +0200 |
commit | 407296d4091a5cfd41d0a51f1f1a518cb35fc373 (patch) | |
tree | 3d44d3bb51b455b4eed27c277a52e79effd8c19b | |
parent | 1f5a503f7a3ec8dd54e9accbc34c49a5c84bed51 (diff) | |
parent | 43b76205b8622f07069b96bd4d061a219154c705 (diff) |
Merge pull request #19021 from yehzu/new_syntax
add supports of geoLocation and nearestNeighbor operators to client library
5 files changed, 166 insertions, 0 deletions
diff --git a/client/src/main/java/ai/vespa/client/dsl/Annotation.java b/client/src/main/java/ai/vespa/client/dsl/Annotation.java index 906f2abcca0..1949bc7d3f9 100644 --- a/client/src/main/java/ai/vespa/client/dsl/Annotation.java +++ b/client/src/main/java/ai/vespa/client/dsl/Annotation.java @@ -20,6 +20,10 @@ public class Annotation { return this; } + public boolean contains(String key) { + return annotations.containsKey(key); + } + @Override public String toString() { return annotations == null || annotations.isEmpty() diff --git a/client/src/main/java/ai/vespa/client/dsl/GeoLocation.java b/client/src/main/java/ai/vespa/client/dsl/GeoLocation.java new file mode 100644 index 00000000000..c0d8fabc42f --- /dev/null +++ b/client/src/main/java/ai/vespa/client/dsl/GeoLocation.java @@ -0,0 +1,44 @@ +package ai.vespa.client.dsl; + +import org.apache.commons.text.StringEscapeUtils; + +public class GeoLocation extends QueryChain { + + private String fieldName; + private Double longitude; + private Double latitude; + private String radius; + + public GeoLocation(String fieldName, Double longitude, Double latitude, String radius) { + this.fieldName = fieldName; + this.longitude = longitude; + this.latitude = latitude; + this.radius = radius; + this.nonEmpty = true; + } + + @Override + boolean hasPositiveSearchField(String fieldName) { + return this.fieldName.equals(fieldName); + } + + @Override + boolean hasPositiveSearchField(String fieldName, Object value) { + return false; + } + + @Override + boolean hasNegativeSearchField(String fieldName) { + return false; + } + + @Override + boolean hasNegativeSearchField(String fieldName, Object value) { + return false; + } + + @Override + public String toString() { + return Text.format("geoLocation(%s, %f, %f, \"%s\")", fieldName, longitude, latitude, StringEscapeUtils.escapeJava(radius)); + } +} diff --git a/client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java b/client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java new file mode 100644 index 00000000000..6c95d2b6fd7 --- /dev/null +++ b/client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java @@ -0,0 +1,52 @@ +package ai.vespa.client.dsl; + +import java.util.stream.Collectors; + +public class NearestNeighbor extends QueryChain { + + private Annotation annotation; + private String docVectorName; + private String queryVectorName; + + + public NearestNeighbor(String docVectorName, String queryVectorName) { + this.docVectorName = docVectorName; + this.queryVectorName = queryVectorName; + this.nonEmpty = true; + } + + NearestNeighbor annotate(Annotation annotation) { + this.annotation = annotation; + return this; + } + + @Override + boolean hasPositiveSearchField(String fieldName) { + return this.docVectorName.equals(fieldName); + } + + @Override + boolean hasPositiveSearchField(String fieldName, Object value) { + return this.docVectorName.equals(fieldName) && queryVectorName.equals(value); + } + + @Override + boolean hasNegativeSearchField(String fieldName) { + return false; + } + + @Override + boolean hasNegativeSearchField(String fieldName, Object value) { + return false; + } + + @Override + public String toString() { + boolean hasAnnotation = A.hasAnnotation(annotation); + if (!hasAnnotation || !annotation.contains("targetHits")) { + throw new IllegalArgumentException("must specify target hits in nearest neighbor query"); + } + String s = Text.format("nearestNeighbor(%s, %s)", docVectorName, queryVectorName); + return Text.format("([%s]%s)", annotation, s); + } +} diff --git a/client/src/main/java/ai/vespa/client/dsl/Q.java b/client/src/main/java/ai/vespa/client/dsl/Q.java index 7637f76f095..f15ffed1ea9 100644 --- a/client/src/main/java/ai/vespa/client/dsl/Q.java +++ b/client/src/main/java/ai/vespa/client/dsl/Q.java @@ -169,4 +169,30 @@ public final class Q { public static WeakAnd weakand(String field, Query query) { return new WeakAnd(field, query); } + + /** + * GeoLocation geo locatoin + * https://docs.vespa.ai/en/reference/query-language-reference.html#geoLocation + * + * @param field the field + * @param longitude longitude + * @param latitude latitude + * @param radius a string specifying the radius and it's unit + * @return the geo-location query + */ + public static GeoLocation geoLocation(String field, Double longitude, Double latitude, String radius) { + return new GeoLocation(field, longitude, latitude, radius); + } + + /** + * NearestNeighbor nearest neighbor + * https://docs.vespa.ai/en/reference/query-language-reference.html#nearestneighbor + * + * @param docVectorName the vector name defined in the vespa schema + * @param queryVectorName the vector name in this query + * @return the nearest neighbor query + */ + public static NearestNeighbor nearestNeighbor(String docVectorName, String queryVectorName) { + return new NearestNeighbor(docVectorName, queryVectorName); + } } diff --git a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy index d1560937fef..1bada4e8f59 100644 --- a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy +++ b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy @@ -244,6 +244,46 @@ class QTest extends Specification { q == """yql=select * from sd1 where weakAnd(f1, f1 contains "v1", f2 contains "v2") and ([{"scoreThreshold":0.13}]weakAnd(f3, f1 contains "v1", f2 contains "v2"));""" } + def "geo location"() { + given: + def q = Q.select("*") + .from("sd1") + .where("a").contains("b").and(Q.geoLocation("taiwan", 25.105497, 121.597366, "200km")) + .semicolon() + .build() + + expect: + q == """yql=select * from sd1 where a contains "b" and geoLocation(taiwan, 25.105497, 121.597366, "200km");""" + } + + def "nearest neighbor query"() { + when: + def q = Q.select("*") + .from("sd1") + .where("a").contains("b") + .and(Q.nearestNeighbor("vec1", "vec2") + .annotate(A.a("targetHits", 10, "approximate", false)) + ) + .semicolon() + .build() + + then: + q == """yql=select * from sd1 where a contains "b" and ([{"approximate":false,"targetHits":10}]nearestNeighbor(vec1, vec2));""" + } + + def "invalid nearest neighbor should throws an exception (targetHits annotation is required)"() { + when: + def q = Q.select("*") + .from("sd1") + .where("a").contains("b").and(Q.nearestNeighbor("vec1", "vec2")) + .semicolon() + .build() + + then: + thrown(IllegalArgumentException) + } + + def "rank with only query"() { given: def q = Q.select("*") |