From f3c29bf22e68691867934fed7635e32f31712654 Mon Sep 17 00:00:00 2001 From: Marcus Eltscheminov Date: Mon, 17 May 2021 22:34:24 +0200 Subject: Support nearest neighbor query --- .../src/main/java/ai/vespa/client/dsl/Field.java | 29 ++++++++++++++++++++++ .../test/groovy/ai/vespa/client/dsl/QTest.groovy | 20 +++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/client/src/main/java/ai/vespa/client/dsl/Field.java b/client/src/main/java/ai/vespa/client/dsl/Field.java index c540e844c7a..cc30b8aaded 100644 --- a/client/src/main/java/ai/vespa/client/dsl/Field.java +++ b/client/src/main/java/ai/vespa/client/dsl/Field.java @@ -550,6 +550,29 @@ public class Field extends QueryChain { return common("=", annotation, false); } + /** + * Nearest neighbor query. + * https://docs.vespa.ai/en/reference/query-language-reference.html#nearestneighbor + * + * @param rankFeature the rankfeature. + * @return the query + */ + public Query nearestNeighbor(String rankFeature) { + return common("nearestNeighbor", annotation, (Object) rankFeature); + } + + /** + * Nearest neighbor query. + * https://docs.vespa.ai/en/reference/query-language-reference.html#nearestneighbor + * + * @param annotation the annotation + * @param rankFeature the rankfeature. + * @return the query + */ + public Query nearestNeighbor(Annotation annotation, String rankFeature) { + return common("nearestNeighbor", annotation, (Object) rankFeature); + } + private Query common(String relation, Annotation annotation, Object value) { return common(relation, annotation, value, values.toArray()); } @@ -604,6 +627,12 @@ public class Field extends QueryChain { case "sameElement": return String.format("%s contains %s(%s)", fieldName, relation, ((Query) values.get(0)).toCommaSeparatedAndQueries()); + case "nearestNeighbor": + valuesStr = values.stream().map(i -> (String) i).collect(Collectors.joining(", ")); + + return hasAnnotation + ? String.format("([%s]nearestNeighbor(%s, %s))", annotation, fieldName, valuesStr) + : String.format("nearestNeighbor(%s, %s)", fieldName, valuesStr); default: Object value = values.get(0); valuesStr = value instanceof Long ? value.toString() + "L" : value.toString(); diff --git a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy index 671405a9c73..d1560937fef 100644 --- a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy +++ b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy @@ -424,6 +424,26 @@ class QTest extends Specification { q == """yql=select * from sources * where f1 contains ([{"key":"value"}]uri("https://test.uri"));""" } + def "nearestNeighbor"() { + given: + def q = Q.p("f1").nearestNeighbor("query_vector") + .semicolon() + .build() + + expect: + q == """yql=select * from sources * where nearestNeighbor(f1, query_vector);""" + } + + def "nearestNeighbor with annotation"() { + given: + def q = Q.p("f1").nearestNeighbor(A.a("targetHits", 10), "query_vector") + .semicolon() + .build() + + expect: + q == """yql=select * from sources * where ([{"targetHits":10}]nearestNeighbor(f1, query_vector));""" + } + def "use contains instead of contains equiv when input size is 1"() { def q = Q.p("f1").containsEquiv(["p1"]) .semicolon() -- cgit v1.2.3