diff options
Diffstat (limited to 'config-model/src/test')
11 files changed, 246 insertions, 78 deletions
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 2b231e0cda2..b6ad5372c05 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -35,7 +35,7 @@ rankprofile[3].name "profile2" rankprofile[3].fef.property[0].name "vespa.rank.firstphase" rankprofile[3].fef.property[0].value "rankingExpression(firstphase)" rankprofile[3].fef.property[1].name "rankingExpression(firstphase).rankingScript" -rankprofile[3].fef.property[1].value "reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x)" +rankprofile[3].fef.property[1].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" rankprofile[3].fef.property[2].name "vespa.type.attribute.f2" rankprofile[3].fef.property[2].value "tensor(x[2],y[])" rankprofile[3].fef.property[3].name "vespa.type.attribute.f3" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index a6a9a98db3a..3d64f6b807e 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -28,7 +28,7 @@ search tensor { rank-profile profile2 { first-phase { - expression: matmul(attribute(f4), diag(x[2],y[2],z[3]), x) + expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x)) } } diff --git a/config-model/src/test/examples/rankpropvars.sd b/config-model/src/test/examples/rankpropvars.sd index 40f9e73f35a..28959edbc09 100644 --- a/config-model/src/test/examples/rankpropvars.sd +++ b/config-model/src/test/examples/rankpropvars.sd @@ -18,8 +18,8 @@ first-phase { second-phase { expression { if (attribute(artist) == query(testvar1), - 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist), - 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) + 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist), + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) } } @@ -42,8 +42,8 @@ first-phase { second-phase { expression { if (attribute(artist) == query(testvar1), - 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist), - 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) + 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist), + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) } } } diff --git a/config-model/src/test/examples/simple.sd b/config-model/src/test/examples/simple.sd index 4fda7f5039e..96b0fa98098 100644 --- a/config-model/src/test/examples/simple.sd +++ b/config-model/src/test/examples/simple.sd @@ -116,7 +116,7 @@ search simple { first-phase { keep-rank-count:200 rank-score-drop-limit: -13.0 - expression: attribute(year) + expression: attribute(popularity) } second-phase { rerank-count: 99 diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 442c8bd41bd..11093d9f008 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { @Test public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException { RankProfileRegistry registry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(registry); + SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes()); builder.importString("search test {\n" + " document test { } \n" + " rank-profile p1 {}\n" + " rank-profile p2 {}\n" + "}"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search search = builder.getSearch(); assertEquals(4, registry.allRankProfiles().size()); @@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search); } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(numeric)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index e94880e61c7..82b9f5ac043 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.importString( "search test {\n" + " document test { \n" + + " field rating_yelp type int {" + + " indexing: attribute" + + " }" + " }\n" + " \n" + " rank-profile test {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 5100ac15c40..ed1b00e2875 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -2,7 +2,10 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; @@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase censorBindingHash(testRankProperties.get(4).toString())); } - @Test public void testNeuralNetworkSetup() throws ParseException { + // Note: the type assigned to query profile and constant tensors here is not the correct type RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])"); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase " expression: sum(final_layer)\n" + " }\n" + " }\n" + - "\n" + + " constant W_hidden {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_input {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant W_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); @@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase testRankProperties.get(5).toString()); } + private QueryProfileRegistry queryProfileWith(String field, String type) { + QueryProfileType queryProfileType = new QueryProfileType("root"); + queryProfileType.addField(new FieldDescription(field, type)); + QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry(); + queryProfileRegistry.getTypeRegistry().register(queryProfileType); + QueryProfile profile = new QueryProfile("default"); + profile.setType(queryProfileType); + queryProfileRegistry.register(profile); + return queryProfileRegistry; + } + private String censorBindingHash(String s) { StringBuilder b = new StringBuilder(); boolean areInHash = false; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 800697b3430..0ce6129ef7f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -38,7 +38,8 @@ class RankProfileSearchFixture { RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry, String rankProfiles, String constant, String field) throws ParseException { - SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry()); + this.queryProfileRegistry = queryProfileRegistry; + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry); String sdContent = "search test {\n" + " " + (constant != null ? constant : "") + "\n" + " document test {\n" + @@ -50,7 +51,6 @@ class RankProfileSearchFixture { builder.importString(sdContent); builder.build(); search = builder.getSearch(); - this.queryProfileRegistry = queryProfileRegistry; } public void assertFirstPhaseExpression(String expExpression, String rankProfile) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java new file mode 100644 index 00000000000..db3b12db1bf --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java @@ -0,0 +1,104 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; +import static com.yahoo.config.model.test.TestUtil.joinLines; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class RankingExpressionTypeValidatorTestCase { + + @Test + public void tensorFirstPhaseMustProduceDouble() throws Exception { + try { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorSecondPhaseMustProduceDouble() throws Exception { + try { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(attribute(a))", + " }", + " second-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception { + try { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(if(1>0, attribute(a), attribute(b)))", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])", + Exceptions.toMessageString(expected)); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 7246b22b0f8..58af8daf1b5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -51,7 +51,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReference() throws ParseException { + public void testTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -60,7 +60,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithConstantFeature() throws ParseException { + public void testTensorFlowReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "tensorflow('mnist_softmax/saved')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", @@ -71,10 +71,10 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithQueryFeature() throws ParseException { + public void testTensorFlowReferenceWithQueryFeature() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" + + " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -90,7 +90,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithDocumentFeature() throws ParseException { + public void testTensorFlowReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "tensorflow('mnist_softmax/saved')", @@ -103,10 +103,10 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithFeatureCombination() throws ParseException { + public void testTensorFlowReferenceWithFeatureCombination() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" + + " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -122,7 +122,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testNestedTensorFlowReference() throws ParseException { + public void testNestedTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); @@ -131,7 +131,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + public void testTensorFlowReferenceSpecifyingSignature() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved', 'serving_default')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index c18cfcfe1aa..d2211b86c9e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,98 +17,129 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; import java.util.List; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class TensorTransformTestCase extends SearchDefinitionTestCase { @Test public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { - assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)"); - assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); - assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))"); - assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); - assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))"); - assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)"); - assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)"); - assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)"); - assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)"); - assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)"); + assertTransformedExpression("max(1.0,2.0)", + "max(1.0,2.0)"); + assertTransformedExpression("min(attribute(double_field),x)", + "min(attribute(double_field),x)"); + assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", + "max(attribute(double_field),attribute(double_array_field))"); + assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", + "min(attribute(tensor_field_1),attribute(double_field))"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", + "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); + assertTransformedExpression("min(constant(test_constant_tensor),1.0)", + "min(test_constant_tensor,1.0)"); + assertTransformedExpression("max(constant(base_constant_tensor),1.0)", + "max(base_constant_tensor,1.0)"); + assertTransformedExpression("min(constant(file_constant_tensor),1.0)", + "min(constant(file_constant_tensor),1.0)"); + assertTransformedExpression("max(query(q),1.0)", + "max(query(q),1.0)"); + assertTransformedExpression("max(query(n),1.0)", + "max(query(n),1.0)"); } @Test public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { - assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)"); - assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error. - assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)"); + assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", + "max(attribute(tensor_field_1),x)"); + assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", + "1 + max(attribute(tensor_field_1),x)"); + assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)", + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", + "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); + assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", + "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); + assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)", + "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. + assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)", + "max(max(attribute(tensor_field_2),x),y)"); } @Test public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { - assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)"); - assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)"); - assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)"); + assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)", + "max(test_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)", + "max(base_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)", + "min(constant(file_constant_tensor),x)"); } @Test public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { - assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); - assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); - assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); - assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) - assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); + assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)", + "min(attribute(double_field) + attribute(tensor_field_1),x)"); + assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)", + "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); + assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)" + , "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); + assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", + "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + assertTransformedExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", + "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)"); - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)", + "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)", + "max(tensorFromLabels(attribute(double_array_field),x),x)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)", + "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)", + "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); } @Test public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { - assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)"); - assertContainsExpression("max(query(n),x)", "max(query(n),x)"); + assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)"); + assertTransformedExpression("max(query(n),x)", "max(query(n),x)"); } @Test public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException { - assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)"); - assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)"); - assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)"); - assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)", + "max(returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)", + "max(wraps_returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)", + "max(tensor_inheriting,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)", + "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); } - private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { - assertTrue("Expected expression '" + transformedExpression + "' found", - containsExpression(expr, transformedExpression)); - } - - private boolean containsExpression(String expr, String transformedExpression) throws ParseException { - for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) { + private void assertTransformedExpression(String expected, String original) throws ParseException { + for (Pair<String, String> rankPropertyExpression : buildSearch(original)) { String rankProperty = rankPropertyExpression.getFirst(); if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) { String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ","")); - return rankExpression.equals(transformedExpression); + assertEquals(expected, rankExpression); + return; } } - return false; + fail("No 'rankingExpression(firstphase).rankingScript' property produced"); } private List<Pair<String, String>> buildSearch(String expression) throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = setupQueryProfileTypes(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -167,16 +198,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { " }\n" + " }\n" + "}\n"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); return testRankProperties; } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -185,7 +216,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(n)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private String censorBindingHash(String s) { |