diff options
47 files changed, 1758 insertions, 875 deletions
diff --git a/client/pom.xml b/client/pom.xml index 3dee909b932..ea33b9f3adf 100644 --- a/client/pom.xml +++ b/client/pom.xml @@ -28,15 +28,8 @@ <version>1.6</version> </dependency> <dependency> - <groupId>org.spockframework</groupId> - <artifactId>spock-core</artifactId> - <version>1.3-groovy-2.5</version> - <scope>test</scope> - </dependency> - <dependency> - <groupId>org.codehaus.groovy</groupId> - <artifactId>groovy</artifactId> - <version>3.0.8</version> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter</artifactId> <scope>test</scope> </dependency> </dependencies> @@ -44,19 +37,6 @@ <build> <plugins> <plugin> - <groupId>org.codehaus.gmavenplus</groupId> - <artifactId>gmavenplus-plugin</artifactId> - <version>1.13.0</version> - <executions> - <execution> - <goals> - <goal>addTestSources</goal> - <goal>compileTests</goal> - </goals> - </execution> - </executions> - </plugin> - <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> diff --git a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy b/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy deleted file mode 100644 index 0d6e2ca3506..00000000000 --- a/client/src/test/groovy/ai/vespa/client/dsl/QTest.groovy +++ /dev/null @@ -1,677 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.client.dsl - -import spock.lang.Specification - -class QTest extends Specification { - - def "select specific fields"() { - given: - def q = Q.select("f1", "f2") - .from("sd1") - .where("f1").contains("v1") - .semicolon() - .build() - - expect: - q == """yql=select f1, f2 from sd1 where f1 contains "v1";""" - } - - def "select from specific sources"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").contains("v1") - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 contains "v1";""" - } - - def "select from multiples sources"() { - given: - def q = Q.select("*") - .from("sd1", "sd2") - .where("f1").contains("v1") - .semicolon() - .build() - - expect: - q == """yql=select * from sources sd1, sd2 where f1 contains "v1";""" - } - - def "basic 'and', 'andnot', 'or', 'offset', 'limit', 'param', 'order by', and 'contains'"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").contains("v1") - .and("f2").contains("v2") - .or("f3").contains("v3") - .andnot("f4").contains("v4") - .offset(1) - .limit(2) - .timeout(3) - .orderByDesc("f1") - .orderByAsc("f2") - .semicolon() - .param("paramk1", "paramv1") - .build() - - expect: - q == """yql=select * from sd1 where f1 contains "v1" and f2 contains "v2" or f3 contains "v3" and !(f4 contains "v4") order by f1 desc, f2 asc limit 2 offset 1 timeout 3;¶mk1=paramv1""" - } - - def "matches"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").matches("v1") - .and("f2").matches("v2") - .or("f3").matches("v3") - .andnot("f4").matches("v4") - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 matches "v1" and f2 matches "v2" or f3 matches "v3" and !(f4 matches "v4");""" - } - - def "numeric operations"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").le(1) - .and("f2").lt(2) - .and("f3").ge(3) - .and("f4").gt(4) - .and("f5").eq(5) - .and("f6").inRange(6, 7) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 <= 1 and f2 < 2 and f3 >= 3 and f4 > 4 and f5 = 5 and range(f6, 6, 7);""" - } - - def "long numeric operations"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").le(1L) - .and("f2").lt(2L) - .and("f3").ge(3L) - .and("f4").gt(4L) - .and("f5").eq(5L) - .and("f6").inRange(6L, 7L) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 <= 1L and f2 < 2L and f3 >= 3L and f4 > 4L and f5 = 5L and range(f6, 6L, 7L);""" - } - - def "float numeric operations"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").le(1.1) - .and("f2").lt(2.2) - .and("f3").ge(3.3) - .and("f4").gt(4.4) - .and("f5").eq(5.5) - .and("f6").inRange(6.6, 7.7) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 <= 1.1 and f2 < 2.2 and f3 >= 3.3 and f4 > 4.4 and f5 = 5.5 and range(f6, 6.6, 7.7);""" - } - - def "double numeric operations"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").le(1.1D) - .and("f2").lt(2.2D) - .and("f3").ge(3.3D) - .and("f4").gt(4.4D) - .and("f5").eq(5.5D) - .and("f6").inRange(6.6D, 7.7D) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 <= 1.1 and f2 < 2.2 and f3 >= 3.3 and f4 > 4.4 and f5 = 5.5 and range(f6, 6.6, 7.7);""" - } - - def "nested queries"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").contains("1") - .andnot(Q.p(Q.p("f2").contains("2").and("f3").contains("3")) - .or(Q.p("f2").contains("4").andnot("f3").contains("5"))) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 contains "1" and !((f2 contains "2" and f3 contains "3") or (f2 contains "4" and !(f3 contains "5")));""" - } - - def "userInput (with and with out defaultIndex)"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.ui("value")) - .and(Q.ui("index", "value2")) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where userInput(@_1) and ([{"defaultIndex":"index"}]userInput(@_2_index));&_2_index=value2&_1=value""" - } - - def "dot product"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.dotPdt("f1", [a: 1, b: 2, c: 3])) - .and("f2").contains("1") - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where dotProduct(f1, {"a":1,"b":2,"c":3}) and f2 contains "1";""" - } - - def "weighted set"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.wtdSet("f1", [a: 1, b: 2, c: 3])) - .and("f2").contains("1") - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where weightedSet(f1, {"a":1,"b":2,"c":3}) and f2 contains "1";""" - } - - def "non empty"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.nonEmpty(Q.p("f1").contains("v1"))) - .and("f2").contains("v2") - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where nonEmpty(f1 contains "v1") and f2 contains "v2";""" - } - - - def "wand (with and without annotation)"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.wand("f1", [a: 1, b: 2, c: 3])) - .and(Q.wand("f2", [[1, 1], [2, 2]])) - .and( - Q.wand("f3", [[1, 1], [2, 2]]) - .annotate(A.a("scoreThreshold", 0.13)) - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where wand(f1, {"a":1,"b":2,"c":3}) and wand(f2, [[1,1],[2,2]]) and ([{"scoreThreshold":0.13}]wand(f3, [[1,1],[2,2]]));""" - } - - def "weak and (with and without annotation)"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.weakand(Q.p("f1").contains("v1").and("f2").contains("v2"))) - .and(Q.weakand(Q.p("f1").contains("v1").and("f2").contains("v2")) - .annotate(A.a("scoreThreshold", 0.13)) - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where weakAnd(f1 contains "v1", f2 contains "v2") and ([{"scoreThreshold":0.13}]weakAnd(f1 contains "v1", f2 contains "v2"));""" - } - - def "geo location"() { - given: - def q = Q.select("*") - .from("sd1") - .where("a").contains("b").and(Q.geoLocation("taiwan", 25.105497, 121.597366, "200km")) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where a contains "b" and geoLocation(taiwan, 25.105497, 121.597366, "200km");""" - } - - def "nearest neighbor query"() { - when: - def q = Q.select("*") - .from("sd1") - .where("a").contains("b") - .and(Q.nearestNeighbor("vec1", "vec2") - .annotate(A.a("targetHits", 10, "approximate", false)) - ) - .semicolon() - .build() - - then: - q == """yql=select * from sd1 where a contains "b" and ([{"approximate":false,"targetHits":10}]nearestNeighbor(vec1, vec2));""" - } - - def "invalid nearest neighbor should throws an exception (targetHits annotation is required)"() { - when: - def q = Q.select("*") - .from("sd1") - .where("a").contains("b").and(Q.nearestNeighbor("vec1", "vec2")) - .semicolon() - .build() - - then: - thrown(IllegalArgumentException) - } - - - def "rank with only query"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.rank( - Q.p("f1").contains("v1") - ) - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where rank(f1 contains "v1");""" - } - - def "rank"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.rank( - Q.p("f1").contains("v1"), - Q.p("f2").contains("v2"), - Q.p("f3").eq(3)) - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where rank(f1 contains "v1", f2 contains "v2", f3 = 3);""" - } - - def "rank with rank query array"() { - given: - Query[] ranks = [Q.p("f2").contains("v2"), Q.p("f3").eq(3)].toArray() - def q = Q.select("*") - .from("sd1") - .where(Q.rank( - Q.p("f1").contains("v1"), - ranks) - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where rank(f1 contains "v1", f2 contains "v2", f3 = 3);""" - } - - def "string/function annotations"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").contains(annotation, "v1") - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where f1 contains (${expected}"v1");""" - - where: - annotation | expected - A.filter() | """[{"filter":true}]""" - A.defaultIndex("idx") | """[{"defaultIndex":"idx"}]""" - A.a([a1: [k1: "v1", k2: 2]]) | """[{"a1":{"k1":"v1","k2":2}}]""" - } - - def "sub-expression annotations"() { - given: - def q = Q.select("*") - .from("sd1") - .where("f1").contains("v1").annotate(A.a("ak1", "av1")) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where ([{"ak1":"av1"}](f1 contains "v1"));""" - } - - def "sub-expressions annotations (annotate in the middle of query)"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.p("f1").contains("v1").annotate(A.a("ak1", "av1")).and("f2").contains("v2")) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where ([{"ak1":"av1"}](f1 contains "v1" and f2 contains "v2"));""" - } - - def "sub-expressions annotations (annotate in nested queries)"() { - given: - def q = Q.select("*") - .from("sd1") - .where(Q.p( - Q.p("f1").contains("v1").annotate(A.a("ak1", "av1"))) - .and("f2").contains("v2") - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sd1 where (([{"ak1":"av1"}](f1 contains "v1")) and f2 contains "v2");""" - } - - def "build query which created from Q.b without select and sources"() { - given: - def q = Q.p("f1").contains("v1") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains "v1";""" - } - - def "order by"() { - given: - def q = Q.p("f1").contains("v1") - .orderByAsc("f2") - .orderByAsc(A.a([function: "uca", locale: "en_US", strength: "IDENTICAL"]), "f3") - .orderByDesc("f4") - .orderByDesc(A.a([function: "lowercase"]), "f5") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains "v1" order by f2 asc, [{"function":"uca","locale":"en_US","strength":"IDENTICAL"}]f3 asc, f4 desc, [{"function":"lowercase"}]f5 desc;""" - } - - def "contains sameElement"() { - given: - def q = Q.p("f1").containsSameElement(Q.p("stime").le(1).and("etime").gt(2)) - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains sameElement(stime <= 1, etime > 2);""" - } - - def "contains phrase/near/onear/equiv"() { - given: - def funcName = "contains${operator.capitalize()}" - def q1 = Q.p("f1")."$funcName"("p1", "p2", "p3") - .semicolon() - .build() - def q2 = Q.p("f1")."$funcName"(["p1", "p2", "p3"]) - .semicolon() - .build() - - expect: - q1 == """yql=select * from sources * where f1 contains ${operator}("p1", "p2", "p3");""" - q2 == """yql=select * from sources * where f1 contains ${operator}("p1", "p2", "p3");""" - - where: - operator | _ - "phrase" | _ - "near" | _ - "onear" | _ - "equiv" | _ - } - - def "contains uri"() { - given: - def q = Q.p("f1").containsUri("https://test.uri") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains uri("https://test.uri");""" - } - - def "contains uri with annotation"() { - given: - def q = Q.p("f1").containsUri(A.a("key", "value"), "https://test.uri") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains ([{"key":"value"}]uri("https://test.uri"));""" - } - - def "nearestNeighbor"() { - given: - def q = Q.p("f1").nearestNeighbor("query_vector") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where nearestNeighbor(f1, query_vector);""" - } - - def "nearestNeighbor with annotation"() { - given: - def q = Q.p("f1").nearestNeighbor(A.a("targetHits", 10), "query_vector") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where ([{"targetHits":10}]nearestNeighbor(f1, query_vector));""" - } - - def "use contains instead of contains equiv when input size is 1"() { - def q = Q.p("f1").containsEquiv(["p1"]) - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains "p1";""" - } - - def "contains phrase/near/onear/equiv empty list should throw illegal argument exception"() { - given: - def funcName = "contains${operator.capitalize()}" - - when: - def q = Q.p("f1")."$funcName"([]) - .semicolon() - .build() - - then: - thrown(IllegalArgumentException) - - where: - operator | _ - "phrase" | _ - "near" | _ - "onear" | _ - "equiv" | _ - } - - - def "contains near/onear with annotation"() { - given: - def funcName = "contains${operator.capitalize()}" - def q = Q.p("f1")."$funcName"(A.a("distance", 5), "p1", "p2", "p3") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains ([{"distance":5}]${operator}("p1", "p2", "p3"));""" - - where: - operator | _ - "near" | _ - "onear" | _ - } - - def "basic group syntax"() { - /* - example from vespa document: - https://docs.vespa.ai/en/grouping.html - all( group(a) max(5) each(output(count()) - all(max(1) each(output(summary()))) - all(group(b) each(output(count()) - all(max(1) each(output(summary()))) - all(group(c) each(output(count()) - all(max(1) each(output(summary())))))))) ); - */ - given: - def q = Q.p("f1").contains("v1") - .group( - G.all(G.group("a"), G.maxRtn(5), G.each(G.output(G.count()), - G.all(G.maxRtn(1), G.each(G.output(G.summary()))), - G.all(G.group("b"), G.each(G.output(G.count()), - G.all(G.maxRtn(1), G.each(G.output(G.summary()))), - G.all(G.group("c"), G.each(G.output(G.count()), - G.all(G.maxRtn(1), G.each(G.output(G.summary()))) - )) - )) - )) - ) - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains "v1" | all(group(a) max(5) each(output(count()) all(max(1) each(output(summary()))) all(group(b) each(output(count()) all(max(1) each(output(summary()))) all(group(c) each(output(count()) all(max(1) each(output(summary())))))))));""" - } - - def "set group syntax string directly"() { - /* - example from vespa document: - https://docs.vespa.ai/en/grouping.html - all( group(a) max(5) each(output(count()) - all(max(1) each(output(summary()))) - all(group(b) each(output(count()) - all(max(1) each(output(summary()))) - all(group(c) each(output(count()) - all(max(1) each(output(summary())))))))) ); - */ - given: - def q = Q.p("f1").contains("v1") - .group("all(group(a) max(5) each(output(count()) all(max(1) each(output(summary()))) all(group(b) each(output(count()) all(max(1) each(output(summary()))) all(group(c) each(output(count()) all(max(1) each(output(summary())))))))))") - .semicolon() - .build() - - expect: - q == """yql=select * from sources * where f1 contains "v1" | all(group(a) max(5) each(output(count()) all(max(1) each(output(summary()))) all(group(b) each(output(count()) all(max(1) each(output(summary()))) all(group(c) each(output(count()) all(max(1) each(output(summary())))))))));""" - } - - def "arbitrary annotations"() { - given: - def a = A.a("a1", "v1", "a2", 2, "a3", [k: "v", k2: 1], "a4", 4D, "a5", [1, 2, 3]) - expect: - a.toString() == """{"a1":"v1","a2":2,"a3":{"k":"v","k2":1},"a4":4.0,"a5":[1,2,3]}""" - } - - def "test programmability"() { - given: - def map = [a: "1", b: "2", c: "3"] - - when: - Query q = map - .entrySet() - .stream() - .map { entry -> Q.p(entry.key).contains(entry.value) } - .reduce { q1, q2 -> q1.and(q2) } - .get() - - then: - q.semicolon().build() == """yql=select * from sources * where a contains "1" and b contains "2" and c contains "3";""" - } - - def "test programmability 2"() { - given: - def map = [a: "1", b: "2", c: "3"] - def q = Q.p() - - when: - map.each { k, v -> - q.and(Q.p(k).contains(v)) - } - - then: - q.semicolon().build() == """yql=select * from sources * where a contains "1" and b contains "2" and c contains "3";""" - } - - def "empty queries should not print out"() { - given: - def q = Q.p(Q.p(Q.p().andnot(Q.p()).and(Q.p()))).and("a").contains("1").semicolon().build() - - expect: - q == """yql=select * from sources * where a contains "1";""" - } - - def "validate positive search term of strings"() { - given: - def q = Q.p(Q.p("k1").contains("v1").and("k2").contains("v2").andnot("k3").contains("v3")) - .andnot(Q.p("nk1").contains("nv1").and("nk2").contains("nv2").andnot("nk3").contains("nv3")) - .and(Q.p("k4").contains("v4") - .andnot(Q.p("k5").contains("v5").andnot("k6").contains("v6")) - ) - - expect: - q.hasPositiveSearchField("k1") - q.hasPositiveSearchField("k2") - q.hasPositiveSearchField("nk3") - q.hasPositiveSearchField("k6") - q.hasPositiveSearchField("k6", "v6") - !q.hasPositiveSearchField("k6", "v5") - - q.hasNegativeSearchField("k3") - q.hasNegativeSearchField("nk1") - q.hasNegativeSearchField("nk2") - q.hasNegativeSearchField("k5") - q.hasNegativeSearchField("k5", "v5") - !q.hasNegativeSearchField("k5", "v4") - } - - def "validate positive search term of user input"() { - given: - def q = Q.p(Q.ui("k1", "v1")).and(Q.ui("k2", "v2")).andnot(Q.ui("k3", "v3")) - .andnot(Q.p(Q.ui("nk1", "nv1")).and(Q.ui("nk2", "nv2")).andnot(Q.ui("nk3", "nv3"))) - .and(Q.p(Q.ui("k4", "v4")) - .andnot(Q.p(Q.ui("k5", "v5")).andnot(Q.ui("k6", "v6"))) - ) - - expect: - q.hasPositiveSearchField("k1") - q.hasPositiveSearchField("k2") - q.hasPositiveSearchField("nk3") - q.hasPositiveSearchField("k6") - q.hasPositiveSearchField("k6", "v6") - !q.hasPositiveSearchField("k6", "v5") - - q.hasNegativeSearchField("k3") - q.hasNegativeSearchField("nk1") - q.hasNegativeSearchField("nk2") - q.hasNegativeSearchField("k5") - q.hasNegativeSearchField("k5", "v5") - !q.hasNegativeSearchField("k5", "v4") - } -} diff --git a/client/src/test/java/ai/vespa/client/dsl/QTest.java b/client/src/test/java/ai/vespa/client/dsl/QTest.java new file mode 100644 index 00000000000..08ab603fa04 --- /dev/null +++ b/client/src/test/java/ai/vespa/client/dsl/QTest.java @@ -0,0 +1,727 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.client.dsl; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author unknown contributor + * @author bjorncs + */ +class QTest { + + @Test + void select_specific_fields() { + String q = Q.select("f1", "f2") + .from("sd1") + .where("f1").contains("v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select f1, f2 from sd1 where f1 contains \"v1\";"); + } + + @Test + void select_from_specific_sources() { + String q = Q.select("*") + .from("sd1") + .where("f1").contains("v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 contains \"v1\";"); + } + + @Test + void select_from_multiples_sources() { + String q = Q.select("*") + .from("sd1", "sd2") + .where("f1").contains("v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources sd1, sd2 where f1 contains \"v1\";"); + } + + @Test + void basic_and_andnot_or_offset_limit_param_order_by_and_contains() { + String q = Q.select("*") + .from("sd1") + .where("f1").contains("v1") + .and("f2").contains("v2") + .or("f3").contains("v3") + .andnot("f4").contains("v4") + .offset(1) + .limit(2) + .timeout(3) + .orderByDesc("f1") + .orderByAsc("f2") + .semicolon() + .param("paramk1", "paramv1") + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 contains \"v1\" and f2 contains \"v2\" or f3 contains \"v3\" and !(f4 contains \"v4\") order by f1 desc, f2 asc limit 2 offset 1 timeout 3;¶mk1=paramv1"); + } + + @Test + void matches() { + String q = Q.select("*") + .from("sd1") + .where("f1").matches("v1") + .and("f2").matches("v2") + .or("f3").matches("v3") + .andnot("f4").matches("v4") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 matches \"v1\" and f2 matches \"v2\" or f3 matches \"v3\" and !(f4 matches \"v4\");"); + } + + @Test + void numeric_operations() { + String q = Q.select("*") + .from("sd1") + .where("f1").le(1) + .and("f2").lt(2) + .and("f3").ge(3) + .and("f4").gt(4) + .and("f5").eq(5) + .and("f6").inRange(6, 7) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 <= 1 and f2 < 2 and f3 >= 3 and f4 > 4 and f5 = 5 and range(f6, 6, 7);"); + } + + @Test + void long_numeric_operations() { + String q = Q.select("*") + .from("sd1") + .where("f1").le(1L) + .and("f2").lt(2L) + .and("f3").ge(3L) + .and("f4").gt(4L) + .and("f5").eq(5L) + .and("f6").inRange(6L, 7L) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 <= 1L and f2 < 2L and f3 >= 3L and f4 > 4L and f5 = 5L and range(f6, 6L, 7L);"); + } + + @Test + void float_numeric_operations() { + String q = Q.select("*") + .from("sd1") + .where("f1").le(1.1) + .and("f2").lt(2.2) + .and("f3").ge(3.3) + .and("f4").gt(4.4) + .and("f5").eq(5.5) + .and("f6").inRange(6.6, 7.7) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 <= 1.1 and f2 < 2.2 and f3 >= 3.3 and f4 > 4.4 and f5 = 5.5 and range(f6, 6.6, 7.7);"); + } + + @Test + void double_numeric_operations() { + String q = Q.select("*") + .from("sd1") + .where("f1").le(1.1D) + .and("f2").lt(2.2D) + .and("f3").ge(3.3D) + .and("f4").gt(4.4D) + .and("f5").eq(5.5D) + .and("f6").inRange(6.6D, 7.7D) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 <= 1.1 and f2 < 2.2 and f3 >= 3.3 and f4 > 4.4 and f5 = 5.5 and range(f6, 6.6, 7.7);"); + } + + @Test + void nested_queries() { + String q = Q.select("*") + .from("sd1") + .where("f1").contains("1") + .andnot(Q.p(Q.p("f2").contains("2").and("f3").contains("3")) + .or(Q.p("f2").contains("4").andnot("f3").contains("5"))) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 contains \"1\" and !((f2 contains \"2\" and f3 contains \"3\") or (f2 contains \"4\" and !(f3 contains \"5\")));"); + } + + @Test + void userInput_with_and_with_out_defaultIndex() { + String q = Q.select("*") + .from("sd1") + .where(Q.ui("value")) + .and(Q.ui("index", "value2")) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where userInput(@_1) and ([{\"defaultIndex\":\"index\"}]userInput(@_2_index));&_2_index=value2&_1=value"); + } + + @Test + void dot_product() { + String q = Q.select("*") + .from("sd1") + .where(Q.dotPdt("f1", stringIntMap("a", 1, "b", 2, "c", 3))) + .and("f2").contains("1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where dotProduct(f1, {\"a\":1,\"b\":2,\"c\":3}) and f2 contains \"1\";"); + } + + @Test + void weighted_set() { + String q = Q.select("*") + .from("sd1") + .where(Q.wtdSet("f1", stringIntMap("a", 1, "b", 2, "c", 3))) + .and("f2").contains("1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where weightedSet(f1, {\"a\":1,\"b\":2,\"c\":3}) and f2 contains \"1\";"); + } + + @Test + void non_empty() { + String q = Q.select("*") + .from("sd1") + .where(Q.nonEmpty(Q.p("f1").contains("v1"))) + .and("f2").contains("v2") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where nonEmpty(f1 contains \"v1\") and f2 contains \"v2\";"); + } + + + @Test + void wand_with_and_without_annotation() { + String q = Q.select("*") + .from("sd1") + .where(Q.wand("f1", stringIntMap("a", 1, "b", 2, "c", 3))) + .and(Q.wand("f2", Arrays.asList(Arrays.asList(1, 1), Arrays.asList(2, 2)))) + .and( + Q.wand("f3", Arrays.asList(Arrays.asList(1, 1), Arrays.asList(2, 2))) + .annotate(A.a("scoreThreshold", 0.13)) + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where wand(f1, {\"a\":1,\"b\":2,\"c\":3}) and wand(f2, [[1,1],[2,2]]) and ([{\"scoreThreshold\":0.13}]wand(f3, [[1,1],[2,2]]));"); + } + + @Test + void weak_and_with_and_without_annotation() { + String q = Q.select("*") + .from("sd1") + .where(Q.weakand(Q.p("f1").contains("v1").and("f2").contains("v2"))) + .and(Q.weakand(Q.p("f1").contains("v1").and("f2").contains("v2")) + .annotate(A.a("scoreThreshold", 0.13)) + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where weakAnd(f1 contains \"v1\", f2 contains \"v2\") and ([{\"scoreThreshold\":0.13}]weakAnd(f1 contains \"v1\", f2 contains \"v2\"));"); + } + + @Test + void geo_location() { + String q = Q.select("*") + .from("sd1") + .where("a").contains("b").and(Q.geoLocation("taiwan", 25.105497, 121.597366, "200km")) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where a contains \"b\" and geoLocation(taiwan, 25.105497, 121.597366, \"200km\");"); + } + + @Test + void nearest_neighbor_query() { + String q = Q.select("*") + .from("sd1") + .where("a").contains("b") + .and(Q.nearestNeighbor("vec1", "vec2") + .annotate(A.a("targetHits", 10, "approximate", false)) + ) + .semicolon() + .build(); + assertEquals(q, "yql=select * from sd1 where a contains \"b\" and ([{\"approximate\":false,\"targetHits\":10}]nearestNeighbor(vec1, vec2));"); + } + + @Test + void invalid_nearest_neighbor_should_throws_an_exception_targetHits_annotation_is_required() { + assertThrows(IllegalArgumentException.class, + () -> Q.select("*") + .from("sd1") + .where("a").contains("b").and(Q.nearestNeighbor("vec1", "vec2")) + .semicolon() + .build()); + } + + + @Test + void rank_with_only_query() { + String q = Q.select("*") + .from("sd1") + .where(Q.rank( + Q.p("f1").contains("v1") + ) + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where rank(f1 contains \"v1\");"); + } + + @Test + void rank() { + String q = Q.select("*") + .from("sd1") + .where(Q.rank( + Q.p("f1").contains("v1"), + Q.p("f2").contains("v2"), + Q.p("f3").eq(3)) + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where rank(f1 contains \"v1\", f2 contains \"v2\", f3 = 3);"); + } + + @Test + void rank_with_rank_query_array() { + Query[] ranks = new Query[]{Q.p("f2").contains("v2"), Q.p("f3").eq(3)}; + String q = Q.select("*") + .from("sd1") + .where(Q.rank( + Q.p("f1").contains("v1"), + ranks) + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where rank(f1 contains \"v1\", f2 contains \"v2\", f3 = 3);"); + } + + @Test + void stringfunction_annotations() { + + { + Annotation annotation = A.filter(); + String expected = "[{\"filter\":true}]"; + String q = Q.select("*") + .from("sd1") + .where("f1").contains(annotation, "v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 contains (" + expected + "\"v1\");"); + } + { + Annotation annotation = A.defaultIndex("idx"); + String expected = "[{\"defaultIndex\":\"idx\"}]"; + String q = Q.select("*") + .from("sd1") + .where("f1").contains(annotation, "v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 contains (" + expected + "\"v1\");"); + } + { + Annotation annotation = A.a(stringObjMap("a1", stringObjMap("k1", "v1", "k2", 2))); + String expected = "[{\"a1\":{\"k1\":\"v1\",\"k2\":2}}]"; + String q = Q.select("*") + .from("sd1") + .where("f1").contains(annotation, "v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where f1 contains (" + expected + "\"v1\");"); + } + + } + + @Test + void sub_expression_annotations() { + String q = Q.select("*") + .from("sd1") + .where("f1").contains("v1").annotate(A.a("ak1", "av1")) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where ([{\"ak1\":\"av1\"}](f1 contains \"v1\"));"); + } + + @Test + void sub_expressions_annotations_annotate_in_the_middle_of_query() { + String q = Q.select("*") + .from("sd1") + .where(Q.p("f1").contains("v1").annotate(A.a("ak1", "av1")).and("f2").contains("v2")) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where ([{\"ak1\":\"av1\"}](f1 contains \"v1\" and f2 contains \"v2\"));"); + } + + @Test + void sub_expressions_annotations_annotate_in_nested_queries() { + String q = Q.select("*") + .from("sd1") + .where(Q.p( + Q.p("f1").contains("v1").annotate(A.a("ak1", "av1"))) + .and("f2").contains("v2") + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sd1 where (([{\"ak1\":\"av1\"}](f1 contains \"v1\")) and f2 contains \"v2\");"); + } + + @Test + void build_query_which_created_from_Q_b_without_select_and_sources() { + String q = Q.p("f1").contains("v1") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains \"v1\";"); + } + + @Test + void order_by() { + String q = Q.p("f1").contains("v1") + .orderByAsc("f2") + .orderByAsc(A.a(stringObjMap("function", "uca", "locale", "en_US", "strength", "IDENTICAL")), "f3") + .orderByDesc("f4") + .orderByDesc(A.a(stringObjMap("function", "lowercase")), "f5") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains \"v1\" order by f2 asc, [{\"function\":\"uca\",\"locale\":\"en_US\",\"strength\":\"IDENTICAL\"}]f3 asc, f4 desc, [{\"function\":\"lowercase\"}]f5 desc;"); + } + + @Test + void contains_sameElement() { + String q = Q.p("f1").containsSameElement(Q.p("stime").le(1).and("etime").gt(2)) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains sameElement(stime <= 1, etime > 2);"); + } + + @Test + void contains_phrase_near_onear_equiv() { + { + String q1 = Q.p("f1").containsPhrase("p1", "p2", "p3") + .semicolon() + .build(); + String q2 = Q.p("f1").containsPhrase(Arrays.asList("p1", "p2", "p3")) + .semicolon() + .build(); + assertEquals(q1, "yql=select * from sources * where f1 contains phrase(\"p1\", \"p2\", \"p3\");"); + assertEquals(q2, "yql=select * from sources * where f1 contains phrase(\"p1\", \"p2\", \"p3\");"); + } + { + String q1 = Q.p("f1").containsNear("p1", "p2", "p3") + .semicolon() + .build(); + String q2 = Q.p("f1").containsNear(Arrays.asList("p1", "p2", "p3")) + .semicolon() + .build(); + assertEquals(q1, "yql=select * from sources * where f1 contains near(\"p1\", \"p2\", \"p3\");"); + assertEquals(q2, "yql=select * from sources * where f1 contains near(\"p1\", \"p2\", \"p3\");"); + } + { + String q1 = Q.p("f1").containsOnear("p1", "p2", "p3") + .semicolon() + .build(); + String q2 = Q.p("f1").containsOnear(Arrays.asList("p1", "p2", "p3")) + .semicolon() + .build(); + assertEquals(q1, "yql=select * from sources * where f1 contains onear(\"p1\", \"p2\", \"p3\");"); + assertEquals(q2, "yql=select * from sources * where f1 contains onear(\"p1\", \"p2\", \"p3\");"); + } + { + String q1 = Q.p("f1").containsEquiv("p1", "p2", "p3") + .semicolon() + .build(); + String q2 = Q.p("f1").containsEquiv(Arrays.asList("p1", "p2", "p3")) + .semicolon() + .build(); + assertEquals(q1, "yql=select * from sources * where f1 contains equiv(\"p1\", \"p2\", \"p3\");"); + assertEquals(q2, "yql=select * from sources * where f1 contains equiv(\"p1\", \"p2\", \"p3\");"); + } + } + + @Test + void contains_uri() { + String q = Q.p("f1").containsUri("https://test.uri") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains uri(\"https://test.uri\");"); + } + + @Test + void contains_uri_with_annotation() { + String q = Q.p("f1").containsUri(A.a("key", "value"), "https://test.uri") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains ([{\"key\":\"value\"}]uri(\"https://test.uri\"));"); + } + + @Test + void nearestNeighbor() { + String q = Q.p("f1").nearestNeighbor("query_vector") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where nearestNeighbor(f1, query_vector);"); + } + + @Test + void nearestNeighbor_with_annotation() { + String q = Q.p("f1").nearestNeighbor(A.a("targetHits", 10), "query_vector") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where ([{\"targetHits\":10}]nearestNeighbor(f1, query_vector));"); + } + + @Test + void use_contains_instead_of_contains_equiv_when_input_size_is_1() { + String q = Q.p("f1").containsEquiv(Collections.singletonList("p1")) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains \"p1\";"); + } + + @Test + void contains_phrase_near_onear_equiv_empty_list_should_throw_illegal_argument_exception() { + assertThrows(IllegalArgumentException.class, () -> Q.p("f1").containsPhrase(Collections.emptyList()) + .semicolon() + .build()); + + assertThrows(IllegalArgumentException.class, () -> Q.p("f1").containsNear(Collections.emptyList()) + .semicolon() + .build()); + + assertThrows(IllegalArgumentException.class, () -> Q.p("f1").containsOnear(Collections.emptyList()) + .semicolon() + .build()); + + assertThrows(IllegalArgumentException.class, () -> Q.p("f1").containsEquiv(Collections.emptyList()) + .semicolon() + .build()); + } + + + @Test + void contains_near_onear_with_annotation() { + { + String q = Q.p("f1").containsNear(A.a("distance", 5), "p1", "p2", "p3") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains ([{\"distance\":5}]near(\"p1\", \"p2\", \"p3\"));"); + } + { + String q = Q.p("f1").containsOnear(A.a("distance", 5), "p1", "p2", "p3") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains ([{\"distance\":5}]onear(\"p1\", \"p2\", \"p3\"));"); + } + } + + @Test + void basic_group_syntax() { + /* + example from vespa document: + https://docs.vespa.ai/en/grouping.html + all( group(a) max(5) each(output(count()) + all(max(1) each(output(summary()))) + all(group(b) each(output(count()) + all(max(1) each(output(summary()))) + all(group(c) each(output(count()) + all(max(1) each(output(summary())))))))) ); + */ + String q = Q.p("f1").contains("v1") + .group( + G.all(G.group("a"), G.maxRtn(5), G.each(G.output(G.count()), + G.all(G.maxRtn(1), G.each(G.output(G.summary()))), + G.all(G.group("b"), G.each(G.output(G.count()), + G.all(G.maxRtn(1), G.each(G.output(G.summary()))), + G.all(G.group("c"), G.each(G.output(G.count()), + G.all(G.maxRtn(1), G.each(G.output(G.summary()))) + )) + )) + )) + ) + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains \"v1\" | all(group(a) max(5) each(output(count()) all(max(1) each(output(summary()))) all(group(b) each(output(count()) all(max(1) each(output(summary()))) all(group(c) each(output(count()) all(max(1) each(output(summary())))))))));"); + } + + @Test + void set_group_syntax_string_directly() { + /* + example from vespa document: + https://docs.vespa.ai/en/grouping.html + all( group(a) max(5) each(output(count()) + all(max(1) each(output(summary()))) + all(group(b) each(output(count()) + all(max(1) each(output(summary()))) + all(group(c) each(output(count()) + all(max(1) each(output(summary())))))))) ); + */ + String q = Q.p("f1").contains("v1") + .group("all(group(a) max(5) each(output(count()) all(max(1) each(output(summary()))) all(group(b) each(output(count()) all(max(1) each(output(summary()))) all(group(c) each(output(count()) all(max(1) each(output(summary())))))))))") + .semicolon() + .build(); + + assertEquals(q, "yql=select * from sources * where f1 contains \"v1\" | all(group(a) max(5) each(output(count()) all(max(1) each(output(summary()))) all(group(b) each(output(count()) all(max(1) each(output(summary()))) all(group(c) each(output(count()) all(max(1) each(output(summary())))))))));"); + } + +@Test + void arbitrary_annotations() { + Annotation a = A.a("a1", "v1", "a2", 2, "a3", stringObjMap("k", "v", "k2", 1), "a4", 4D, "a5", Arrays.asList(1, 2, 3)); + assertEquals(a.toString(), "{\"a1\":\"v1\",\"a2\":2,\"a3\":{\"k\":\"v\",\"k2\":1},\"a4\":4.0,\"a5\":[1,2,3]}"); + } + + @Test + void test_programmability() { + Map<String, String> map = stringStringMap("a", "1", "b", "2", "c", "3"); + + Query q = map + .entrySet() + .stream() + .map(entry -> Q.p(entry.getKey()).contains(entry.getValue())) + .reduce(Query::and) + .get(); + + assertEquals(q.semicolon().build(), "yql=select * from sources * where a contains \"1\" and b contains \"2\" and c contains \"3\";"); + } + + @Test + void test_programmability_2() { + Map<String, String> map = stringStringMap("a", "1", "b", "2", "c", "3"); + Query q = Q.p(); + + map.forEach((k, v) -> q.and(Q.p(k).contains(v))); + + assertEquals(q.semicolon().build(), "yql=select * from sources * where a contains \"1\" and b contains \"2\" and c contains \"3\";"); + } + + @Test + void empty_queries_should_not_print_out() { + String q = Q.p(Q.p(Q.p().andnot(Q.p()).and(Q.p()))).and("a").contains("1").semicolon().build(); + + assertEquals(q, "yql=select * from sources * where a contains \"1\";"); + } + + @Test + void validate_positive_search_term_of_strings() { + Query q = Q.p(Q.p("k1").contains("v1").and("k2").contains("v2").andnot("k3").contains("v3")) + .andnot(Q.p("nk1").contains("nv1").and("nk2").contains("nv2").andnot("nk3").contains("nv3")) + .and(Q.p("k4").contains("v4") + .andnot(Q.p("k5").contains("v5").andnot("k6").contains("v6")) + ); + + assertTrue(q.hasPositiveSearchField("k1")); + assertTrue(q.hasPositiveSearchField("k2")); + assertTrue(q.hasPositiveSearchField("nk3")); + assertTrue(q.hasPositiveSearchField("k6")); + assertTrue(q.hasPositiveSearchField("k6", "v6")); + assertFalse(q.hasPositiveSearchField("k6", "v5")); + + assertTrue(q.hasNegativeSearchField("k3")); + assertTrue(q.hasNegativeSearchField("nk1")); + assertTrue(q.hasNegativeSearchField("nk2")); + assertTrue(q.hasNegativeSearchField("k5")); + assertTrue(q.hasNegativeSearchField("k5", "v5")); + assertFalse(q.hasNegativeSearchField("k5", "v4")); + } + + @Test + void validate_positive_search_term_of_user_input() { + Query q = Q.p(Q.ui("k1", "v1")).and(Q.ui("k2", "v2")).andnot(Q.ui("k3", "v3")) + .andnot(Q.p(Q.ui("nk1", "nv1")).and(Q.ui("nk2", "nv2")).andnot(Q.ui("nk3", "nv3"))) + .and(Q.p(Q.ui("k4", "v4")) + .andnot(Q.p(Q.ui("k5", "v5")).andnot(Q.ui("k6", "v6"))) + ); + + assertTrue(q.hasPositiveSearchField("k1")); + assertTrue(q.hasPositiveSearchField("k2")); + assertTrue(q.hasPositiveSearchField("nk3")); + assertTrue(q.hasPositiveSearchField("k6")); + assertTrue(q.hasPositiveSearchField("k6", "v6")); + assertFalse(q.hasPositiveSearchField("k6", "v5")); + + assertTrue(q.hasNegativeSearchField("k3")); + assertTrue(q.hasNegativeSearchField("nk1")); + assertTrue(q.hasNegativeSearchField("nk2")); + assertTrue(q.hasNegativeSearchField("k5")); + assertTrue(q.hasNegativeSearchField("k5", "v5")); + assertFalse(q.hasNegativeSearchField("k5", "v4")); + } + + private static Map<String, Integer> stringIntMap(String k1, int v1, String k2, int v2, String k3, int v3) { + HashMap<String, Integer> m = new HashMap<>(); + m.put(k1, v1); + m.put(k2, v2); + m.put(k3, v3); + return m; + } + + private static Map<String, Object> stringObjMap(String k, Object v) { + HashMap<String, Object> m = new HashMap<>(); + m.put(k, v); + return m; + } + + private static Map<String, Object> stringObjMap(String k1, Object v1, String k2, Object v2) { + Map<String, Object> m = new LinkedHashMap<>(); + m.put(k1, v1); + m.put(k2, v2); + return m; + } + + private static Map<String, Object> stringObjMap(String k1, Object v1, String k2, Object v2, String k3, Object v3) { + Map<String, Object> m = new LinkedHashMap<>(); + m.put(k1, v1); + m.put(k2, v2); + m.put(k3, v3); + return m; + } + + private static Map<String, String> stringStringMap(String k1, String v1, String k2, String v2, String k3, String v3) { + Map<String, String> m = new LinkedHashMap<>(); + m.put(k1, v1); + m.put(k2, v2); + m.put(k3, v3); + return m; + } +}
\ No newline at end of file diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java index ebd4059f9f9..f6583468322 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java @@ -145,7 +145,11 @@ public class ApplicationClusterEndpoint { return name; } - // TODO: remove + // TODO: remove when 7.508 is latest version + public static DnsName sharedNameFrom(ClusterSpec.Id cluster, ApplicationId applicationId, String suffix) { + return sharedNameFrom(SystemName.main, cluster, applicationId, suffix); + } + public static DnsName sharedNameFrom(SystemName systemName, ClusterSpec.Id cluster, ApplicationId applicationId, String suffix) { String name = dnsParts(systemName, cluster, applicationId) .filter(Objects::nonNull) // remove null values that were "default" @@ -153,6 +157,11 @@ public class ApplicationClusterEndpoint { return new DnsName(sanitize(name) + suffix); // Need to sanitize name since it is considered one label } + // TODO remove this method when 7.508 is latest version + public static DnsName sharedL4NameFrom(ClusterSpec.Id cluster, ApplicationId applicationId, String suffix) { + return sharedL4NameFrom(SystemName.main, cluster, applicationId, suffix); + } + public static DnsName sharedL4NameFrom(SystemName systemName, ClusterSpec.Id cluster, ApplicationId applicationId, String suffix) { String name = dnsParts(systemName, cluster, applicationId) .filter(Objects::nonNull) // remove null values that were "default" diff --git a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionRpcServer.java b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionRpcServer.java index 8b9d1f34154..dfbd605ab50 100644 --- a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionRpcServer.java +++ b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/filedistribution/FileDistributionRpcServer.java @@ -9,6 +9,7 @@ import com.yahoo.jrt.Request; import com.yahoo.jrt.StringArray; import com.yahoo.jrt.StringValue; import com.yahoo.jrt.Supervisor; +import com.yahoo.net.HostName; import com.yahoo.vespa.filedistribution.FileDownloader; import java.io.File; @@ -101,7 +102,7 @@ class FileDistributionRpcServer { private void downloadFile(Request req) { FileReference fileReference = new FileReference(req.parameters().get(0).asString()); log.log(Level.FINE, () -> "getFile() called for file reference '" + fileReference.value() + "'"); - Optional<File> file = downloader.getFile(fileReference); + Optional<File> file = downloader.getFile(fileReference, HostName.getLocalhost()); if (file.isPresent()) { new RequestTracker().trackRequest(file.get().getParentFile()); req.returnValues().add(new StringValue(file.get().getAbsolutePath())); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v1/RoutingStatusApiHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v1/RoutingStatusApiHandler.java new file mode 100644 index 00000000000..bc44093b89a --- /dev/null +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v1/RoutingStatusApiHandler.java @@ -0,0 +1,277 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.config.server.http.v1; + +import com.google.inject.Inject; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.Deployer; +import com.yahoo.jdisc.http.HttpRequest; +import com.yahoo.path.Path; +import com.yahoo.restapi.RestApi; +import com.yahoo.restapi.RestApiException; +import com.yahoo.restapi.RestApiRequestHandler; +import com.yahoo.restapi.SlimeJsonResponse; +import com.yahoo.slime.Cursor; +import com.yahoo.slime.Slime; +import com.yahoo.slime.SlimeUtils; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.yolean.Exceptions; + +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Locale; +import java.util.Objects; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * This implements the /routing/v1/status REST API on the config server, providing explicit control over the routing + * status of a deployment or zone (all deployments). The routing status manipulated by this is only respected by the + * shared routing layer. + * + * @author bjorncs + * @author mpolden + */ +public class RoutingStatusApiHandler extends RestApiRequestHandler<RoutingStatusApiHandler> { + + private static final Logger log = Logger.getLogger(RoutingStatusApiHandler.class.getName()); + + private static final Path ROUTING_ROOT = Path.fromString("/routing/v1/"); + private static final Path DEPLOYMENT_STATUS_ROOT = ROUTING_ROOT.append("status"); + private static final Path ZONE_STATUS_ROOT = ROUTING_ROOT.append("zone-inactive"); + + private final Curator curator; + private final Clock clock; + private final Deployer deployer; + + @Inject + public RoutingStatusApiHandler(Context context, Curator curator, Deployer deployer) { + this(context, curator, Clock.systemUTC(), deployer); + } + + RoutingStatusApiHandler(Context context, Curator curator, Clock clock, Deployer deployer) { + super(context, RoutingStatusApiHandler::createRestApiDefinition); + this.curator = Objects.requireNonNull(curator); + this.clock = Objects.requireNonNull(clock); + this.deployer = Objects.requireNonNull(deployer); + } + + private static RestApi createRestApiDefinition(RoutingStatusApiHandler self) { + return RestApi.builder() + .addRoute(RestApi.route("/routing/v1/status") + .get(self::listInactiveDeployments)) + .addRoute(RestApi.route("/routing/v1/status/zone") + .get(self::zoneStatus) + .put(self::changeZoneStatus) + .delete(self::changeZoneStatus)) + .addRoute(RestApi.route("/routing/v1/status/{upstreamName}") + .get(self::getDeploymentStatus) + .put(self::changeDeploymentStatus)) + .build(); + } + + /** Get upstream of all deployments with status OUT */ + private SlimeJsonResponse listInactiveDeployments(RestApi.RequestContext context) { + List<String> inactiveDeployments = curator.getChildren(DEPLOYMENT_STATUS_ROOT).stream() + .filter(upstreamName -> deploymentStatus(upstreamName).status() == RoutingStatus.out) + .collect(Collectors.toUnmodifiableList()); + Slime slime = new Slime(); + Cursor rootArray = slime.setArray(); + inactiveDeployments.forEach(rootArray::addString); + return new SlimeJsonResponse(slime); + } + + /** Get the routing status of a deployment */ + private SlimeJsonResponse getDeploymentStatus(RestApi.RequestContext context) { + String upstreamName = upstreamName(context); + DeploymentRoutingStatus deploymentRoutingStatus = deploymentStatus(upstreamName); + // If the entire zone is out, we always return OUT regardless of the actual routing status + if (zoneStatus() == RoutingStatus.out) { + String reason = String.format("Rotation is OUT because the zone is OUT (actual deployment status is %s)", + deploymentRoutingStatus.status().name().toUpperCase(Locale.ENGLISH)); + deploymentRoutingStatus = new DeploymentRoutingStatus(RoutingStatus.out, "operator", reason, + clock.instant()); + } + return new SlimeJsonResponse(toSlime(deploymentRoutingStatus)); + } + + /** Change routing status of a deployment */ + private SlimeJsonResponse changeDeploymentStatus(RestApi.RequestContext context) { + String upstreamName = upstreamName(context); + ApplicationId instance = instance(context); + Path path = deploymentStatusPath(upstreamName); + + RestApi.RequestContext.RequestContent requestContent = context.requestContentOrThrow(); + Slime requestBody = Exceptions.uncheck(() -> SlimeUtils.jsonToSlime(requestContent.content().readAllBytes())); + DeploymentRoutingStatus wantedStatus = deploymentRoutingStatusFromSlime(requestBody, clock.instant()); + DeploymentRoutingStatus currentStatus = deploymentStatus(upstreamName); + if (wantedStatus.status() == currentStatus.status()) { // No change + return new SlimeJsonResponse(toSlime(currentStatus)); + } + + // Redeploy application so that a new LbServicesConfig containing the updated status is generated and consumed + // by routing layer. This is required to update weights for application endpoints when routing status for a + // deployment is changed + curator.set(path, toJsonBytes(wantedStatus)); + try { + deployer.deployFromLocalActive(instance, Duration.ofMinutes(1)); + } catch (Exception e) { + log.log(Level.SEVERE, "Failed to redeploy " + instance + ". Reverting routing status to " + + currentStatus.status(), e); + curator.set(path, toJsonBytes(currentStatus)); + throw new RestApiException.InternalServerError("Failed to change status to " + + wantedStatus.status() + ", reverting to " + + currentStatus.status() + + " because redeployment of " + + instance + " failed: " + + Exceptions.toMessageString(e)); + } + return new SlimeJsonResponse(toSlime(wantedStatus)); + } + + /** Change routing status of a zone */ + private SlimeJsonResponse changeZoneStatus(RestApi.RequestContext context) { + boolean in = context.request().getMethod() == HttpRequest.Method.DELETE; + if (in) { + curator.delete(ZONE_STATUS_ROOT); + return new SlimeJsonResponse(toSlime(RoutingStatus.in)); + } else { + curator.create(ZONE_STATUS_ROOT); + return new SlimeJsonResponse(toSlime(RoutingStatus.out)); + } + } + + /** Read the status for zone */ + private SlimeJsonResponse zoneStatus(RestApi.RequestContext context) { + return new SlimeJsonResponse(toSlime(zoneStatus())); + } + + /** Read the status for a deployment */ + private DeploymentRoutingStatus deploymentStatus(String upstreamName) { + Instant changedAt = clock.instant(); + Path path = deploymentStatusPath(upstreamName); + Optional<byte[]> data = curator.getData(path); + if (data.isEmpty()) { + return new DeploymentRoutingStatus(RoutingStatus.in, "", "", changedAt); + } + String agent = ""; + String reason = ""; + RoutingStatus status = RoutingStatus.out; + if (data.get().length > 0) { // Compatibility with old format, where no data is stored + Slime slime = SlimeUtils.jsonToSlime(data.get()); + Cursor root = slime.get(); + status = asRoutingStatus(root.field("status").asString()); + agent = root.field("agent").asString(); + reason = root.field("cause").asString(); + changedAt = Instant.ofEpochSecond(root.field("lastUpdate").asLong()); + } + return new DeploymentRoutingStatus(status, agent, reason, changedAt); + } + + private RoutingStatus zoneStatus() { + return curator.exists(ZONE_STATUS_ROOT) ? RoutingStatus.out : RoutingStatus.in; + } + + protected Path deploymentStatusPath(String upstreamName) { + return DEPLOYMENT_STATUS_ROOT.append(upstreamName); + } + + private static String upstreamName(RestApi.RequestContext context) { + String upstreamName = context.pathParameters().getStringOrThrow("upstreamName"); + if (upstreamName.contains(" ")) { + throw new RestApiException.BadRequest("Invalid upstream name: '" + upstreamName + "'"); + } + return upstreamName; + } + + private static ApplicationId instance(RestApi.RequestContext context) { + return context.queryParameters().getString("application") + .map(ApplicationId::fromSerializedForm) + .orElseThrow(() -> new RestApiException.BadRequest("Missing application parameter")); + } + + private byte[] toJsonBytes(DeploymentRoutingStatus status) { + return Exceptions.uncheck(() -> SlimeUtils.toJsonBytes(toSlime(status))); + } + + private Slime toSlime(DeploymentRoutingStatus status) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + root.setString("status", asString(status.status())); + root.setString("cause", status.reason()); + root.setString("agent", status.agent()); + root.setLong("lastUpdate", status.changedAt().getEpochSecond()); + return slime; + } + + private static Slime toSlime(RoutingStatus status) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + root.setString("status", asString(status)); + return slime; + } + + private static RoutingStatus asRoutingStatus(String s) { + switch (s) { + case "IN": return RoutingStatus.in; + case "OUT": return RoutingStatus.out; + } + throw new IllegalArgumentException("Unknown status: '" + s + "'"); + } + + private static String asString(RoutingStatus status) { + switch (status) { + case in: return "IN"; + case out: return "OUT"; + } + throw new IllegalArgumentException("Unknown status: " + status); + } + + private static DeploymentRoutingStatus deploymentRoutingStatusFromSlime(Slime slime, Instant changedAt) { + Cursor root = slime.get(); + return new DeploymentRoutingStatus(asRoutingStatus(root.field("status").asString()), + root.field("agent").asString(), + root.field("cause").asString(), + changedAt); + } + + private static class DeploymentRoutingStatus { + + private final RoutingStatus status; + private final String agent; + private final String reason; + private final Instant changedAt; + + public DeploymentRoutingStatus(RoutingStatus status, String agent, String reason, Instant changedAt) { + this.status = Objects.requireNonNull(status); + this.agent = Objects.requireNonNull(agent); + this.reason = Objects.requireNonNull(reason); + this.changedAt = Objects.requireNonNull(changedAt); + } + + public RoutingStatus status() { + return status; + } + + public String agent() { + return agent; + } + + public String reason() { + return reason; + } + + public Instant changedAt() { + return changedAt; + } + + } + + private enum RoutingStatus { + in, out + } + +} diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v1/RoutingStatusApiHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v1/RoutingStatusApiHandlerTest.java new file mode 100644 index 00000000000..3eed93ce131 --- /dev/null +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v1/RoutingStatusApiHandlerTest.java @@ -0,0 +1,204 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.config.server.http.v1; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.Deployer; +import com.yahoo.config.provision.Deployment; +import com.yahoo.container.jdisc.HttpRequestBuilder; +import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.jdisc.http.HttpRequest.Method; +import com.yahoo.path.Path; +import com.yahoo.restapi.RestApiTestDriver; +import com.yahoo.test.ManualClock; +import com.yahoo.vespa.curator.Curator; +import com.yahoo.vespa.curator.mock.MockCurator; +import org.junit.Before; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.yahoo.yolean.Exceptions.uncheck; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author bjorncs + * @author mpolden + */ +public class RoutingStatusApiHandlerTest { + + private static final ApplicationId instance = ApplicationId.from("t1", "a1", "i1"); + private static final String upstreamName = "test-upstream-name"; + + private final Curator curator = new MockCurator(); + private final ManualClock clock = new ManualClock(); + private final MockDeployer deployer = new MockDeployer(clock); + + private RestApiTestDriver testDriver; + + @Before + public void before() { + RoutingStatusApiHandler requestHandler = new RoutingStatusApiHandler(RestApiTestDriver.createHandlerTestContext(), + curator, + clock, + deployer); + testDriver = RestApiTestDriver.newBuilder(requestHandler).build(); + } + + @Test + public void list_deployment_status() { + List<String> expected = List.of("foo", "bar"); + for (String upstreamName : expected) { + executeRequest(Method.PUT, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), + statusOut()); + } + String actual = responseAsString(executeRequest(Method.GET, "/routing/v1/status", null)); + assertEquals("[\"foo\",\"bar\"]", actual); + } + + @Test + public void get_deployment_status() { + String response = responseAsString(executeRequest(Method.GET, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), null)); + assertEquals(response("IN", "", "", clock.instant()), response); + } + + @Test + public void set_deployment_status() { + String response = responseAsString(executeRequest(Method.PUT, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), + statusOut())); + assertEquals(response("OUT", "issue-XXX", "operator", clock.instant()), response); + assertTrue("Re-deployed " + instance, deployer.lastDeployed.containsKey(instance)); + + // Status is reverted if redeployment fails + deployer.failNextDeployment(true); + response = responseAsString(executeRequest(Method.PUT, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), + requestContent("IN", "all good"))); + assertEquals("{\"error-code\":\"INTERNAL_SERVER_ERROR\",\"message\":\"Failed to change status to in, reverting to out because redeployment of t1.a1.i1 failed: Deployment failed\"}", + response); + + // Read status stored in old format (path exists, but without content) + curator.set(Path.fromString("/routing/v1/status/" + upstreamName), new byte[0]); + response = responseAsString(executeRequest(Method.GET, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), null)); + + assertEquals(response("OUT", "", "", clock.instant()), response); + } + + @Test + public void fail_on_invalid_upstream_name() { + HttpResponse response = executeRequest(Method.GET, "/routing/v1/status/" + upstreamName + "%20invalid", null); + assertEquals(400, response.getStatus()); + } + + @Test + public void fail_on_changing_routing_status_without_request_content() { + HttpResponse response = executeRequest(Method.PUT, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), null); + assertEquals(400, response.getStatus()); + } + + @Test + public void zone_status_out_overrides_deployment_status() { + // Setting zone out overrides deployment status + executeRequest(Method.PUT, "/routing/v1/status/zone", null); + String response = responseAsString(executeRequest(Method.GET, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), null)); + assertEquals(response("OUT", "Rotation is OUT because the zone is OUT (actual deployment status is IN)", "operator", clock.instant()), response); + + // Setting zone back in falls back to deployment status, which is also out + executeRequest(Method.DELETE, "/routing/v1/status/zone", null); + String response2 = responseAsString(executeRequest(Method.PUT, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), + statusOut())); + assertEquals(response("OUT", "issue-XXX", "operator", clock.instant()), response2); + + // Deployment status is changed to in + String response3 = responseAsString(executeRequest(Method.PUT, "/routing/v1/status/" + upstreamName + "?application=" + instance.serializedForm(), + requestContent("IN", "all good"))); + assertEquals(response("IN", "all good", "operator", clock.instant()), response3); + } + + @Test + public void set_zone_status() { + executeRequest(Method.PUT, "/routing/v1/status/zone", null); + String response = responseAsString(executeRequest(Method.GET, "/routing/v1/status/zone", null)); + assertEquals("{\"status\":\"OUT\"}", response); + executeRequest(Method.DELETE, "/routing/v1/status/zone", null); + response = responseAsString(executeRequest(Method.GET, "/routing/v1/status/zone", null)); + assertEquals("{\"status\":\"IN\"}", response); + } + + private HttpResponse executeRequest(Method method, String path, String requestContent) { + var builder = HttpRequestBuilder.create(method, path); + if (requestContent != null) { + builder.withRequestContent(new ByteArrayInputStream(requestContent.getBytes(StandardCharsets.UTF_8))); + } + return testDriver.executeRequest(builder.build()); + } + + private static String responseAsString(HttpResponse response) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + uncheck(() -> response.render(out)); + return out.toString(StandardCharsets.UTF_8); + } + + private static String statusOut() { + return requestContent("OUT", "issue-XXX"); + } + + private static String requestContent(String status, String cause) { + return "{\"status\": \"" + status + "\", \"agent\":\"operator\", \"cause\": \"" + cause + "\"}"; + } + + private static String response(String status, String reason, String agent, Instant instant) { + return "{\"status\":\"" + status + "\",\"cause\":\"" + reason + "\",\"agent\":\"" + agent + "\",\"lastUpdate\":" + instant.getEpochSecond() + "}"; + } + + private static class MockDeployer implements Deployer { + + private final Map<ApplicationId, Instant> lastDeployed = new HashMap<>(); + private final Clock clock; + + private boolean failNextDeployment = false; + + public MockDeployer(Clock clock) { + this.clock = clock; + } + + public MockDeployer failNextDeployment(boolean fail) { + this.failNextDeployment = fail; + return this; + } + + @Override + public Optional<Deployment> deployFromLocalActive(ApplicationId application, boolean bootstrap) { + return deployFromLocalActive(application, Duration.ZERO, false); + } + + @Override + public Optional<Deployment> deployFromLocalActive(ApplicationId application, Duration timeout, boolean bootstrap) { + if (failNextDeployment) { + throw new RuntimeException("Deployment failed"); + } + lastDeployed.put(application, clock.instant()); + return Optional.empty(); + } + + @Override + public Optional<Instant> lastDeployTime(ApplicationId application) { + return Optional.ofNullable(lastDeployed.get(application)); + } + + @Override + public Duration serverDeployTimeout() { + return Duration.ZERO; + } + + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ContainerEndpoint.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ContainerEndpoint.java index 159e4aa15da..7246903a51b 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ContainerEndpoint.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ContainerEndpoint.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.configserver; +import com.yahoo.config.provision.zone.RoutingMethod; + import java.util.List; import java.util.Objects; import java.util.OptionalInt; @@ -16,12 +18,14 @@ public class ContainerEndpoint { private final String scope; private final List<String> names; private final OptionalInt weight; + private final RoutingMethod routingMethod; - public ContainerEndpoint(String clusterId, String scope, List<String> names, OptionalInt weight) { + public ContainerEndpoint(String clusterId, String scope, List<String> names, OptionalInt weight, RoutingMethod routingMethod) { this.clusterId = nonEmpty(clusterId, "clusterId must be non-empty"); this.scope = Objects.requireNonNull(scope, "scope must be non-null"); this.names = List.copyOf(Objects.requireNonNull(names, "names must be non-null")); this.weight = Objects.requireNonNull(weight, "weight must be non-null"); + this.routingMethod = Objects.requireNonNull(routingMethod, "routingMethod must be non-null"); } /** ID of the cluster to which this points */ @@ -47,23 +51,28 @@ public class ContainerEndpoint { return weight; } + /** The routing method used by this endpoint */ + public RoutingMethod routingMethod() { + return routingMethod; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ContainerEndpoint that = (ContainerEndpoint) o; - return clusterId.equals(that.clusterId) && scope.equals(that.scope) && names.equals(that.names) && weight.equals(that.weight); + return clusterId.equals(that.clusterId) && scope.equals(that.scope) && names.equals(that.names) && weight.equals(that.weight) && routingMethod == that.routingMethod; } @Override public int hashCode() { - return Objects.hash(clusterId, scope, names, weight); + return Objects.hash(clusterId, scope, names, weight, routingMethod); } @Override public String toString() { return "container endpoint for cluster " + clusterId + ": " + String.join(", ", names) + - " [scope=" + scope + ",weight=" + + " [method=" + routingMethod + ",scope=" + scope + ",weight=" + weight.stream().boxed().map(Object::toString).findFirst().orElse("<none>") + "]"; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java index 1a9dc90b1b8..6ef0df9f099 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java @@ -35,9 +35,10 @@ import com.yahoo.vespa.hosted.controller.routing.RoutingPolicies; import com.yahoo.vespa.hosted.controller.routing.context.DeploymentRoutingContext; import com.yahoo.vespa.hosted.controller.routing.context.DeploymentRoutingContext.ExclusiveDeploymentRoutingContext; import com.yahoo.vespa.hosted.controller.routing.context.DeploymentRoutingContext.SharedDeploymentRoutingContext; -import com.yahoo.vespa.hosted.controller.routing.context.ExclusiveRoutingContext; +import com.yahoo.vespa.hosted.controller.routing.context.ExclusiveZoneRoutingContext; import com.yahoo.vespa.hosted.controller.routing.context.RoutingContext; -import com.yahoo.vespa.hosted.controller.routing.context.SharedRoutingContext; +import com.yahoo.vespa.hosted.controller.routing.context.SharedZoneRoutingContext; +import com.yahoo.vespa.hosted.controller.routing.rotation.Rotation; import com.yahoo.vespa.hosted.controller.routing.rotation.RotationLock; import com.yahoo.vespa.hosted.controller.routing.rotation.RotationRepository; import com.yahoo.vespa.hosted.rotation.config.RotationsConfig; @@ -97,9 +98,9 @@ public class RoutingController { /** Create a routing context for given zone */ public RoutingContext of(ZoneId zone) { if (usesSharedRouting(zone)) { - return new SharedRoutingContext(zone, controller.serviceRegistry().configServer()); + return new SharedZoneRoutingContext(zone, controller.serviceRegistry().configServer()); } - return new ExclusiveRoutingContext(zone, routingPolicies); + return new ExclusiveZoneRoutingContext(zone, routingPolicies); } public RoutingPolicies policies() { @@ -259,7 +260,6 @@ public class RoutingController { EndpointList endpoints = declaredEndpointsOf(application.get()).targets(deployment); EndpointList globalEndpoints = endpoints.scope(Endpoint.Scope.global); for (var assignedRotation : instance.rotations()) { - var names = new ArrayList<String>(); EndpointList rotationEndpoints = globalEndpoints.named(assignedRotation.endpointId()) .requiresRotation(); @@ -274,22 +274,21 @@ public class RoutingController { } // Register names in DNS - var rotation = rotationRepository.getRotation(assignedRotation.rotationId()); - if (rotation.isPresent()) { - rotationEndpoints.forEach(endpoint -> { - controller.nameServiceForwarder().createCname(RecordName.from(endpoint.dnsName()), - RecordData.fqdn(rotation.get().name()), - Priority.normal); - names.add(endpoint.dnsName()); - }); + Rotation rotation = rotationRepository.requireRotation(assignedRotation.rotationId()); + for (var endpoint : rotationEndpoints) { + controller.nameServiceForwarder().createCname(RecordName.from(endpoint.dnsName()), + RecordData.fqdn(rotation.name()), + Priority.normal); + List<String> names = List.of(endpoint.dnsName(), + // Include rotation ID as a valid name of this container endpoint + // (required by global routing health checks) + assignedRotation.rotationId().asString()); + containerEndpoints.add(new ContainerEndpoint(assignedRotation.clusterId().value(), + asString(Endpoint.Scope.global), + names, + OptionalInt.empty(), + endpoint.routingMethod())); } - - // Include rotation ID as a valid name of this container endpoint (required by global routing health checks) - names.add(assignedRotation.rotationId().asString()); - containerEndpoints.add(new ContainerEndpoint(assignedRotation.clusterId().value(), - asString(Endpoint.Scope.global), - names, - OptionalInt.empty())); } // Add endpoints not backed by a rotation (i.e. other routing methods so that the config server always knows // about global names, even when not using rotations) @@ -299,7 +298,8 @@ public class RoutingController { containerEndpoints.add(new ContainerEndpoint(clusterId.value(), asString(Endpoint.Scope.global), clusterEndpoints.mapToList(Endpoint::dnsName), - OptionalInt.empty())); + OptionalInt.empty(), + RoutingMethod.exclusive)); }); // Add application endpoints EndpointList applicationEndpoints = endpoints.scope(Endpoint.Scope.application); @@ -329,7 +329,8 @@ public class RoutingController { containerEndpoints.add(new ContainerEndpoint(clusterId.value(), asString(Endpoint.Scope.application), List.of(endpoint.dnsName()), - OptionalInt.of(matchingTarget.get().weight()))); + OptionalInt.of(matchingTarget.get().weight()), + endpoint.routingMethod())); } } return Collections.unmodifiableSet(containerEndpoints); @@ -389,8 +390,8 @@ public class RoutingController { var deploymentsByMethod = new HashMap<RoutingMethod, Set<DeploymentId>>(); for (var deployment : deployments) { for (var method : controller.zoneRegistry().routingMethods(deployment.zoneId())) { - deploymentsByMethod.putIfAbsent(method, new LinkedHashSet<>()); - deploymentsByMethod.get(method).add(deployment); + deploymentsByMethod.computeIfAbsent(method, k -> new LinkedHashSet<>()) + .add(deployment); } } var routingMethods = new ArrayList<RoutingMethod>(); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/ExclusiveRoutingContext.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/ExclusiveZoneRoutingContext.java index e949c45f2fd..e29fb5ab404 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/ExclusiveRoutingContext.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/ExclusiveZoneRoutingContext.java @@ -13,12 +13,12 @@ import java.util.Objects; * * @author mpolden */ -public class ExclusiveRoutingContext implements RoutingContext { +public class ExclusiveZoneRoutingContext implements RoutingContext { private final RoutingPolicies policies; private final ZoneId zone; - public ExclusiveRoutingContext(ZoneId zone, RoutingPolicies policies) { + public ExclusiveZoneRoutingContext(ZoneId zone, RoutingPolicies policies) { this.policies = Objects.requireNonNull(policies); this.zone = Objects.requireNonNull(zone); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/SharedRoutingContext.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/SharedZoneRoutingContext.java index e38212d7f80..2923c8dff5c 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/SharedRoutingContext.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/SharedZoneRoutingContext.java @@ -15,12 +15,12 @@ import java.util.Objects; * * @author mpolden */ -public class SharedRoutingContext implements RoutingContext { +public class SharedZoneRoutingContext implements RoutingContext { private final ConfigServer configServer; private final ZoneId zone; - public SharedRoutingContext(ZoneId zone, ConfigServer configServer) { + public SharedZoneRoutingContext(ZoneId zone, ConfigServer configServer) { this.configServer = Objects.requireNonNull(configServer); this.zone = Objects.requireNonNull(zone); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepository.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepository.java index 961fdc6dd9c..39a0b6a8858 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepository.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepository.java @@ -21,7 +21,6 @@ import java.util.Comparator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.function.Function; import java.util.logging.Logger; @@ -56,9 +55,11 @@ public class RotationRepository { return new RotationLock(curator.lockRotations()); } - /** Get rotation by given rotationId */ - public Optional<Rotation> getRotation(RotationId rotationId) { - return Optional.of(allRotations.get(rotationId)); + /** Get rotation with given id */ + public Rotation requireRotation(RotationId id) { + Rotation rotation = allRotations.get(id); + if (rotation == null) throw new IllegalArgumentException("No such rotation: '" + id.asString() + "'"); + return rotation; } /** diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index 2e0bcf78838..1215ddbc2ad 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -233,22 +233,41 @@ public class ControllerTest { public void testDnsUpdatesForGlobalEndpoint() { var betaContext = tester.newDeploymentContext("tenant1", "app1", "beta"); var defaultContext = tester.newDeploymentContext("tenant1", "app1", "default"); + + ZoneId usWest = ZoneId.from("prod.us-west-1"); + ZoneId usCentral = ZoneId.from("prod.us-central-1"); ApplicationPackage applicationPackage = new ApplicationPackageBuilder() + .athenzIdentity(AthenzDomain.from("domain"), AthenzService.from("service")) .instances("beta,default") .endpoint("default", "foo") - .region("us-west-1") - .region("us-central-1") // Two deployments should result in each DNS alias being registered once + .region(usWest.region()) + .region(usCentral.region()) // Two deployments should result in each DNS alias being registered once .build(); + tester.controllerTester().zoneRegistry().setRoutingMethod(List.of(ZoneApiMock.from(usWest), ZoneApiMock.from(usCentral)), + RoutingMethod.shared, + RoutingMethod.sharedLayer4); betaContext.submit(applicationPackage).deploy(); { // Expected rotation names are passed to beta instance deployments Collection<Deployment> betaDeployments = betaContext.instance().deployments().values(); assertFalse(betaDeployments.isEmpty()); + Set<ContainerEndpoint> containerEndpoints = Set.of(new ContainerEndpoint("foo", + "global", + List.of("beta--app1--tenant1.global.vespa.oath.cloud", + "rotation-id-01"), + OptionalInt.empty(), + RoutingMethod.shared), + new ContainerEndpoint("foo", + "global", + List.of("beta.app1.tenant1.global.vespa.oath.cloud", + "rotation-id-01"), + OptionalInt.empty(), + RoutingMethod.sharedLayer4)); + for (Deployment deployment : betaDeployments) { - assertEquals("Rotation names are passed to config server in " + deployment.zone(), - Set.of("rotation-id-01", - "beta--app1--tenant1.global.vespa.oath.cloud"), - tester.configServer().containerEndpointNames(betaContext.deploymentIdIn(deployment.zone()))); + assertEquals(containerEndpoints, + tester.configServer().containerEndpoints() + .get(betaContext.deploymentIdIn(deployment.zone()))); } betaContext.flushDnsUpdates(); } @@ -256,11 +275,21 @@ public class ControllerTest { { // Expected rotation names are passed to default instance deployments Collection<Deployment> defaultDeployments = defaultContext.instance().deployments().values(); assertFalse(defaultDeployments.isEmpty()); + Set<ContainerEndpoint> containerEndpoints = Set.of(new ContainerEndpoint("foo", + "global", + List.of("app1--tenant1.global.vespa.oath.cloud", + "rotation-id-02"), + OptionalInt.empty(), + RoutingMethod.shared), + new ContainerEndpoint("foo", + "global", + List.of("app1.tenant1.global.vespa.oath.cloud", + "rotation-id-02"), + OptionalInt.empty(), + RoutingMethod.sharedLayer4)); for (Deployment deployment : defaultDeployments) { - assertEquals("Rotation names are passed to config server in " + deployment.zone(), - Set.of("rotation-id-02", - "app1--tenant1.global.vespa.oath.cloud"), - tester.configServer().containerEndpointNames(defaultContext.deploymentIdIn(deployment.zone()))); + assertEquals(containerEndpoints, + tester.configServer().containerEndpoints().get(defaultContext.deploymentIdIn(deployment.zone()))); } defaultContext.flushDnsUpdates(); } @@ -274,13 +303,17 @@ public class ControllerTest { assertEquals(data, record.get().data().asString()); }); - Map<ApplicationId, List<String>> globalDnsNamesByInstance = Map.of(betaContext.instanceId(), List.of("beta--app1--tenant1.global.vespa.oath.cloud"), - defaultContext.instanceId(), List.of("app1--tenant1.global.vespa.oath.cloud")); + Map<ApplicationId, Set<String>> globalDnsNamesByInstance = Map.of(betaContext.instanceId(), Set.of("beta--app1--tenant1.global.vespa.oath.cloud", + "beta.app1.tenant1.global.vespa.oath.cloud"), + defaultContext.instanceId(), Set.of("app1--tenant1.global.vespa.oath.cloud", + "app1.tenant1.global.vespa.oath.cloud")); globalDnsNamesByInstance.forEach((instance, dnsNames) -> { - List<String> actualDnsNames = tester.controller().routing().readDeclaredEndpointsOf(instance) - .scope(Endpoint.Scope.global) - .mapToList(Endpoint::dnsName); + Set<String> actualDnsNames = tester.controller().routing().readDeclaredEndpointsOf(instance) + .scope(Endpoint.Scope.global) + .asList().stream() + .map(Endpoint::dnsName) + .collect(Collectors.toSet()); assertEquals("Global DNS names for " + instance, dnsNames, actualDnsNames); }); } @@ -651,7 +684,8 @@ public class ControllerTest { Set<ContainerEndpoint> expected = endpoints.entrySet().stream() .map(kv -> new ContainerEndpoint("default", "application", List.of(kv.getKey()), - OptionalInt.of(kv.getValue()))) + OptionalInt.of(kv.getValue()), + RoutingMethod.sharedLayer4)) .collect(Collectors.toSet()); assertEquals("Endpoint names for " + deployment + " are passed to config server", expected, diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepositoryTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepositoryTest.java index 9a3ac8b547d..9a56123e8e3 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepositoryTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/rotation/RotationRepositoryTest.java @@ -205,10 +205,9 @@ public class RotationRepositoryTest { private void assertSingleRotation(Rotation expected, List<AssignedRotation> assignedRotations, RotationRepository repository) { assertEquals(1, assignedRotations.size()); - var rotationId = assignedRotations.get(0).rotationId(); - var rotation = repository.getRotation(rotationId); - assertTrue(rotationId + " exists", rotation.isPresent()); - assertEquals(expected, rotation.get()); + RotationId rotationId = assignedRotations.get(0).rotationId(); + Rotation rotation = repository.requireRotation(rotationId); + assertEquals(expected, rotation); } private static List<RotationId> rotationIds(List<AssignedRotation> assignedRotations) { diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java index 3674cba0d97..b2efd35e41e 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileDownloader.java @@ -69,8 +69,8 @@ public class FileDownloader implements AutoCloseable { downloadDirectory); } - public Optional<File> getFile(FileReference fileReference) { - return getFile(new FileReferenceDownload(fileReference)); + public Optional<File> getFile(FileReference fileReference, String client) { + return getFile(new FileReferenceDownload(fileReference, client)); } public Optional<File> getFile(FileReferenceDownload fileReferenceDownload) { diff --git a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java index 21e35bf67af..796f6ad2ebf 100644 --- a/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java +++ b/filedistribution/src/main/java/com/yahoo/vespa/filedistribution/FileReferenceDownload.java @@ -22,6 +22,10 @@ public class FileReferenceDownload { this(fileReference, true, "unknown"); } + public FileReferenceDownload(FileReference fileReference, String client) { + this(fileReference, true, client); + } + public FileReferenceDownload(FileReference fileReference, boolean downloadFromOtherSourceIfNotFound, String client) { Objects.requireNonNull(fileReference, "file reference cannot be null"); this.fileReference = fileReference; diff --git a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java index 97b948ef5d4..460a1ee593a 100644 --- a/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java +++ b/filedistribution/src/test/java/com/yahoo/vespa/filedistribution/FileDownloaderTest.java @@ -79,7 +79,7 @@ public class FileDownloaderTest { fileDownloader.downloads().completedDownloading(fileReference, fileReferenceFullPath); // Check that we get correct path and content when asking for file reference - Optional<File> pathToFile = fileDownloader.getFile(fileReference); + Optional<File> pathToFile = getFile(fileReference); assertTrue(pathToFile.isPresent()); String downloadedFile = new File(fileReferenceFullPath, filename).getAbsolutePath(); assertEquals(new File(fileReferenceFullPath, filename).getAbsolutePath(), downloadedFile); @@ -96,7 +96,7 @@ public class FileDownloaderTest { FileReference fileReference = new FileReference("bar"); File fileReferenceFullPath = fileReferenceFullPath(downloadDir, fileReference); - assertFalse(fileReferenceFullPath.getAbsolutePath(), fileDownloader.getFile(fileReference).isPresent()); + assertFalse(fileReferenceFullPath.getAbsolutePath(), getFile(fileReference).isPresent()); // Verify download status when unable to download assertDownloadStatus(fileReference, 0.0); @@ -107,7 +107,7 @@ public class FileDownloaderTest { FileReference fileReference = new FileReference("baz"); File fileReferenceFullPath = fileReferenceFullPath(downloadDir, fileReference); - assertFalse(fileReferenceFullPath.getAbsolutePath(), fileDownloader.getFile(fileReference).isPresent()); + assertFalse(fileReferenceFullPath.getAbsolutePath(), getFile(fileReference).isPresent()); // Verify download status assertDownloadStatus(fileReference, 0.0); @@ -115,7 +115,7 @@ public class FileDownloaderTest { // Receives fileReference, should return and make it available to caller String filename = "abc.jar"; receiveFile(fileReference, filename, FileReferenceData.Type.file, "some other content"); - Optional<File> downloadedFile = fileDownloader.getFile(fileReference); + Optional<File> downloadedFile = getFile(fileReference); assertTrue(downloadedFile.isPresent()); File downloadedFileFullPath = new File(fileReferenceFullPath, filename); @@ -132,7 +132,7 @@ public class FileDownloaderTest { FileReference fileReference = new FileReference("fileReferenceToDirWithManyFiles"); File fileReferenceFullPath = fileReferenceFullPath(downloadDir, fileReference); - assertFalse(fileReferenceFullPath.getAbsolutePath(), fileDownloader.getFile(fileReference).isPresent()); + assertFalse(fileReferenceFullPath.getAbsolutePath(), getFile(fileReference).isPresent()); // Verify download status assertDownloadStatus(fileReference, 0.0); @@ -150,7 +150,7 @@ public class FileDownloaderTest { File tarFile = CompressedFileReference.compress(tempPath.toFile(), Arrays.asList(fooFile, barFile), new File(tempPath.toFile(), filename)); byte[] tarredContent = IOUtils.readFileBytes(tarFile); receiveFile(fileReference, filename, FileReferenceData.Type.compressed, tarredContent); - Optional<File> downloadedFile = fileDownloader.getFile(fileReference); + Optional<File> downloadedFile = getFile(fileReference); assertTrue(downloadedFile.isPresent()); File downloadedFoo = new File(fileReferenceFullPath, tempPath.relativize(fooFile.toPath()).toString()); @@ -174,7 +174,7 @@ public class FileDownloaderTest { FileReference fileReference = new FileReference("fileReference"); File fileReferenceFullPath = fileReferenceFullPath(downloadDir, fileReference); - assertFalse(fileReferenceFullPath.getAbsolutePath(), fileDownloader.getFile(fileReference).isPresent()); + assertFalse(fileReferenceFullPath.getAbsolutePath(), getFile(fileReference).isPresent()); // Getting file failed, verify download status and since there was an error is not downloading ATM assertDownloadStatus(fileReference, 0.0); @@ -183,7 +183,7 @@ public class FileDownloaderTest { // Receives fileReference, should return and make it available to caller String filename = "abc.jar"; receiveFile(fileReference, filename, FileReferenceData.Type.file, "some other content"); - Optional<File> downloadedFile = fileDownloader.getFile(fileReference); + Optional<File> downloadedFile = getFile(fileReference); assertTrue(downloadedFile.isPresent()); File downloadedFileFullPath = new File(fileReferenceFullPath, filename); assertEquals(downloadedFileFullPath.getAbsolutePath(), downloadedFile.get().getAbsolutePath()); @@ -244,13 +244,13 @@ public class FileDownloaderTest { // Should download since we do not have the file on disk fileDownloader.downloadIfNeeded(new FileReferenceDownload(xyzzy)); assertTrue(fileDownloader.isDownloading(xyzzy)); - assertFalse(fileDownloader.getFile(xyzzy).isPresent()); + assertFalse(getFile(xyzzy).isPresent()); // Receive files to simulate download receiveFile(xyzzy, "xyzzy.jar", FileReferenceData.Type.file, "content"); // Should not download, since file has already been downloaded fileDownloader.downloadIfNeeded(new FileReferenceDownload(xyzzy)); // and file should be available - assertTrue(fileDownloader.getFile(xyzzy).isPresent()); + assertTrue(getFile(xyzzy).isPresent()); } @Test @@ -296,6 +296,10 @@ public class FileDownloaderTest { fileDownloader.downloads().completedDownloading(fileReference, file); } + private Optional<File> getFile(FileReference fileReference) { + return fileDownloader.getFile(fileReference, "test"); + } + private static class MockConnection implements ConnectionPool, com.yahoo.vespa.config.Connection { private ResponseHandler responseHandler; diff --git a/searchcore/src/apps/tests/persistenceconformance_test.cpp b/searchcore/src/apps/tests/persistenceconformance_test.cpp index 8238eb21831..be9d394a2b6 100644 --- a/searchcore/src/apps/tests/persistenceconformance_test.cpp +++ b/searchcore/src/apps/tests/persistenceconformance_test.cpp @@ -3,15 +3,19 @@ #include <vespa/vespalib/testkit/testapp.h> #include <tests/proton/common/dummydbowner.h> +#include <vespa/config-attributes.h> +#include <vespa/config-bucketspaces.h> #include <vespa/config-imported-fields.h> +#include <vespa/config-indexschema.h> #include <vespa/config-rank-profiles.h> +#include <vespa/config-summary.h> #include <vespa/config-summarymap.h> #include <vespa/document/base/testdocman.h> +#include <vespa/document/repo/documenttyperepo.h> +#include <vespa/document/test/make_bucket_space.h> #include <vespa/fastos/file.h> #include <vespa/persistence/conformancetest/conformancetest.h> #include <vespa/persistence/dummyimpl/dummy_bucket_executor.h> -#include <vespa/document/repo/documenttyperepo.h> -#include <vespa/document/test/make_bucket_space.h> #include <vespa/searchcommon/common/schemaconfigurer.h> #include <vespa/searchcore/proton/common/alloc_config.h> #include <vespa/searchcore/proton/common/hw_info.h> @@ -28,13 +32,10 @@ #include <vespa/searchcore/proton/server/persistencehandlerproxy.h> #include <vespa/searchcore/proton/server/threading_service_config.h> #include <vespa/searchcore/proton/test/disk_mem_usage_notifier.h> +#include <vespa/searchcore/proton/test/mock_shared_threading_service.h> #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/searchsummary/config/config-juniperrc.h> -#include <vespa/config-bucketspaces.h> -#include <vespa/config-attributes.h> -#include <vespa/config-indexschema.h> -#include <vespa/config-summary.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/size_literals.h> @@ -174,6 +175,7 @@ private: mutable DummyWireService _metricsWireService; mutable MemoryConfigStores _config_stores; vespalib::ThreadStackExecutor _summaryExecutor; + MockSharedThreadingService _shared_service; storage::spi::dummy::DummyBucketExecutor _bucketExecutor; public: @@ -202,7 +204,7 @@ public: mgr.nextGeneration(0ms); return DocumentDB::create(_baseDir, mgr.getConfig(), _tlsSpec, _queryLimiter, _clock, docType, bucketSpace, *b->getProtonConfigSP(), const_cast<DocumentDBFactory &>(*this), - _summaryExecutor, _summaryExecutor, _bucketExecutor, _tls, _metricsWireService, + _shared_service, _bucketExecutor, _tls, _metricsWireService, _fileHeaderContext, _config_stores.getConfigStore(docType.toString()), std::make_shared<vespalib::ThreadStackExecutor>(16, 128_Ki), HwInfo()); } @@ -218,6 +220,7 @@ DocumentDBFactory::DocumentDBFactory(const vespalib::string &baseDir, int tlsLis _clock(), _metricsWireService(), _summaryExecutor(8, 128_Ki), + _shared_service(_summaryExecutor, _summaryExecutor), _bucketExecutor(2) {} DocumentDBFactory::~DocumentDBFactory() = default; diff --git a/searchcore/src/tests/proton/docsummary/docsummary.cpp b/searchcore/src/tests/proton/docsummary/docsummary.cpp index 5c3fe94a8d7..c5a01de6b3b 100644 --- a/searchcore/src/tests/proton/docsummary/docsummary.cpp +++ b/searchcore/src/tests/proton/docsummary/docsummary.cpp @@ -1,45 +1,46 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <tests/proton/common/dummydbowner.h> +#include <vespa/config-bucketspaces.h> #include <vespa/config/helper/configgetter.hpp> +#include <vespa/document/repo/documenttyperepo.h> +#include <vespa/document/test/make_bucket_space.h> #include <vespa/eval/eval/simple_value.h> #include <vespa/eval/eval/tensor_spec.h> -#include <vespa/eval/eval/value.h> #include <vespa/eval/eval/test/value_compare.h> -#include <vespa/document/repo/documenttyperepo.h> -#include <vespa/document/test/make_bucket_space.h> +#include <vespa/eval/eval/value.h> +#include <vespa/persistence/dummyimpl/dummy_bucket_executor.h> #include <vespa/searchcore/proton/attribute/attribute_writer.h> -#include <vespa/searchcore/proton/test/bucketfactory.h> #include <vespa/searchcore/proton/docsummary/docsumcontext.h> #include <vespa/searchcore/proton/docsummary/documentstoreadapter.h> #include <vespa/searchcore/proton/docsummary/summarymanager.h> #include <vespa/searchcore/proton/documentmetastore/documentmetastore.h> #include <vespa/searchcore/proton/feedoperation/putoperation.h> +#include <vespa/searchcore/proton/matching/querylimiter.h> #include <vespa/searchcore/proton/metrics/metricswireservice.h> #include <vespa/searchcore/proton/server/bootstrapconfig.h> #include <vespa/searchcore/proton/server/documentdb.h> -#include <vespa/searchcore/proton/server/feedhandler.h> #include <vespa/searchcore/proton/server/documentdbconfigmanager.h> +#include <vespa/searchcore/proton/server/feedhandler.h> #include <vespa/searchcore/proton/server/idocumentsubdb.h> #include <vespa/searchcore/proton/server/memoryconfigstore.h> #include <vespa/searchcore/proton/server/searchview.h> #include <vespa/searchcore/proton/server/summaryadapter.h> -#include <vespa/searchcore/proton/matching/querylimiter.h> -#include <vespa/persistence/dummyimpl/dummy_bucket_executor.h> -#include <vespa/vespalib/util/destructor_callbacks.h> +#include <vespa/searchcore/proton/test/bucketfactory.h> +#include <vespa/searchcore/proton/test/mock_shared_threading_service.h> #include <vespa/searchlib/engine/docsumapi.h> #include <vespa/searchlib/index/docbuilder.h> #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/searchlib/tensor/tensor_attribute.h> #include <vespa/searchlib/transactionlog/nosyncproxy.h> #include <vespa/searchlib/transactionlog/translogserver.h> -#include <vespa/vespalib/data/slime/slime.h> -#include <vespa/vespalib/data/slime/json_format.h> #include <vespa/vespalib/data/simple_buffer.h> +#include <vespa/vespalib/data/slime/json_format.h> +#include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/encoding/base64.h> -#include <vespa/vespalib/util/size_literals.h> -#include <vespa/config-bucketspaces.h> #include <vespa/vespalib/testkit/testapp.h> +#include <vespa/vespalib/util/destructor_callbacks.h> +#include <vespa/vespalib/util/size_literals.h> #include <regex> #include <vespa/log/log.h> @@ -176,6 +177,7 @@ public: DummyFileHeaderContext _fileHeaderContext; TransLogServer _tls; vespalib::ThreadStackExecutor _summaryExecutor; + MockSharedThreadingService _shared_service; storage::spi::dummy::DummyBucketExecutor _bucketExecutor; bool _mkdirOk; matching::QueryLimiter _queryLimiter; @@ -196,6 +198,7 @@ public: _fileHeaderContext(), _tls("tmp", 9013, ".", _fileHeaderContext), _summaryExecutor(8, 128_Ki), + _shared_service(_summaryExecutor, _summaryExecutor), _bucketExecutor(2), _mkdirOk(FastOS_File::MakeDirectory("tmpdb")), _queryLimiter(), @@ -224,7 +227,7 @@ public: } _ddb = DocumentDB::create("tmpdb", _configMgr.getConfig(), "tcp/localhost:9013", _queryLimiter, _clock, DocTypeName(docTypeName), makeBucketSpace(), *b->getProtonConfigSP(), *this, - _summaryExecutor, _summaryExecutor, _bucketExecutor, _tls, _dummy, _fileHeaderContext, + _shared_service, _bucketExecutor, _tls, _dummy, _fileHeaderContext, std::make_unique<MemoryConfigStore>(), std::make_shared<vespalib::ThreadStackExecutor>(16, 128_Ki), _hwInfo), _ddb->start(); diff --git a/searchcore/src/tests/proton/documentdb/documentdb_test.cpp b/searchcore/src/tests/proton/documentdb/documentdb_test.cpp index a24eeb262ab..b31534c011c 100644 --- a/searchcore/src/tests/proton/documentdb/documentdb_test.cpp +++ b/searchcore/src/tests/proton/documentdb/documentdb_test.cpp @@ -1,10 +1,12 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <tests/proton/common/dummydbowner.h> +#include <vespa/config-bucketspaces.h> #include <vespa/document/datatype/documenttype.h> #include <vespa/document/repo/documenttyperepo.h> -#include <vespa/fastos/file.h> #include <vespa/document/test/make_bucket_space.h> +#include <vespa/fastos/file.h> +#include <vespa/persistence/dummyimpl/dummy_bucket_executor.h> #include <vespa/searchcore/proton/attribute/flushableattribute.h> #include <vespa/searchcore/proton/common/statusreport.h> #include <vespa/searchcore/proton/docsummary/summaryflushtarget.h> @@ -22,17 +24,16 @@ #include <vespa/searchcore/proton/server/feedhandler.h> #include <vespa/searchcore/proton/server/fileconfigmanager.h> #include <vespa/searchcore/proton/server/memoryconfigstore.h> -#include <vespa/persistence/dummyimpl/dummy_bucket_executor.h> +#include <vespa/searchcore/proton/test/mock_shared_threading_service.h> #include <vespa/searchcorespi/index/indexflushtarget.h> #include <vespa/searchlib/attribute/attribute_read_guard.h> #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/vespalib/data/slime/slime.h> -#include <vespa/vespalib/util/size_literals.h> -#include <vespa/config-bucketspaces.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/util/size_literals.h> #include <iostream> using namespace cloud::config::filedistribution; @@ -118,6 +119,7 @@ struct Fixture : public FixtureBase { DummyWireService _dummy; MyDBOwner _myDBOwner; vespalib::ThreadStackExecutor _summaryExecutor; + MockSharedThreadingService _shared_service; HwInfo _hwInfo; storage::spi::dummy::DummyBucketExecutor _bucketExecutor; DocumentDB::SP _db; @@ -142,6 +144,7 @@ Fixture::Fixture(bool file_config) _dummy(), _myDBOwner(), _summaryExecutor(8, 128_Ki), + _shared_service(_summaryExecutor, _summaryExecutor), _hwInfo(), _bucketExecutor(2), _db(), @@ -165,7 +168,7 @@ Fixture::Fixture(bool file_config) mgr.nextGeneration(0ms); _db = DocumentDB::create(".", mgr.getConfig(), "tcp/localhost:9014", _queryLimiter, _clock, DocTypeName("typea"), makeBucketSpace(), - *b->getProtonConfigSP(), _myDBOwner, _summaryExecutor, _summaryExecutor, _bucketExecutor, _tls, _dummy, + *b->getProtonConfigSP(), _myDBOwner, _shared_service, _bucketExecutor, _tls, _dummy, _fileHeaderContext, make_config_store(), std::make_shared<vespalib::ThreadStackExecutor>(16, 128_Ki), _hwInfo); _db->start(); diff --git a/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp b/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp index 0882153edd6..62d86ce895d 100644 --- a/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp +++ b/searchcore/src/vespa/searchcore/bmcluster/bm_node.cpp @@ -1,19 +1,20 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "bm_node.h" #include "bm_cluster.h" #include "bm_cluster_params.h" #include "bm_message_bus.h" +#include "bm_node.h" #include "bm_node_stats.h" #include "bm_storage_chain_builder.h" #include "bm_storage_link_context.h" -#include "storage_api_chain_bm_feed_handler.h" -#include "storage_api_message_bus_bm_feed_handler.h" -#include "storage_api_rpc_bm_feed_handler.h" #include "document_api_message_bus_bm_feed_handler.h" #include "i_bm_distribution.h" #include "i_bm_feed_handler.h" #include "spi_bm_feed_handler.h" +#include "storage_api_chain_bm_feed_handler.h" +#include "storage_api_message_bus_bm_feed_handler.h" +#include "storage_api_rpc_bm_feed_handler.h" +#include <tests/proton/common/dummydbowner.h> #include <vespa/config-attributes.h> #include <vespa/config-bucketspaces.h> #include <vespa/config-imported-fields.h> @@ -39,18 +40,19 @@ #include <vespa/searchcore/proton/common/alloc_config.h> #include <vespa/searchcore/proton/matching/querylimiter.h> #include <vespa/searchcore/proton/metrics/metricswireservice.h> -#include <vespa/searchcore/proton/persistenceengine/ipersistenceengineowner.h> #include <vespa/searchcore/proton/persistenceengine/i_resource_write_filter.h> +#include <vespa/searchcore/proton/persistenceengine/ipersistenceengineowner.h> #include <vespa/searchcore/proton/persistenceengine/persistenceengine.h> #include <vespa/searchcore/proton/server/bootstrapconfig.h> -#include <vespa/searchcore/proton/server/documentdb.h> #include <vespa/searchcore/proton/server/document_db_maintenance_config.h> #include <vespa/searchcore/proton/server/document_meta_store_read_guards.h> +#include <vespa/searchcore/proton/server/documentdb.h> #include <vespa/searchcore/proton/server/documentdbconfigmanager.h> #include <vespa/searchcore/proton/server/fileconfigmanager.h> #include <vespa/searchcore/proton/server/memoryconfigstore.h> #include <vespa/searchcore/proton/server/persistencehandlerproxy.h> #include <vespa/searchcore/proton/test/disk_mem_usage_notifier.h> +#include <vespa/searchcore/proton/test/mock_shared_threading_service.h> #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/searchsummary/config/config-juniperrc.h> @@ -75,7 +77,6 @@ #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/size_literals.h> -#include <tests/proton/common/dummydbowner.h> #include <vespa/log/log.h> LOG_SETUP(".bmcluster.bm_node"); @@ -459,6 +460,7 @@ class MyBmNode : public BmNode proton::DummyWireService _metrics_wire_service; proton::MemoryConfigStores _config_stores; vespalib::ThreadStackExecutor _summary_executor; + proton::MockSharedThreadingService _shared_service; proton::DummyDBOwner _document_db_owner; BucketSpace _bucket_space; std::shared_ptr<DocumentDB> _document_db; @@ -523,6 +525,7 @@ MyBmNode::MyBmNode(const vespalib::string& base_dir, int base_port, uint32_t nod _metrics_wire_service(), _config_stores(), _summary_executor(8, 128_Ki), + _shared_service(_summary_executor, _summary_executor), _document_db_owner(), _bucket_space(document::test::makeBucketSpace(_doc_type_name.getName())), _document_db(), @@ -594,7 +597,7 @@ MyBmNode::create_document_db(const BmClusterParams& params) mgr.nextGeneration(0ms); _document_db = DocumentDB::create(_base_dir, mgr.getConfig(), _tls_spec, _query_limiter, _clock, _doc_type_name, _bucket_space, *bootstrap_config->getProtonConfigSP(), _document_db_owner, - _summary_executor, _summary_executor, *_persistence_engine, _tls, + _shared_service, *_persistence_engine, _tls, _metrics_wire_service, _file_header_context, _config_stores.getConfigStore(_doc_type_name.toString()), std::make_shared<vespalib::ThreadStackExecutor>(16, 128_Ki), HwInfo()); diff --git a/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt b/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt index 511adbe66e9..efa22be6533 100644 --- a/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt +++ b/searchcore/src/vespa/searchcore/proton/server/CMakeLists.txt @@ -97,6 +97,8 @@ vespa_add_library(searchcore_server STATIC searchhandlerproxy.cpp searchview.cpp simpleflush.cpp + shared_threading_service.cpp + shared_threading_service_config.cpp storeonlydocsubdb.cpp storeonlyfeedview.cpp summaryadapter.cpp diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp index e5bf5013528..d491f4ab364 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp @@ -7,6 +7,7 @@ #include "documentdb.h" #include "documentdbconfigscout.h" #include "feedhandler.h" +#include "i_shared_threading_service.h" #include "idocumentdbowner.h" #include "idocumentsubdb.h" #include "maintenance_jobs_injector.h" @@ -131,8 +132,7 @@ DocumentDB::create(const vespalib::string &baseDir, document::BucketSpace bucketSpace, const ProtonConfig &protonCfg, IDocumentDBOwner &owner, - vespalib::ThreadExecutor &warmupExecutor, - vespalib::ThreadExecutor &sharedExecutor, + ISharedThreadingService& shared_service, storage::spi::BucketExecutor &bucketExecutor, const search::transactionlog::WriterFactory &tlsWriterFactory, MetricsWireService &metricsWireService, @@ -143,7 +143,7 @@ DocumentDB::create(const vespalib::string &baseDir, { return DocumentDB::SP( new DocumentDB(baseDir, std::move(currentSnapshot), tlsSpec, queryLimiter, clock, docTypeName, bucketSpace, - protonCfg, owner, warmupExecutor, sharedExecutor, bucketExecutor, tlsWriterFactory, + protonCfg, owner, shared_service, bucketExecutor, tlsWriterFactory, metricsWireService, fileHeaderContext, std::move(config_store), initializeThreads, hwInfo)); } DocumentDB::DocumentDB(const vespalib::string &baseDir, @@ -155,8 +155,7 @@ DocumentDB::DocumentDB(const vespalib::string &baseDir, document::BucketSpace bucketSpace, const ProtonConfig &protonCfg, IDocumentDBOwner &owner, - vespalib::Executor &warmupExecutor, - vespalib::ThreadExecutor &sharedExecutor, + ISharedThreadingService& shared_service, storage::spi::BucketExecutor & bucketExecutor, const search::transactionlog::WriterFactory &tlsWriterFactory, MetricsWireService &metricsWireService, @@ -176,7 +175,7 @@ DocumentDB::DocumentDB(const vespalib::string &baseDir, _baseDir(baseDir + "/" + _docTypeName.toString()), // Only one thread per executor, or performDropFeedView() will fail. _writeServiceConfig(configSnapshot->get_threading_service_config()), - _writeService(sharedExecutor, _writeServiceConfig, indexing_thread_stack_size), + _writeService(shared_service.shared(), _writeServiceConfig, indexing_thread_stack_size), _initializeThreads(std::move(initializeThreads)), _initConfigSnapshot(), _initConfigSerialNum(0u), @@ -204,9 +203,9 @@ DocumentDB::DocumentDB(const vespalib::string &baseDir, _writeFilter(), _transient_usage_provider(std::make_shared<DocumentDBResourceUsageProvider>(*this)), _feedHandler(std::make_unique<FeedHandler>(_writeService, tlsSpec, docTypeName, *this, _writeFilter, *this, tlsWriterFactory)), - _subDBs(*this, *this, *_feedHandler, _docTypeName, _writeService, warmupExecutor, fileHeaderContext, + _subDBs(*this, *this, *_feedHandler, _docTypeName, _writeService, shared_service.warmup(), fileHeaderContext, metricsWireService, getMetrics(), queryLimiter, clock, _configMutex, _baseDir, hwInfo), - _maintenanceController(_writeService.master(), sharedExecutor, _refCount, _docTypeName), + _maintenanceController(_writeService.master(), shared_service.shared(), _refCount, _docTypeName), _jobTrackers(), _calc(), _metricsUpdater(_subDBs, _writeService, _jobTrackers, *_sessionManager, _writeFilter) diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.h b/searchcore/src/vespa/searchcore/proton/server/documentdb.h index ee414db28bf..e829f477e8a 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdb.h +++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.h @@ -47,12 +47,13 @@ namespace storage::spi { struct BucketExecutor; } namespace proton { class AttributeConfigInspector; +class ExecutorThreadingServiceStats; class IDocumentDBOwner; +class ISharedThreadingService; class ITransientResourceUsageProvider; -struct MetricsWireService; class StatusReport; -class ExecutorThreadingServiceStats; class TransientResourceUsageProvider; +struct MetricsWireService; namespace matching { class SessionManager; } @@ -200,8 +201,7 @@ private: document::BucketSpace bucketSpace, const ProtonConfig &protonCfg, IDocumentDBOwner &owner, - vespalib::Executor &warmupExecutor, - vespalib::ThreadExecutor &sharedExecutor, + ISharedThreadingService& shared_service, storage::spi::BucketExecutor &bucketExecutor, const search::transactionlog::WriterFactory &tlsWriterFactory, MetricsWireService &metricsWireService, @@ -232,8 +232,7 @@ public: document::BucketSpace bucketSpace, const ProtonConfig &protonCfg, IDocumentDBOwner &owner, - vespalib::ThreadExecutor &warmupExecutor, - vespalib::ThreadExecutor &sharedExecutor, + ISharedThreadingService& shared_service, storage::spi::BucketExecutor & bucketExecutor, const search::transactionlog::WriterFactory &tlsWriterFactory, MetricsWireService &metricsWireService, diff --git a/searchcore/src/vespa/searchcore/proton/server/i_shared_threading_service.h b/searchcore/src/vespa/searchcore/proton/server/i_shared_threading_service.h new file mode 100644 index 00000000000..5145dbec43e --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/server/i_shared_threading_service.h @@ -0,0 +1,33 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +namespace vespalib { class ThreadExecutor; } + +namespace proton { + +/** + * Interface containing the thread executors that are shared across all document dbs. + */ +class ISharedThreadingService { +public: + virtual ~ISharedThreadingService() {} + + /** + * Returns the executor used for warmup (e.g. index warmup). + */ + virtual vespalib::ThreadExecutor& warmup() = 0; + + /** + * Returns the shared executor used for various assisting tasks in a document db. + * + * Example usages include: + * - Disk index fusion. + * - Updating nearest neighbor index (in DenseTensorAttribute). + * - Loading nearest neighbor index (in DenseTensorAttribute). + * - Writing of data in the document store. + */ + virtual vespalib::ThreadExecutor& shared() = 0; +}; + +} + diff --git a/searchcore/src/vespa/searchcore/proton/server/proton.cpp b/searchcore/src/vespa/searchcore/proton/server/proton.cpp index 5b21242e397..e056325e0d3 100644 --- a/searchcore/src/vespa/searchcore/proton/server/proton.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/proton.cpp @@ -111,15 +111,6 @@ diskMemUsageSamplerConfig(const ProtonConfig &proton, const HwInfo &hwInfo) hwInfo); } -size_t -derive_shared_threads(const ProtonConfig &proton, const HwInfo::Cpu &cpuInfo) { - size_t scaledCores = (size_t)std::ceil(cpuInfo.cores() * proton.feeding.concurrency); - - // We need at least 1 guaranteed free worker in order to ensure progress so #documentsdbs + 1 should suffice, - // but we will not be cheap and give it one extra. - return std::max(scaledCores, proton.documentdb.size() + proton.flush.maxconcurrent + 1); -} - uint32_t computeRpcTransportThreads(const ProtonConfig & cfg, const HwInfo::Cpu &cpuInfo) { bool areSearchAndDocsumAsync = cfg.docsum.async && cfg.search.async; @@ -144,8 +135,6 @@ struct MetricsUpdateHook : metrics::UpdateHook const vespalib::string CUSTOM_COMPONENT_API_PATH = "/state/v1/custom/component"; -VESPA_THREAD_STACK_TAG(proton_shared_executor) -VESPA_THREAD_STACK_TAG(index_warmup_executor) VESPA_THREAD_STACK_TAG(initialize_executor) VESPA_THREAD_STACK_TAG(close_executor) @@ -240,8 +229,7 @@ Proton::Proton(const config::ConfigUri & configUri, _protonDiskLayout(), _protonConfigurer(_executor, *this, _protonDiskLayout), _protonConfigFetcher(configUri, _protonConfigurer, subscribeTimeout), - _warmupExecutor(), - _sharedExecutor(), + _shared_service(), _compile_cache_executor_binding(), _queryLimiter(), _clock(0.001), @@ -333,11 +321,8 @@ Proton::init(const BootstrapConfig::SP & configSnapshot) protonConfig.visit.ignoremaxbytes); vespalib::string fileConfigId; - _warmupExecutor = std::make_unique<vespalib::ThreadStackExecutor>(4, 128_Ki, index_warmup_executor); - - const size_t sharedThreads = derive_shared_threads(protonConfig, hwInfo.cpu()); - _sharedExecutor = std::make_shared<vespalib::BlockingThreadStackExecutor>(sharedThreads, 128_Ki, sharedThreads*16, proton_shared_executor); - _compile_cache_executor_binding = vespalib::eval::CompileCache::bind(_sharedExecutor); + _shared_service = std::make_unique<SharedThreadingService>(SharedThreadingServiceConfig::make(protonConfig, hwInfo.cpu())); + _compile_cache_executor_binding = vespalib::eval::CompileCache::bind(_shared_service->shared_raw()); InitializeThreads initializeThreads; if (protonConfig.initialize.threads > 0) { initializeThreads = std::make_shared<vespalib::ThreadStackExecutor>(protonConfig.initialize.threads, 128_Ki, initialize_executor); @@ -460,11 +445,9 @@ Proton::~Proton() if (_flushEngine) { _flushEngine->close(); } - if (_warmupExecutor) { - _warmupExecutor->sync(); - } - if (_sharedExecutor) { - _sharedExecutor->sync(); + if (_shared_service) { + _shared_service->warmup_raw().sync(); + _shared_service->shared_raw()->sync(); } if ( ! _documentDBMap.empty()) { @@ -483,9 +466,8 @@ Proton::~Proton() _documentDBMap.clear(); _persistenceEngine.reset(); _tls.reset(); - _warmupExecutor.reset(); _compile_cache_executor_binding.reset(); - _sharedExecutor.reset(); + _shared_service.reset(); _clock.stop(); LOG(debug, "Explicit destructor done"); } @@ -619,11 +601,23 @@ Proton::addDocumentDB(const document::DocumentType &docType, // 1 thread per document type. initializeThreads = std::make_shared<vespalib::ThreadStackExecutor>(1, 128_Ki); } - auto ret = DocumentDB::create(config.basedir + "/documents", documentDBConfig, config.tlsspec, - _queryLimiter, _clock, docTypeName, bucketSpace, config, *this, - *_warmupExecutor, *_sharedExecutor, *_persistenceEngine, *_tls->getTransLogServer(), - *_metricsEngine, _fileHeaderContext, std::move(config_store), - initializeThreads, bootstrapConfig->getHwInfo()); + auto ret = DocumentDB::create(config.basedir + "/documents", + documentDBConfig, + config.tlsspec, + _queryLimiter, + _clock, + docTypeName, + bucketSpace, + config, + *this, + *_shared_service, + *_persistenceEngine, + *_tls->getTransLogServer(), + *_metricsEngine, + _fileHeaderContext, + std::move(config_store), + initializeThreads, + bootstrapConfig->getHwInfo()); try { ret->start(); } catch (vespalib::Exception &e) { @@ -791,11 +785,9 @@ Proton::updateMetrics(const metrics::MetricLockGuard &) if (_summaryEngine) { updateExecutorMetrics(metrics.docsum, _summaryEngine->getExecutorStats()); } - if (_sharedExecutor) { - metrics.shared.update(_sharedExecutor->getStats()); - } - if (_warmupExecutor) { - metrics.warmup.update(_warmupExecutor->getStats()); + if (_shared_service) { + metrics.shared.update(_shared_service->shared().getStats()); + metrics.warmup.update(_shared_service->warmup().getStats()); } } } @@ -947,12 +939,12 @@ Proton::get_child(vespalib::stringref name) const return std::make_unique<ResourceUsageExplorer>(_diskMemUsageSampler->writeFilter(), _persistenceEngine->get_resource_usage_tracker()); } else if (name == THREAD_POOLS) { - return std::make_unique<ProtonThreadPoolsExplorer>(_sharedExecutor.get(), + return std::make_unique<ProtonThreadPoolsExplorer>((_shared_service) ? _shared_service->shared_raw().get() : nullptr, (_matchEngine) ? &_matchEngine->get_executor() : nullptr, (_summaryEngine) ? &_summaryEngine->get_executor() : nullptr, (_flushEngine) ? &_flushEngine->get_executor() : nullptr, &_executor, - _warmupExecutor.get()); + (_shared_service) ? &_shared_service->warmup() : nullptr); } return Explorer_UP(nullptr); } diff --git a/searchcore/src/vespa/searchcore/proton/server/proton.h b/searchcore/src/vespa/searchcore/proton/server/proton.h index fffc16089fa..91635dc7497 100644 --- a/searchcore/src/vespa/searchcore/proton/server/proton.h +++ b/searchcore/src/vespa/searchcore/proton/server/proton.h @@ -12,6 +12,8 @@ #include "proton_config_fetcher.h" #include "proton_configurer.h" #include "rpc_hooks.h" +#include "shared_threading_service.h" +#include <vespa/eval/eval/llvm/compile_cache.h> #include <vespa/searchcore/proton/matching/querylimiter.h> #include <vespa/searchcore/proton/metrics/metrics_engine.h> #include <vespa/searchcore/proton/persistenceengine/i_resource_write_filter.h> @@ -24,7 +26,6 @@ #include <vespa/vespalib/net/json_handler_repo.h> #include <vespa/vespalib/net/state_explorer.h> #include <vespa/vespalib/util/varholder.h> -#include <vespa/eval/eval/llvm/compile_cache.h> #include <mutex> #include <shared_mutex> @@ -100,8 +101,7 @@ private: std::unique_ptr<IProtonDiskLayout> _protonDiskLayout; ProtonConfigurer _protonConfigurer; ProtonConfigFetcher _protonConfigFetcher; - std::unique_ptr<vespalib::ThreadStackExecutorBase> _warmupExecutor; - std::shared_ptr<vespalib::ThreadStackExecutorBase> _sharedExecutor; + std::unique_ptr<SharedThreadingService> _shared_service; vespalib::eval::CompileCache::ExecutorBinding::UP _compile_cache_executor_binding; matching::QueryLimiter _queryLimiter; vespalib::Clock _clock; diff --git a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service.cpp b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service.cpp new file mode 100644 index 00000000000..04e775674b4 --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service.cpp @@ -0,0 +1,19 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "shared_threading_service.h" +#include <vespa/vespalib/util/size_literals.h> +#include <vespa/vespalib/util/blockingthreadstackexecutor.h> + +VESPA_THREAD_STACK_TAG(proton_shared_executor) +VESPA_THREAD_STACK_TAG(proton_warmup_executor) + +namespace proton { + +SharedThreadingService::SharedThreadingService(const SharedThreadingServiceConfig& cfg) + : _warmup(cfg.warmup_threads(), 128_Ki, proton_warmup_executor), + _shared(std::make_shared<vespalib::BlockingThreadStackExecutor>(cfg.shared_threads(), 128_Ki, + cfg.shared_task_limit(), proton_shared_executor)) +{ +} + +} diff --git a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service.h b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service.h new file mode 100644 index 00000000000..ef0ff31c389 --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service.h @@ -0,0 +1,30 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "i_shared_threading_service.h" +#include "shared_threading_service_config.h" +#include <vespa/vespalib/util/threadstackexecutor.h> +#include <vespa/vespalib/util/syncable.h> +#include <memory> + +namespace proton { + +/** + * Class containing the thread executors that are shared across all document dbs. + */ +class SharedThreadingService : public ISharedThreadingService { +private: + vespalib::ThreadStackExecutor _warmup; + std::shared_ptr<vespalib::SyncableThreadExecutor> _shared; + +public: + SharedThreadingService(const SharedThreadingServiceConfig& cfg); + + vespalib::SyncableThreadExecutor& warmup_raw() { return _warmup; } + std::shared_ptr<vespalib::SyncableThreadExecutor> shared_raw() { return _shared; } + + vespalib::ThreadExecutor& warmup() override { return _warmup; } + vespalib::ThreadExecutor& shared() override { return *_shared; } +}; + +} diff --git a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp new file mode 100644 index 00000000000..cf62cf3b76c --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp @@ -0,0 +1,41 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "shared_threading_service_config.h" +#include <vespa/searchcore/config/config-proton.h> +#include <cmath> + +namespace proton { + +using ProtonConfig = SharedThreadingServiceConfig::ProtonConfig; + +SharedThreadingServiceConfig::SharedThreadingServiceConfig(uint32_t shared_threads_in, + uint32_t shared_task_limit_in, + uint32_t warmup_threads_in) + : _shared_threads(shared_threads_in), + _shared_task_limit(shared_task_limit_in), + _warmup_threads(warmup_threads_in) +{ +} + +namespace { + +size_t +derive_shared_threads(const ProtonConfig& cfg, const HwInfo::Cpu& cpu_info) +{ + size_t scaled_cores = (size_t)std::ceil(cpu_info.cores() * cfg.feeding.concurrency); + + // We need at least 1 guaranteed free worker in order to ensure progress. + return std::max(scaled_cores, cfg.documentdb.size() + cfg.flush.maxconcurrent + 1); +} + +} + +SharedThreadingServiceConfig +SharedThreadingServiceConfig::make(const proton::SharedThreadingServiceConfig::ProtonConfig& cfg, + const proton::HwInfo::Cpu& cpu_info) +{ + size_t shared_threads = derive_shared_threads(cfg, cpu_info); + return proton::SharedThreadingServiceConfig(shared_threads, shared_threads * 16, 4); +} + +} diff --git a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.h b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.h new file mode 100644 index 00000000000..02966e0efeb --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.h @@ -0,0 +1,36 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/searchcore/proton/common/hw_info.h> + +namespace vespa::config::search::core::internal { class InternalProtonType; } + +namespace proton { + +/** + * Config for the thread executors that are shared across all document dbs. + */ +class SharedThreadingServiceConfig { +public: + using ProtonConfig = const vespa::config::search::core::internal::InternalProtonType; + +private: + uint32_t _shared_threads; + uint32_t _shared_task_limit; + uint32_t _warmup_threads; + +public: + SharedThreadingServiceConfig(uint32_t shared_threads_in, + uint32_t shared_task_limit_in, + uint32_t warmup_threads_in); + + static SharedThreadingServiceConfig make(const ProtonConfig& cfg, const HwInfo::Cpu& cpu_info); + + uint32_t shared_threads() const { return _shared_threads; } + uint32_t shared_task_limit() const { return _shared_task_limit; } + uint32_t warmup_threads() const { return _warmup_threads; } + +}; + +} + diff --git a/searchcore/src/vespa/searchcore/proton/test/mock_shared_threading_service.h b/searchcore/src/vespa/searchcore/proton/test/mock_shared_threading_service.h new file mode 100644 index 00000000000..f21f43ed5ad --- /dev/null +++ b/searchcore/src/vespa/searchcore/proton/test/mock_shared_threading_service.h @@ -0,0 +1,23 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/searchcore/proton/server/i_shared_threading_service.h> + +namespace proton { + +class MockSharedThreadingService : public ISharedThreadingService { +private: + vespalib::ThreadExecutor& _warmup; + vespalib::ThreadExecutor& _shared; + +public: + MockSharedThreadingService(vespalib::ThreadExecutor& warmup_in, + vespalib::ThreadExecutor& shared_in) + : _warmup(warmup_in), + _shared(shared_in) + {} + vespalib::ThreadExecutor& warmup() override { return _warmup; } + vespalib::ThreadExecutor& shared() override { return _shared; } +}; + +} diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index a54f981352b..f0e156a96ed 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -71,7 +71,6 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest) auto euclid = make_distance_function(DistanceMetric::Euclidean, ct); - std::vector<double> p00{0.0, 0.0, 0.0}; std::vector<Int8Float> p0{0.0, 0.0, 0.0}; std::vector<Int8Float> p1{1.0, 0.0, 0.0}; std::vector<Int8Float> p5{0.0,-1.0, 0.0}; @@ -85,9 +84,6 @@ TEST(DistanceFunctionsTest, euclidean_int8_smoketest) EXPECT_DOUBLE_EQ(12.0, euclid->calc(t(p1), t(p7))); EXPECT_DOUBLE_EQ(14.0, euclid->calc(t(p5), t(p7))); - EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p00), t(p1))); - EXPECT_DOUBLE_EQ(1.0, euclid->calc(t(p00), t(p5))); - EXPECT_DOUBLE_EQ(9.0, euclid->calc(t(p00), t(p7))); } TEST(DistanceFunctionsTest, angular_gives_expected_score) diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index 315d4c8535c..96dfc580d87 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -21,6 +21,7 @@ make_distance_function(DistanceMetric variant, CellType cell_type) switch (cell_type) { case CellType::FLOAT: return std::make_unique<SquaredEuclideanDistanceHW<float>>(); case CellType::DOUBLE: return std::make_unique<SquaredEuclideanDistanceHW<double>>(); + case CellType::INT8: return std::make_unique<SquaredEuclideanDistanceHW<vespalib::eval::Int8Float>>(); default: return std::make_unique<SquaredEuclideanDistance>(CellType::FLOAT); } case DistanceMetric::Angular: diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h index 517ef68511b..6505ea119ea 100644 --- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.h @@ -44,6 +44,9 @@ public: assert(expected_cell_type() == vespalib::eval::get_cell_type<FloatType>()); } + static const double *cast(const double * p) { return p; } + static const float *cast(const float * p) { return p; } + static const int8_t *cast(const vespalib::eval::Int8Float * p) { return reinterpret_cast<const int8_t *>(p); } double calc(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs) const override { constexpr vespalib::eval::CellType expected = vespalib::eval::get_cell_type<FloatType>(); assert(lhs.type == expected && rhs.type == expected); @@ -51,7 +54,7 @@ public: auto rhs_vector = rhs.typify<FloatType>(); size_t sz = lhs_vector.size(); assert(sz == rhs_vector.size()); - return _computer.squaredEuclideanDistance(&lhs_vector[0], &rhs_vector[0], sz); + return _computer.squaredEuclideanDistance(cast(&lhs_vector[0]), cast(&rhs_vector[0]), sz); } double calc_with_limit(const vespalib::eval::TypedCells& lhs, diff --git a/vespalib/src/tests/hwaccelrated/.gitignore b/vespalib/src/tests/hwaccelrated/.gitignore new file mode 100644 index 00000000000..42f73a39d78 --- /dev/null +++ b/vespalib/src/tests/hwaccelrated/.gitignore @@ -0,0 +1 @@ +vespalib_hwaccelrated_bench_app diff --git a/vespalib/src/tests/hwaccelrated/CMakeLists.txt b/vespalib/src/tests/hwaccelrated/CMakeLists.txt index 960ae840995..9edea9c4472 100644 --- a/vespalib/src/tests/hwaccelrated/CMakeLists.txt +++ b/vespalib/src/tests/hwaccelrated/CMakeLists.txt @@ -6,3 +6,10 @@ vespa_add_executable(vespalib_hwaccelrated_test_app TEST vespalib ) vespa_add_test(NAME vespalib_hwaccelrated_test_app COMMAND vespalib_hwaccelrated_test_app) + +vespa_add_executable(vespalib_hwaccelrated_bench_app + SOURCES + hwaccelrated_bench.cpp + DEPENDS + vespalib +) diff --git a/vespalib/src/tests/hwaccelrated/hwaccelrated_bench.cpp b/vespalib/src/tests/hwaccelrated/hwaccelrated_bench.cpp new file mode 100644 index 00000000000..9984cfca440 --- /dev/null +++ b/vespalib/src/tests/hwaccelrated/hwaccelrated_bench.cpp @@ -0,0 +1,59 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/hwaccelrated/iaccelrated.h> +#include <vespa/vespalib/hwaccelrated/generic.h> +#include <vespa/vespalib/util/time.h> +# +using namespace vespalib; + +template<typename T> +std::vector<T> createAndFill(size_t sz) { + std::vector<T> v(sz); + for (size_t i(0); i < sz; i++) { + v[i] = rand()%128; + } + return v; +} + +template<typename T> +void +benchmarkEuclideanDistance(const hwaccelrated::IAccelrated & accel, size_t sz, size_t count) { + srand(1); + std::vector<T> a = createAndFill<T>(sz); + std::vector<T> b = createAndFill<T>(sz); + steady_time start = steady_clock::now(); + double sumOfSums(0); + for (size_t j(0); j < count; j++) { + double sum = accel.squaredEuclideanDistance(&a[0], &b[0], sz); + sumOfSums += sum; + } + duration elapsed = steady_clock::now() - start; + printf("sum=%f of N=%zu and vector length=%zu took %ld\n", sumOfSums, count, sz, count_ms(elapsed)); +} + +void +benchMarkEuclidianDistance(const hwaccelrated::IAccelrated & accelrator, size_t sz, size_t count) { + printf("double : "); + benchmarkEuclideanDistance<double>(accelrator, sz, count); + printf("float : "); + benchmarkEuclideanDistance<float>(accelrator, sz, count); + printf("int8_t : "); + benchmarkEuclideanDistance<int8_t>(accelrator, sz, count); +} + +int main(int argc, char *argv[]) { + int length = 1000; + int count = 1000000; + if (argc > 1) { + length = atol(argv[1]); + } + if (argc > 2) { + count = atol(argv[2]); + } + printf("%s %d %d\n", argv[0], length, count); + printf("Squared Euclidian Distance - Generic\n"); + benchMarkEuclidianDistance(hwaccelrated::GenericAccelrator(), length, count); + printf("Squared Euclidian Distance - Optimized for this cpu\n"); + benchMarkEuclidianDistance(hwaccelrated::IAccelrated::getAccelerator(), length, count); + return 0; +} diff --git a/vespalib/src/tests/hwaccelrated/hwaccelrated_test.cpp b/vespalib/src/tests/hwaccelrated/hwaccelrated_test.cpp index 3d66769c15a..bbe0ff6663a 100644 --- a/vespalib/src/tests/hwaccelrated/hwaccelrated_test.cpp +++ b/vespalib/src/tests/hwaccelrated/hwaccelrated_test.cpp @@ -3,6 +3,8 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> #include <vespa/vespalib/hwaccelrated/generic.h> +#include <vespa/log/log.h> +LOG_SETUP("hwaccelrated_test"); using namespace vespalib; @@ -15,26 +17,34 @@ std::vector<T> createAndFill(size_t sz) { return v; } -template<typename T> -void verifyEuclideanDistance(const hwaccelrated::IAccelrated & accel) { - const size_t testLength(255); +template<typename T, typename P> +void verifyEuclideanDistance(const hwaccelrated::IAccelrated & accel, size_t testLength, double approxFactor) { srand(1); std::vector<T> a = createAndFill<T>(testLength); std::vector<T> b = createAndFill<T>(testLength); for (size_t j(0); j < 0x20; j++) { - T sum(0); + P sum(0); for (size_t i(j); i < testLength; i++) { - sum += (a[i] - b[i]) * (a[i] - b[i]); + P d = P(a[i]) - P(b[i]); + sum += d * d; } - T hwComputedSum(accel.squaredEuclideanDistance(&a[j], &b[j], testLength - j)); - EXPECT_EQUAL(sum, hwComputedSum); + P hwComputedSum(accel.squaredEuclideanDistance(&a[j], &b[j], testLength - j)); + EXPECT_APPROX(sum, hwComputedSum, sum*approxFactor); } } +void +verifyEuclideanDistance(const hwaccelrated::IAccelrated & accelrator, size_t testLength) { + verifyEuclideanDistance<int8_t, double>(accelrator, testLength, 0.0); + verifyEuclideanDistance<float, double>(accelrator, testLength, 0.0001); // Small deviation requiring EXPECT_APPROX + verifyEuclideanDistance<double, double>(accelrator, testLength, 0.0); +} + TEST("test euclidean distance") { hwaccelrated::GenericAccelrator genericAccelrator; - verifyEuclideanDistance<float>(genericAccelrator); - verifyEuclideanDistance<double >(genericAccelrator); + constexpr size_t TEST_LENGTH = 140000; // must be longer than 64k + TEST_DO(verifyEuclideanDistance(hwaccelrated::GenericAccelrator(), TEST_LENGTH)); + TEST_DO(verifyEuclideanDistance(hwaccelrated::IAccelrated::getAccelerator(), TEST_LENGTH)); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp index 6a6421ad016..590223ed13a 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp @@ -11,6 +11,11 @@ Avx2Accelrator::populationCount(const uint64_t *a, size_t sz) const { } double +Avx2Accelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const { + return helper::squaredEuclideanDistance(a, b, sz); +} + +double Avx2Accelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const { return avx::euclideanDistanceSelectAlignment<float, 32>(a, b, sz); } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h index 44752dd9270..2949e81fd36 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h @@ -13,6 +13,7 @@ class Avx2Accelrator : public GenericAccelrator { public: size_t populationCount(const uint64_t *a, size_t sz) const override; + double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const override; double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const override; void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp index 94a6637a072..5878165bb6d 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp @@ -23,6 +23,11 @@ Avx512Accelrator::populationCount(const uint64_t *a, size_t sz) const { } double +Avx512Accelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const { + return helper::squaredEuclideanDistance(a, b, sz); +} + +double Avx512Accelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const { return avx::euclideanDistanceSelectAlignment<float, 64>(a, b, sz); } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h index 826cf63be70..4989f72e698 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h @@ -15,6 +15,7 @@ public: float dotProduct(const float * a, const float * b, size_t sz) const override; double dotProduct(const double * a, const double * b, size_t sz) const override; size_t populationCount(const uint64_t *a, size_t sz) const override; + double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const override; double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const override; void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp index fb6ec167cf4..13946fa3398 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp @@ -34,7 +34,7 @@ multiplyAdd(const T * a, const T * b, size_t sz) template <typename T, size_t UNROLL> double -euclideanDistanceT(const T * a, const T * b, size_t sz) +squaredEuclideanDistanceT(const T * a, const T * b, size_t sz) { T partial[UNROLL]; for (size_t i(0); i < UNROLL; i++) { @@ -43,11 +43,13 @@ euclideanDistanceT(const T * a, const T * b, size_t sz) size_t i(0); for (; i + UNROLL <= sz; i += UNROLL) { for (size_t j(0); j < UNROLL; j++) { - partial[j] += (a[i+j] - b[i+j]) * (a[i+j] - b[i+j]); + T d = a[i+j] - b[i+j]; + partial[j] += d * d; } } for (;i < sz; i++) { - partial[i%UNROLL] += (a[i] - b[i]) * (a[i] - b[i]); + T d = a[i] - b[i]; + partial[i%UNROLL] += d * d; } double sum(0); for (size_t j(0); j < UNROLL; j++) { @@ -156,13 +158,18 @@ GenericAccelrator::populationCount(const uint64_t *a, size_t sz) const { } double +GenericAccelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const { + return helper::squaredEuclideanDistance(a, b, sz); +} + +double GenericAccelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const { - return euclideanDistanceT<float, 8>(a, b, sz); + return squaredEuclideanDistanceT<float, 2>(a, b, sz); } double GenericAccelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const { - return euclideanDistanceT<double, 4>(a, b, sz); + return squaredEuclideanDistanceT<double, 2>(a, b, sz); } void diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/generic.h b/vespalib/src/vespa/vespalib/hwaccelrated/generic.h index c6b75bbcaf0..315e807da07 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/generic.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/generic.h @@ -23,6 +23,7 @@ public: void andNotBit(void * a, const void * b, size_t bytes) const override; void notBit(void * a, size_t bytes) const override; size_t populationCount(const uint64_t *a, size_t sz) const override; + double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const override; double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const override; double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const override; void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h b/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h index afb2024b322..6eae41ead4b 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h @@ -28,6 +28,7 @@ public: virtual void andNotBit(void * a, const void * b, size_t bytes) const = 0; virtual void notBit(void * a, size_t bytes) const = 0; virtual size_t populationCount(const uint64_t *a, size_t sz) const = 0; + virtual double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const = 0; virtual double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const = 0; virtual double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const = 0; // AND 64 bytes from multiple, optionally inverted sources diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp b/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp index 824e0e1ebd9..3b063ce6805 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/private_helpers.hpp @@ -74,5 +74,31 @@ orChunks(size_t offset, const std::vector<std::pair<const void *, bool>> & src, } } +template<typename TemporaryT=int32_t> +double squaredEuclideanDistanceT(const int8_t * a, const int8_t * b, size_t sz) __attribute__((noinline)); +template<typename TemporaryT> +double squaredEuclideanDistanceT(const int8_t * a, const int8_t * b, size_t sz) +{ + //Note that this is 3 times faster with int32_t than with int64_t and 16x faster than float + TemporaryT sum = 0; + for (size_t i(0); i < sz; i++) { + int16_t d = int16_t(a[i]) - int16_t(b[i]); + sum += d * d; + } + return sum; +} + +inline double +squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) { + constexpr size_t LOOP_COUNT = 0x10000; + double sum(0); + size_t i=0; + for (; i + LOOP_COUNT <= sz; i += LOOP_COUNT) { + sum += squaredEuclideanDistanceT<int32_t>(a + i, b + i, LOOP_COUNT); + } + sum += squaredEuclideanDistanceT<int32_t>(a + i, b + i, sz - i); + return sum; +} + } } |