summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
blob: cc40f84ccd3347ae8e8203e6aa8a46920285cf4d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import com.google.common.annotations.Beta;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.L1Normalize;
import com.yahoo.tensor.functions.L2Normalize;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;

/**
 * A multidimensional array which can be used in computations.
 * <p>
 * A tensor consists of a set of <i>dimension</i> names and a set of <i>cells</i> containing scalar <i>values</i>.
 * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines
 * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer.
 * <p>
 * The size of the set of dimensions of a tensor is called its <i>order</i>.
 * <p>
 * In contrast to regular mathematical formulations of tensors, this definition of a tensor allows <i>sparseness</i>
 * as there is no built-in notion of a contiguous space, and even in cases where a space is implied (such as when
 * address labels are integers), there is no requirement that every implied cell has a defined value.
 * Undefined values have no define representation as they are never observed.
 * <p>
 * Tensors can be read and serialized to and from a string form documented in the {@link #toString} method.
 *
 * @author bratseth
 */
@Beta
public interface Tensor {

    // ----------------- Accessors
    
    TensorType type();

    /** Returns whether this have any cells */
    default boolean isEmpty() { return size() == 0; }

    /** Returns the number of cells in this */
    int size();

    /** Returns the value of a cell, or NaN if this cell does not exist/have no value */
    double get(TensorAddress address);

    Iterator<Map.Entry<TensorAddress, Double>> cellIterator();

    /** Returns an immutable map of the cells of this. This may be expensive for some implementations - avoid when possible */
    Map<TensorAddress, Double> cells();

    /** 
     * Returns the value of this as a double if it has no dimensions and one value
     *
     * @throws IllegalStateException if this does not have zero dimensions and one value
     */
    default double asDouble() {
        if (type().dimensions().size() > 0)
            throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size());
        if (size() == 0) return Double.NaN;
        if (size() > 1)
            throw new IllegalStateException("This tensor does not have a single value, it has " + size());
        return cellIterator().next().getValue();
    }
    
    // ----------------- Primitive tensor functions
    
    default Tensor map(DoubleUnaryOperator mapper) {
        return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
    }

    /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
    default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) {
        return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate();
    }
    /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
    default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
        return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate();
    }

    default Tensor join(Tensor argument, DoubleBinaryOperator combinator) {
        return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate();
    }

    default Tensor rename(String fromDimension, String toDimension) {
        return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), 
                                                    Collections.singletonList(toDimension)).evaluate();
    }

    default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
        return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
    }
    
    static Tensor from(TensorType type, Function<List<Integer>, Double> valueSupplier) {
        return new Generate(type, valueSupplier).evaluate();
    }
    
    // ----------------- Composite tensor functions which have a defined primitive mapping
    
    default Tensor l1Normalize(String dimension) {
        return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
    }

    default Tensor l2Normalize(String dimension) {
        return new L2Normalize(new ConstantTensor(this), dimension).evaluate();
    }

    default Tensor matmul(Tensor argument, String dimension) {
        return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
    }

    default Tensor softmax(String dimension) {
        return new Softmax(new ConstantTensor(this), dimension).evaluate();
    }

    // ----------------- Composite tensor functions mapped to primitives here on the fly

    default Tensor multiply(Tensor argument) { return join(argument, (a, b) -> (a * b )); }
    default Tensor add(Tensor argument) { return join(argument, (a, b) -> (a + b )); }
    default Tensor divide(Tensor argument) { return join(argument, (a, b) -> (a / b )); }
    default Tensor subtract(Tensor argument) { return join(argument, (a, b) -> (a - b )); }
    default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); }
    default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); }
    default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); }
    default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); }
    default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); }
    default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); }
    default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); }
    default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); }
    default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); }

    default Tensor avg(List<String> dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); }
    default Tensor count(List<String> dimensions) { return reduce(Reduce.Aggregator.count, dimensions); }
    default Tensor max(List<String> dimensions) { return reduce(Reduce.Aggregator.max, dimensions); }
    default Tensor min(List<String> dimensions) { return reduce(Reduce.Aggregator.min, dimensions); }
    default Tensor prod(List<String> dimensions) { return reduce(Reduce.Aggregator.prod, dimensions); }
    default Tensor sum(List<String> dimensions) { return reduce(Reduce.Aggregator.sum, dimensions); }

    // ----------------- serialization

    /**
     * Returns this tensor on the form
     * <code>{address1:value1,address2:value2,...}</code>
     * where each address is on the form <code>{dimension1:label1,dimension2:label2,...}</code>,
     * and values are numbers.
     * <p>
     * Cells are listed in the natural order of tensor addresses: Increasing size primarily
     * and by element lexical order secondarily.
     * <p>
     * Note that while this is suggestive of JSON, it is not JSON.
     */
    @Override
    String toString();

    /**
     * Call this from toString in implementations to return the standard string format.
     * (toString cannot be a default method because default methods cannot override super methods).
     *
     * @param tensor the tensor to return the standard string format of
     * @return the tensor on the standard string format
     */
    static String toStandardString(Tensor tensor) {
        if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that?
            return tensor.type() + ":" + contentToString(tensor);
        else
            return contentToString(tensor);
    }

    static String contentToString(Tensor tensor) {
        List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
        if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number
            if (cellEntries.isEmpty()) return "{}";
            double value = cellEntries.get(0).getValue();
            return value == 0.0 ? "{}" : "{" + value +"}";
        }
        
        Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());

        StringBuilder b = new StringBuilder("{");
        for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) {
            b.append(cell.getKey().toString(tensor.type())).append(":").append(cell.getValue());
            b.append(",");
        }
        if (b.length() > 1)
            b.setLength(b.length() - 1);
        b.append("}");
        return b.toString();
    }

    // ----------------- equality

    /**
     * Returns true if the given tensor is mathematically equal to this:
     * Both are of type Tensor and have the same content.
     */
    @Override
    boolean equals(Object o);

    /** Returns true if the two given tensors are mathematically equivalent, that is whether both have the same content */
    static boolean equals(Tensor a, Tensor b) {
        if (a == b) return true;
        if ( ! a.cells().equals(b.cells())) return false;
        return true;
    }

    // ----------------- Factories

    /**
     * Returns a tensor instance containing the given data on the standard string format returned by toString
     *
     * @param type the type of the tensor to return
     * @param tensorString the tensor on the standard tensor string format
     */
    static Tensor from(TensorType type, String tensorString) {
        return TensorParser.tensorFrom(tensorString, Optional.of(type));
    }

    /**
     * Returns a tensor instance containing the given data on the standard string format returned by toString
     *
     * @param tensorType the type of the tensor to return, as a string on the tensor type format, given in
     *        {@link TensorType#fromSpec}
     * @param tensorString the tensor on the standard tensor string format
     */
    static Tensor from(String tensorType, String tensorString) {
        return TensorParser.tensorFrom(tensorString, Optional.of(TensorType.fromSpec(tensorType)));
    }

    /**
     * Returns a tensor instance containing the given data on the standard string format returned by toString.
     * If a type is not specified it is derived from the first cell of the tensor
     */
    static Tensor from(String tensorString) {
        return TensorParser.tensorFrom(tensorString, Optional.empty());
    }
    
    interface Builder {
        
        /** Creates a suitable builder for the given type */
        static Builder of(TensorType type) {
            boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
            boolean containsMapped = type.dimensions().stream().anyMatch( d ->  ! d.isIndexed());
            if (containsIndexed && containsMapped)
                throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
            if (containsMapped)
                return MappedTensor.Builder.of(type);
            else // indexed or empty
                return IndexedTensor.Builder.of(type);
        }
        
        /** Returns the type this is building */
        TensorType type();
        
        /** Return a cell builder */
        CellBuilder cell();

        /** Add a cell */
        Builder cell(TensorAddress address, double value);
        
        /** Add a cell */
        Builder cell(double value, int ... labels);

        Tensor build();

        class CellBuilder {

            private final TensorAddress.Builder addressBuilder;
            private final Tensor.Builder tensorBuilder;
            
            CellBuilder(TensorType type, Tensor.Builder tensorBuilder) {
                addressBuilder = new TensorAddress.Builder(type);
                this.tensorBuilder = tensorBuilder;
            }

            public CellBuilder label(String dimension, String label) {
                addressBuilder.add(dimension, label);
                return this;
            }

            public CellBuilder label(String dimension, int label) {
                return label(dimension, String.valueOf(label));
            }

            public Builder value(double cellValue) {
                return tensorBuilder.cell(addressBuilder.build(), cellValue);
            }

        }

    }
    
}