aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
blob: 89847228b3137a213f2ec71e97b136d5b3bfe3f5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import com.google.common.annotations.Beta;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;

/**
 * A multidimensional array which can be used in computations.
 * <p>
 * A tensor consists of a set of <i>dimension</i> names and a set of <i>cells</i> containing scalar <i>values</i>.
 * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines
 * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer.
 * Any dimension in an address may be assigned the special label "undefined", represented in string form as "-".
 * <p>
 * The size of the set of dimensions of a tensor is called its <i>order</i>.
 * <p>
 * In contrast to regular mathematical formulations of tensors, this definition of a tensor allows <i>sparseness</i>
 * as there is no built-in notion of a contiguous space, and even in cases where a space is implied (such as when
 * address labels are integers), there is no requirement that every implied cell has a defined value.
 * Undefined values have no define representation as they are never observed.
 * <p>
 * Tensors can be read and serialized to and from a string form documented in the {@link #toString} method.
 *
 * @author bratseth
 */
@Beta
public interface Tensor {

    /**
     * Returns the immutable set of dimensions of this tensor.
     * The size of this set is the tensor's <i>order</i>.
     */
    Set<String> dimensions();

    /** Returns an immutable map of the cells of this */
    Map<TensorAddress, Double> cells();

    /** Returns the value of a cell, or NaN if this cell does not exist/have no value */
    double get(TensorAddress address);

    /**
     * Returns the <i>sparse tensor product</i> of this tensor and the argument tensor.
     * This is the all-to-all combinations of cells in the argument tenors, except the combinations
     * which have conflicting labels for the same dimension. The value of each combination is the product
     * of the values of the two input cells. The dimensions of the tensor product is the set union of the
     * dimensions of the argument tensors.
     * <p>
     * If there are no overlapping dimensions this is the regular tensor product.
     * If the two tensors have exactly the same dimensions this is the Hadamard product.
     * <p>
     * The sparse tensor product is associative and commutative.
     *
     * @param argument the tensor to multiply by this
     * @return the resulting tensor.
     */
    default Tensor multiply(Tensor argument) {
        return new TensorProduct(this, argument).result();
    }

    /**
     * Returns the <i>match product</i> of two tensors.
     * This returns a tensor which contains the <i>matching</i> cells in the two tensors, with their
     * values multiplied.
     * <p>
     * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
     * and have the value undefined for any non-shared dimension.
     * <p>
     * The dimensions of the resulting tensor is the set intersection of the two argument tensors.
     * <p>
     * If the two tensors have exactly the same dimensions, this is the Hadamard product.
     */
    default Tensor match(Tensor argument) {
        return new MatchProduct(this, argument).result();
    }

    /**
     * Returns a tensor which contains the cells of both argument tensors, where the value for
     * any <i>matching</i> cell is the min of the two possible values.
     * <p>
     * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
     * and have the value undefined for any non-shared dimension.
     */
    default Tensor min(Tensor argument) {
        return new TensorMin(this, argument).result();
    }

    /**
     * Returns a tensor which contains the cells of both argument tensors, where the value for
     * any <i>matching</i> cell is the max of the two possible values.
     * <p>
     * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
     * and have the value undefined for any non-shared dimension.
     */
    default Tensor max(Tensor argument) {
        return new TensorMax(this, argument).result();
    }

    /**
     * Returns a tensor which contains the cells of both argument tensors, where the value for
     * any <i>matching</i> cell is the sum of the two possible values.
     * <p>
     * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
     * and have the value undefined for any non-shared dimension.
     */
    default Tensor add(Tensor argument) {
        return new TensorSum(this, argument).result();
    }

    /**
     * Returns a tensor which contains the cells of both argument tensors, where the value for
     * any <i>matching</i> cell is the difference of the two possible values.
     * <p>
     * Two cells are matching if they have the same labels for all dimensions shared between the two argument tensors,
     * and have the value undefined for any non-shared dimension.
     */
    default Tensor subtract(Tensor argument) {
        return new TensorDifference(this, argument).result();
    }

    /**
     * Returns a tensor with the same cells as this and the given function is applied to all its cell values.
     *
     * @param function the function to apply to all cells
     * @return the tensor with the function applied to all the cells of this
     */
    default Tensor apply(UnaryOperator<Double> function) {
        return new TensorFunction(this, function).result();
    }

    /**
     * Returns a tensor with the given dimension removed and cells which contains the sum of the values
     * in the removed dimension.
     */
    default Tensor sum(String dimension) {
        return new TensorDimensionSum(dimension, this).result();
    }

    /**
     * Returns the sum of all the cells of this tensor.
     */
    default double sum() {
        double sum = 0;
        for (Map.Entry<TensorAddress, Double> cell : cells().entrySet())
            sum += cell.getValue();
        return sum;
    }

    /**
     * Returns true if the given tensor is mathematically equal to this:
     * Both are of type Tensor and have the same content.
     */
    @Override
    boolean equals(Object o);

    /** Returns true if the two given tensors are mathematically equivalent, that is whether both have the same content */
    static boolean equals(Tensor a, Tensor b) {
        if (a == b) return true;
        if ( ! a.dimensions().equals(b.dimensions())) return false;
        if ( ! a.cells().equals(b.cells())) return false;
        return true;
    }

    /**
     * Returns this tensor on the form
     * <code>{address1:value1,address2:value2,...}</code>
     * where each address is on the form <code>{dimension1:label1,dimension2:label2,...}</code>,
     * and values are numbers.
     * <p>
     * Cells are listed in the natural order of tensor addresses: Increasing size primarily
     * and by element lexical order secondarily.
     * <p>
     * Note that while this is suggestive of JSON, it is not JSON.
     */
    @Override
    String toString();

    /** Returns a tensor instance containing the given data on the standard string format returned by toString */
    static Tensor from(String tensorString) {
        return MapTensor.from(tensorString);
    }

    /**
     * Returns a tensor instance containing the given data on the standard string format returned by toString
     *
     * @param tensorType the type of the tensor to return, as a string on the tensor type format, given in
     *        {@link TensorType#fromSpec}
     * @param tensorString the tensor on the standard tensor string format
     */
    static Tensor from(String tensorType, String tensorString) {
        TensorType.fromSpec(tensorType); // Just validate type spec for now, as we only have one, generic implementation
        return MapTensor.from(tensorString);
    }

    /**
     * Call this from toString in implementations to return the standard string format.
     * (toString cannot be a default method because default methods cannot override super methods).
     *
     * @param tensor the tensor to return the standard string format of
     * @return the tensor on the standard string format
     */
    static String toStandardString(Tensor tensor) {
        Set<String> emptyDimensions = emptyDimensions(tensor);
        if (emptyDimensions.size() > 0) // explicitly list empty dimensions
            return "( " + unitTensorWithDimensions(emptyDimensions) + " * " + contentToString(tensor) + " )";
        else
            return contentToString(tensor);
    }

    static String contentToString(Tensor tensor) {
        List<Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
        Collections.sort(cellEntries, Map.Entry.<TensorAddress, Double>comparingByKey());

        StringBuilder b = new StringBuilder("{");
        for (Map.Entry<TensorAddress, Double> cell : cellEntries) {
            b.append(cell.getKey()).append(":").append(cell.getValue());
            b.append(",");
        }
        if (b.length() > 1)
            b.setLength(b.length() - 1);
        b.append("}");
        return b.toString();
    }

    /**
     * Returns the dimensions of this which have no values.
     * This is a possibly empty subset of the dimensions of this tensor.
     */
    static Set<String> emptyDimensions(Tensor tensor) {
        Set<String> emptyDimensions = new HashSet<>(tensor.dimensions());
        for (TensorAddress address : tensor.cells().keySet())
            emptyDimensions.removeAll(address.dimensions());
        return emptyDimensions;
    }

    static String unitTensorWithDimensions(Set<String> dimensions) {
        return new MapTensor(Collections.singletonMap(TensorAddress.emptyWithDimensions(dimensions), 1.0)).toString();
    }

}