diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-22 14:18:15 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-22 14:18:15 +0000 |
commit | de315c066a6dd0b78431230a5035fa519ebac601 (patch) | |
tree | f986435917ca55ba9db95c7dfddd9dfef11fa9d8 /config-model/src/main | |
parent | 39ef03dbfea4665dbe06b187d8d9d19d65e6d660 (diff) |
use ExpressionFunction.Instance for consistency
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java | 36 |
1 files changed, 27 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index 83ae6048051..f454d941e31 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema.expressiontransforms; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.OperationNode; @@ -11,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.TensorType; @@ -18,7 +21,9 @@ import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Slice; import com.yahoo.tensor.functions.TensorFunction; +import java.io.StringReader; import java.util.ArrayList; +import java.util.ArrayDeque; import java.util.List; import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; @@ -207,8 +212,21 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform } } - private String lengthFunctionName(ReferenceNode arg) { - return "__token_length@" + Integer.toHexString(arg.hashCode()); + private static final ExpressionFunction commonLengthFunction = makeLengthFunction(); + private static ExpressionFunction makeLengthFunction() { + String func = "sum(map(input, f(x)(x > 0)))"; + String name = "__token_length"; + try (var r = new StringReader(func)) { + return new ExpressionFunction(name, List.of("input"), new RankingExpression(name, r)); + } + catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) { + throw new IllegalStateException("unexpected", e); + } + } + + private ExpressionFunction.Instance lengthFunctionFor(ReferenceNode arg) { + var ctx = new SerializationContext(); + return commonLengthFunction.expand(ctx, List.of(arg), new ArrayDeque<String>()); } private List<ExpressionNode> createTokenSequence(ReferenceNode feature) { @@ -229,14 +247,13 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) { for (int i = 1; i < feature.getArguments().size(); ++i) { ExpressionNode arg = feature.getArguments().expressions().get(i); - if ( ! (arg instanceof ReferenceNode)) { + if ( ! (arg instanceof ReferenceNode ref)) { throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " + "the argument must be a reference. Got " + arg.toString()); } - ReferenceNode ref = (ReferenceNode) arg; - String functionName = lengthFunctionName(ref); - if ( ! context.rankProfile().getFunctions().containsKey(functionName)) { - context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false); + var f = lengthFunctionFor(ref); + if ( ! context.rankProfile().getFunctions().containsKey(f.getName())) { + context.rankProfile().addFunction(f.getName(), List.of(), f.getExpressionString(), false); } } } @@ -282,8 +299,9 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform for (int i = 0; i < iter + 1; ++i) { if (sequence.get(i) instanceof ConstantNode) { factors.add(ONE); - } else if (sequence.get(i) instanceof ReferenceNode) { - factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i)))); + } else if (sequence.get(i) instanceof ReferenceNode ref) { + var f = lengthFunctionFor(ref); + factors.add(new ReferenceNode(f.getName())); } if (i >= 1) { operators.add(Operator.plus); |