aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
blob: 8dd1e3ff33d132bdba7961cd0badcc9734159264 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.TensorFunction;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;

/**
 * Onnx slice operation.
 *
 * Opset 1 to 9 accepts starts, ends, and axes tensors as attributes
 *
 * Opset 10 and up accepts starts, ends, axes, and steps tensors as inputs. Here we assume these are
 * constants, otherwise we can't import this model because that would mean we
 * would not know the resulting tensor type until run-time, and that is currently
 * not supported in Vespa.
 */
public class Slice extends IntermediateOperation {

    private final AttributeMap attributes;

    private int[] starts;
    private int[] ends;
    private int[] steps;

    public Slice(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) {
        super(modelName, nodeName, inputs);
        this.attributes = attributes;
    }

    @Override
    protected OrderedTensorType lazyGetType() {
        if (inputs.size() < 1 || inputs.get(0).type().isEmpty()) {
            return null;
        }
        OrderedTensorType dataType = inputs.get(0).type().get();

        // required as we use tensor create
        inputs.get(0).exportAsRankingFunction = true;

        // Todo: only supports opsets 1-9, for >= get these from inputs
        int[] startsInput = attributeListAsArray("starts", 0);
        int[] endsInput = attributeListAsArray("ends", 0);
        int[] stepsInput = new int[dataType.rank()]; Arrays.fill(stepsInput, 1);  // Todo: get from input when opset >= 10

        int[] axes;
        if (attributes.getList("axes").isPresent()) {
            axes = attributeListAsArray("axes", 0);
        } else {
            // infer axes: default is [0, 1, ..., len('starts')-1]
            axes = new int[startsInput.length];
            for (int i = 0; i < startsInput.length; ++i) {
                axes[i] = i;
            }
        }

        if (startsInput.length != endsInput.length) {
            throw new IllegalArgumentException("Slice in " + name + ": 'starts' and 'ends' indexes are not of the same size.");
        }
        if (startsInput.length != axes.length) {
            throw new IllegalArgumentException("Slice in " + name + ": 'axes' and 'starts' are not of same size.");
        }

        int[] dimensionSizes = new int [dataType.rank()];
        for (int i = 0; i < dataType.rank(); ++i) {
            dimensionSizes[i] = dataType.dimensions().get(i).size().get().intValue();
        }

        starts = new int[dataType.rank()]; Arrays.fill(starts, 0);
        ends = new int[dataType.rank()];
        steps = new int[dataType.rank()]; Arrays.fill(steps, 1);

        for (int i = 0; i < axes.length; ++i) {
            int axis = axes[i];
            int start = startsInput[i];
            int end = endsInput[i];
            int step = stepsInput[i];

            axis = Math.min(axis, dataType.rank() - 1);
            axis = axis < 0 ? axis + dataType.rank() : axis;

            start = Math.min(start, dimensionSizes[axis]);
            start = start < 0 ? start + dimensionSizes[axis] : start;

            end = Math.min(end, dimensionSizes[axis]);
            end = end < 0 ? end + dimensionSizes[axis] : end;

            // Todo: check negative values for step size

            starts[axis] = start;
            steps[axis] = step;

            if (step == 0) {
                throw new IllegalArgumentException("Slice in " + name + ": illegal step size of 0.");
            }
            if ((end - start) < 1) {
                throw new IllegalArgumentException("Slice in " + name + ": illegal start (" + start + ") and end (" + end + ") index.");
            }
            dimensionSizes[axis] = (end - start) / step;
        }

        OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
        for (int i = 0; i < dataType.rank(); ++i) {
            addDimension(i, dimensionSizes[i], typeBuilder);
        }
        return typeBuilder.build();
    }

    private int[] attributeListAsArray(String name, int defaultValue) {
        if (attributes.getList(name).isEmpty()) {
            throw new IllegalArgumentException("Slice in " + name + ": Required attribute '" + name + "' is missing.");
        }
        List<Value> list = attributes.getList(name).get();
        int[] result = new int[list.size()]; Arrays.fill(result, defaultValue);
        for (int i = 0; i < list.size(); ++i) {
            result[i] = (int)list.get(i).asDouble();
        }
        return result;
    }

    private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) {
        String name = String.format("%s_%d", vespaName(), dimensionIndex);
        typeBuilder.add(TensorType.Dimension.indexed(name, size));
    }

    @Override
    protected TensorFunction lazyGetFunction() {
        if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) {
            return null;
        }

        IntermediateOperation data = inputs.get(0);
        OrderedTensorType dataType = data.type().get();
        String dataFunctionName = data.rankingExpressionFunctionName();

        List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();

        for (int axis = 0; axis < dataType.rank(); ++axis) {
            int start = starts[axis];
            int step = steps[axis];

            String inputDimensionName = dataType.dimensions().get(axis).name();
            String outputDimensionName = type.dimensions().get(axis).name();

            ExpressionNode stepSize = new ConstantNode(new DoubleValue(step));
            ExpressionNode startIndex = new ConstantNode(new DoubleValue(start));

            // step * (d0 + start)
            ExpressionNode reference = new ReferenceNode(outputDimensionName);
            ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex));
            ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus);

            dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul))));
        }

        TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(dataFunctionName));
        com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
        ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);

        TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
        return generate;
    }

    @Override
    public void addDimensionNameConstraints(DimensionRenamer renamer) {
        for (int i = 0; i < type.dimensions().size(); i++) {
            renamer.addDimension(type.dimensions().get(i).name());
            for (int j = i + 1; j < type.dimensions().size(); j++) {
                renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
                        DimensionRenamer.Constraint.lessThan(), this);
            }
        }
    }

    @Override
    public Slice withInputs(List<IntermediateOperation> inputs) {
        return new Slice(modelName(), name(), inputs, attributes);
    }

    @Override
    public String operationName() { return "Slice"; }

}