summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--container-search/abi-spec.json3
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java13
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/SelectParser.java51
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/yql/YqlParser.java23
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java6
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java14
-rw-r--r--container-search/src/test/java/com/yahoo/select/SelectTestCase.java13
-rw-r--r--hosted-api/pom.xml12
-rw-r--r--hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java56
-rw-r--r--hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java17
-rw-r--r--tenant-base/pom.xml5
-rw-r--r--vespa-maven-plugin/pom.xml15
-rw-r--r--vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java61
-rw-r--r--vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java74
15 files changed, 333 insertions, 32 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 2b4424654a2..ba52826cd3f 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -869,7 +869,8 @@
"public java.lang.String getName()",
"public int getTermCount()",
"public int encode(java.nio.ByteBuffer)",
- "protected void appendBodyString(java.lang.StringBuilder)"
+ "protected void appendBodyString(java.lang.StringBuilder)",
+ "public void disclose(com.yahoo.prelude.query.textualrepresentation.Discloser)"
],
"fields": []
},
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
index 52ef6c40a6a..be3ae913476 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
@@ -4,6 +4,7 @@ package com.yahoo.prelude.query;
import com.google.common.annotations.Beta;
import com.yahoo.compress.IntegerCompressor;
+import com.yahoo.prelude.query.textualrepresentation.Discloser;
import java.nio.ByteBuffer;
@@ -83,7 +84,17 @@ public class NearestNeighborItem extends SimpleTaggableItem {
buffer.append(",queryTensorName=").append(queryTensorName);
buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
buffer.append(",approximate=").append(approximate);
- buffer.append(",targetNumHits=").append(targetNumHits).append("}");
+ buffer.append(",targetHits=").append(targetNumHits).append("}");
+ }
+
+ @Override
+ public void disclose(Discloser discloser) {
+ super.disclose(discloser);
+ discloser.addProperty("field", field);
+ discloser.addProperty("queryTensorName", queryTensorName);
+ discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits);
+ discloser.addProperty("approximate", approximate);
+ discloser.addProperty("targetHits", targetNumHits);
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java
index 775dca7c444..9910eb9532d 100644
--- a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java
+++ b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java
@@ -16,6 +16,7 @@ import com.yahoo.prelude.query.IntItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.Limit;
import com.yahoo.prelude.query.NearItem;
+import com.yahoo.prelude.query.NearestNeighborItem;
import com.yahoo.prelude.query.NotItem;
import com.yahoo.prelude.query.ONearItem;
import com.yahoo.prelude.query.OrItem;
@@ -93,14 +94,17 @@ public class SelectParser implements Parser {
private static final String ACCENT_DROP = "accentDrop";
private static final String ALTERNATIVES = "alternatives";
private static final String AND_SEGMENTING = "andSegmenting";
+ private static final String APPROXIMATE = "approximate";
private static final String DISTANCE = "distance";
private static final String DOT_PRODUCT = "dotProduct";
private static final String EQUIV = "equiv";
private static final String FILTER = "filter";
private static final String HIT_LIMIT = "hitLimit";
+ private static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits";
private static final String IMPLICIT_TRANSFORMS = "implicitTransforms";
private static final String LABEL = "label";
private static final String NEAR = "near";
+ private static final String NEAREST_NEIGHBOR = "nearestNeighbor";
private static final String NORMALIZE_CASE = "normalizeCase";
private static final String ONEAR = "onear";
private static final String PHRASE = "phrase";
@@ -114,6 +118,7 @@ public class SelectParser implements Parser {
private static final String STEM = "stem";
private static final String SUBSTRING = "substring";
private static final String SUFFIX = "suffix";
+ private static final String TARGET_HITS = "targetHits";
private static final String TARGET_NUM_HITS = "targetNumHits";
private static final String THRESHOLD_BOOST_FACTOR = "thresholdBoostFactor";
private static final String UNIQUE_ID = "id";
@@ -130,7 +135,7 @@ public class SelectParser implements Parser {
private static final String CONTAINS = "contains";
private static final String MATCHES = "matches";
private static final String CALL = "call";
- private static final List<String> FUNCTION_CALLS = Arrays.asList(WAND, WEIGHTED_SET, DOT_PRODUCT, PREDICATE, RANK, WEAK_AND);
+ private static final List<String> FUNCTION_CALLS = Arrays.asList(WAND, WEIGHTED_SET, DOT_PRODUCT, NEAREST_NEIGHBOR, PREDICATE, RANK, WEAK_AND);
public SelectParser(ParserEnvironment environment) {
indexFacts = environment.getIndexFacts();
@@ -259,6 +264,8 @@ public class SelectParser implements Parser {
return buildWeightedSet(key, value);
case DOT_PRODUCT:
return buildDotProduct(key, value);
+ case NEAREST_NEIGHBOR:
+ return buildNearestNeighbor(key, value);
case PREDICATE:
return buildPredicate(key, value);
case RANK:
@@ -266,7 +273,7 @@ public class SelectParser implements Parser {
case WEAK_AND:
return buildWeakAnd(key, value);
default:
- throw newUnexpectedArgumentException(key, DOT_PRODUCT, RANK, WAND, WEAK_AND, WEIGHTED_SET, PREDICATE);
+ throw newUnexpectedArgumentException(key, DOT_PRODUCT, NEAREST_NEIGHBOR, RANK, WAND, WEAK_AND, WEIGHTED_SET, PREDICATE);
}
}
@@ -403,6 +410,38 @@ public class SelectParser implements Parser {
return orItem;
}
+ private Item buildNearestNeighbor(String key, Inspector value) {
+
+ HashMap<Integer, Inspector> children = childMap(value);
+ Preconditions.checkArgument(children.size() == 2, "Expected 2 arguments, got %s.", children.size());
+ String field = children.get(0).asString();
+ String property = children.get(1).asString();
+ NearestNeighborItem item = new NearestNeighborItem(field, property);
+ Inspector annotations = getAnnotations(value);
+ if (annotations != null){
+ annotations.traverse((ObjectTraverser) (annotation_name, annotation_value) -> {
+ if (TARGET_HITS.equals(annotation_name)){
+ item.setTargetNumHits((int)(annotation_value.asDouble()));
+ }
+ if (TARGET_NUM_HITS.equals(annotation_name)){
+ item.setTargetNumHits((int)(annotation_value.asDouble()));
+ }
+ if (HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) {
+ int hnswExploreAdditionalHits = (int)(annotation_value.asDouble());
+ item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits);
+ }
+ if (APPROXIMATE.equals(annotation_name)) {
+ boolean allowApproximate = annotation_value.asBool();
+ item.setAllowApproximate(allowApproximate);
+ }
+ if (LABEL.equals(annotation_name)) {
+ item.setLabel(annotation_value.asString());
+ }
+ });
+ }
+ return item;
+ }
+
private CompositeItem buildWeakAnd(String key, Inspector value) {
WeakAndItem weakAnd = new WeakAndItem();
addItemsFromInspector(weakAnd, value);
@@ -410,6 +449,9 @@ public class SelectParser implements Parser {
if (annotations != null){
annotations.traverse((ObjectTraverser) (annotation_name, annotation_value) -> {
+ if (TARGET_HITS.equals(annotation_name)){
+ weakAnd.setN((int)(annotation_value.asDouble()));
+ }
if (TARGET_NUM_HITS.equals(annotation_name)){
weakAnd.setN((int)(annotation_value.asDouble()));
}
@@ -662,7 +704,10 @@ public class SelectParser implements Parser {
HashMap<Integer, Inspector> children = childMap(value);
Preconditions.checkArgument(children.size() == 2, "Expected 2 arguments, got %s.", children.size());
- Integer target_num_hits= getIntegerAnnotation(TARGET_NUM_HITS, annotations, DEFAULT_TARGET_NUM_HITS);
+ Integer target_num_hits= getIntegerAnnotation(TARGET_HITS, annotations, null);
+ if (target_num_hits == null) {
+ target_num_hits= getIntegerAnnotation(TARGET_NUM_HITS, annotations, DEFAULT_TARGET_NUM_HITS);
+ }
WandItem out = new WandItem(children.get(0).asString(), target_num_hits);
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
index 76b8c1ef8a2..aca2998cba3 100644
--- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -97,7 +97,7 @@ public class ValidateNearestNeighborSearcher extends Searcher {
/** Returns an error message if this is invalid, or null if it is valid */
private String validate(NearestNeighborItem item) {
if (item.getTargetNumHits() < 1)
- return item + " has invalid targetNumHits " + item.getTargetNumHits() + ": Must be >= 1";
+ return item + " has invalid targetHits " + item.getTargetNumHits() + ": Must be >= 1";
String queryFeatureName = "query(" + item.getQueryTensorName() + ")";
Optional<Tensor> queryTensor = query.getRanking().getFeatures().getTensor(queryFeatureName);
diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
index f4560806dd2..7d17fe4f09d 100644
--- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
+++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java
@@ -173,6 +173,7 @@ public class YqlParser implements Parser {
static final String STEM = "stem";
static final String SUBSTRING = "substring";
static final String SUFFIX = "suffix";
+ static final String TARGET_HITS = "targetHits";
static final String TARGET_NUM_HITS = "targetNumHits";
static final String THRESHOLD_BOOST_FACTOR = "thresholdBoostFactor";
static final String UNIQUE_ID = "id";
@@ -418,8 +419,12 @@ public class YqlParser implements Parser {
String field = fetchFieldRead(args.get(0));
String property = fetchFieldRead(args.get(1));
NearestNeighborItem item = new NearestNeighborItem(field, property);
- Integer targetNumHits = getAnnotation(ast, TARGET_NUM_HITS,
+ Integer targetNumHits = getAnnotation(ast, TARGET_HITS,
Integer.class, null, "desired minimum hits to produce");
+ if (targetNumHits == null) {
+ targetNumHits = getAnnotation(ast, TARGET_NUM_HITS,
+ Integer.class, null, "desired minimum hits to produce");
+ }
if (targetNumHits != null) {
item.setTargetNumHits(targetNumHits);
}
@@ -504,9 +509,13 @@ public class YqlParser implements Parser {
List<OperatorNode<ExpressionOperator>> args = ast.getArgument(1);
Preconditions.checkArgument(args.size() == 2, "Expected 2 arguments, got %s.", args.size());
- WandItem out = new WandItem(getIndex(args.get(0)), getAnnotation(ast,
- TARGET_NUM_HITS, Integer.class, DEFAULT_TARGET_NUM_HITS,
- "desired number of hits to accumulate in wand"));
+ Integer targetNumHits = getAnnotation(ast, TARGET_HITS,
+ Integer.class, null, "desired number of hits to accumulate in wand");
+ if (targetNumHits == null) {
+ targetNumHits = getAnnotation(ast, TARGET_NUM_HITS,
+ Integer.class, DEFAULT_TARGET_NUM_HITS, "desired number of hits to accumulate in wand");
+ }
+ WandItem out = new WandItem(getIndex(args.get(0)), targetNumHits);
Double scoreThreshold = getAnnotation(ast, SCORE_THRESHOLD, Double.class, null,
"min score for hit inclusion");
if (scoreThreshold != null) {
@@ -1028,8 +1037,12 @@ public class YqlParser implements Parser {
private CompositeItem buildWeakAnd(OperatorNode<ExpressionOperator> spec) {
WeakAndItem weakAnd = new WeakAndItem();
- Integer targetNumHits = getAnnotation(spec, TARGET_NUM_HITS,
+ Integer targetNumHits = getAnnotation(spec, TARGET_HITS,
Integer.class, null, "desired minimum hits to produce");
+ if (targetNumHits == null) {
+ targetNumHits = getAnnotation(spec, TARGET_NUM_HITS,
+ Integer.class, null, "desired minimum hits to produce");
+ }
if (targetNumHits != null) {
weakAnd.setN(targetNumHits);
}
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 2c849a9b52c..c49603737a6 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -93,7 +93,7 @@ public class ValidateNearestNeighborTestCase {
}
private String makeQuery(String attributeTensor, String queryTensor) {
- return "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");";
+ return "select * from sources * where [{\"targetHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");";
}
@Test
@@ -139,7 +139,7 @@ public class ValidateNearestNeighborTestCase {
r.append(",queryTensorName=").append(qt);
r.append(",hnsw.exploreAdditionalHits=0");
r.append(",approximate=true");
- r.append(",targetNumHits=").append(th);
+ r.append(",targetHits=").append(th);
r.append("} ").append(errmsg);
return r.toString();
}
@@ -149,7 +149,7 @@ public class ValidateNearestNeighborTestCase {
String q = "select * from sources * where nearestNeighbor(dvector,qvector);";
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
- assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetNumHits 0: Must be >= 1"), r);
+ assertErrMsg(desc("dvector", "qvector", 0, "has invalid targetHits 0: Must be >= 1"), r);
}
@Test
diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
index e43dbd4e266..2ace21daace 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java
@@ -513,7 +513,7 @@ public class YqlParserTestCase {
public void testWand() {
assertParse("select foo from bar where wand(description, {\"a\":1, \"b\":2});",
"WAND(10,0.0,1.0) description{[1]:\"a\",[2]:\"b\"}");
- assertParse("select foo from bar where [ {\"scoreThreshold\": 13.3, \"targetNumHits\": 7, " +
+ assertParse("select foo from bar where [ {\"scoreThreshold\": 13.3, \"targetHits\": 7, " +
"\"thresholdBoostFactor\": 2.3} ]wand(description, {\"a\":1, \"b\":2});",
"WAND(7,13.3,2.3) description{[1]:\"a\",[2]:\"b\"}");
}
@@ -550,11 +550,11 @@ public class YqlParserTestCase {
@Test
public void testNearestNeighbor() {
assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=0}");
- assertParse("select foo from bar where [{\"targetNumHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetNumHits=37}");
- assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetNumHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);",
- "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetNumHits=3}");
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}");
+ assertParse("select foo from bar where [{\"targetHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}");
+ assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);",
+ "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetHits=3}");
}
@Test
@@ -597,7 +597,7 @@ public class YqlParserTestCase {
public void testWeakAnd() {
assertParse("select foo from bar where weakAnd(a contains \"A\", b contains \"B\");",
"WAND(100) a:A b:B");
- assertParse("select foo from bar where [{\"targetNumHits\": 37}]weakAnd(a contains \"A\", " +
+ assertParse("select foo from bar where [{\"targetHits\": 37}]weakAnd(a contains \"A\", " +
"b contains \"B\");",
"WAND(37) a:A b:B");
diff --git a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java
index 1715ed38964..4691ef42e55 100644
--- a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java
+++ b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java
@@ -473,7 +473,7 @@ public class SelectTestCase {
public void testWand() {
assertParse("{ \"wand\": [\"description\", { \"a\": 1, \"b\": 2 }] }",
"WAND(10,0.0,1.0) description{[1]:\"a\",[2]:\"b\"}");
- assertParse("{ \"wand\": { \"children\": [\"description\", { \"a\": 1, \"b\": 2 }], \"attributes\": { \"scoreThreshold\": 13.3, \"targetNumHits\": 7, \"thresholdBoostFactor\": 2.3 } } }",
+ assertParse("{ \"wand\": { \"children\": [\"description\", { \"a\": 1, \"b\": 2 }], \"attributes\": { \"scoreThreshold\": 13.3, \"targetHits\": 7, \"thresholdBoostFactor\": 2.3 } } }",
"WAND(7,13.3,2.3) description{[1]:\"a\",[2]:\"b\"}");
}
@@ -522,10 +522,19 @@ public class SelectTestCase {
}
@Test
+ public void testNearestNeighbor() {
+ assertParse("{ \"nearestNeighbor\": [ \"f1field\", \"q2prop\" ] }",
+ "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}");
+
+ assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37} }}",
+ "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}");
+ }
+
+ @Test
public void testWeakAnd() {
assertParse("{ \"weakAnd\": [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ] }",
"WAND(100) a:A b:B");
- assertParse("{ \"weakAnd\": { \"children\" : [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ], \"attributes\" : {\"targetNumHits\": 37} }}",
+ assertParse("{ \"weakAnd\": { \"children\" : [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ], \"attributes\" : {\"targetHits\": 37} }}",
"WAND(37) a:A b:B");
QueryTree tree = parseWhere("{ \"weakAnd\": { \"children\" : [{ \"contains\": [\"a\", \"A\"] }, { \"contains\": [\"b\", \"B\"] } ], \"attributes\" : {\"scoreThreshold\": 41}}}");
diff --git a/hosted-api/pom.xml b/hosted-api/pom.xml
index 2a42f890ba4..b066cb158e0 100644
--- a/hosted-api/pom.xml
+++ b/hosted-api/pom.xml
@@ -33,6 +33,12 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>yolean</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
@@ -44,6 +50,12 @@
<artifactId>junit-vintage-engine</artifactId>
<scope>test</scope>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>testutil</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
</project>
diff --git a/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java b/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java
index 37858148ef0..08cd3932ae7 100644
--- a/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java
+++ b/hosted-api/src/main/java/ai/vespa/hosted/api/TestDescriptor.java
@@ -3,18 +3,30 @@ package ai.vespa.hosted.api;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Inspector;
+import com.yahoo.slime.JsonFormat;
+import com.yahoo.slime.Slime;
import com.yahoo.slime.SlimeStream;
import com.yahoo.slime.SlimeUtils;
+import java.io.ByteArrayOutputStream;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
+import static com.yahoo.yolean.Exceptions.uncheck;
+
/**
* @author mortent
*/
public class TestDescriptor {
public static final String DEFAULT_FILENAME = "META-INF/ai.vespa/testDescriptor.json";
+ public static final String CURRENT_VERSION = "1.0";
+
+ private static final String JSON_FIELD_VERSION = "version";
+ private static final String JSON_FIELD_CONFIGURED_TESTS = "configuredTests";
+ private static final String JSON_FIELD_SYSTEM_TESTS = "systemTests";
+ private static final String JSON_FIELD_STAGING_TESTS = "stagingTests";
+ private static final String JSON_FIELD_PRODUCTION_TESTS = "productionTests";
private final Map<TestCategory, List<String>> configuredTestClasses;
private final String version;
@@ -27,16 +39,26 @@ public class TestDescriptor {
public static TestDescriptor fromJsonString(String testDescriptor) {
var slime = SlimeUtils.jsonToSlime(testDescriptor);
var root = slime.get();
- var version = root.field("version").asString();
- var testRoot = root.field("configuredTests");
- var systemTests = getJsonArray(testRoot, "systemTests");
- var stagingTests = getJsonArray(testRoot, "stagingTests");
- var productionTests = getJsonArray(testRoot, "productionTests");
- return new TestDescriptor(version, Map.of(
+ var version = root.field(JSON_FIELD_VERSION).asString();
+ var testRoot = root.field(JSON_FIELD_CONFIGURED_TESTS);
+ var systemTests = getJsonArray(testRoot, JSON_FIELD_SYSTEM_TESTS);
+ var stagingTests = getJsonArray(testRoot, JSON_FIELD_STAGING_TESTS);
+ var productionTests = getJsonArray(testRoot, JSON_FIELD_PRODUCTION_TESTS);
+ return new TestDescriptor(version, toMap(systemTests, stagingTests, productionTests));
+ }
+
+ public static TestDescriptor from(
+ String version, List<String> systemTests, List<String> stagingTests, List<String> productionTests) {
+ return new TestDescriptor(version, toMap(systemTests, stagingTests, productionTests));
+ }
+
+ private static Map<TestCategory, List<String>> toMap(
+ List<String> systemTests, List<String> stagingTests, List<String> productionTests) {
+ return Map.of(
TestCategory.systemtest, systemTests,
TestCategory.stagingtest, stagingTests,
TestCategory.productiontest, productionTests
- ));
+ );
}
private static List<String> getJsonArray(Cursor cursor, String field) {
@@ -51,6 +73,26 @@ public class TestDescriptor {
return List.copyOf(configuredTestClasses.get(category));
}
+ public String toJson() {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ root.setString(JSON_FIELD_VERSION, this.version);
+ Cursor tests = root.setObject(JSON_FIELD_CONFIGURED_TESTS);
+ addJsonArrayForTests(tests, JSON_FIELD_SYSTEM_TESTS, TestCategory.systemtest);
+ addJsonArrayForTests(tests, JSON_FIELD_STAGING_TESTS, TestCategory.stagingtest);
+ addJsonArrayForTests(tests, JSON_FIELD_PRODUCTION_TESTS, TestCategory.productiontest);
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ uncheck(() -> new JsonFormat(/*compact*/false).encode(out, slime));
+ return out.toString();
+ }
+
+ private void addJsonArrayForTests(Cursor testsRoot, String fieldName, TestCategory category) {
+ List<String> tests = configuredTestClasses.get(category);
+ if (tests.isEmpty()) return;
+ Cursor cursor = testsRoot.setArray(fieldName);
+ tests.forEach(cursor::addString);
+ }
+
@Override
public String toString() {
return "TestClassDescriptor{" +
diff --git a/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java b/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java
index 2676d9d79da..7e59af9ced8 100644
--- a/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java
+++ b/hosted-api/src/test/java/ai/vespa/hosted/api/TestDescriptorTest.java
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.hosted.api;
+import com.yahoo.test.json.JsonTestHelper;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
@@ -67,4 +68,20 @@ public class TestDescriptorTest {
var productionTests = testClassDescriptor.getConfiguredTests(TestDescriptor.TestCategory.productiontest);
Assertions.assertIterableEquals(List.of("ai.vespa.test.ProductionTest1", "ai.vespa.test.ProductionTest2"), productionTests);
}
+
+ @Test
+ public void generatesCorrectJson() {
+ String json = "{\n" +
+ " \"version\": \"1.0\",\n" +
+ " \"configuredTests\": {\n" +
+ " \"systemTests\": [\n" +
+ " \"ai.vespa.test.SystemTest1\",\n" +
+ " \"ai.vespa.test.SystemTest2\"\n" +
+ " ]\n" +
+ " " +
+ " }\n" +
+ "}\n";
+ var descriptor = TestDescriptor.fromJsonString(json);
+ JsonTestHelper.assertJsonEquals(json, descriptor.toJson());
+ }
}
diff --git a/tenant-base/pom.xml b/tenant-base/pom.xml
index 767119d2a02..490dcf6371b 100644
--- a/tenant-base/pom.xml
+++ b/tenant-base/pom.xml
@@ -117,6 +117,11 @@
</exclusion>
</exclusions>
</dependency>
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter-engine</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<profiles>
diff --git a/vespa-maven-plugin/pom.xml b/vespa-maven-plugin/pom.xml
index 9f1d6f5ff6b..0910f38d5e5 100644
--- a/vespa-maven-plugin/pom.xml
+++ b/vespa-maven-plugin/pom.xml
@@ -44,10 +44,21 @@
<artifactId>config-application-package</artifactId>
<version>${project.version}</version>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>tenant-cd-api</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>yolean</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+
<dependency>
- <groupId>commons-cli</groupId>
- <artifactId>commons-cli</artifactId>
+ <groupId>org.ow2.asm</groupId>
+ <artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>org.apache.maven</groupId>
diff --git a/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java
new file mode 100644
index 00000000000..8309b7a8124
--- /dev/null
+++ b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/GenerateTestDescriptorMojo.java
@@ -0,0 +1,61 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.hosted.plugin;
+
+import ai.vespa.hosted.api.TestDescriptor;
+import org.apache.maven.plugin.AbstractMojo;
+import org.apache.maven.plugin.MojoExecutionException;
+import org.apache.maven.plugins.annotations.Mojo;
+import org.apache.maven.plugins.annotations.Parameter;
+import org.apache.maven.project.MavenProject;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.stream.Stream;
+
+/**
+ * Generates a test descriptor file based on content of the compiled test classes
+ *
+ * @author bjorncs
+ */
+@Mojo(name = "generateTestDescriptor", threadSafe = true)
+public class GenerateTestDescriptorMojo extends AbstractMojo {
+
+ @Parameter(defaultValue = "${project}", readonly = true)
+ protected MavenProject project;
+
+ @Override
+ public void execute() throws MojoExecutionException {
+ TestAnnotationAnalyzer analyzer = new TestAnnotationAnalyzer();
+ analyzeTestClasses(analyzer);
+ TestDescriptor descriptor = TestDescriptor.from(
+ TestDescriptor.CURRENT_VERSION,
+ analyzer.systemTests(),
+ analyzer.stagingTests(),
+ analyzer.productionTests());
+ writeDescriptorFile(descriptor);
+ }
+
+ private void analyzeTestClasses(TestAnnotationAnalyzer analyzer) throws MojoExecutionException {
+ try (Stream<Path> files = Files.walk(testClassesDirectory())) {
+ files
+ .filter(f -> f.toString().endsWith(".class"))
+ .forEach(analyzer::analyzeClass);
+ } catch (Exception e) {
+ throw new MojoExecutionException("Failed to analyze test classes: " + e.getMessage(), e);
+ }
+ }
+
+ private void writeDescriptorFile(TestDescriptor descriptor) throws MojoExecutionException {
+ try {
+ Path descriptorFile = testClassesDirectory().resolve(TestDescriptor.DEFAULT_FILENAME);
+ Files.createDirectories(descriptorFile.getParent());
+ Files.write(descriptorFile, descriptor.toJson().getBytes());
+ } catch (IOException e) {
+ throw new MojoExecutionException("Failed to write test descriptor file: " + e.getMessage(), e);
+ }
+ }
+
+ private Path testClassesDirectory() { return Paths.get(project.getBuild().getTestOutputDirectory()); }
+}
diff --git a/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java
new file mode 100644
index 00000000000..c45ef21bc31
--- /dev/null
+++ b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/TestAnnotationAnalyzer.java
@@ -0,0 +1,74 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.hosted.plugin;
+
+
+import ai.vespa.hosted.cd.ProductionTest;
+import ai.vespa.hosted.cd.StagingTest;
+import ai.vespa.hosted.cd.SystemTest;
+import org.objectweb.asm.AnnotationVisitor;
+import org.objectweb.asm.ClassReader;
+import org.objectweb.asm.ClassVisitor;
+import org.objectweb.asm.Opcodes;
+import org.objectweb.asm.Type;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.UncheckedIOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Analyzes test classes and tracks all classes containing hosted Vespa test annotations ({@link ai.vespa.hosted.cd}).
+ *
+ * @author bjorncs
+ */
+class TestAnnotationAnalyzer {
+
+ private final List<String> systemTests = new ArrayList<>();
+ private final List<String> stagingTests = new ArrayList<>();
+ private final List<String> productionTests = new ArrayList<>();
+
+ List<String> systemTests() { return systemTests; }
+ List<String> stagingTests() { return stagingTests; }
+ List<String> productionTests() { return productionTests; }
+
+ void analyzeClass(Path classFile) {
+ try (InputStream in = Files.newInputStream(classFile)) {
+ new ClassReader(in).accept(new AsmClassVisitor(), ClassReader.SKIP_DEBUG);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private class AsmClassVisitor extends ClassVisitor {
+
+ private String className;
+
+ AsmClassVisitor() { super(Opcodes.ASM7); }
+
+ @Override
+ public void visit(
+ int version, int access, String name, String signature, String superName, String[] interfaces) {
+ Type type = Type.getObjectType(name);
+ if (type.getSort() == Type.OBJECT) {
+ this.className = type.getClassName();
+ super.visit(version, access, name, signature, superName, interfaces);
+ }
+ }
+
+ @Override
+ public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
+ String annotationClassName = Type.getType(descriptor).getClassName();
+ if (ProductionTest.class.getName().equals(annotationClassName)) {
+ productionTests.add(className);
+ } else if (StagingTest.class.getName().equals(annotationClassName)) {
+ stagingTests.add(className);
+ } else if (SystemTest.class.getName().equals(annotationClassName)) {
+ systemTests.add(className);
+ }
+ return null;
+ }
+ }
+}