summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
blob: 59a5e2a49b15ed04be51133d5c18d0a4626c6a09 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import com.yahoo.tensor.impl.Label;
import com.yahoo.tensor.impl.TensorAddressAny;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
 * An immutable address to a tensor cell. This simply supplies a value to each dimension
 * in a particular tensor type. By itself it is just a list of cell labels, it's meaning depends on its accompanying type.
 *
 * @author bratseth
 */
public abstract class TensorAddress implements Comparable<TensorAddress> {

    public static TensorAddress of(String[] labels) {
        return TensorAddressAny.of(labels);
    }

    public static TensorAddress ofLabels(String... labels) {
        return TensorAddressAny.of(labels);
    }

    public static TensorAddress of(long... labels) {
        return TensorAddressAny.of(labels);
    }

    public static TensorAddress of(int... labels) {
        return TensorAddressAny.of(labels);
    }

    /**
     * Returns the number of labels in this
     */
    public abstract int size();

    /**
     * Returns the i'th label in this
     *
     * @throws IllegalArgumentException if there is no label at this index
     */
    public abstract String label(int i);

    /**
     * Returns the i'th label in this as a long.
     * Prefer this if you know that this is a numeric address, but not otherwise.
     *
     * @throws IllegalArgumentException if there is no label at this index
     */
    public abstract long numericLabel(int i);

    public abstract TensorAddress withLabel(int labelIndex, long label);

    public final boolean isEmpty() { return size() == 0; }

    @Override
    public int compareTo(TensorAddress other) {
        // TODO: Formal issue (only): Ordering with different address sizes
        for (int i = 0; i < size(); i++) {
            int elementComparison = this.label(i).compareTo(other.label(i));
            if (elementComparison != 0) return elementComparison;
        }
        return 0;
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder("cell address (");
        int sz = size();
        if (sz > 0) {
            sb.append(label(0));
            for (int i = 1; i < sz; i++) {
                sb.append(',').append(label(i));
            }
        }

        return sb.append(')').toString();
    }

    /**
     * Returns this as a string on the appropriate form given the type
     */
    public final String toString(TensorType type) {
        StringBuilder b = new StringBuilder("{");
        for (int i = 0; i < size(); i++) {
            b.append(type.dimensions().get(i).name()).append(":").append(labelToString(label(i)));
            b.append(",");
        }
        if (b.length() > 1)
            b.setLength(b.length() - 1);
        b.append("}");
        return b.toString();
    }

    /**
     * Returns a label as a string with appropriate quoting/escaping when necessary
     */
    public static String labelToString(String label) {
        if (TensorType.labelMatcher.matches(label)) return label; // no quoting
        if (label.contains("'")) return "\"" + label + "\"";
        return "'" + label + "'";
    }

    /** Returns an address with only some of the dimension */
    public TensorAddress partialCopy(int[] indexMap) {
        int[] labels = new int[indexMap.length];
        for (int i = 0; i < labels.length; ++i) {
            labels[i] = (int)numericLabel(indexMap[i]);
        }
        return TensorAddressAny.ofUnsafe(labels);
    }

    /** Creates a complete address by taking the sparse dimmensions from this and the indexed from the densePart */
    public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int [] densePart) {
        int [] labels = new int[dimensions.size()];
        int mappedIndex = 0;
        int indexedIndex = 0;
        for (int i = 0; i < labels.length; i++) {
            TensorType.Dimension d = dimensions.get(i);
            if (d.isIndexed()) {
                labels[i] = densePart[indexedIndex];
                indexedIndex++;
            } else {
                labels[i] = (int)numericLabel(mappedIndex);
                mappedIndex++;
            }
        }
        return TensorAddressAny.ofUnsafe(labels);
    }

    /** Extracts the sparse(non-indexed) dimensions of the address */
    public  TensorAddress sparsePartialAddress(TensorType sparseType, List<TensorType.Dimension> dimensions) {
        if (dimensions.size() != size())
            throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this);
        TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
        for (int i = 0; i < dimensions.size(); ++i) {
            TensorType.Dimension dimension = dimensions.get(i);
            if ( ! dimension.isIndexed())
                builder.add(dimension.name(), (int)numericLabel(i));
        }
        return builder.build();
    }

    /** Builder of a tensor address */
    public static class Builder {

        final TensorType type;
        final int[] labels;

        private static int [] createEmptyLabels(int size) {
            int [] labels = new int[size];
            Arrays.fill(labels, Tensor.INVALID_INDEX);
            return labels;
        }

        public Builder(TensorType type) {
            this(type, createEmptyLabels(type.dimensions().size()));
        }

        private Builder(TensorType type, int[] labels) {
            this.type = type;
            this.labels = labels;
        }

        /**
         * Adds the label to the only mapped dimension of this.
         *
         * @throws IllegalArgumentException if this does not have exactly one dimension
         */
        public Builder add(String label) {
            var mappedSubtype = type.mappedSubtype();
            if (mappedSubtype.rank() != 1)
                throw new IllegalArgumentException("Cannot add a label without explicit dimension to a tensor of type " +
                                                   type + ": Must have exactly one sparse dimension");
            add(mappedSubtype.dimensions().get(0).name(), label);
            return this;
        }

        /**
         * Adds a label in a dimension to this.
         *
         * @return this for convenience
         */
        public Builder add(String dimension, String label) {
            Objects.requireNonNull(dimension, "dimension cannot be null");
            Objects.requireNonNull(label, "label cannot be null");
            int labelIndex = type.indexOfDimensionAsInt(dimension);
            if ( labelIndex < 0)
                throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
            labels[labelIndex] = Label.toNumber(label);
            return this;
        }
        public Builder add(String dimension, int label) {
            Objects.requireNonNull(dimension, "dimension cannot be null");
            int labelIndex = type.indexOfDimensionAsInt(dimension);
            if ( labelIndex < 0)
                throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
            labels[labelIndex] = label;
            return this;
        }

        /** Creates a copy of this which can be modified separately */
        public Builder copy() {
            return new Builder(type, Arrays.copyOf(labels, labels.length));
        }

        /** Returns the type of the tensor this address is being built for. */
        public TensorType type() { return type; }

        void validate() {
            for (int i = 0; i < labels.length; i++)
                if (labels[i] == Tensor.INVALID_INDEX)
                    throw new IllegalArgumentException("Missing a label for dimension '" +
                                                       type.dimensions().get(i).name() + "' for " + type);
        }

        public TensorAddress build() {
            validate();
            return TensorAddressAny.ofUnsafe(labels);
        }

    }

    /** Builder of an address to a subset of the dimensions of a tensor type */
    public static class PartialBuilder extends Builder {

        public PartialBuilder(TensorType type) {
            super(type);
        }

        private PartialBuilder(TensorType type, int[] labels) {
            super(type, labels);
        }

        /** Creates a copy of this which can be modified separately */
        public Builder copy() {
            return new PartialBuilder(type, Arrays.copyOf(labels, labels.length));
        }

        @Override
        void validate() { }

    }

}