diff options
author | Lester Solbakken <lesters@oath.com> | 2021-01-06 15:19:17 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-01-06 15:19:17 +0100 |
commit | 3578f2b70312e681b11db97e6ead8997e2dd7d3c (patch) | |
tree | 3368bbe2199393570c0554527da720c722bc86e0 /config-model | |
parent | 26540a4ca27bc8fcda64815988f17a6b3b9bb6ef (diff) |
Allow expressions as arguments
Diffstat (limited to 'config-model')
2 files changed, 114 insertions, 8 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java new file mode 100644 index 00000000000..14228968161 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java @@ -0,0 +1,108 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import com.yahoo.collections.Pair; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RawRankProfile; +import com.yahoo.searchdefinition.parser.ParseException; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class RankingExpressionFeatureArgumentsTestCase extends SchemaTestCase { + + @Test + public void testFeatureWithExpressionArguments() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field t1 type tensor<float>(x{}) { \n" + + " indexing: attribute | summary \n" + + " }\n" + + " field t2 type tensor<float>(x{}) { \n" + + " indexing: attribute | summary \n" + + " }\n" + + " }\n" + + " rank-profile test {\n" + + " function my_func(t) {\n" + + " expression: sum(t, x) \n" + + " }\n" + + " function eval_func() {\n" + + " expression: my_func( attribute(t1) ) \n" + + " }\n" + + " function eval_func_with_expr() {\n" + + " expression: my_func( attribute(t1) * attribute(t2) ) \n" + + " }\n" + + " function eval_func_with_expr_2() {\n" + + " expression: my_func( attribute(t1){x:0} ) \n" + + " }\n" + + " function eval_func_via_func_with_expr() {\n" + + " expression: call_func_with_expr( attribute(t1), attribute(t2) ) \n" + + " }\n" + + " function call_func_with_expr(a, b) {\n" + + " expression: my_func( a * b ) \n" + + " }\n" + + " first-phase {\n" + + " expression: 42 \n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + builder.build(); + Search s = builder.getSearch(); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); + List<Pair<String, String>> testRankProperties = new RawRankProfile(test, + new QueryProfileRegistry(), + new ImportedMlModels(), + new AttributeFields(s)).configProperties(); + + for(Pair<String,String> prop : testRankProperties) { + System.out.println(prop); + } + + assertEquals("(rankingExpression(my_func).rankingScript, reduce(t, sum, x))", + testRankProperties.get(0).toString()); + + // eval_func + assertEquals("(rankingExpression(eval_func).rankingScript, rankingExpression(my_func@9bbaee2bad5a2fc0))", + testRankProperties.get(2).toString()); + assertEquals("(rankingExpression(my_func@9bbaee2bad5a2fc0).rankingScript, reduce(attribute(t1), sum, x))", + testRankProperties.get(1).toString()); + + // The following functions should generate features to evaluate the expression argument before passing to my_func + + // eval_func_with_expr + assertEquals("(rankingExpression(eval_func_with_expr).rankingScript, rankingExpression(my_func@45673ba956ae9b77))", + testRankProperties.get(5).toString()); + assertEquals("(rankingExpression(my_func@45673ba956ae9b77).rankingScript, reduce(autogenerated_ranking_feature@43bc412603c00a4a, sum, x))", + testRankProperties.get(4).toString()); + assertEquals("(rankingExpression(autogenerated_ranking_feature@43bc412603c00a4a).rankingScript, attribute(t1) * attribute(t2))", + testRankProperties.get(3).toString()); + + // eval_func_with_expr_2 + assertEquals("(rankingExpression(eval_func_with_expr_2).rankingScript, rankingExpression(my_func@2192533eaad2293d))", + testRankProperties.get(8).toString()); + assertEquals("(rankingExpression(my_func@2192533eaad2293d).rankingScript, reduce(autogenerated_ranking_feature@71a4196136b577cf, sum, x))", + testRankProperties.get(7).toString()); + assertEquals("(rankingExpression(autogenerated_ranking_feature@71a4196136b577cf).rankingScript, attribute(t1){x:0})", + testRankProperties.get(6).toString()); + + // eval_func_via_func_with_expr + assertEquals("(rankingExpression(eval_func_via_func_with_expr).rankingScript, rankingExpression(call_func_with_expr@640470df47a83000.c156faa8f98c0b0c))", + testRankProperties.get(10).toString()); + assertEquals("(rankingExpression(call_func_with_expr@640470df47a83000.c156faa8f98c0b0c).rankingScript, rankingExpression(my_func@45673ba956ae9b77))", + testRankProperties.get(9).toString()); + // my_func@45673ba956ae9b77 is the same as under eval_func_with_expr + + } + +}
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 84a6d2a154a..20182c89a8c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -207,22 +207,20 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { queryProfiles, new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", + assertEquals("(rankingExpression(autogenerated_ranking_feature@).rankingScript, reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input))", censorBindingHash(testRankProperties.get(0).toString())); - assertEquals("(rankingExpression(hidden_layer).rankingScript, rankingExpression(relu@))", + assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,autogenerated_ranking_feature@))", censorBindingHash(testRankProperties.get(1).toString())); - assertEquals("(rankingExpression(hidden_layer).type, tensor(x[1]))", + assertEquals("(rankingExpression(hidden_layer).rankingScript, rankingExpression(relu@))", censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(final_layer).rankingScript, sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", - testRankProperties.get(3).toString()); - assertEquals("(rankingExpression(final_layer).type, tensor(x[1]))", testRankProperties.get(4).toString()); assertEquals("(rankingExpression(relu).rankingScript, max(1.0,x))", - testRankProperties.get(5).toString()); + testRankProperties.get(6).toString()); assertEquals("(vespa.rank.secondphase, rankingExpression(secondphase))", - testRankProperties.get(6).toString()); + testRankProperties.get(7).toString()); assertEquals("(rankingExpression(secondphase).rankingScript, reduce(rankingExpression(final_layer), sum))", - testRankProperties.get(7).toString()); + testRankProperties.get(8).toString()); } private QueryProfileRegistry queryProfileWith(String field, String type) { |