summaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-22 14:18:15 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-22 14:18:15 +0000
commitde315c066a6dd0b78431230a5035fa519ebac601 (patch)
treef986435917ca55ba9db95c7dfddd9dfef11fa9d8 /config-model/src/main
parent39ef03dbfea4665dbe06b187d8d9d19d65e6d660 (diff)
use ExpressionFunction.Instance for consistency
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java36
1 files changed, 27 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index 83ae6048051..f454d941e31 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -1,6 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.expressiontransforms;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
@@ -11,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
@@ -18,7 +21,9 @@ import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
+import java.io.StringReader;
import java.util.ArrayList;
+import java.util.ArrayDeque;
import java.util.List;
import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
@@ -207,8 +212,21 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
}
}
- private String lengthFunctionName(ReferenceNode arg) {
- return "__token_length@" + Integer.toHexString(arg.hashCode());
+ private static final ExpressionFunction commonLengthFunction = makeLengthFunction();
+ private static ExpressionFunction makeLengthFunction() {
+ String func = "sum(map(input, f(x)(x > 0)))";
+ String name = "__token_length";
+ try (var r = new StringReader(func)) {
+ return new ExpressionFunction(name, List.of("input"), new RankingExpression(name, r));
+ }
+ catch (com.yahoo.searchlib.rankingexpression.parser.ParseException e) {
+ throw new IllegalStateException("unexpected", e);
+ }
+ }
+
+ private ExpressionFunction.Instance lengthFunctionFor(ReferenceNode arg) {
+ var ctx = new SerializationContext();
+ return commonLengthFunction.expand(ctx, List.of(arg), new ArrayDeque<String>());
}
private List<ExpressionNode> createTokenSequence(ReferenceNode feature) {
@@ -229,14 +247,13 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
private void createTokenLengthFunctions(ReferenceNode feature, RankProfileTransformContext context) {
for (int i = 1; i < feature.getArguments().size(); ++i) {
ExpressionNode arg = feature.getArguments().expressions().get(i);
- if ( ! (arg instanceof ReferenceNode)) {
+ if ( ! (arg instanceof ReferenceNode ref)) {
throw new IllegalArgumentException("Invalid argument " + i + " to " + feature.getName() + ": " +
"the argument must be a reference. Got " + arg.toString());
}
- ReferenceNode ref = (ReferenceNode) arg;
- String functionName = lengthFunctionName(ref);
- if ( ! context.rankProfile().getFunctions().containsKey(functionName)) {
- context.rankProfile().addFunction(functionName, List.of(), "sum(map(" + ref + ", f(x)(x > 0)))", false);
+ var f = lengthFunctionFor(ref);
+ if ( ! context.rankProfile().getFunctions().containsKey(f.getName())) {
+ context.rankProfile().addFunction(f.getName(), List.of(), f.getExpressionString(), false);
}
}
}
@@ -282,8 +299,9 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
for (int i = 0; i < iter + 1; ++i) {
if (sequence.get(i) instanceof ConstantNode) {
factors.add(ONE);
- } else if (sequence.get(i) instanceof ReferenceNode) {
- factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
+ } else if (sequence.get(i) instanceof ReferenceNode ref) {
+ var f = lengthFunctionFor(ref);
+ factors.add(new ReferenceNode(f.getName()));
}
if (i >= 1) {
operators.add(Operator.plus);