aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java134
1 files changed, 134 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java
new file mode 100644
index 00000000000..a8fee966656
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java
@@ -0,0 +1,134 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema.expressiontransforms;
+
+import com.yahoo.schema.FeatureNames;
+import com.yahoo.schema.RankProfile.RankFeatureNormalizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.functions.Generate;
+
+import java.io.StringReader;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.logging.Logger;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.Iterator;
+import java.util.List;
+import java.util.ArrayList;
+
+/**
+ * Recognizes pseudo-functions and creates global-phase normalizers
+ * @author arnej
+ */
+public class NormalizerFunctionExpander extends ExpressionTransformer<RankProfileTransformContext> {
+
+ public final static String NORMALIZE_LINEAR = "normalize_linear";
+ public final static String RECIPROCAL_RANK = "reciprocal_rank";
+ public final static String RECIPROCAL_RANK_FUSION = "reciprocal_rank_fusion";
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode r) {
+ node = transformReference(r, context);
+ }
+ if (node instanceof CompositeNode composite) {
+ node = transformChildren(composite, context);
+ }
+ return node;
+ }
+
+ private ExpressionNode transformReference(ReferenceNode node, RankProfileTransformContext context) {
+ Reference ref = node.reference();
+ String name = ref.name();
+ if (ref.output() != null) {
+ return node;
+ }
+ var f = context.rankProfile().getFunctions().get(name);
+ if (f != null) {
+ // never transform declared functions
+ return node;
+ }
+ return switch(name) {
+ case RECIPROCAL_RANK_FUSION -> transform(expandRRF(ref), context);
+ case NORMALIZE_LINEAR -> transformNormLin(ref, context);
+ case RECIPROCAL_RANK -> transformRRank(ref, context);
+ default -> node;
+ };
+ }
+
+ private ExpressionNode expandRRF(Reference ref) {
+ var args = ref.arguments();
+ if (args.size() < 2) {
+ throw new IllegalArgumentException("must have at least 2 arguments: " + ref);
+ }
+ List<ExpressionNode> children = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
+ for (var arg : args.expressions()) {
+ if (! children.isEmpty()) operators.add(Operator.plus);
+ children.add(new ReferenceNode(RECIPROCAL_RANK, List.of(arg), null));
+ }
+ // must be further transformed (see above)
+ return new OperationNode(children, operators);
+ }
+
+ private ExpressionNode transformNormLin(Reference ref, RankProfileTransformContext context) {
+ var args = ref.arguments();
+ if (args.size() != 1) {
+ throw new IllegalArgumentException("must have exactly 1 argument: " + ref);
+ }
+ var input = args.expressions().get(0);
+ if (input instanceof ReferenceNode inputRefNode) {
+ var inputRef = inputRefNode.reference();
+ RankFeatureNormalizer normalizer = RankFeatureNormalizer.linear(ref, inputRef);
+ context.rankProfile().addFeatureNormalizer(normalizer);
+ var newRef = Reference.fromIdentifier(normalizer.name());
+ return new ReferenceNode(newRef);
+ } else {
+ throw new IllegalArgumentException("the first argument must be a simple feature: " + ref + " => " + input.getClass());
+ }
+ }
+
+ private ExpressionNode transformRRank(Reference ref, RankProfileTransformContext context) {
+ var args = ref.arguments();
+ if (args.size() < 1 || args.size() > 2) {
+ throw new IllegalArgumentException("must have 1 or 2 arguments: " + ref);
+ }
+ double k = 60.0;
+ if (args.size() == 2) {
+ var kArg = args.expressions().get(1);
+ if (kArg instanceof ConstantNode kNode) {
+ k = kNode.getValue().asDouble();
+ } else {
+ throw new IllegalArgumentException("the second argument (k) must be a constant in: " + ref);
+ }
+ }
+ var input = args.expressions().get(0);
+ if (input instanceof ReferenceNode inputRefNode) {
+ var inputRef = inputRefNode.reference();
+ RankFeatureNormalizer normalizer = RankFeatureNormalizer.rrank(ref, inputRef, k);
+ context.rankProfile().addFeatureNormalizer(normalizer);
+ var newRef = Reference.fromIdentifier(normalizer.name());
+ return new ReferenceNode(newRef);
+ } else {
+ throw new IllegalArgumentException("the first argument must be a simple feature: " + ref);
+ }
+ }
+}