aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
blob: a0a3552eb929e744bcfcb73f5ae7006b8bd1e85c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Returns a subspace of a tensor
 *
 * @author bratseth
 */
@Beta
public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {

    private final TensorFunction<NAMETYPE> argument;
    private final List<DimensionValue<NAMETYPE>> subspaceAddress;

    /**
     * Creates a value function
     *
     * @param argument the tensor to return a cell value from
     * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress
     *                        because those require a type, but a type is not resolved until this is evaluated
     */
    public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) {
        this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
        if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
            throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " +
                                               "Specify dimension names explicitly instead");
        this.subspaceAddress = subspaceAddress;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }

    @Override
    public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if (arguments.size() != 1)
            throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
        return new Slice<>(arguments.get(0), subspaceAddress);
    }

    public List<DimensionValue<NAMETYPE>> getSubspaceAddress() {
        return subspaceAddress;
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor tensor = argument.evaluate(context);
        TensorType resultType = resultType(tensor.type());

        PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context);
        if (resultType.rank() == 0) // shortcut common case
            return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));

        Tensor.Builder b = Tensor.Builder.of(resultType);
        for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
            Tensor.Cell cell = i.next();
            if (matches(subspaceAddress, cell.getKey(), tensor.type()))
                b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue());
        }
        return b.build();
    }

    private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) {
        PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size());
        for (int i = 0; i < subspaceAddress.size(); i++) {
            if (subspaceAddress.get(i).label().isPresent())
                b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
                      subspaceAddress.get(i).label().get());
            else
                b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
                      subspaceAddress.get(i).index().get().apply(context).intValue());
        }
        return b.build();
    }

    private boolean matches(PartialAddress subspaceAddress,
                            TensorAddress address, TensorType type) {
        for (int i = 0; i < subspaceAddress.size(); i++) {
            String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get());
            if ( ! label.equals(subspaceAddress.label(i)))
                return false;
        }
        return true;
    }

    /** Returns the subset of the given address which is present in the subspace type */
    private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) {
        TensorAddress.Builder b = new TensorAddress.Builder(subspaceType);
        for (int i = 0; i < address.size(); i++) {
            String dimension = type.dimensions().get(i).name();
            if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
                b.add(dimension, address.label(i));
        }
        return b.build();
    }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return resultType(argument.type(context));
    }

    private TensorType resultType(TensorType argumentType) {
        TensorType.Builder b = new TensorType.Builder();

        // Special case where a single indexed or mapped dimension is sliced
        if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
            if (subspaceAddress.get(0).index().isPresent()) {
                if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1)
                    throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
                                                       " to " + argumentType + ", which have multiple");
                for (TensorType.Dimension dimension : argumentType.dimensions()) {
                    if ( ! dimension.isIndexed())
                        b.dimension(dimension);
                }
            }
            else {
                if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
                    throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
                                                       " to " + argumentType + ", which have multiple");
                for (TensorType.Dimension dimension : argumentType.dimensions()) {
                    if (dimension.isIndexed())
                        b.dimension(dimension);
                }

            }
        }
        else { // general slicing
            Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
            for (TensorType.Dimension dimension : argumentType.dimensions()) {
                if (slicedDimensions.contains(dimension.name()))
                    slicedDimensions.remove(dimension.name());
                else
                    b.dimension(dimension);
            }
            if ( ! slicedDimensions.isEmpty())
                throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " +
                                                   argumentType);
        }
        return b.build();
    }

    @Override
    public String toString(ToStringContext context) {
        StringBuilder b = new StringBuilder(argument.toString(context));
        if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
            if (subspaceAddress.get(0).index().isPresent())
                b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
            else
                b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
        }
        else {
            b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
        }
        return b.toString();
    }

    public static class DimensionValue<NAMETYPE extends Name>  {

        private final Optional<String> dimension;

        /** The label of this, or null if index is set */
        private final String label;

        /** The function returning the index of this, or null if label is set */
        private final ScalarFunction<NAMETYPE> index;

        public DimensionValue(String dimension, String label) {
            this(Optional.of(dimension), label, null);
        }

        public DimensionValue(String dimension, int index) {
            this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
        }

        public DimensionValue(int index) {
            this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
        }

        public DimensionValue(String label) {
            this(Optional.empty(), label, null);
        }

        public DimensionValue(ScalarFunction<NAMETYPE> index) {
            this(Optional.empty(), null, index);
        }

        public DimensionValue(Optional<String> dimension, String label) {
            this(dimension, label, null);
        }

        public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
            this(dimension, null, index);
        }

        public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
            this(Optional.of(dimension), null, index);
        }

        private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
            this.dimension = dimension;
            this.label = label;
            this.index = index;
        }

        /**
         * Returns the given name of the dimension, or null if dense form is used, such that name
         * must be inferred from order
         */
        public Optional<String> dimension() { return dimension; }

        /** Returns the label for this dimension or empty if it is provided by an index function */
        public Optional<String> label() { return Optional.ofNullable(label); }

        /** Returns the index expression for this dimension, or empty if it is not a number */
        public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }

        @Override
        public String toString() {
            return toString(ToStringContext.empty());
        }

        public String toString(ToStringContext context) {
            StringBuilder b = new StringBuilder();
            dimension.ifPresent(d -> b.append(d).append(":"));
            if (label != null)
                b.append(label);
            else
                b.append(index.toString(context));
            return b.toString();
        }

    }

    private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {

        private final int value;

        public ConstantIntegerFunction(int value) {
            this.value = value;
        }

        @Override
        public Double apply(EvaluationContext<NAMETYPE> context) {
            return (double)value;
        }

        @Override
        public String toString() { return String.valueOf(value); }

    }

}