summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
blob: d97235d11d251a74c90352f1238a39868f04d9e2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;

import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * An array context supporting functions invocations implemented as lazy values.
 *
 * @author bratseth
 */
public final class LazyArrayContext extends Context implements ContextIndex {

    private final ExpressionFunction function;
    private final IndexedBindings indexedBindings;

    private LazyArrayContext(ExpressionFunction function, IndexedBindings indexedBindings) {
        this.function = function;
        this.indexedBindings = indexedBindings.copy(this);
    }

    /** Create a fast lookup, lazy context for a function */
    LazyArrayContext(ExpressionFunction function,
                     Map<FunctionReference, ExpressionFunction> referencedFunctions,
                     List<Constant> constants,
                     List<OnnxModel> onnxModels,
                     Model model) {
        this.function = function;
        this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, onnxModels, this, model);
    }

    /**
     * Sets the value to use for lookups to existing values which are not set in this context.
     * The default value that will be returned is NaN
     */
    public void setMissingValue(Tensor value) {
        indexedBindings.setMissingValue(value);
    }

    /**
     * Puts a value by name.
     * The value will be frozen if it isn't already.
     *
     * @throws IllegalArgumentException if the name is not present in the ranking expression this was created with, and
     *         ignoredUnknownValues is false
     */
    @Override
    public void put(String name, Value value) {
        put(requireIndexOf(name), value);
    }

    /** Same as put(index,DoubleValue.frozen(value)) */
    public final void put(int index, double value) {
        put(index, DoubleValue.frozen(value));
    }

    /**
     * Puts a value by index.
     * The value will be frozen if it isn't already.
     */
    public void put(int index, Value value) {
        indexedBindings.set(index, value.freeze());
    }

    @Override
    public TensorType getType(Reference reference) {
        return get(requireIndexOf(reference.toString())).type();
    }

    /** Perform a slow lookup by name */
    @Override
    public Value get(String name) {
        return get(requireIndexOf(name));
    }

    /** Perform a fast lookup by index */
    @Override
    public Value get(int index) {
        return indexedBindings.get(index);
    }

    @Override
    public double getDouble(int index) {
        return get(index).asDouble();
    }

    @Override
    public int getIndex(String name) {
        return requireIndexOf(name);
    }

    @Override
    public int size() {
        return indexedBindings.names().size();
    }

    @Override
    public Set<String> names() { return indexedBindings.names(); }

    /** Returns the (immutable) subset of names in this which must be bound when invoking */
    public Set<String> arguments() { return indexedBindings.arguments(); }

    /** Returns the set of ONNX models that need to be evaluated on this context */
    public Map<String, OnnxModel> onnxModels() { return indexedBindings.onnxModels(); }

    private Integer requireIndexOf(String name) {
        Integer index = indexedBindings.indexOf(name);
        if (index == null)
            throw new IllegalArgumentException("Value '" + name + "' can not be bound in " + this);
        return index;
    }

    boolean isMissing(String name) {
        return indexedBindings.indexOf(name) == null;
    }

    /** Returns the value which should be used when no value is set */
    public Value defaultValue() {
        return indexedBindings.missingValue;
    }

    /**
     * Creates a copy of this context suitable for evaluating against the same ranking expression
     * in a different thread or for re-binding free variables.
     */
    LazyArrayContext copy() {
        return new LazyArrayContext(function, indexedBindings);
    }

    private static class IndexedBindings {

        /** The mapping from variable name to index */
        private final ImmutableMap<String, Integer> nameToIndex;

        /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
        private final ImmutableSet<String> arguments;

        /** The current values set */
        private final Value[] values;

        /** ONNX models indexed by rank feature that calls them */
        private final ImmutableMap<String, OnnxModel> onnxModels;

        /** The object instance which encodes "no value is set". The actual value of this is never used. */
        private static final Value missing = new DoubleValue(Double.NaN).freeze();

        /** The value to return for lookups where no value is set (default: NaN) */
        private Value missingValue = new DoubleValue(Double.NaN).freeze();


        private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
                                Value[] values,
                                ImmutableSet<String> arguments,
                                ImmutableMap<String, OnnxModel> onnxModels) {
            this.nameToIndex = nameToIndex;
            this.values = values;
            this.arguments = arguments;
            this.onnxModels = onnxModels;
        }

        /**
         * Creates indexed bindings for the given expressions.
         * The given expression and functions may be inspected but cannot be stored.
         */
        IndexedBindings(ExpressionFunction function,
                        Map<FunctionReference, ExpressionFunction> referencedFunctions,
                        List<Constant> constants,
                        List<OnnxModel> onnxModels,
                        LazyArrayContext owner,
                        Model model) {
            // 1. Determine and prepare bind targets
            Set<String> bindTargets = new LinkedHashSet<>();
            Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
            Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
            extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);

            this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
            this.arguments = ImmutableSet.copyOf(arguments);
            values = new Value[bindTargets.size()];
            Arrays.fill(values, missing);

            int i = 0;
            ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
            for (String variable : bindTargets)
                nameToIndexBuilder.put(variable, i++);
            nameToIndex = nameToIndexBuilder.build();

            // 2. Bind the bind targets
            for (Constant constant : constants) {
                String constantReference = "constant(" + constant.name() + ")";
                Integer index = nameToIndex.get(constantReference);
                if (index != null) {
                    values[index] = new TensorValue(constant.value());
                }
            }

            for (Map.Entry<FunctionReference, ExpressionFunction> referencedFunction : referencedFunctions.entrySet()) {
                Integer index = nameToIndex.get(referencedFunction.getKey().serialForm());
                if (index != null) { // Referenced in this, so bind it
                    values[index] = new LazyValue(referencedFunction.getKey(), owner, model);
                }
            }
        }

        private void setMissingValue(Tensor value) {
            missingValue = new TensorValue(value).freeze();
        }

        private void extractBindTargets(ExpressionNode node,
                                        Map<FunctionReference, ExpressionFunction> functions,
                                        Set<String> bindTargets,
                                        Set<String> arguments,
                                        List<OnnxModel> onnxModels,
                                        Map<String, OnnxModel> onnxModelsInUse) {
            if (isFunctionReference(node)) {
                FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
                bindTargets.add(reference.serialForm());

                ExpressionNode functionNode = functions.get(reference).getBody().getRoot();
                extractBindTargets(functionNode, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
            }
            else if (isOnnx(node)) {
                extractOnnxTargets(node, bindTargets, arguments, onnxModels, onnxModelsInUse);
            }
            else if (isConstant(node)) {
                bindTargets.add(node.toString());
            }
            else if (node instanceof ReferenceNode) {
                bindTargets.add(node.toString());
                arguments.add(node.toString());
            }
            else if (node instanceof CompositeNode) {
                CompositeNode cNode = (CompositeNode)node;
                for (ExpressionNode child : cNode.children())
                    extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
            }
        }

        /**
         * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add
         * that as a bind target and argument. During evaluation, this will be evaluated before
         * the rest of the expression and the result is added to the context. Also extract the
         * inputs to the model and add them as bind targets and arguments.
         */
        private void extractOnnxTargets(ExpressionNode node,
                                        Set<String> bindTargets,
                                        Set<String> arguments,
                                        List<OnnxModel> onnxModels,
                                        Map<String, OnnxModel> onnxModelsInUse) {
            Optional<String> modelName = getArgument(node);
            if (modelName.isPresent()) {
                for (OnnxModel onnxModel : onnxModels) {
                    if (onnxModel.name().equals(modelName.get())) {
                        String onnxFeature = node.toString();
                        bindTargets.add(onnxFeature);

                        // Load the model (if not already loaded) to extract inputs
                        onnxModel.load();

                        for(String input : onnxModel.inputs().keySet()) {
                            bindTargets.add(input);
                            arguments.add(input);
                        }
                        onnxModelsInUse.put(onnxFeature, onnxModel);
                    }
                }
            }
        }

        private Optional<String> getArgument(ExpressionNode node) {
            if (node instanceof ReferenceNode) {
                ReferenceNode reference = (ReferenceNode) node;
                if (reference.getArguments().size() > 0) {
                    if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
                        ExpressionNode constantNode = reference.getArguments().expressions().get(0);
                        return Optional.of(stripQuotes(constantNode.toString()));
                    }
                    if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
                        ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
                        return Optional.of(referenceNode.getName());
                    }
                }
            }
            return Optional.empty();
        }

        public static String stripQuotes(String s) {
            if (s.codePointAt(0) == '"' && s.codePointAt(s.length()-1) == '"')
                return s.substring(1, s.length()-1);
            if (s.codePointAt(0) == '\'' && s.codePointAt(s.length()-1) == '\'')
                return s.substring(1, s.length()-1);
            return s;
        }

        private boolean isFunctionReference(ExpressionNode node) {
            if ( ! (node instanceof ReferenceNode)) return false;
            ReferenceNode reference = (ReferenceNode)node;
            return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
        }

        private boolean isOnnx(ExpressionNode node) {
            if ( ! (node instanceof ReferenceNode)) return false;
            ReferenceNode reference = (ReferenceNode) node;
            return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
        }

        private boolean isConstant(ExpressionNode node) {
            if ( ! (node instanceof ReferenceNode)) return false;
            ReferenceNode reference = (ReferenceNode)node;
            return reference.getName().equals("constant") && reference.getArguments().size() == 1;
        }

        Value get(int index) {
            Value value = values[index];
            return value == missing ? missingValue : value;
        }

        void set(int index, Value value) {
            values[index] = value;
        }

        Set<String> names() { return nameToIndex.keySet(); }
        Set<String> arguments() { return arguments; }
        Integer indexOf(String name) { return nameToIndex.get(name); }
        Map<String, OnnxModel> onnxModels() { return onnxModels; }

        IndexedBindings copy(Context context) {
            Value[] valueCopy = new Value[values.length];
            for (int i = 0; i < values.length; i++)
                valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i];
            return new IndexedBindings(nameToIndex, valueCopy, arguments, onnxModels);
        }

    }

}