summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-03-23 13:17:32 +0100
committerGitHub <noreply@github.com>2023-03-23 13:17:32 +0100
commit51de23a458a1dd7239ebfca4c1dcbe502b1e6a03 (patch)
treed1aa54b1d0b86269a5e4cad092063af0a8c74715 /config-model
parentb74558f7b462faa498eae165ee28fb1e4f6932f9 (diff)
parentde315c066a6dd0b78431230a5035fa519ebac601 (diff)
Merge pull request #26535 from vespa-engine/arnej/use-common-expressionfunction-instance
use ExpressionFunction.Instance for consistency
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java36
-rw-r--r--config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg10
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java2
3 files changed, 33 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index 83ae6048051..f454d941e31 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -1,6 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.expressiontransforms;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
@@ -11,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
@@ -18,7 +21,9 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
+import java.io.StringReader;
import java.util.ArrayList;
+import java.util.ArrayDeque;
import java.util.List;
import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
@@ -207,8 +212,21 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
}
}
- private String lengthFunctionName(ReferenceNode arg) {
- return "__token_length@" + Integer.toHexString(arg.hashCode());
+ private static final ExpressionFunction commonLengthFunction = makeLengthFunction();
+ private static ExpressionFunction makeLengthFunction() {
+ String func = "sum(map(input, f(x)(x > 0)))";
+ String name = "__token_length";
+ try (var r = new StringReader(func)) {
+ return new ExpressionFunction(name, List.of("input"), new RankingExpression(name, r));
+ }
+ catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) {
+ throw new IllegalStateException("unexpected", e);
+ }
+ }
+
+ private ExpressionFunction.Instance lengthFunctionFor(ReferenceNode arg) {
+ var ctx = new SerializationContext();
+ return commonLengthFunction.expand(ctx, List.of(arg), new ArrayDeque<String>());
}
private List<ExpressionNode> createTokenSequence(ReferenceNode feature) {
@@ -229,14 +247,13 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) {
for (int i = 1; i < feature.getArguments().size(); ++i) {
ExpressionNode arg = feature.getArguments().expressions().get(i);
- if ( ! (arg instanceof ReferenceNode)) {
+ if ( ! (arg instanceof ReferenceNode ref)) {
throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " +
"the argument must be a reference. Got " + arg.toString());
}
- ReferenceNode ref = (ReferenceNode) arg;
- String functionName = lengthFunctionName(ref);
- if ( ! context.rankProfile().getFunctions().containsKey(functionName)) {
- context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false);
+ var f = lengthFunctionFor(ref);
+ if ( ! context.rankProfile().getFunctions().containsKey(f.getName())) {
+ context.rankProfile().addFunction(f.getName(), List.of(), f.getExpressionString(), false);
}
}
}
@@ -282,8 +299,9 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
for (int i = 0; i < iter + 1; ++i) {
if (sequence.get(i) instanceof ConstantNode) {
factors.add(ONE);
- } else if (sequence.get(i) instanceof ReferenceNode) {
- factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
+ } else if (sequence.get(i) instanceof ReferenceNode ref) {
+ var f = lengthFunctionFor(ref);
+ factors.add(new ReferenceNode(f.getName()));
}
if (i >= 1) {
operators.add(Operator.plus);
diff --git a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
index be6ec0e3a51..fc2453f3aa9 100644
--- a/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
+++ b/config-model/src/test/derived/globalphase_token_functions/rank-profiles.cfg
@@ -13,20 +13,20 @@ rankprofile[].fef.property[].value "true"
rankprofile[].fef.property[].name "vespa.type.attribute.tokens"
rankprofile[].fef.property[].value "tensor(d0[128])"
rankprofile[].name "using_model"
-rankprofile[].fef.property[].name "rankingExpression(__token_length@3cbfb934).rankingScript"
+rankprofile[].fef.property[].name "rankingExpression(__token_length@4d7c1b66085df918).rankingScript"
rankprofile[].fef.property[].value "reduce(map(query(input), f(x)(x > 0)), sum)"
-rankprofile[].fef.property[].name "rankingExpression(__token_length@cf90db10).rankingScript"
+rankprofile[].fef.property[].name "rankingExpression(__token_length@a16087c578950aea).rankingScript"
rankprofile[].fef.property[].value "reduce(map(attribute(tokens), f(x)(x > 0)), sum)"
rankprofile[].fef.property[].name "rankingExpression(input_ids).rankingScript"
-rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < 1.0, 101.0, if (d1 < 1.0 + rankingExpression(__token_length@3cbfb934), query(input){d0:(d1 - (1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0), 102.0, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10)), attribute(tokens){d0:(d1 - (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10) + 1.0), 102.0, 0.0)))))))"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < 1.0, 101.0, if (d1 < 1.0 + rankingExpression(__token_length@4d7c1b66085df918), query(input){d0:(d1 - (1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0), 102.0, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea)), attribute(tokens){d0:(d1 - (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0))}, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea) + 1.0), 102.0, 0.0)))))))"
rankprofile[].fef.property[].name "rankingExpression(input_ids).type"
rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
rankprofile[].fef.property[].name "rankingExpression(token_type_ids).rankingScript"
-rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10) + 1.0), 1.0, 0.0))))"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea) + 1.0), 1.0, 0.0))))"
rankprofile[].fef.property[].name "rankingExpression(token_type_ids).type"
rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
rankprofile[].fef.property[].name "rankingExpression(attention_mask).rankingScript"
-rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@3cbfb934) + 1.0 + rankingExpression(__token_length@cf90db10) + 1.0), 1.0, 0.0)))"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])((if (d1 < (1.0 + rankingExpression(__token_length@4d7c1b66085df918) + 1.0 + rankingExpression(__token_length@a16087c578950aea) + 1.0), 1.0, 0.0)))"
rankprofile[].fef.property[].name "rankingExpression(attention_mask).type"
rankprofile[].fef.property[].value "tensor<float>(d0[1],d1[128])"
rankprofile[].fef.property[].name "vespa.rank.globalphase"
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
index 69c68124908..b094130ed0f 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@892e3154) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@892e3154) + 1.0 + rankingExpression(__token_length@892e3154) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0 + rankingExpression(__token_length@a2e4b6abdeb5fb3a) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());