diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-22 14:18:15 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-22 14:18:15 +0000 |
commit | de315c066a6dd0b78431230a5035fa519ebac601 (patch) | |
tree | f986435917ca55ba9db95c7dfddd9dfef11fa9d8 /config-model/src | |
parent | 39ef03dbfea4665dbe06b187d8d9d19d65e6d660 (diff) |
use ExpressionFunction.Instance for consistency
Diffstat (limited to 'config-model/src')
3 files changed, 33 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index 83ae6048051..f454d941e31 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.OperationNode; @@ -11,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.TensorType; @@ -18,7 +21,9 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Slice; import com.yahoo.tensor.functions.TensorFunction; +import java.io.StringReader; import java.util.ArrayList; +import java.util.ArrayDeque; import java.util.List; import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; @@ -207,8 +212,21 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform } } - private String lengthFunctionName(ReferenceNode arg) { - return "__token_length@" + Integer.toHexString(arg.hashCode()); + private static final ExpressionFunction commonLengthFunction = makeLengthFunction(); + private static ExpressionFunction makeLengthFunction() { + String func = "sum(map(input, f(x)(x > 0)))"; + String name = "__token_length"; + try (var r = new StringReader(func)) { + return new ExpressionFunction(name, List.of("input"), new RankingExpression(name, r)); + } + catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) { + throw new IllegalStateException("unexpected", e); + } + } + + private ExpressionFunction.Instance lengthFunctionFor(ReferenceNode arg) { + var ctx = new SerializationContext(); + return commonLengthFunction.expand(ctx, List.of(arg), new ArrayDeque<String>()); } private List<ExpressionNode> createTokenSequence(ReferenceNode feature) { @@ -229,14 +247,13 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) { for (int i = 1; i < feature.getArguments().size(); ++i) { ExpressionNode arg = feature.getArguments().expressions().get(i); - if ( ! (arg instanceof ReferenceNode)) { + if ( ! (arg instanceof ReferenceNode ref)) { throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " + "the argument must be a reference. Got " + arg.toString()); } - ReferenceNode ref = (ReferenceNode) arg; - String functionName = lengthFunctionName(ref); - if ( ! context.rankProfile().getFunctions().containsKey(functionName)) { - context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false); + var f = lengthFunctionFor(ref); + if ( ! context.rankProfile().getFunctions().containsKey(f.getName())) { + context.rankProfile().addFunction(f.getName(), List.of(), f.getExpressionString(), false); } } } @@ -282,8 +299,9 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform for (int i = 0; i < iter + 1; ++i) { if (sequence.get(i) instanceof ConstantNode) { factors.add(ONE); - } else if (sequence.get(i) instanceof ReferenceNode) { - factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i)))); + } else if (sequence.get(i) instanceof ReferenceNode ref) { + var f = lengthFunctionFor(ref); + factors.add(new ReferenceNode(f.getName())); } if (i >= 1) { operators.add(Operator.plus); diff --git a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg index be6ec0e3a51..fc2453f3aa9 100644 --- a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg +++ b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg @@ -13,20 +13,20 @@ rankprofile[].fef.property[].value "true" rankprofile[].fef.property[].name "vespa.type.attribute.tokens" rankprofile[].fef.property[].value "tensor(d0[128])" rankprofile[].name "using_model" -rankprofile[].fef.property[].name "rankingExpression(__token_length@3cbfb934).rankingScript" +rankprofile[].fef.property[].name "rankingExpression(__token_length@4d7c1b66085df918).rankingScript" rankprofile[].fef.property[].value "reduce(map(query(input), f(x)(x > 0)), sum)" -rankprofile[].fef.property[].name "rankingExpression(__token_length@cf90db10).rankingScript" +rankprofile[].fef.property[].name "rankingExpression(__token_length@a16087c578950aea).rankingScript" rankprofile[].fef.property[].value "reduce(map(attribute(tokens), f(x)(x > 0)), sum)" rankprofile[].fef.property[].name "rankingExpression(input_ids).rankingScript" -rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < 1.0, 101.0, if (d1 < 1.0 + rankingExpression(__token_length@3cbfb934), query(input){d0:(d1 - (1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0), 102.0, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10)), attribute(tokens){d0:(d1 - (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10) + 1.0), 102.0, 0.0)))))))" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < 1.0, 101.0, if (d1 < 1.0 + rankingExpression(__token_length@4d7c1b66085df918), query(input){d0:(d1 - (1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0), 102.0, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea)), attribute(tokens){d0:(d1 - (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea) + 1.0), 102.0, 0.0)))))))" rankprofile[].fef.property[].name "rankingExpression(input_ids).type" rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" rankprofile[].fef.property[].name "rankingExpression(token_type_ids).rankingScript" -rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10) + 1.0), 1.0, 0.0))))" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea) + 1.0), 1.0, 0.0))))" rankprofile[].fef.property[].name "rankingExpression(token_type_ids).type" rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" rankprofile[].fef.property[].name "rankingExpression(attention_mask).rankingScript" -rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10) + 1.0), 1.0, 0.0)))" +rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea) + 1.0), 1.0, 0.0)))" rankprofile[].fef.property[].name "rankingExpression(attention_mask).type" rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])" rankprofile[].fef.property[].name "vespa.rank.globalphase" diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java index 69c68124908..b094130ed0f 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); - assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@892e3154) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@892e3154) + 1.0 + rankingExpression(__token_length@892e3154) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); + assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); |