aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
blob: 85452d16a77c37bb47659561b70bebc9344d4df0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.TensorFunction;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

/**
 * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
 *
 * @author bratseth
 */
class OperationMapper {

    /*
       A note on conversion from implicitly numbered to explicitly named dimensions:
       Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
       'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
       comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
       around dimension renaming operations which mirrors those built into the TF operation definitions.

       To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
       dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
       and the result is then renamed again (if necessary) to recover this convention across a full nested
       computation.

       This requires us to track tensor types throughout the conversion.
     */

    private TensorConverter tensorConverter = new TensorConverter();

    TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
        ensureArguments(2, arguments, "join");
        TypedTensorFunction a = arguments.get(0);
        TypedTensorFunction b = arguments.get(1);

        if (a.type().rank() == 0 && b.type().rank() > 0) {
            return new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction));
        }
        if (b.type().rank() == 0 && a.type().rank() > 0) {
            return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction));
        }
        if (a.type().rank() == b.type().rank()) {
            return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction));
        }

        // Well now we have entered the wonderful world of "broadcasting"
        // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
        // I'm not able to extract from that any unambiguous specification of which dimensions
        // should be "stretched" when the tensor do not have the same number of dimensions.
        // From trying this with TensorFlow it appears that the second tensor is matched to the
        // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
        // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).

        if (a.type().rank() > b.type().rank()) {
            TensorFunction renameFunction = renameForBroadcast(a, b);
            return new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction));
        }
        TensorFunction renameFunction = renameForBroadcast(b, a);
        return new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction));
    }

    private TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) {
        List<String> renameFrom = new ArrayList<>();
        List<String> renameTo = new ArrayList<>();
        int sizeDifference = a.type().rank() - b.type().rank();
        for (int i = 0; i < b.type().rank(); i++) {
            renameFrom.add(b.type().dimensions().get(i).name());
            renameTo.add("d" + (sizeDifference + i));
        }
        return new Rename(b.function(), renameFrom, renameTo);
    }

    TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
        ensureArguments(1, arguments, "apply");
        TypedTensorFunction a = arguments.get(0);

        TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
        com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
        return new TypedTensorFunction(resultType, function);
    }

    TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) {
        String name = tfNode.getName();
        TensorType type = result.arguments().get(name);
        if (type == null)
            throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
                                               "', but there is no such placeholder");
        // Included literally in the expression and so must be produced by a separate macro in the rank profile
        return new TypedTensorFunction(type, new VariableTensor(name));
    }

    TypedTensorFunction placeholderWithDefault(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
        String name = tfNode.getInput(0);
        Tensor defaultValue = getConstantTensor(model, name);
        result.constant(name, defaultValue);
        result.macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
        // The default value will be provided by the macro. Users can override macro to change value.
        return new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
    }

    TypedTensorFunction constant(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
        String name = tfNode.getName();
        if (tfNode.getInputList().size() != 0) {
            throw new IllegalArgumentException("A constant node must have zero inputs but '" + name + "' has " +
                    tfNode.getInputList().size());
        }
        return importConstantTensor(tfNode, model, result, name);
    }

    TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
        if ( ! tfNode.getName().endsWith("/read"))
            throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
                                               "nodes are only supported when reading variables");
        if (tfNode.getInputList().size() != 1)
            throw new IllegalArgumentException("A Variable/read node must have one input but '" +
                                                tfNode.getName() + "' has " + tfNode.getInputList().size());

        String name = tfNode.getInput(0);
        return importConstantTensor(tfNode, model, result, name);
    }

    private TypedTensorFunction importConstantTensor(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, String name) {
        AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
        if (shapes == null)
            throw new IllegalArgumentException("'" + name + "' is missing a tensor shape");
        Tensor constant = getConstantTensor(model, name);
        result.constant(name, constant);
        return new TypedTensorFunction(constant.type(),
                new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(\"" + name + "\")")));
    }

    private Tensor getConstantTensor(SavedModelBundle model, String name) {
        Session.Runner fetched = model.session().runner().fetch(name);
        List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
        if (importedTensors.size() != 1)
            throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " +
                    importedTensors.size());
        return tensorConverter.toVespaTensor(importedTensors.get(0));
    }

    TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
        ensureArguments(2, arguments, "matmul");
        TypedTensorFunction a = arguments.get(0);
        TypedTensorFunction b = arguments.get(1);
        if (a.type().rank() < 2 || b.type().rank() < 2)
            throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
        if (a.type().rank() != b.type().rank())
            throw new IllegalArgumentException("Tensors in matmul must have the same rank");

        String afterLastDim = "d" + (a.type().rank() + 1);
        // Let the first dimension of the second tensor be the same as the second dimension of the first
        // and the second dimension of the second argument be not present in the first argument, while leaving the
        // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.

        // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly

        Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
                                     ImmutableList.of("d1", afterLastDim));
        Matmul matmul = new Matmul(a.function(), renamedB, "d1");
        return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
                                       new Rename(matmul, afterLastDim, "d1"));
    }

    TypedTensorFunction mean(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) {
        ensureArguments(2, arguments, "mean");
        Tensor reductionIndices = getConstantTensor(model, tfNode.getInput(1));

        TensorFunction inputFunction = arguments.get(0).function();
        TensorType inputType = arguments.get(0).type();

        List<String> reduceDimensions = new ArrayList<>();
        for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) {
            Tensor.Cell cell = cellIterator.next();
            int dimensionIndex = cell.getValue().intValue();
            if (dimensionIndex < 0) {
                dimensionIndex = inputType.dimensions().size() - dimensionIndex;
            }
            reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
        }

        TensorType outputType = Reduce.outputType(inputType, reduceDimensions);
        TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);

        if (shouldKeepDimensions(tfNode)) {
            return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions));
        }

        TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
        return output;
    }

    private boolean shouldKeepDimensions(NodeDef tfNode) {
        AttrValue keepDimsAttr = tfNode.getAttrMap().get("keep_dims");
        return keepDimsAttr != null && keepDimsAttr.getB();
    }

    private TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) {
        TensorType.Builder builder = new TensorType.Builder();
        for (TensorType.Dimension dimension: inputType.dimensions()) {
            String name = dimension.name();
            Long size = dimensionSize(dimension);
            if (reduceDimensions.contains(name)) {
                size = 1L;
            }
            builder.indexed(name, size);
        }
        return builder.build();
    }

    private TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) {
        for (int i = 0; i < type.dimensions().size(); ++i) {
            String correct = String.format("d%d", i);
            String current = type.dimensions().get(i).name();
            if (!current.equals(correct)) {
                return fixNamingConvention(type, function);
            }
        }
        return new TypedTensorFunction(type, function);
    }

    private TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) {
        TensorType.Builder correctType = new TensorType.Builder();
        List<String> from = new ArrayList<>();
        List<String> to = new ArrayList<>();
        for (int i = 0; i < type.dimensions().size(); ++i) {
            String correct = String.format("d%d", i);
            String current = type.dimensions().get(i).name();
            if (!current.equals(correct)) {
                from.add(current);
                to.add(correct);
            }
            correctType.indexed(correct, dimensionSize(type.dimensions().get(i)));
        }
        if (from.size() > 0) {
            function = new Rename(function, from, to);
            type = correctType.build();
        }
        return new TypedTensorFunction(type, function);
    }

    TypedTensorFunction noOp(List<TypedTensorFunction> arguments) {
        ensureArguments(1, arguments, "noOp");
        return arguments.get(0);
    }

    TypedTensorFunction expandDims(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) {
        ensureArguments(2, arguments, "expandDims");
        Tensor axis = getConstantTensor(model, tfNode.getInput(1));
        if (axis.type().rank() != 0) {
            throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar");
        }

        TensorFunction inputFunction = arguments.get(0).function();
        TensorType inputType = arguments.get(0).type();

        int dimensionToInsert = (int)axis.asDouble();
        if (dimensionToInsert < 0) {
            dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
        }

        TensorType.Builder outputTypeBuilder = new TensorType.Builder();
        int dimensionIndex = 0;
        for (int i = 0; i < inputType.dimensions().size() + 1; ++i) {
            String name = String.format("temp_%d", i);
            Long size;
            if (i == dimensionToInsert) {
                size = 1L;
            } else {
                size = dimensionSize(inputType.dimensions().get(dimensionIndex));
                dimensionIndex++;
            }
            outputTypeBuilder.indexed(name, size);
        }

        return reshape(inputFunction, inputType, outputTypeBuilder.build());
    }

    TypedTensorFunction reshape(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) {
        ensureArguments(2, arguments, "reshape");
        Tensor shape = getConstantTensor(model, tfNode.getInput(1));

        TensorFunction inputFunction = arguments.get(0).function();
        TensorType inputType = arguments.get(0).type();

        TensorType.Builder outputTypeBuilder = new TensorType.Builder();
        int dimensionIndex = 0;
        for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
            Tensor.Cell cell = cellIterator.next();
            int size = cell.getValue().intValue();
            if (size < 0) {
                size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue();
            }
            outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size);
            dimensionIndex++;
        }
        return reshape(inputFunction, inputType, outputTypeBuilder.build());
    }

    private TypedTensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
        if (!tensorSize(inputType).equals(tensorSize(outputType))) {
            throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
        }

        // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
        // then use the dimension order of the new shape to roll back into a tensor.
        // Here we create a transformation tensor that is multiplied with the from tensor to map into
        // the new shape. We have to introduce temporary dimension names and rename back if dimension names
        // in the new and old tensor type overlap.

        ExpressionNode unrollFrom = unrollTensorExpression(inputType);
        ExpressionNode unrollTo = unrollTensorExpression(outputType);
        ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);

        TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
        Generate transformTensor = new Generate(transformationType,
                new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());

        TensorFunction outputFunction = new Reduce(
                new Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
                Reduce.Aggregator.sum,
                inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
        TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
        return output;
    }

    private ExpressionNode unrollTensorExpression(TensorType type) {
        if (type.rank() == 0) {
            return new ConstantNode(DoubleValue.zero);
        }
        List<ExpressionNode> children = new ArrayList<>();
        List<ArithmeticOperator> operators = new ArrayList<>();
        int size = 1;
        for (int i = type.dimensions().size() - 1; i >= 0; --i) {
            TensorType.Dimension dimension = type.dimensions().get(i);
            children.add(0, new ReferenceNode(dimension.name()));
            if (size > 1) {
                operators.add(0, ArithmeticOperator.MULTIPLY);
                children.add(0, new ConstantNode(new DoubleValue(size)));
            }
            size *= dimensionSize(dimension);
            if (i > 0) {
                operators.add(0, ArithmeticOperator.PLUS);
            }
        }
        return new ArithmeticNode(children, operators);
    }

    TypedTensorFunction select(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, List<TypedTensorFunction> arguments) {
        ensureArguments(3, arguments, "select");
        Tensor condition = getConstantTensor(model, tfNode.getInput(0));

        TypedTensorFunction x = arguments.get(1);
        TypedTensorFunction y = arguments.get(2);
        if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) {
            throw new IllegalArgumentException("'Select': input tensors must have the same shape");
        }

        if (condition.type().rank() == 0) {
            return (int)condition.asDouble() == 0 ? y : x;
        }
        if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
            return condition.cellIterator().next().getValue().intValue() == 0 ? y : x;
        }

        // The task is to select cells from 'x' or 'y' based on 'condition'.
        // If 'condition' is 0 (false), select from 'y', if 1 (true) select
        // from 'x'. We do this by individually joining 'x' and 'y' with
        // 'condition', and then joining the resulting two tensors.

        TypedTensorFunction conditionFunction = importConstantTensor(tfNode, model, result, tfNode.getInput(0));
        TensorFunction xCond = new Join(x.function(), conditionFunction.function(), ScalarFunctions.multiply());
        TensorFunction yCond = new Join(y.function(), conditionFunction.function(), new DoubleBinaryOperator() {
            @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
            @Override public String toString() { return "f(a,b)(a * (1-b))"; }
        });
        TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add());
        return new TypedTensorFunction(x.type(), outputFunction);
    }

    TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
        ensureArguments(1, arguments, "softmax");
        TypedTensorFunction a = arguments.get(0);
        // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
        String dimension = "d" + (a.type().rank() - 1);
        Softmax softmax = new Softmax(a.function(), dimension);
        return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
    }

    TypedTensorFunction squeeze(NodeDef tfNode, List<TypedTensorFunction> arguments) {
        ensureArguments(1, arguments, "squeeze");

        TensorFunction inputFunction = arguments.get(0).function();
        TensorType inputType = arguments.get(0).type();
        List<String> squeezeDimensions;

        AttrValue squeezeDimsAttr = tfNode.getAttrMap().get("squeeze_dims");
        if (squeezeDimsAttr == null) {
            squeezeDimensions = inputType.dimensions().stream().
                    filter(dim -> dimensionSize(dim) == 1).
                    map(TensorType.Dimension::name).
                    collect(Collectors.toList());
        } else {
            squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
                    map(i -> i < 0 ? inputType.dimensions().size() - i : i).
                    map(i -> inputType.dimensions().get(i.intValue())).
                    filter(dim -> dimensionSize(dim) == 1).
                    map(TensorType.Dimension::name).
                    collect(Collectors.toList());
        }

        if (squeezeDimensions.isEmpty()) {
            return arguments.get(0);
        }

        TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
        TensorType outputType = Reduce.outputType(inputType, squeezeDimensions);
        TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
        return output;
    }

    private Long tensorSize(TensorType type) {
        Long size = 1L;
        for (TensorType.Dimension dimension : type.dimensions()) {
            size *= dimensionSize(dimension);
        }
        return size;
    }

    private Long dimensionSize(TensorType.Dimension dim) {
        return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
    }

    private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
        if ( arguments.size() != count)
            throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
                                               ", but got " + arguments.size());
    }

}