aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
blob: 94e40f2437a5844591a271861963dfddcba8337d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;

import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;

import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * An evaluator which can be used to evaluate a single function once.
 *
 * @author bratseth
 */
// This wraps all access to the context and the ranking expression to avoid incorrect usage
public class FunctionEvaluator {

    private final ExpressionFunction function;
    private final LazyArrayContext context;
    private boolean evaluated = false;

    FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) {
        this.function = function;
        this.context = context;
    }

    /**
     * Binds the given variable referred in this expression to the given value.
     *
     * @param name the variable to bind
     * @param value the value this becomes bound to
     * @return this for chaining
     */
    public FunctionEvaluator bind(String name, Tensor value) {
        if (evaluated)
            throw new IllegalStateException("Cannot bind a new value in a used evaluator");
        TensorType requiredType = function.getArgumentType(name);
        if (requiredType == null)
            throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
                                               ". Expected arguments: " +
                    function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
                            .map(e -> e.getKey() + ": " + e.getValue())
                            .collect(Collectors.joining(", ")));
        if ( ! value.type().isAssignableTo(requiredType))
            throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
        context.put(name, new TensorValue(value));
        return this;
    }

    /**
     * Binds the given variable referred in this expression to the given value.
     * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code>
     *
     * @param name the variable to bind
     * @param value the value this becomes bound to
     * @return this for chaining
     */
    public FunctionEvaluator bind(String name, double value) {
        return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build());
    }

    /**
     * Binds the given variable referred in this expression to the given value.
     * String values are not yet supported in tensors.
     *
     * @param name the variable to bind
     * @param value the value this becomes bound to
     * @return this for chaining
     */
    public FunctionEvaluator bind(String name, String value) {
        if (evaluated)
            throw new IllegalStateException("Cannot bind a new value in a used evaluator");
        context.put(name, new StringValue(value));
        return this;
    }

    /**
     * Sets the default value to use for variables which are not bound
     *
     * @param value the default value
     * @return this for chaining
     */
    public FunctionEvaluator setMissingValue(Tensor value) {
        if (evaluated)
            throw new IllegalStateException("Cannot change the missing value in a used evaluator");
        context.setMissingValue(value);
        return this;
    }

    /**
     * Sets the default value to use for variables which are not bound
     *
     * @param value the default value
     * @return this for chaining
     */
    public FunctionEvaluator setMissingValue(double value) {
        return setMissingValue(Tensor.Builder.of(TensorType.empty).cell(value).build());
    }

    public Tensor evaluate() {
        function.argumentTypes().entrySet().stream().sorted(Map.Entry.comparingByKey())
                .forEach(argument -> checkArgument(argument.getKey(), argument.getValue()));
        evaluated = true;
        evaluateOnnxModels();
        return function.getBody().evaluate(context).asTensor();
    }

    private void checkArgument(String name, TensorType type) {
        if (context.isMissing(name))
            throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type);
        if (! context.get(name).type().isAssignableTo(type))
            throw new IllegalStateException("Argument '" + name + "' must be bound to a value of type " + type);
    }

    /**
     * Evaluate ONNX models (if not already evaluated) and add the result back to the context.
     */
    private void evaluateOnnxModels() {
        for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) {
            String onnxFeature = entry.getKey();
            String outputName = function.getName(); // Function name is output of model (sometimes)
            int idx = onnxFeature.indexOf(").");
            if (idx > 0 && idx + 2 < onnxFeature.length()) {
                // explicitly specified as onnx(modelname).outputname ; pick the last part
                outputName = onnxFeature.substring(idx+2);
            }
            OnnxModel onnxModel = entry.getValue();
            if (context.get(onnxFeature).equals(context.defaultValue())) {
                Map<String, Tensor> inputs = new HashMap<>();
                for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
                    inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
                }
                Tensor result = onnxModel.evaluate(inputs, outputName);
                context.put(onnxFeature, new TensorValue(result));
            }
        }
    }

    /** Returns the function evaluated by this */
    public ExpressionFunction function() { return function; }

    public LazyArrayContext context() { return context; }

}