blob: da643d8c173acde2c6f222e3db66a675e2bac760 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
import com.yahoo.tensor.impl.Label;
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
* dimensions.
*
* @author bratseth
*/
// Implementation notes:
// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
// We also avoid non-essential error checking.
// - We can add support for string labels later without breaking the API
public class PartialAddress {
// Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
private final long[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
this.labels = builder.labels;
builder.dimensionNames = null; // invalidate builder to safely take over array ownership
builder.labels = null;
}
public String dimension(int i) {
return dimensionNames[i];
}
/** Returns the numeric label of this dimension, or -1 if no label is specified for it */
public long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
return labels[i];
return Tensor.INVALID_INDEX;
}
/** Returns the label of this dimension, or null if no label is specified for it */
public String label(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
return Label.fromNumber(labels[i]);
return null;
}
/**
* Returns the label at position i
*
* @throws IllegalArgumentException if i is out of bounds
*/
public String label(int i) {
if (i >= size())
throw new IllegalArgumentException("No label at position " + i + " in " + this);
return Label.fromNumber(labels[i]);
}
public int size() { return dimensionNames.length; }
/** Returns this as an address in the given tensor type */
// We need the type here not just for validation but because this must map to the dimension order given by the type
public TensorAddress asAddress(TensorType type) {
if (type.rank() != size())
throw new IllegalArgumentException(type + " has a different rank than " + this);
long[] numericLabels = new long[labels.length];
for (int i = 0; i < type.dimensions().size(); i++) {
long label = numericLabel(type.dimensions().get(i).name());
if (label == Tensor.INVALID_INDEX)
throw new IllegalArgumentException(type + " dimension names does not match " + this);
numericLabels[i] = label;
}
return TensorAddress.of(numericLabels);
}
@Override
public String toString() {
StringBuilder b = new StringBuilder("Partial address {");
for (int i = 0; i < dimensionNames.length; i++)
b.append(dimensionNames[i]).append(":").append(label(i)).append(", ");
if (size() > 0)
b.setLength(b.length() - 2);
return b.toString();
}
public static class Builder {
private String[] dimensionNames;
private long[] labels;
private int index = 0;
public Builder(int size) {
dimensionNames = new String[size];
labels = new long[size];
}
public Builder add(String dimensionName, long label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
return this;
}
public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
labels[index] = Label.toNumber(label);
index++;
return this;
}
public PartialAddress build() {
return new PartialAddress(this);
}
}
}
|