diff options
5 files changed, 128 insertions, 31 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java new file mode 100644 index 00000000000..14228968161 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java @@ -0,0 +1,108 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import com.yahoo.collections.Pair; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RawRankProfile; +import com.yahoo.searchdefinition.parser.ParseException; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class RankingExpressionFeatureArgumentsTestCase extends SchemaTestCase { + + @Test + public void testFeatureWithExpressionArguments() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field t1 type tensor<float>(x{}) { \n" + + " indexing: attribute | summary \n" + + " }\n" + + " field t2 type tensor<float>(x{}) { \n" + + " indexing: attribute | summary \n" + + " }\n" + + " }\n" + + " rank-profile test {\n" + + " function my_func(t) {\n" + + " expression: sum(t, x) \n" + + " }\n" + + " function eval_func() {\n" + + " expression: my_func( attribute(t1) ) \n" + + " }\n" + + " function eval_func_with_expr() {\n" + + " expression: my_func( attribute(t1) * attribute(t2) ) \n" + + " }\n" + + " function eval_func_with_expr_2() {\n" + + " expression: my_func( attribute(t1){x:0} ) \n" + + " }\n" + + " function eval_func_via_func_with_expr() {\n" + + " expression: call_func_with_expr( attribute(t1), attribute(t2) ) \n" + + " }\n" + + " function call_func_with_expr(a, b) {\n" + + " expression: my_func( a * b ) \n" + + " }\n" + + " first-phase {\n" + + " expression: 42 \n" + + " }\n" + + " }\n" + + "\n" + + "}\n"); + builder.build(); + Search s = builder.getSearch(); + RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); + List<Pair<String, String>> testRankProperties = new RawRankProfile(test, + new QueryProfileRegistry(), + new ImportedMlModels(), + new AttributeFields(s)).configProperties(); + + for(Pair<String,String> prop : testRankProperties) { + System.out.println(prop); + } + + assertEquals("(rankingExpression(my_func).rankingScript, reduce(t, sum, x))", + testRankProperties.get(0).toString()); + + // eval_func + assertEquals("(rankingExpression(eval_func).rankingScript, rankingExpression(my_func@9bbaee2bad5a2fc0))", + testRankProperties.get(2).toString()); + assertEquals("(rankingExpression(my_func@9bbaee2bad5a2fc0).rankingScript, reduce(attribute(t1), sum, x))", + testRankProperties.get(1).toString()); + + // The following functions should generate features to evaluate the expression argument before passing to my_func + + // eval_func_with_expr + assertEquals("(rankingExpression(eval_func_with_expr).rankingScript, rankingExpression(my_func@45673ba956ae9b77))", + testRankProperties.get(5).toString()); + assertEquals("(rankingExpression(my_func@45673ba956ae9b77).rankingScript, reduce(autogenerated_ranking_feature@43bc412603c00a4a, sum, x))", + testRankProperties.get(4).toString()); + assertEquals("(rankingExpression(autogenerated_ranking_feature@43bc412603c00a4a).rankingScript, attribute(t1) * attribute(t2))", + testRankProperties.get(3).toString()); + + // eval_func_with_expr_2 + assertEquals("(rankingExpression(eval_func_with_expr_2).rankingScript, rankingExpression(my_func@2192533eaad2293d))", + testRankProperties.get(8).toString()); + assertEquals("(rankingExpression(my_func@2192533eaad2293d).rankingScript, reduce(autogenerated_ranking_feature@71a4196136b577cf, sum, x))", + testRankProperties.get(7).toString()); + assertEquals("(rankingExpression(autogenerated_ranking_feature@71a4196136b577cf).rankingScript, attribute(t1){x:0})", + testRankProperties.get(6).toString()); + + // eval_func_via_func_with_expr + assertEquals("(rankingExpression(eval_func_via_func_with_expr).rankingScript, rankingExpression(call_func_with_expr@640470df47a83000.c156faa8f98c0b0c))", + testRankProperties.get(10).toString()); + assertEquals("(rankingExpression(call_func_with_expr@640470df47a83000.c156faa8f98c0b0c).rankingScript, rankingExpression(my_func@45673ba956ae9b77))", + testRankProperties.get(9).toString()); + // my_func@45673ba956ae9b77 is the same as under eval_func_with_expr + + } + +}
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 84a6d2a154a..20182c89a8c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -207,22 +207,20 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { queryProfiles, new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", + assertEquals("(rankingExpression(autogenerated_ranking_feature@).rankingScript, reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input))", censorBindingHash(testRankProperties.get(0).toString())); - assertEquals("(rankingExpression(hidden_layer).rankingScript, rankingExpression(relu@))", + assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,autogenerated_ranking_feature@))", censorBindingHash(testRankProperties.get(1).toString())); - assertEquals("(rankingExpression(hidden_layer).type, tensor(x[1]))", + assertEquals("(rankingExpression(hidden_layer).rankingScript, rankingExpression(relu@))", censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(final_layer).rankingScript, sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", - testRankProperties.get(3).toString()); - assertEquals("(rankingExpression(final_layer).type, tensor(x[1]))", testRankProperties.get(4).toString()); assertEquals("(rankingExpression(relu).rankingScript, max(1.0,x))", - testRankProperties.get(5).toString()); + testRankProperties.get(6).toString()); assertEquals("(vespa.rank.secondphase, rankingExpression(secondphase))", - testRankProperties.get(6).toString()); + testRankProperties.get(7).toString()); assertEquals("(rankingExpression(secondphase).rankingScript, reduce(rankingExpression(final_layer), sum))", - testRankProperties.get(7).toString()); + testRankProperties.get(8).toString()); } private QueryProfileRegistry queryProfileWith(String field, String type) { diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 88eccb4559f..d412f408350 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -875,7 +875,6 @@ "public final java.lang.String outs()", "public final java.lang.String out()", "public final java.util.List args()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arg()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode function()", "public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 9f900ffed36..b97c8316c9b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -3,7 +3,10 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; @@ -131,7 +134,16 @@ public class ExpressionFunction { public Instance expand(SerializationContext context, List<ExpressionNode> argumentValues, Deque<String> path) { Map<String, String> argumentBindings = new HashMap<>(); for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) { - argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString()); + String key = arguments.get(i); + ExpressionNode expr = argumentValues.get(i); + String binding = expr.toString(new StringBuilder(), context, path, null).toString(); + + if ( ! (expr instanceof ReferenceNode) && ! (expr instanceof ConstantNode) && ! (expr instanceof FunctionNode) ) { + String funcName = "autogenerated_ranking_feature@" + Long.toHexString(symbolCode(key + "=" + binding)); + context.addFunctionSerialization(RankingExpression.propertyName(funcName), binding); + binding = funcName; + } + argumentBindings.put(key, binding); } context = argumentBindings.isEmpty() ? context.withoutBindings() : context.withBindings(argumentBindings); String symbol = toSymbol(argumentBindings); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 09880b8dfc3..36b1f9627bb 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -328,30 +328,10 @@ List<ExpressionNode> args() : ExpressionNode argument; } { - ( ( argument = arg() { arguments.add(argument); } ( <COMMA> argument = arg() { arguments.add(argument); } )* )? ) + ( ( argument = expression() { arguments.add(argument); } ( <COMMA> argument = expression() { arguments.add(argument); } )* )? ) { return arguments; } } -// TODO: Replace use of this for function arguments with value() -// For that to work with the current search execution framework -// we need to generate another function for the argument such that we can replace -// instances of the argument with the reference to that function in the same way -// as we replace by constants/names today (this can make for some fun combinatorial explosion). -// We should also stop doing function expansion in the toString of a function. -// - Jon 2014-05-02 -ExpressionNode arg() : -{ - ExpressionNode ret; - String name; - Function fnc; -} -{ - ( ret = constantPrimitive() | - LOOKAHEAD(2) ret = feature() | - name = identifier() { ret = new NameNode(name); } ) - { return ret; } -} - ExpressionNode function() : { ExpressionNode function; |