From a23fc5e8d4e9ef0f737041f6d4f2ebc50b38c40b Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 6 Feb 2018 15:05:05 +0100 Subject: Type check all expressions --- .../searchdefinition/RankProfileTestCase.java | 8 ++--- .../RankingExpressionConstantsTestCase.java | 3 ++ .../RankingExpressionShadowingTestCase.java | 40 ++++++++++++++++++--- .../processing/RankProfileSearchFixture.java | 4 +-- .../RankingExpressionTypeValidatorTestCase.java | 42 ++++++++++++++++++++++ .../RankingExpressionWithTensorFlowTestCase.java | 18 +++++----- .../processing/TensorTransformTestCase.java | 13 +++---- 7 files changed, 102 insertions(+), 26 deletions(-) create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java (limited to 'config-model/src/test/java/com') diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 442c8bd41bd..11093d9f008 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { @Test public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException { RankProfileRegistry registry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(registry); + SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes()); builder.importString("search test {\n" + " document test { } \n" + " rank-profile p1 {}\n" + " rank-profile p2 {}\n" + "}"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search search = builder.getSearch(); assertEquals(4, registry.allRankProfiles().size()); @@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search); } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(numeric)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index e94880e61c7..82b9f5ac043 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.importString( "search test {\n" + " document test { \n" + + " field rating_yelp type int {" + + " indexing: attribute" + + " }" + " }\n" + " \n" + " rank-profile test {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 5100ac15c40..ed1b00e2875 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -2,7 +2,10 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; @@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase censorBindingHash(testRankProperties.get(4).toString())); } - @Test public void testNeuralNetworkSetup() throws ParseException { + // Note: the type assigned to query profile and constant tensors here is not the correct type RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])"); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase " expression: sum(final_layer)\n" + " }\n" + " }\n" + - "\n" + + " constant W_hidden {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_input {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant W_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); @@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase testRankProperties.get(5).toString()); } + private QueryProfileRegistry queryProfileWith(String field, String type) { + QueryProfileType queryProfileType = new QueryProfileType("root"); + queryProfileType.addField(new FieldDescription(field, type)); + QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry(); + queryProfileRegistry.getTypeRegistry().register(queryProfileType); + QueryProfile profile = new QueryProfile("default"); + profile.setType(queryProfileType); + queryProfileRegistry.register(profile); + return queryProfileRegistry; + } + private String censorBindingHash(String s) { StringBuilder b = new StringBuilder(); boolean areInHash = false; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 800697b3430..0ce6129ef7f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -38,7 +38,8 @@ class RankProfileSearchFixture { RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry, String rankProfiles, String constant, String field) throws ParseException { - SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry()); + this.queryProfileRegistry = queryProfileRegistry; + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry); String sdContent = "search test {\n" + " " + (constant != null ? constant : "") + "\n" + " document test {\n" + @@ -50,7 +51,6 @@ class RankProfileSearchFixture { builder.importString(sdContent); builder.build(); search = builder.getSearch(); - this.queryProfileRegistry = queryProfileRegistry; } public void assertFirstPhaseExpression(String expExpression, String rankProfile) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java new file mode 100644 index 00000000000..5c654f09c51 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java @@ -0,0 +1,42 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; +import static com.yahoo.config.model.test.TestUtil.joinLines; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class RankingExpressionTypeValidatorTestCase { + + @Test + public void tensorTypeValidation() throws Exception { + try { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 5203e686681..464772fc10d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -51,7 +51,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReference() throws ParseException { + public void testTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -60,7 +60,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithConstantFeature() throws ParseException { + public void testTensorFlowReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "tensorflow('mnist_softmax/saved')", "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", @@ -71,10 +71,10 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithQueryFeature() throws ParseException { + public void testTensorFlowReferenceWithQueryFeature() { String queryProfile = ""; String queryProfileType = "" + - " " + + " " + ""; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -90,7 +90,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithDocumentFeature() throws ParseException { + public void testTensorFlowReferenceWithDocumentFeature() { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "tensorflow('mnist_softmax/saved')", @@ -103,10 +103,10 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceWithFeatureCombination() throws ParseException { + public void testTensorFlowReferenceWithFeatureCombination() { String queryProfile = ""; String queryProfileType = "" + - " " + + " " + ""; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -122,7 +122,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testNestedTensorFlowReference() throws ParseException { + public void testNestedTensorFlowReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); @@ -131,7 +131,7 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { + public void testTensorFlowReferenceSpecifyingSignature() { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved', 'serving_default')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index c18cfcfe1aa..c1d987ef3ad 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -108,7 +108,8 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { private List> buildSearch(String expression) throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = setupQueryProfileTypes(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -167,16 +168,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { " }\n" + " }\n" + "}\n"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); return testRankProperties; } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -185,7 +186,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(n)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private String censorBindingHash(String s) { -- cgit v1.2.3 From 063b679c7cac060c44121a2ee7ce5a5d4b81849b Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 6 Feb 2018 15:25:28 +0100 Subject: Merge with master & fix test --- .../processing/TensorTransformTestCase.java | 120 +++++++++++++-------- 1 file changed, 75 insertions(+), 45 deletions(-) (limited to 'config-model/src/test/java/com') diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index c1d987ef3ad..d2211b86c9e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,93 +17,123 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; import java.util.List; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class TensorTransformTestCase extends SearchDefinitionTestCase { @Test public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { - assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)"); - assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); - assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))"); - assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); - assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))"); - assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)"); - assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)"); - assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)"); - assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)"); - assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)"); + assertTransformedExpression("max(1.0,2.0)", + "max(1.0,2.0)"); + assertTransformedExpression("min(attribute(double_field),x)", + "min(attribute(double_field),x)"); + assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", + "max(attribute(double_field),attribute(double_array_field))"); + assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", + "min(attribute(tensor_field_1),attribute(double_field))"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", + "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); + assertTransformedExpression("min(constant(test_constant_tensor),1.0)", + "min(test_constant_tensor,1.0)"); + assertTransformedExpression("max(constant(base_constant_tensor),1.0)", + "max(base_constant_tensor,1.0)"); + assertTransformedExpression("min(constant(file_constant_tensor),1.0)", + "min(constant(file_constant_tensor),1.0)"); + assertTransformedExpression("max(query(q),1.0)", + "max(query(q),1.0)"); + assertTransformedExpression("max(query(n),1.0)", + "max(query(n),1.0)"); } @Test public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { - assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)"); - assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error. - assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)"); + assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", + "max(attribute(tensor_field_1),x)"); + assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", + "1 + max(attribute(tensor_field_1),x)"); + assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)", + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", + "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); + assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", + "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); + assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)", + "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. + assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)", + "max(max(attribute(tensor_field_2),x),y)"); } @Test public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { - assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)"); - assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)"); - assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)"); + assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)", + "max(test_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)", + "max(base_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)", + "min(constant(file_constant_tensor),x)"); } @Test public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { - assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); - assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); - assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); - assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) - assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); + assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)", + "min(attribute(double_field) + attribute(tensor_field_1),x)"); + assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)", + "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); + assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)" + , "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); + assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", + "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + assertTransformedExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", + "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)"); - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)", + "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)", + "max(tensorFromLabels(attribute(double_array_field),x),x)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)", + "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)", + "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); } @Test public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { - assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)"); - assertContainsExpression("max(query(n),x)", "max(query(n),x)"); + assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)"); + assertTransformedExpression("max(query(n),x)", "max(query(n),x)"); } @Test public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException { - assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)"); - assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)"); - assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)"); - assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)", + "max(returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)", + "max(wraps_returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)", + "max(tensor_inheriting,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)", + "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); } - private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { - assertTrue("Expected expression '" + transformedExpression + "' found", - containsExpression(expr, transformedExpression)); - } - - private boolean containsExpression(String expr, String transformedExpression) throws ParseException { - for (Pair rankPropertyExpression : buildSearch(expr)) { + private void assertTransformedExpression(String expected, String original) throws ParseException { + for (Pair rankPropertyExpression : buildSearch(original)) { String rankProperty = rankPropertyExpression.getFirst(); if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) { String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ","")); - return rankExpression.equals(transformedExpression); + assertEquals(expected, rankExpression); + return; } } - return false; + fail("No 'rankingExpression(firstphase).rankingScript' property produced"); } private List> buildSearch(String expression) throws ParseException { -- cgit v1.2.3 From cb18a346ddb89b604251564f59968d4d62b065a3 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 6 Feb 2018 15:59:14 +0100 Subject: Cleanup --- .../processing/RankingExpressionTypeValidator.java | 3 - .../RankingExpressionTypeValidatorTestCase.java | 64 +++++++++++++++++++++- 2 files changed, 63 insertions(+), 4 deletions(-) (limited to 'config-model/src/test/java/com') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java index 494d8d56161..a7a5ad58430 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java @@ -39,8 +39,6 @@ public class RankingExpressionTypeValidator extends Processor { /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */ private void validate(RankProfile profile) { profile.parseExpressions(); - System.out.println("Type checking " + profile + ":"); - System.out.println(" First-phase: " + profile.getFirstPhaseRanking()); TypeContext context = profile.typeContext(queryProfiles); for (RankProfile.Macro macro : profile.getMacros().values()) ensureValid(macro.getRankingExpression(), "macro '" + macro.getName() + "'", context); @@ -58,7 +56,6 @@ public class RankingExpressionTypeValidator extends Processor { catch (IllegalArgumentException e) { throw new IllegalArgumentException("The " + expressionDescription + " is invalid", e); } - System.out.println(" Type of " + expressionDescription + " " + expression.getRoot() + ": " + type); if (type == null) // Not expected to happen throw new IllegalStateException("Could not determine the type produced by " + expressionDescription); return type; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java index 5c654f09c51..db3b12db1bf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java @@ -12,7 +12,7 @@ import static org.junit.Assert.fail; public class RankingExpressionTypeValidatorTestCase { @Test - public void tensorTypeValidation() throws Exception { + public void tensorFirstPhaseMustProduceDouble() throws Exception { try { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); @@ -39,4 +39,66 @@ public class RankingExpressionTypeValidatorTestCase { } } + @Test + public void tensorSecondPhaseMustProduceDouble() throws Exception { + try { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(attribute(a))", + " }", + " second-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception { + try { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder searchBuilder = new SearchBuilder(rankProfileRegistry); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(if(1>0, attribute(a), attribute(b)))", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])", + Exceptions.toMessageString(expected)); + } + } + } -- cgit v1.2.3