diff options
15 files changed, 333 insertions, 32 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 2b4424654a2..ba52826cd3f 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -869,7 +869,8 @@ "public java.lang.String getName()", "public int getTermCount()", "public int encode(java.nio.ByteBuffer)", - "protected void appendBodyString(java.lang.StringBuilder)" + "protected void appendBodyString(java.lang.StringBuilder)", + "public void disclose(com.yahoo.prelude.query.textualrepresentation.Discloser)" ], "fields": [] }, diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java index 52ef6c40a6a..be3ae913476 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java @@ -4,6 +4,7 @@ package com.yahoo.prelude.query; import com.google.common.annotations.Beta; import com.yahoo.compress.IntegerCompressor; +import com.yahoo.prelude.query.textualrepresentation.Discloser; import java.nio.ByteBuffer; @@ -83,7 +84,17 @@ public class NearestNeighborItem extends SimpleTaggableItem { buffer.append(",queryTensorName=").append(queryTensorName); buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits); buffer.append(",approximate=").append(approximate); - buffer.append(",targetNumHits=").append(targetNumHits).append("}"); + buffer.append(",targetHits=").append(targetNumHits).append("}"); + } + + @Override + public void disclose(Discloser discloser) { + super.disclose(discloser); + discloser.addProperty("field", field); + discloser.addProperty("queryTensorName", queryTensorName); + discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits); + discloser.addProperty("approximate", approximate); + discloser.addProperty("targetHits", targetNumHits); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java index 775dca7c444..9910eb9532d 100644 --- a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java +++ b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java @@ -16,6 +16,7 @@ import com.yahoo.prelude.query.IntItem; import com.yahoo.prelude.query.Item; import com.yahoo.prelude.query.Limit; import com.yahoo.prelude.query.NearItem; +import com.yahoo.prelude.query.NearestNeighborItem; import com.yahoo.prelude.query.NotItem; import com.yahoo.prelude.query.ONearItem; import com.yahoo.prelude.query.OrItem; @@ -93,14 +94,17 @@ public class SelectParser implements Parser { private static final String ACCENT_DROP = "accentDrop"; private static final String ALTERNATIVES = "alternatives"; private static final String AND_SEGMENTING = "andSegmenting"; + private static final String APPROXIMATE = "approximate"; private static final String DISTANCE = "distance"; private static final String DOT_PRODUCT = "dotProduct"; private static final String EQUIV = "equiv"; private static final String FILTER = "filter"; private static final String HIT_LIMIT = "hitLimit"; + private static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits"; private static final String IMPLICIT_TRANSFORMS = "implicitTransforms"; private static final String LABEL = "label"; private static final String NEAR = "near"; + private static final String NEAREST_NEIGHBOR = "nearestNeighbor"; private static final String NORMALIZE_CASE = "normalizeCase"; private static final String ONEAR = "onear"; private static final String PHRASE = "phrase"; @@ -114,6 +118,7 @@ public class SelectParser implements Parser { private static final String STEM = "stem"; private static final String SUBSTRING = "substring"; private static final String SUFFIX = "suffix"; + private static final String TARGET_HITS = "targetHits"; private static final String TARGET_NUM_HITS = "targetNumHits"; private static final String THRESHOLD_BOOST_FACTOR = "thresholdBoostFactor"; private static final String UNIQUE_ID = "id"; @@ -130,7 +135,7 @@ public class SelectParser implements Parser { private static final String CONTAINS = "contains"; private static final String MATCHES = "matches"; private static final String CALL = "call"; - private static final List<String> FUNCTION_CALLS = Arrays.asList(WAND, WEIGHTED_SET, DOT_PRODUCT, PREDICATE, RANK, WEAK_AND); + private static final List<String> FUNCTION_CALLS = Arrays.asList(WAND, WEIGHTED_SET, DOT_PRODUCT, NEAREST_NEIGHBOR, PREDICATE, RANK, WEAK_AND); public SelectParser(ParserEnvironment environment) { indexFacts = environment.getIndexFacts(); @@ -259,6 +264,8 @@ public class SelectParser implements Parser { return buildWeightedSet(key, value); case DOT_PRODUCT: return buildDotProduct(key, value); + case NEAREST_NEIGHBOR: + return buildNearestNeighbor(key, value); case PREDICATE: return buildPredicate(key, value); case RANK: @@ -266,7 +273,7 @@ public class SelectParser implements Parser { case WEAK_AND: return buildWeakAnd(key, value); default: - throw newUnexpectedArgumentException(key, DOT_PRODUCT, RANK, WAND, WEAK_AND, WEIGHTED_SET, PREDICATE); + throw newUnexpectedArgumentException(key, DOT_PRODUCT, NEAREST_NEIGHBOR, RANK, WAND, WEAK_AND, WEIGHTED_SET, PREDICATE); } } @@ -403,6 +410,38 @@ public class SelectParser implements Parser { return orItem; } + private Item buildNearestNeighbor(String key, Inspector value) { + + HashMap<Integer, Inspector> children = childMap(value); + Preconditions.checkArgument(children.size() == 2, "Expected 2 arguments, got %s.", children.size()); + String field = children.get(0).asString(); + String property = children.get(1).asString(); + NearestNeighborItem item = new NearestNeighborItem(field, property); + Inspector annotations = getAnnotations(value); + if (annotations != null){ + annotations.traverse((ObjectTraverser) (annotation_name, annotation_value) -> { + if (TARGET_HITS.equals(annotation_name)){ + item.setTargetNumHits((int)(annotation_value.asDouble())); + } + if (TARGET_NUM_HITS.equals(annotation_name)){ + item.setTargetNumHits((int)(annotation_value.asDouble())); + } + if (HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) { + int hnswExploreAdditionalHits = (int)(annotation_value.asDouble()); + item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits); + } + if (APPROXIMATE.equals(annotation_name)) { + boolean allowApproximate = annotation_value.asBool(); + item.setAllowApproximate(allowApproximate); + } + if (LABEL.equals(annotation_name)) { + item.setLabel(annotation_value.asString()); + } + }); + } + return item; + } + private CompositeItem buildWeakAnd(String key, Inspector value) { WeakAndItem weakAnd = new WeakAndItem(); addItemsFromInspector(weakAnd, value); @@ -410,6 +449,9 @@ public class SelectParser implements Parser { if (annotations != null){ annotations.traverse((ObjectTraverser) (annotation_name, annotation_value) -> { + if (TARGET_HITS.equals(annotation_name)){ + weakAnd.setN((int)(annotation_value.asDouble())); + } if (TARGET_NUM_HITS.equals(annotation_name)){ weakAnd.setN((int)(annotation_value.asDouble())); } @@ -662,7 +704,10 @@ public class SelectParser implements Parser { HashMap<Integer, Inspector> children = childMap(value); Preconditions.checkArgument(children.size() == 2, "Expected 2 arguments, got %s.", children.size()); - Integer target_num_hits= getIntegerAnnotation(TARGET_NUM_HITS, annotations, DEFAULT_TARGET_NUM_HITS); + Integer target_num_hits= getIntegerAnnotation(TARGET_HITS, annotations, null); + if (target_num_hits == null) { + target_num_hits= getIntegerAnnotation(TARGET_NUM_HITS, annotations, DEFAULT_TARGET_NUM_HITS); + } WandItem out = new WandItem(children.get(0).asString(), target_num_hits); diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index 76b8c1ef8a2..aca2998cba3 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -97,7 +97,7 @@ public class ValidateNearestNeighborSearcher extends Searcher { /** Returns an error message if this is invalid, or null if it is valid */ private String validate(NearestNeighborItem item) { if (item.getTargetNumHits() < 1) - return item + " has invalid targetNumHits " + item.getTargetNumHits() + ": Must be >= 1"; + return item + " has invalid targetHits " + item.getTargetNumHits() + ": Must be >= 1"; String queryFeatureName = "query(" + item.getQueryTensorName() + ")"; Optional<Tensor> queryTensor = query.getRanking().getFeatures().getTensor(queryFeatureName); diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index f4560806dd2..7d17fe4f09d 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -173,6 +173,7 @@ public class YqlParser implements Parser { static final String STEM = "stem"; static final String SUBSTRING = "substring"; static final String SUFFIX = "suffix"; + static final String TARGET_HITS = "targetHits"; static final String TARGET_NUM_HITS = "targetNumHits"; static final String THRESHOLD_BOOST_FACTOR = "thresholdBoostFactor"; static final String UNIQUE_ID = "id"; @@ -418,8 +419,12 @@ public class YqlParser implements Parser { String field = fetchFieldRead(args.get(0)); String property = fetchFieldRead(args.get(1)); NearestNeighborItem item = new NearestNeighborItem(field, property); - Integer targetNumHits = getAnnotation(ast, TARGET_NUM_HITS, + Integer targetNumHits = getAnnotation(ast, TARGET_HITS, Integer.class, null, "desired minimum hits to produce"); + if (targetNumHits == null) { + targetNumHits = getAnnotation(ast, TARGET_NUM_HITS, + Integer.class, null, "desired minimum hits to produce"); + } if (targetNumHits != null) { item.setTargetNumHits(targetNumHits); } @@ -504,9 +509,13 @@ public class YqlParser implements Parser { List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1); Preconditions.checkArgument(args.size() == 2, "Expected 2 arguments, got %s.", args.size()); - WandItem out = new WandItem(getIndex(args.get(0)), getAnnotation(ast, - TARGET_NUM_HITS, Integer.class, DEFAULT_TARGET_NUM_HITS, - "desired number of hits to accumulate in wand")); + Integer targetNumHits = getAnnotation(ast, TARGET_HITS, + Integer.class, null, "desired number of hits to accumulate in wand"); + if (targetNumHits == null) { + targetNumHits = getAnnotation(ast, TARGET_NUM_HITS, + Integer.class, DEFAULT_TARGET_NUM_HITS, "desired number of hits to accumulate in wand"); + } + WandItem out = new WandItem(getIndex(args.get(0)), targetNumHits); Double scoreThreshold = getAnnotation(ast, SCORE_THRESHOLD, Double.class, null, "min score for hit inclusion"); if (scoreThreshold != null) { @@ -1028,8 +1037,12 @@ public class YqlParser implements Parser { private CompositeItem buildWeakAnd(OperatorNode<ExpressionOperator> spec) { WeakAndItem weakAnd = new WeakAndItem(); - Integer targetNumHits = getAnnotation(spec, TARGET_NUM_HITS, + Integer targetNumHits = getAnnotation(spec, TARGET_HITS, Integer.class, null, "desired minimum hits to produce"); + if (targetNumHits == null) { + targetNumHits = getAnnotation(spec, TARGET_NUM_HITS, + Integer.class, null, "desired minimum hits to produce"); + } if (targetNumHits != null) { weakAnd.setN(targetNumHits); } diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 2c849a9b52c..c49603737a6 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -93,7 +93,7 @@ public class ValidateNearestNeighborTestCase { } private String makeQuery(String attributeTensor, String queryTensor) { - return "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");"; + return "select * from sources * where [{\"targetHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");"; } @Test @@ -139,7 +139,7 @@ public class ValidateNearestNeighborTestCase { r.append(",queryTensorName=").append(qt); r.append(",hnsw.exploreAdditionalHits=0"); r.append(",approximate=true"); - r.append(",targetNumHits=").append(th); + r.append(",targetHits=").append(th); r.append("} ").append(errmsg); return r.toString(); } @@ -149,7 +149,7 @@ public class ValidateNearestNeighborTestCase { String q = "select * from sources * where nearestNeighbor(dvector,qvector);"; Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); - assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits 0: Must be >= 1"), r); + assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetHits 0: Must be >= 1"), r); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java index e43dbd4e266..2ace21daace 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java @@ -513,7 +513,7 @@ public class YqlParserTestCase { public void testWand() { assertParse("select foo from bar where wand(description, {\"a\":1, \"b\":2});", "WAND(10,0.0,1.0) description{[1]:\"a\",[2]:\"b\"}"); - assertParse("select foo from bar where [ {\"scoreThreshold\": 13.3, \"targetNumHits\": 7, " + + assertParse("select foo from bar where [ {\"scoreThreshold\": 13.3, \"targetHits\": 7, " + "\"thresholdBoostFactor\": 2.3} ]wand(description, {\"a\":1, \"b\":2});", "WAND(7,13.3,2.3) description{[1]:\"a\",[2]:\"b\"}"); } @@ -550,11 +550,11 @@ public class YqlParserTestCase { @Test public void testNearestNeighbor() { assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=0}"); - assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=37}"); - assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetNumHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetNumHits=3}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}"); + assertParse("select foo from bar where [{\"targetHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);", + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}"); + assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);", + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetHits=3}"); } @Test @@ -597,7 +597,7 @@ public class YqlParserTestCase { public void testWeakAnd() { assertParse("select foo from bar where weakAnd(a contains \"A\", b contains \"B\");", "WAND(100) a:A b:B"); - assertParse("select foo from bar where [{\"targetNumHits\": 37}]weakAnd(a contains \"A\", " + + assertParse("select foo from bar where [{\"targetHits\": 37}]weakAnd(a contains \"A\", " + "b contains \"B\");", "WAND(37) a:A b:B"); diff --git a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java index 1715ed38964..4691ef42e55 100644 --- a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java +++ b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java @@ -473,7 +473,7 @@ public class SelectTestCase { public void testWand() { assertParse("{ \"wand\": [\"description\", { \"a\": 1, \"b\": 2 }] }", "WAND(10,0.0,1.0) description{[1]:\"a\",[2]:\"b\"}"); - assertParse("{ \"wand\": { \"children\": [\"description\", { \"a\": 1, \"b\": 2 }], \"attributes\": { \"scoreThreshold\": 13.3, \"targetNumHits\": 7, \"thresholdBoostFactor\": 2.3 } } }", + assertParse("{ \"wand\": { \"children\": [\"description\", { \"a\": 1, \"b\": 2 }], \"attributes\": { \"scoreThreshold\": 13.3, \"targetHits\": 7, \"thresholdBoostFactor\": 2.3 } } }", "WAND(7,13.3,2.3) description{[1]:\"a\",[2]:\"b\"}"); } @@ -522,10 +522,19 @@ public class SelectTestCase { } @Test + public void testNearestNeighbor() { + assertParse("{ \"nearestNeighbor\": [ \"f1field\", \"q2prop\" ] }", + "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}"); + + assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37} }}", + "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}"); + } + + @Test public void testWeakAnd() { assertParse("{ \"weakAnd\": [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ] }", "WAND(100) a:A b:B"); - assertParse("{ \"weakAnd\": { \"children\" : [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ], \"attributes\" : {\"targetNumHits\": 37} }}", + assertParse("{ \"weakAnd\": { \"children\" : [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ], \"attributes\" : {\"targetHits\": 37} }}", "WAND(37) a:A b:B"); QueryTree tree = parseWhere("{ \"weakAnd\": { \"children\" : [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ], \"attributes\" : {\"scoreThreshold\": 41}}}"); diff --git a/hosted-api/pom.xml b/hosted-api/pom.xml index 2a42f890ba4..b066cb158e0 100644 --- a/hosted-api/pom.xml +++ b/hosted-api/pom.xml @@ -33,6 +33,12 @@ <version>${project.version}</version> <scope>provided</scope> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>yolean</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> <dependency> <groupId>org.junit.jupiter</groupId> @@ -44,6 +50,12 @@ <artifactId>junit-vintage-engine</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>testutil</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> </dependencies> </project> diff --git a/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java b/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java index 37858148ef0..08cd3932ae7 100644 --- a/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java +++ b/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java @@ -3,18 +3,30 @@ package ai.vespa.hosted.api; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; +import com.yahoo.slime.JsonFormat; +import com.yahoo.slime.Slime; import com.yahoo.slime.SlimeStream; import com.yahoo.slime.SlimeUtils; +import java.io.ByteArrayOutputStream; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import static com.yahoo.yolean.Exceptions.uncheck; + /** * @author mortent */ public class TestDescriptor { public static final String DEFAULT_FILENAME = "META-INF/ai.vespa/testDescriptor.json"; + public static final String CURRENT_VERSION = "1.0"; + + private static final String JSON_FIELD_VERSION = "version"; + private static final String JSON_FIELD_CONFIGURED_TESTS = "configuredTests"; + private static final String JSON_FIELD_SYSTEM_TESTS = "systemTests"; + private static final String JSON_FIELD_STAGING_TESTS = "stagingTests"; + private static final String JSON_FIELD_PRODUCTION_TESTS = "productionTests"; private final Map<TestCategory, List<String>> configuredTestClasses; private final String version; @@ -27,16 +39,26 @@ public class TestDescriptor { public static TestDescriptor fromJsonString(String testDescriptor) { var slime = SlimeUtils.jsonToSlime(testDescriptor); var root = slime.get(); - var version = root.field("version").asString(); - var testRoot = root.field("configuredTests"); - var systemTests = getJsonArray(testRoot, "systemTests"); - var stagingTests = getJsonArray(testRoot, "stagingTests"); - var productionTests = getJsonArray(testRoot, "productionTests"); - return new TestDescriptor(version, Map.of( + var version = root.field(JSON_FIELD_VERSION).asString(); + var testRoot = root.field(JSON_FIELD_CONFIGURED_TESTS); + var systemTests = getJsonArray(testRoot, JSON_FIELD_SYSTEM_TESTS); + var stagingTests = getJsonArray(testRoot, JSON_FIELD_STAGING_TESTS); + var productionTests = getJsonArray(testRoot, JSON_FIELD_PRODUCTION_TESTS); + return new TestDescriptor(version, toMap(systemTests, stagingTests, productionTests)); + } + + public static TestDescriptor from( + String version, List<String> systemTests, List<String> stagingTests, List<String> productionTests) { + return new TestDescriptor(version, toMap(systemTests, stagingTests, productionTests)); + } + + private static Map<TestCategory, List<String>> toMap( + List<String> systemTests, List<String> stagingTests, List<String> productionTests) { + return Map.of( TestCategory.systemtest, systemTests, TestCategory.stagingtest, stagingTests, TestCategory.productiontest, productionTests - )); + ); } private static List<String> getJsonArray(Cursor cursor, String field) { @@ -51,6 +73,26 @@ public class TestDescriptor { return List.copyOf(configuredTestClasses.get(category)); } + public String toJson() { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + root.setString(JSON_FIELD_VERSION, this.version); + Cursor tests = root.setObject(JSON_FIELD_CONFIGURED_TESTS); + addJsonArrayForTests(tests, JSON_FIELD_SYSTEM_TESTS, TestCategory.systemtest); + addJsonArrayForTests(tests, JSON_FIELD_STAGING_TESTS, TestCategory.stagingtest); + addJsonArrayForTests(tests, JSON_FIELD_PRODUCTION_TESTS, TestCategory.productiontest); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + uncheck(() -> new JsonFormat(/*compact*/false).encode(out, slime)); + return out.toString(); + } + + private void addJsonArrayForTests(Cursor testsRoot, String fieldName, TestCategory category) { + List<String> tests = configuredTestClasses.get(category); + if (tests.isEmpty()) return; + Cursor cursor = testsRoot.setArray(fieldName); + tests.forEach(cursor::addString); + } + @Override public String toString() { return "TestClassDescriptor{" + diff --git a/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java b/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java index 2676d9d79da..7e59af9ced8 100644 --- a/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java +++ b/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.hosted.api; +import com.yahoo.test.json.JsonTestHelper; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -67,4 +68,20 @@ public class TestDescriptorTest { var productionTests = testClassDescriptor.getConfiguredTests(TestDescriptor.TestCategory.productiontest); Assertions.assertIterableEquals(List.of("ai.vespa.test.ProductionTest1", "ai.vespa.test.ProductionTest2"), productionTests); } + + @Test + public void generatesCorrectJson() { + String json = "{\n" + + " \"version\": \"1.0\",\n" + + " \"configuredTests\": {\n" + + " \"systemTests\": [\n" + + " \"ai.vespa.test.SystemTest1\",\n" + + " \"ai.vespa.test.SystemTest2\"\n" + + " ]\n" + + " " + + " }\n" + + "}\n"; + var descriptor = TestDescriptor.fromJsonString(json); + JsonTestHelper.assertJsonEquals(json, descriptor.toJson()); + } } diff --git a/tenant-base/pom.xml b/tenant-base/pom.xml index 767119d2a02..490dcf6371b 100644 --- a/tenant-base/pom.xml +++ b/tenant-base/pom.xml @@ -117,6 +117,11 @@ </exclusion> </exclusions> </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-engine</artifactId> + <scope>test</scope> + </dependency> </dependencies> <profiles> diff --git a/vespa-maven-plugin/pom.xml b/vespa-maven-plugin/pom.xml index 9f1d6f5ff6b..0910f38d5e5 100644 --- a/vespa-maven-plugin/pom.xml +++ b/vespa-maven-plugin/pom.xml @@ -44,10 +44,21 @@ <artifactId>config-application-package</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>tenant-cd-api</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>yolean</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> - <groupId>commons-cli</groupId> - <artifactId>commons-cli</artifactId> + <groupId>org.ow2.asm</groupId> + <artifactId>asm</artifactId> </dependency> <dependency> <groupId>org.apache.maven</groupId> diff --git a/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java new file mode 100644 index 00000000000..8309b7a8124 --- /dev/null +++ b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java @@ -0,0 +1,61 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.hosted.plugin; + +import ai.vespa.hosted.api.TestDescriptor; +import org.apache.maven.plugin.AbstractMojo; +import org.apache.maven.plugin.MojoExecutionException; +import org.apache.maven.plugins.annotations.Mojo; +import org.apache.maven.plugins.annotations.Parameter; +import org.apache.maven.project.MavenProject; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.stream.Stream; + +/** + * Generates a test descriptor file based on content of the compiled test classes + * + * @author bjorncs + */ +@Mojo(name = "generateTestDescriptor", threadSafe = true) +public class GenerateTestDescriptorMojo extends AbstractMojo { + + @Parameter(defaultValue = "${project}", readonly = true) + protected MavenProject project; + + @Override + public void execute() throws MojoExecutionException { + TestAnnotationAnalyzer analyzer = new TestAnnotationAnalyzer(); + analyzeTestClasses(analyzer); + TestDescriptor descriptor = TestDescriptor.from( + TestDescriptor.CURRENT_VERSION, + analyzer.systemTests(), + analyzer.stagingTests(), + analyzer.productionTests()); + writeDescriptorFile(descriptor); + } + + private void analyzeTestClasses(TestAnnotationAnalyzer analyzer) throws MojoExecutionException { + try (Stream<Path> files = Files.walk(testClassesDirectory())) { + files + .filter(f -> f.toString().endsWith(".class")) + .forEach(analyzer::analyzeClass); + } catch (Exception e) { + throw new MojoExecutionException("Failed to analyze test classes: " + e.getMessage(), e); + } + } + + private void writeDescriptorFile(TestDescriptor descriptor) throws MojoExecutionException { + try { + Path descriptorFile = testClassesDirectory().resolve(TestDescriptor.DEFAULT_FILENAME); + Files.createDirectories(descriptorFile.getParent()); + Files.write(descriptorFile, descriptor.toJson().getBytes()); + } catch (IOException e) { + throw new MojoExecutionException("Failed to write test descriptor file: " + e.getMessage(), e); + } + } + + private Path testClassesDirectory() { return Paths.get(project.getBuild().getTestOutputDirectory()); } +} diff --git a/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java new file mode 100644 index 00000000000..c45ef21bc31 --- /dev/null +++ b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java @@ -0,0 +1,74 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.hosted.plugin; + + +import ai.vespa.hosted.cd.ProductionTest; +import ai.vespa.hosted.cd.StagingTest; +import ai.vespa.hosted.cd.SystemTest; +import org.objectweb.asm.AnnotationVisitor; +import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; +import org.objectweb.asm.Opcodes; +import org.objectweb.asm.Type; + +import java.io.IOException; +import java.io.InputStream; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; + +/** + * Analyzes test classes and tracks all classes containing hosted Vespa test annotations ({@link ai.vespa.hosted.cd}). + * + * @author bjorncs + */ +class TestAnnotationAnalyzer { + + private final List<String> systemTests = new ArrayList<>(); + private final List<String> stagingTests = new ArrayList<>(); + private final List<String> productionTests = new ArrayList<>(); + + List<String> systemTests() { return systemTests; } + List<String> stagingTests() { return stagingTests; } + List<String> productionTests() { return productionTests; } + + void analyzeClass(Path classFile) { + try (InputStream in = Files.newInputStream(classFile)) { + new ClassReader(in).accept(new AsmClassVisitor(), ClassReader.SKIP_DEBUG); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private class AsmClassVisitor extends ClassVisitor { + + private String className; + + AsmClassVisitor() { super(Opcodes.ASM7); } + + @Override + public void visit( + int version, int access, String name, String signature, String superName, String[] interfaces) { + Type type = Type.getObjectType(name); + if (type.getSort() == Type.OBJECT) { + this.className = type.getClassName(); + super.visit(version, access, name, signature, superName, interfaces); + } + } + + @Override + public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) { + String annotationClassName = Type.getType(descriptor).getClassName(); + if (ProductionTest.class.getName().equals(annotationClassName)) { + productionTests.add(className); + } else if (StagingTest.class.getName().equals(annotationClassName)) { + stagingTests.add(className); + } else if (SystemTest.class.getName().equals(annotationClassName)) { + systemTests.add(className); + } + return null; + } + } +} |