aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
blob: e80e5a91bfde13d3ddcb9bc0fef6e03bd82ebf97 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.tensor;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import static com.yahoo.tensor.TensorType.Dimension;
import static com.yahoo.tensor.TensorType.Value;

/**
 * Common type resolving for basic tensor operations.
 *
 * @author arnej
 */
public class TypeResolver {

    private static final Logger logger = Logger.getLogger(TypeResolver.class.getName());

    static private TensorType scalar() {
        return TensorType.empty;
    }

    static public TensorType map(TensorType inputType) {
        Value orig = inputType.valueType();
        Value cellType = Value.largestOf(orig, Value.FLOAT);
        if (cellType == orig) {
            return inputType;
        }
        return new TensorType(cellType, inputType.dimensions());
    }

    static public TensorType reduce(TensorType inputType, List<String> reduceDimensions) {
        if (reduceDimensions.isEmpty()) {
            return scalar();
        }
        Map<String, Dimension> map = new HashMap<>();
        for (Dimension dim : inputType.dimensions()) {
            map.put(dim.name(), dim);
        }
        for (String name : reduceDimensions) {
            if (map.containsKey(name)) {
                map.remove(name);
            } else {
                logger.log(Level.WARNING, "reducing non-existing dimension "+name+" in type "+inputType);
                // throw new IllegalArgumentException("reducing non-existing dimension "+name+" in type "+inputType);
            }
        }
        if (map.isEmpty()) {
            return scalar();
        }
        Value cellType = Value.largestOf(inputType.valueType(), Value.FLOAT);
        return new TensorType(cellType, map.values());
    }

    static public TensorType peek(TensorType inputType, List<String> peekDimensions) {
        if (peekDimensions.isEmpty()) {
            throw new IllegalArgumentException("Peeking no dimensions makes no sense");
        }
        Map<String, Dimension> map = new HashMap<>();
        for (Dimension dim : inputType.dimensions()) {
            map.put(dim.name(), dim);
        }
        for (String name : peekDimensions) {
            if (map.containsKey(name)) {
                map.remove(name);
            } else {
                throw new IllegalArgumentException("Peeking non-existing dimension '" + name + "'");
            }
        }
        if (map.isEmpty()) {
            return scalar();
        }
        Value cellType = inputType.valueType();
        return new TensorType(cellType, map.values());
    }

    static public TensorType rename(TensorType inputType, List<String> from, List<String> to) {
        if (from.isEmpty()) {
            throw new IllegalArgumentException("Renaming no dimensions");
        }
        if (from.size() != to.size()) {
            throw new IllegalArgumentException("Bad rename, from size "+from.size()+" != to.size "+to.size());
        }
        Map<String,Dimension> oldDims = new HashMap<>();
        for (Dimension dim : inputType.dimensions()) {
            oldDims.put(dim.name(), dim);
        }
        Map<String,Dimension> newDims = new HashMap<>();
        for (int i = 0; i < from.size(); ++i) {
            String oldName = from.get(i);
            String newName = to.get(i);
            if (oldDims.containsKey(oldName)) {
                var dim = oldDims.remove(oldName);
                newDims.put(newName, dim.withName(newName));
            } else {
                logger.log(Level.WARNING, "Renaming non-existing dimension "+oldName+" in type "+inputType);
                // throw new IllegalArgumentException("bad rename, dimension  "+oldName+" not found");
            }
        }
        for (var keep : oldDims.values()) {
            newDims.put(keep.name(), keep);
        }
        if (inputType.dimensions().size() == newDims.size()) {
            return new TensorType(inputType.valueType(), newDims.values());
        } else {
            throw new IllegalArgumentException("Bad rename, lost some dimensions");
        }
    }

    static public TensorType cell_cast(TensorType inputType, Value toCellType) {
        if (toCellType != Value.DOUBLE && inputType.dimensions().isEmpty()) {
            throw new IllegalArgumentException("Cannot cast "+inputType+" to valueType"+toCellType);
        }
        return new TensorType(toCellType, inputType.dimensions());
    }

    private static boolean firstIsBoundSecond(Dimension first, Dimension second) {
        return (first.type() == Dimension.Type.indexedBound &&
                second.type() == Dimension.Type.indexedUnbound &&
                first.name().equals(second.name()));
    }

    private static boolean firstIsSmaller(Dimension first, Dimension second) {
        return (first.type() == Dimension.Type.indexedBound &&
                second.type() == Dimension.Type.indexedBound &&
                first.name().equals(second.name()) &&
                first.size().isPresent() && second.size().isPresent() &&
                first.size().get() < second.size().get());
    }

    static public TensorType join(TensorType lhs, TensorType rhs) {
        Value cellType = Value.DOUBLE;
        if (lhs.rank() > 0 && rhs.rank() > 0) {
            // both types decide the new cell type
            cellType = Value.largestOf(lhs.valueType(), rhs.valueType());
        } else if (lhs.rank() > 0) {
            // only the tensor decide the new cell type
            cellType = lhs.valueType();
        } else if (rhs.rank() > 0) {
            // only the tensor decide the new cell type
            cellType = rhs.valueType();
        }
        // result of computation must be at least float
        cellType = Value.largestOf(cellType, Value.FLOAT);

        Map<String, Dimension> map = new HashMap<>();
        for (Dimension dim : lhs.dimensions()) {
            map.put(dim.name(), dim);
        }
        for (Dimension dim : rhs.dimensions()) {
            if (map.containsKey(dim.name())) {
                Dimension other = map.get(dim.name());
                if (! other.equals(dim)) {
                    if (firstIsBoundSecond(dim, other)) {
                        map.put(dim.name(), dim);
                    } else if (firstIsBoundSecond(other, dim)) {
                        map.put(dim.name(), other);
                    } else if (dim.isMapped() && other.isIndexed()) {
                        map.put(dim.name(), dim);  // {} and [] -> {}. Note: this is not allowed in C++
                    } else if (dim.isIndexed() && other.isMapped()) {
                        map.put(dim.name(), other);  // {} and [] -> {}. Note: this is not allowed in C++
                    } else {
                        throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
                    }
                }
            } else {
                map.put(dim.name(), dim);
            }
        }
        return new TensorType(cellType, map.values());
    }

    static public TensorType merge(TensorType lhs, TensorType rhs) {
        int sz = lhs.dimensions().size();
        boolean allOk = (rhs.dimensions().size() == sz);
        if (allOk) {
            for (int i = 0; i < sz; i++) {
                String lName = lhs.dimensions().get(i).name();
                String rName = rhs.dimensions().get(i).name();
                if (! lName.equals(rName)) {
                    allOk = false;
                }
            }
        }
        if (allOk) {
            return join(lhs, rhs);
        } else {
            throw new IllegalArgumentException("Types in merge() dimensions mismatch: "+lhs+" != "+rhs);
        }
    }

    static public TensorType concat(TensorType lhs, TensorType rhs, String concatDimension) {
        Value cellType = Value.DOUBLE;
        if (lhs.rank() > 0 && rhs.rank() > 0) {
            if (lhs.valueType() == rhs.valueType()) {
                cellType = lhs.valueType();
            } else {
                cellType = Value.largestOf(lhs.valueType(), rhs.valueType());
                // when changing cell type, make it at least float
                cellType = Value.largestOf(cellType, Value.FLOAT);
            }
        } else if (lhs.rank() > 0) {
            cellType = lhs.valueType();
        } else if (rhs.rank() > 0) {
            cellType = rhs.valueType();
        }
        Dimension first = Dimension.indexed(concatDimension, 1);
        Dimension second = Dimension.indexed(concatDimension, 1);
        Map<String, Dimension> map = new HashMap<>();
        for (Dimension dim : lhs.dimensions()) {
            if (dim.name().equals(concatDimension)) {
                first = dim;
            } else {
                map.put(dim.name(), dim);
            }
        }
        for (Dimension dim : rhs.dimensions()) {
            if (dim.name().equals(concatDimension)) {
                second = dim;
            } else if (map.containsKey(dim.name())) {
                Dimension other = map.get(dim.name());
                if (! other.equals(dim)) {
                    if (firstIsBoundSecond(dim, other)) {
                        map.put(dim.name(), other);  // [N] and [] -> []
                    } else if (firstIsBoundSecond(other, dim)) {
                        map.put(dim.name(), dim);  // [N] and [] -> []
                    } else if (firstIsSmaller(dim, other)) {
                        map.put(dim.name(), dim); // [N] and [M] -> [ min(N,M] ].
                    } else if (firstIsSmaller(other, dim)) {
                        map.put(dim.name(), other); // [N] and [M] -> [ min(N,M] ].
                    } else {
                        throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
                    }
                }
            } else {
                map.put(dim.name(), dim);
            }
        }
        if (first.type() == Dimension.Type.mapped) {
            throw new IllegalArgumentException("Bad concat dimension "+concatDimension+" in lhs: "+lhs);
        }
        if (second.type() == Dimension.Type.mapped) {
            throw new IllegalArgumentException("Bad concat dimension "+concatDimension+" in rhs: "+rhs);
        }
        if (first.type() == Dimension.Type.indexedUnbound) {
            map.put(concatDimension, first);
        } else if (second.type() == Dimension.Type.indexedUnbound) {
            map.put(concatDimension, second);
        } else {
            long concatSize = first.size().get() + second.size().get();
            map.put(concatDimension, Dimension.indexed(concatDimension, concatSize));
        }
        return new TensorType(cellType, map.values());
    }

}