summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-02-06 20:09:31 +0100
committerGitHub <noreply@github.com>2018-02-06 20:09:31 +0100
commitc2797cb1f2745a1fee89610e6eb7a4c1d3215c18 (patch)
tree1dcdd2217ff440211e712aa331e102b1b626a7f7 /config-model/src/test/java/com
parentedefa8dd65dfb26f68d4d6b0c2019a3cbc9d0888 (diff)
parentd0eeb393e656e6ce5b1f461e802573e9ec145e72 (diff)
Merge pull request #4935 from vespa-engine/bratseth/typecheck-all
Bratseth/typecheck all
Diffstat (limited to 'config-model/src/test/java/com')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java40
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java104
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java18
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java133
7 files changed, 239 insertions, 71 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
index 442c8bd41bd..11093d9f008 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java
@@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException {
RankProfileRegistry registry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(registry);
+ SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes());
builder.importString("search test {\n" +
" document test { } \n" +
" rank-profile p1 {}\n" +
" rank-profile p2 {}\n" +
"}");
- builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
+ builder.build(new BaseDeployLogger());
Search search = builder.getSearch();
assertEquals(4, registry.allRankProfiles().size());
@@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search);
}
- private static QueryProfiles setupQueryProfileTypes() {
+ private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(numeric)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return new QueryProfiles(registry);
+ return registry;
}
private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index e94880e61c7..82b9f5ac043 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase
builder.importString(
"search test {\n" +
" document test { \n" +
+ " field rating_yelp type int {" +
+ " indexing: attribute" +
+ " }" +
" }\n" +
" \n" +
" rank-profile test {\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 5100ac15c40..ed1b00e2875 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -2,7 +2,10 @@
package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
+import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.QueryProfileType;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
@@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
censorBindingHash(testRankProperties.get(4).toString()));
}
-
@Test
public void testNeuralNetworkSetup() throws ParseException {
+ // Note: the type assigned to query profile and constant tensors here is not the correct type
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])");
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
" expression: sum(final_layer)\n" +
" }\n" +
" }\n" +
- "\n" +
+ " constant W_hidden {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_input {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant W_final {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_final {\n" +
+ " type: tensor(x[])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
"}\n");
builder.build();
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- new QueryProfileRegistry(),
+ queryProfiles,
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(0).toString());
@@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
testRankProperties.get(5).toString());
}
+ private QueryProfileRegistry queryProfileWith(String field, String type) {
+ QueryProfileType queryProfileType = new QueryProfileType("root");
+ queryProfileType.addField(new FieldDescription(field, type));
+ QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry();
+ queryProfileRegistry.getTypeRegistry().register(queryProfileType);
+ QueryProfile profile = new QueryProfile("default");
+ profile.setType(queryProfileType);
+ queryProfileRegistry.register(profile);
+ return queryProfileRegistry;
+ }
+
private String censorBindingHash(String s) {
StringBuilder b = new StringBuilder();
boolean areInHash = false;
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 800697b3430..0ce6129ef7f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -38,7 +38,8 @@ class RankProfileSearchFixture {
RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry,
String rankProfiles, String constant, String field)
throws ParseException {
- SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry());
+ this.queryProfileRegistry = queryProfileRegistry;
+ SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry);
String sdContent = "search test {\n" +
" " + (constant != null ? constant : "") + "\n" +
" document test {\n" +
@@ -50,7 +51,6 @@ class RankProfileSearchFixture {
builder.importString(sdContent);
builder.build();
search = builder.getSearch();
- this.queryProfileRegistry = queryProfileRegistry;
}
public void assertFirstPhaseExpression(String expExpression, String rankProfile) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
new file mode 100644
index 00000000000..db3b12db1bf
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java
@@ -0,0 +1,104 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.yolean.Exceptions;
+import org.junit.Test;
+import static com.yahoo.config.model.test.TestUtil.joinLines;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class RankingExpressionTypeValidatorTestCase {
+
+ @Test
+ public void tensorFirstPhaseMustProduceDouble() throws Exception {
+ try {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry);
+ searchBuilder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ searchBuilder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorSecondPhaseMustProduceDouble() throws Exception {
+ try {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry);
+ searchBuilder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(attribute(a))",
+ " }",
+ " second-phase {",
+ " expression: attribute(a)",
+ " }",
+ " }",
+ "}"
+ ));
+ searchBuilder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception {
+ try {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry);
+ searchBuilder.importString(joinLines(
+ "search test {",
+ " document test { ",
+ " field a type tensor(x[],y[]) {",
+ " indexing: attribute",
+ " }",
+ " field b type tensor(z[10]) {",
+ " indexing: attribute",
+ " }",
+ " }",
+ " rank-profile my_rank_profile {",
+ " first-phase {",
+ " expression: sum(if(1>0, attribute(a), attribute(b)))",
+ " }",
+ " }",
+ "}"
+ ));
+ searchBuilder.build();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 7246b22b0f8..58af8daf1b5 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -51,7 +51,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReference() throws ParseException {
+ public void testTensorFlowReference() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -60,7 +60,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceWithConstantFeature() throws ParseException {
+ public void testTensorFlowReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"tensorflow('mnist_softmax/saved')",
"constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
@@ -71,10 +71,10 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceWithQueryFeature() throws ParseException {
+ public void testTensorFlowReferenceWithQueryFeature() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -90,7 +90,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceWithDocumentFeature() throws ParseException {
+ public void testTensorFlowReferenceWithDocumentFeature() {
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
"tensorflow('mnist_softmax/saved')",
@@ -103,10 +103,10 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceWithFeatureCombination() throws ParseException {
+ public void testTensorFlowReferenceWithFeatureCombination() {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType = "<query-profile-type id='root'>" +
- " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -122,7 +122,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testNestedTensorFlowReference() throws ParseException {
+ public void testNestedTensorFlowReference() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(tensorflow('mnist_softmax/saved'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
@@ -131,7 +131,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReferenceSpecifyingSignature() throws ParseException {
+ public void testTensorFlowReferenceSpecifyingSignature() {
RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved', 'serving_default')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index c18cfcfe1aa..d2211b86c9e 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,98 +17,129 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
-import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
import java.util.List;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
public class TensorTransformTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException {
- assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)");
- assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)");
- assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))");
- assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))");
- assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))");
- assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)");
- assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)");
- assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)");
- assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)");
- assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)");
+ assertTransformedExpression("max(1.0,2.0)",
+ "max(1.0,2.0)");
+ assertTransformedExpression("min(attribute(double_field),x)",
+ "min(attribute(double_field),x)");
+ assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))",
+ "max(attribute(double_field),attribute(double_array_field))");
+ assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))",
+ "min(attribute(tensor_field_1),attribute(double_field))");
+ assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)",
+ "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)");
+ assertTransformedExpression("min(constant(test_constant_tensor),1.0)",
+ "min(test_constant_tensor,1.0)");
+ assertTransformedExpression("max(constant(base_constant_tensor),1.0)",
+ "max(base_constant_tensor,1.0)");
+ assertTransformedExpression("min(constant(file_constant_tensor),1.0)",
+ "min(constant(file_constant_tensor),1.0)");
+ assertTransformedExpression("max(query(q),1.0)",
+ "max(query(q),1.0)");
+ assertTransformedExpression("max(query(n),1.0)",
+ "max(query(n),1.0)");
}
@Test
public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException {
- assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)");
- assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)");
- assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)");
- assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
- assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
- assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error.
- assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)");
+ assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)",
+ "max(attribute(tensor_field_1),x)");
+ assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)",
+ "1 + max(attribute(tensor_field_1),x)");
+ assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)",
+ "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)");
+ assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)",
+ "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)");
+ assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)",
+ "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)");
+ assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)",
+ "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error.
+ assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)",
+ "max(max(attribute(tensor_field_2),x),y)");
}
@Test
public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException {
- assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)");
- assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)");
- assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)");
+ assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)",
+ "max(test_constant_tensor,x)");
+ assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)",
+ "max(base_constant_tensor,x)");
+ assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)",
+ "min(constant(file_constant_tensor),x)");
}
@Test
public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException {
- assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)");
- assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
- assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
- assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
- assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
+ assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)",
+ "min(attribute(double_field) + attribute(tensor_field_1),x)");
+ assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)",
+ "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)");
+ assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"
+ , "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)");
+ assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)",
+ "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
+ assertTransformedExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)",
+ "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
}
@Test
public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException {
- assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)");
- assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)");
- assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)");
- assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)");
+ assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)",
+ "max(tensorFromLabels(attribute(double_array_field)),double_array_field)");
+ assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)",
+ "max(tensorFromLabels(attribute(double_array_field),x),x)");
+ assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)",
+ "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)");
+ assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)",
+ "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)");
}
@Test
public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException {
- assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)");
- assertContainsExpression("max(query(n),x)", "max(query(n),x)");
+ assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)");
+ assertTransformedExpression("max(query(n),x)", "max(query(n),x)");
}
@Test
public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException {
- assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)");
- assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)");
- assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)");
- assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)");
+ assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)",
+ "max(returns_tensor,x)");
+ assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)",
+ "max(wraps_returns_tensor,x)");
+ assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)",
+ "max(tensor_inheriting,x)");
+ assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)",
+ "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)");
}
- private void assertContainsExpression(String expr, String transformedExpression) throws ParseException {
- assertTrue("Expected expression '" + transformedExpression + "' found",
- containsExpression(expr, transformedExpression));
- }
-
- private boolean containsExpression(String expr, String transformedExpression) throws ParseException {
- for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) {
+ private void assertTransformedExpression(String expected, String original) throws ParseException {
+ for (Pair<String, String> rankPropertyExpression : buildSearch(original)) {
String rankProperty = rankPropertyExpression.getFirst();
if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) {
String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ",""));
- return rankExpression.equals(transformedExpression);
+ assertEquals(expected, rankExpression);
+ return;
}
}
- return false;
+ fail("No 'rankingExpression(firstphase).rankingScript' property produced");
}
private List<Pair<String, String>> buildSearch(String expression) throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
+ QueryProfileRegistry queryProfiles = setupQueryProfileTypes();
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -167,16 +198,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
" }\n" +
" }\n" +
"}\n");
- builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
+ builder.build(new BaseDeployLogger());
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- new QueryProfileRegistry(),
+ queryProfiles,
new AttributeFields(s)).configProperties();
return testRankProperties;
}
- private static QueryProfiles setupQueryProfileTypes() {
+ private static QueryProfileRegistry setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -185,7 +216,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(n)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return new QueryProfiles(registry);
+ return registry;
}
private String censorBindingHash(String s) {