diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java | 133 |
1 files changed, 51 insertions, 82 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index d2211b86c9e..c18cfcfe1aa 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,129 +17,98 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; import java.util.List; -import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; public class TensorTransformTestCase extends SearchDefinitionTestCase { @Test public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { - assertTransformedExpression("max(1.0,2.0)", - "max(1.0,2.0)"); - assertTransformedExpression("min(attribute(double_field),x)", - "min(attribute(double_field),x)"); - assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", - "max(attribute(double_field),attribute(double_array_field))"); - assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", - "min(attribute(tensor_field_1),attribute(double_field))"); - assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", - "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); - assertTransformedExpression("min(constant(test_constant_tensor),1.0)", - "min(test_constant_tensor,1.0)"); - assertTransformedExpression("max(constant(base_constant_tensor),1.0)", - "max(base_constant_tensor,1.0)"); - assertTransformedExpression("min(constant(file_constant_tensor),1.0)", - "min(constant(file_constant_tensor),1.0)"); - assertTransformedExpression("max(query(q),1.0)", - "max(query(q),1.0)"); - assertTransformedExpression("max(query(n),1.0)", - "max(query(n),1.0)"); + assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)"); + assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); + assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))"); + assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); + assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))"); + assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)"); + assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)"); + assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)"); + assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)"); + assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)"); } @Test public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { - assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", - "max(attribute(tensor_field_1),x)"); - assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", - "1 + max(attribute(tensor_field_1),x)"); - assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)", - "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)"); - assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", - "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); - assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", - "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); - assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)", - "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. - assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)", - "max(max(attribute(tensor_field_2),x),y)"); + assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)"); + assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)"); + assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)"); + assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); + assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); + assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error. + assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)"); } @Test public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { - assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)", - "max(test_constant_tensor,x)"); - assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)", - "max(base_constant_tensor,x)"); - assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)", - "min(constant(file_constant_tensor),x)"); + assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)"); + assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)"); + assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)"); } @Test public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { - assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)", - "min(attribute(double_field) + attribute(tensor_field_1),x)"); - assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)", - "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); - assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)" - , "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); - assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", - "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) - assertTransformedExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", - "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); + assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); + assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); + assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); + assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { - assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)", - "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); - assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)", - "max(tensorFromLabels(attribute(double_array_field),x),x)"); - assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)", - "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); - assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)", - "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); + assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)"); + assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)"); + assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)"); + assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)"); } @Test public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { - assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)"); - assertTransformedExpression("max(query(n),x)", "max(query(n),x)"); + assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)"); + assertContainsExpression("max(query(n),x)", "max(query(n),x)"); } @Test public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException { - assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)", - "max(returns_tensor,x)"); - assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)", - "max(wraps_returns_tensor,x)"); - assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)", - "max(tensor_inheriting,x)"); - assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)", - "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); + assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)"); + assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)"); + assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)"); + assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)"); } - private void assertTransformedExpression(String expected, String original) throws ParseException { - for (Pair<String, String> rankPropertyExpression : buildSearch(original)) { + private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { + assertTrue("Expected expression '" + transformedExpression + "' found", + containsExpression(expr, transformedExpression)); + } + + private boolean containsExpression(String expr, String transformedExpression) throws ParseException { + for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) { String rankProperty = rankPropertyExpression.getFirst(); if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) { String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ","")); - assertEquals(expected, rankExpression); - return; + return rankExpression.equals(transformedExpression); } } - fail("No 'rankingExpression(firstphase).rankingScript' property produced"); + return false; } private List<Pair<String, String>> buildSearch(String expression) throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - QueryProfileRegistry queryProfiles = setupQueryProfileTypes(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); builder.importString( "search test {\n" + " document test { \n" + @@ -198,16 +167,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { " }\n" + " }\n" + "}\n"); - builder.build(new BaseDeployLogger()); + builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - queryProfiles, + new QueryProfileRegistry(), new AttributeFields(s)).configProperties(); return testRankProperties; } - private static QueryProfileRegistry setupQueryProfileTypes() { + private static QueryProfiles setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -216,7 +185,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(n)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return registry; + return new QueryProfiles(registry); } private String censorBindingHash(String s) { |