aboutsummaryrefslogtreecommitdiffstats
path: root/client/src
diff options
context:
space:
mode:
authoryehzu <yehzu2@gmail.com>2021-09-08 22:29:13 +0800
committeryehzu <yehzu2@gmail.com>2021-09-08 22:40:45 +0800
commit43b76205b8622f07069b96bd4d061a219154c705 (patch)
tree4037b183082acd7cbf53f74045980c5718031456 /client/src
parent022a5c2a87b22d77911e0a6eef02f872ddcc0b15 (diff)
feat: support nearestNeighbor operator
Diffstat (limited to 'client/src')
-rw-r--r--client/src/main/java/ai/vespa/client/dsl/Annotation.java4
-rw-r--r--client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java52
-rw-r--r--client/src/main/java/ai/vespa/client/dsl/Q.java12
-rw-r--r--client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy28
4 files changed, 96 insertions, 0 deletions
diff --git a/client/src/main/java/ai/vespa/client/dsl/Annotation.java b/client/src/main/java/ai/vespa/client/dsl/Annotation.java
index 906f2abcca0..1949bc7d3f9 100644
--- a/client/src/main/java/ai/vespa/client/dsl/Annotation.java
+++ b/client/src/main/java/ai/vespa/client/dsl/Annotation.java
@@ -20,6 +20,10 @@ public class Annotation {
return this;
}
+ public boolean contains(String key) {
+ return annotations.containsKey(key);
+ }
+
@Override
public String toString() {
return annotations == null || annotations.isEmpty()
diff --git a/client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java b/client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java
new file mode 100644
index 00000000000..6c95d2b6fd7
--- /dev/null
+++ b/client/src/main/java/ai/vespa/client/dsl/NearestNeighbor.java
@@ -0,0 +1,52 @@
+package ai.vespa.client.dsl;
+
+import java.util.stream.Collectors;
+
+public class NearestNeighbor extends QueryChain {
+
+ private Annotation annotation;
+ private String docVectorName;
+ private String queryVectorName;
+
+
+ public NearestNeighbor(String docVectorName, String queryVectorName) {
+ this.docVectorName = docVectorName;
+ this.queryVectorName = queryVectorName;
+ this.nonEmpty = true;
+ }
+
+ NearestNeighbor annotate(Annotation annotation) {
+ this.annotation = annotation;
+ return this;
+ }
+
+ @Override
+ boolean hasPositiveSearchField(String fieldName) {
+ return this.docVectorName.equals(fieldName);
+ }
+
+ @Override
+ boolean hasPositiveSearchField(String fieldName, Object value) {
+ return this.docVectorName.equals(fieldName) && queryVectorName.equals(value);
+ }
+
+ @Override
+ boolean hasNegativeSearchField(String fieldName) {
+ return false;
+ }
+
+ @Override
+ boolean hasNegativeSearchField(String fieldName, Object value) {
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ boolean hasAnnotation = A.hasAnnotation(annotation);
+ if (!hasAnnotation || !annotation.contains("targetHits")) {
+ throw new IllegalArgumentException("must specify target hits in nearest neighbor query");
+ }
+ String s = Text.format("nearestNeighbor(%s, %s)", docVectorName, queryVectorName);
+ return Text.format("([%s]%s)", annotation, s);
+ }
+}
diff --git a/client/src/main/java/ai/vespa/client/dsl/Q.java b/client/src/main/java/ai/vespa/client/dsl/Q.java
index 2d957dcfb92..f15ffed1ea9 100644
--- a/client/src/main/java/ai/vespa/client/dsl/Q.java
+++ b/client/src/main/java/ai/vespa/client/dsl/Q.java
@@ -183,4 +183,16 @@ public final class Q {
public static GeoLocation geoLocation(String field, Double longitude, Double latitude, String radius) {
return new GeoLocation(field, longitude, latitude, radius);
}
+
+ /**
+ * NearestNeighbor nearest neighbor
+ * https://docs.vespa.ai/en/reference/query-language-reference.html#nearestneighbor
+ *
+ * @param docVectorName the vector name defined in the vespa schema
+ * @param queryVectorName the vector name in this query
+ * @return the nearest neighbor query
+ */
+ public static NearestNeighbor nearestNeighbor(String docVectorName, String queryVectorName) {
+ return new NearestNeighbor(docVectorName, queryVectorName);
+ }
}
diff --git a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy
index e07e5d6cefc..1bada4e8f59 100644
--- a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy
+++ b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy
@@ -256,6 +256,34 @@ class QTest extends Specification {
q == """yql=select * from sd1 where a contains "b" and geoLocation(taiwan, 25.105497, 121.597366, "200km");"""
}
+ def "nearest neighbor query"() {
+ when:
+ def q = Q.select("*")
+ .from("sd1")
+ .where("a").contains("b")
+ .and(Q.nearestNeighbor("vec1", "vec2")
+ .annotate(A.a("targetHits", 10, "approximate", false))
+ )
+ .semicolon()
+ .build()
+
+ then:
+ q == """yql=select * from sd1 where a contains "b" and ([{"approximate":false,"targetHits":10}]nearestNeighbor(vec1, vec2));"""
+ }
+
+ def "invalid nearest neighbor should throws an exception (targetHits annotation is required)"() {
+ when:
+ def q = Q.select("*")
+ .from("sd1")
+ .where("a").contains("b").and(Q.nearestNeighbor("vec1", "vec2"))
+ .semicolon()
+ .build()
+
+ then:
+ thrown(IllegalArgumentException)
+ }
+
+
def "rank with only query"() {
given:
def q = Q.select("*")