aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/expressiontransforms/NormalizerFunctionExpander.java
blob: a8fee966656bf4da904e790eb2b5f31f4a014949 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.expressiontransforms;

import com.yahoo.schema.FeatureNames;
import com.yahoo.schema.RankProfile.RankFeatureNormalizer;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.functions.Generate;

import java.io.StringReader;
import java.util.HashSet;
import java.util.Set;
import java.util.logging.Logger;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import java.util.ArrayList;

/**
 * Recognizes pseudo-functions and creates global-phase normalizers
 * @author arnej
 */
public class NormalizerFunctionExpander extends ExpressionTransformer<RankProfileTransformContext> {

    public final static String NORMALIZE_LINEAR = "normalize_linear";
    public final static String RECIPROCAL_RANK = "reciprocal_rank";
    public final static String RECIPROCAL_RANK_FUSION = "reciprocal_rank_fusion";

    @Override
    public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
        if (node instanceof ReferenceNode r) {
            node = transformReference(r, context);
        }
        if (node instanceof CompositeNode composite) {
            node = transformChildren(composite, context);
        }
        return node;
    }

    private ExpressionNode transformReference(ReferenceNode node, RankProfileTransformContext context) {
        Reference ref = node.reference();
        String name = ref.name();
        if (ref.output() != null) {
            return node;
        }
        var f = context.rankProfile().getFunctions().get(name);
        if (f != null) {
            // never transform declared functions
            return node;
        }
        return switch(name) {
            case RECIPROCAL_RANK_FUSION -> transform(expandRRF(ref), context);
            case NORMALIZE_LINEAR -> transformNormLin(ref, context);
            case RECIPROCAL_RANK -> transformRRank(ref, context);
            default -> node;
        };
    }

    private ExpressionNode expandRRF(Reference ref) {
        var args = ref.arguments();
        if (args.size() < 2) {
            throw new IllegalArgumentException("must have at least 2 arguments: " + ref);
        }
        List<ExpressionNode> children = new ArrayList<>();
        List<Operator> operators = new ArrayList<>();
        for (var arg : args.expressions()) {
            if (! children.isEmpty()) operators.add(Operator.plus);
            children.add(new ReferenceNode(RECIPROCAL_RANK, List.of(arg), null));
        }
        // must be further transformed (see above)
        return new OperationNode(children, operators);
    }

    private ExpressionNode transformNormLin(Reference ref, RankProfileTransformContext context) {
        var args = ref.arguments();
        if (args.size() != 1) {
            throw new IllegalArgumentException("must have exactly 1 argument: " + ref);
        }
        var input = args.expressions().get(0);
        if (input instanceof ReferenceNode inputRefNode) {
            var inputRef = inputRefNode.reference();
            RankFeatureNormalizer normalizer = RankFeatureNormalizer.linear(ref, inputRef);
            context.rankProfile().addFeatureNormalizer(normalizer);
            var newRef = Reference.fromIdentifier(normalizer.name());
            return new ReferenceNode(newRef);
        } else {
            throw new IllegalArgumentException("the first argument must be a simple feature: " + ref + " => " + input.getClass());
        }
    }

    private ExpressionNode transformRRank(Reference ref, RankProfileTransformContext context) {
        var args = ref.arguments();
        if (args.size() < 1 || args.size() > 2) {
            throw new IllegalArgumentException("must have 1 or 2 arguments: " + ref);
        }
        double k = 60.0;
        if (args.size() == 2) {
            var kArg = args.expressions().get(1);
            if (kArg instanceof ConstantNode kNode) {
                k = kNode.getValue().asDouble();
            } else {
                throw new IllegalArgumentException("the second argument (k) must be a constant in: " + ref);
            }
        }
        var input = args.expressions().get(0);
        if (input instanceof ReferenceNode inputRefNode) {
            var inputRef = inputRefNode.reference();
            RankFeatureNormalizer normalizer = RankFeatureNormalizer.rrank(ref, inputRef, k);
            context.rankProfile().addFeatureNormalizer(normalizer);
            var newRef = Reference.fromIdentifier(normalizer.name());
            return new ReferenceNode(newRef);
        } else {
            throw new IllegalArgumentException("the first argument must be a simple feature: " + ref);
        }
    }
}