summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
blob: da643d8c173acde2c6f222e3db66a675e2bac760 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import com.yahoo.tensor.impl.Label;

/**
 * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
 * dimensions.
 *
 * @author bratseth
 */
// Implementation notes:
// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
//   We also avoid non-essential error checking.
// - We can add support for string labels later without breaking the API
public class PartialAddress {

    // Two arrays which contains corresponding dimension:label pairs.
    // The sizes of these are always equal.
    private final String[] dimensionNames;
    private final long[] labels;

    private PartialAddress(Builder builder) {
        this.dimensionNames = builder.dimensionNames;
        this.labels = builder.labels;
        builder.dimensionNames = null; // invalidate builder to safely take over array ownership
        builder.labels = null;
    }

    public String dimension(int i) {
        return dimensionNames[i];
    }

    /** Returns the numeric label of this dimension, or -1 if no label is specified for it */
    public long numericLabel(String dimensionName) {
        for (int i = 0; i < dimensionNames.length; i++)
            if (dimensionNames[i].equals(dimensionName))
                return labels[i];
        return Tensor.INVALID_INDEX;
    }

    /** Returns the label of this dimension, or null if no label is specified for it */
    public String label(String dimensionName) {
        for (int i = 0; i < dimensionNames.length; i++)
            if (dimensionNames[i].equals(dimensionName))
                return Label.fromNumber(labels[i]);
        return null;
    }

    /**
     * Returns the label at position i
     *
     * @throws IllegalArgumentException if i is out of bounds
     */
    public String label(int i) {
        if (i >= size())
            throw new IllegalArgumentException("No label at position " + i + " in " + this);
        return Label.fromNumber(labels[i]);
    }

    public int size() { return dimensionNames.length; }

    /** Returns this as an address in the given tensor type */
    // We need the type here not just for validation but because this must map to the dimension order given by the type
    public TensorAddress asAddress(TensorType type) {
        if (type.rank() != size())
            throw new IllegalArgumentException(type + " has a different rank than " + this);
        long[] numericLabels = new long[labels.length];
        for (int i = 0; i < type.dimensions().size(); i++) {
            long label = numericLabel(type.dimensions().get(i).name());
            if (label == Tensor.INVALID_INDEX)
                throw new IllegalArgumentException(type + " dimension names does not match " + this);
            numericLabels[i] = label;
        }
        return TensorAddress.of(numericLabels);
    }

    @Override
    public String toString() {
        StringBuilder b = new StringBuilder("Partial address {");
        for (int i = 0; i < dimensionNames.length; i++)
            b.append(dimensionNames[i]).append(":").append(label(i)).append(", ");
        if (size() > 0)
            b.setLength(b.length() - 2);
        return b.toString();
    }

    public static class Builder {

        private String[] dimensionNames;
        private long[] labels;
        private int index = 0;

        public Builder(int size) {
            dimensionNames = new String[size];
            labels = new long[size];
        }

        public Builder add(String dimensionName, long label) {
            dimensionNames[index] = dimensionName;
            labels[index] = label;
            index++;
            return this;
        }

        public Builder add(String dimensionName, String label) {
            dimensionNames[index] = dimensionName;
            labels[index] = Label.toNumber(label);
            index++;
            return this;
        }

        public PartialAddress build() {
            return new PartialAddress(this);
        }

    }

}