aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java133
1 files changed, 51 insertions, 82 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index d2211b86c9e..c18cfcfe1aa 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -17,129 +17,98 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.vespa.model.container.search.QueryProfiles;
import org.junit.Test;
import java.util.List;
-import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
public class TensorTransformTestCase extends SearchDefinitionTestCase {
@Test
public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException {
- assertTransformedExpression("max(1.0,2.0)",
- "max(1.0,2.0)");
- assertTransformedExpression("min(attribute(double_field),x)",
- "min(attribute(double_field),x)");
- assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))",
- "max(attribute(double_field),attribute(double_array_field))");
- assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))",
- "min(attribute(tensor_field_1),attribute(double_field))");
- assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)",
- "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)");
- assertTransformedExpression("min(constant(test_constant_tensor),1.0)",
- "min(test_constant_tensor,1.0)");
- assertTransformedExpression("max(constant(base_constant_tensor),1.0)",
- "max(base_constant_tensor,1.0)");
- assertTransformedExpression("min(constant(file_constant_tensor),1.0)",
- "min(constant(file_constant_tensor),1.0)");
- assertTransformedExpression("max(query(q),1.0)",
- "max(query(q),1.0)");
- assertTransformedExpression("max(query(n),1.0)",
- "max(query(n),1.0)");
+ assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)");
+ assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)");
+ assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))");
+ assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))");
+ assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))");
+ assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)");
+ assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)");
+ assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)");
+ assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)");
+ assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)");
}
@Test
public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException {
- assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)",
- "max(attribute(tensor_field_1),x)");
- assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)",
- "1 + max(attribute(tensor_field_1),x)");
- assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)",
- "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)");
- assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)",
- "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)");
- assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)",
- "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)");
- assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)",
- "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error.
- assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)",
- "max(max(attribute(tensor_field_2),x),y)");
+ assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)");
+ assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)");
+ assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)");
+ assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
+ assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)");
+ assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error.
+ assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)");
}
@Test
public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException {
- assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)",
- "max(test_constant_tensor,x)");
- assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)",
- "max(base_constant_tensor,x)");
- assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)",
- "min(constant(file_constant_tensor),x)");
+ assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)");
+ assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)");
+ assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)");
}
@Test
public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException {
- assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)",
- "min(attribute(double_field) + attribute(tensor_field_1),x)");
- assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)",
- "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)");
- assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"
- , "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)");
- assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)",
- "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
- assertTransformedExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)",
- "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
+ assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)");
+ assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
+ assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
+ assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
+ assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
}
@Test
public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException {
- assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)",
- "max(tensorFromLabels(attribute(double_array_field)),double_array_field)");
- assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)",
- "max(tensorFromLabels(attribute(double_array_field),x),x)");
- assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)",
- "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)");
- assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)",
- "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)");
+ assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)");
+ assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)");
+ assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)");
+ assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)");
}
@Test
public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException {
- assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)");
- assertTransformedExpression("max(query(n),x)", "max(query(n),x)");
+ assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)");
+ assertContainsExpression("max(query(n),x)", "max(query(n),x)");
}
@Test
public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException {
- assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)",
- "max(returns_tensor,x)");
- assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)",
- "max(wraps_returns_tensor,x)");
- assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)",
- "max(tensor_inheriting,x)");
- assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)",
- "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)");
+ assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)");
+ assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)");
+ assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)");
+ assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)");
}
- private void assertTransformedExpression(String expected, String original) throws ParseException {
- for (Pair<String, String> rankPropertyExpression : buildSearch(original)) {
+ private void assertContainsExpression(String expr, String transformedExpression) throws ParseException {
+ assertTrue("Expected expression '" + transformedExpression + "' found",
+ containsExpression(expr, transformedExpression));
+ }
+
+ private boolean containsExpression(String expr, String transformedExpression) throws ParseException {
+ for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) {
String rankProperty = rankPropertyExpression.getFirst();
if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) {
String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ",""));
- assertEquals(expected, rankExpression);
- return;
+ return rankExpression.equals(transformedExpression);
}
}
- fail("No 'rankingExpression(firstphase).rankingScript' property produced");
+ return false;
}
private List<Pair<String, String>> buildSearch(String expression) throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- QueryProfileRegistry queryProfiles = setupQueryProfileTypes();
- SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
+ SearchBuilder builder = new SearchBuilder(rankProfileRegistry);
builder.importString(
"search test {\n" +
" document test { \n" +
@@ -198,16 +167,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
" }\n" +
" }\n" +
"}\n");
- builder.build(new BaseDeployLogger());
+ builder.build(new BaseDeployLogger(), setupQueryProfileTypes());
Search s = builder.getSearch();
- RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles);
+ RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry());
List<Pair<String, String>> testRankProperties = new RawRankProfile(test,
- queryProfiles,
+ new QueryProfileRegistry(),
new AttributeFields(s)).configProperties();
return testRankProperties;
}
- private static QueryProfileRegistry setupQueryProfileTypes() {
+ private static QueryProfiles setupQueryProfileTypes() {
QueryProfileRegistry registry = new QueryProfileRegistry();
QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry();
QueryProfileType type = new QueryProfileType(new ComponentId("testtype"));
@@ -216,7 +185,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
type.addField(new FieldDescription("ranking.features.query(n)",
FieldType.fromString("integer", typeRegistry)), typeRegistry);
typeRegistry.register(type);
- return registry;
+ return new QueryProfiles(registry);
}
private String censorBindingHash(String s) {