summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java98
1 files changed, 91 insertions, 7 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index 9c41d5aad68..4eca9c47402 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.Arrays;
+
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
* dimensions.
@@ -13,10 +15,10 @@ package com.yahoo.tensor;
// - We can add support for string labels later without breaking the API
public class PartialAddress {
- // Two arrays which contains corresponding dimension=label pairs.
+ // Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final long[] labels;
+ private final Object[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -25,23 +27,99 @@ public class PartialAddress {
builder.labels = null;
}
- /** Returns the int label of this dimension, or -1 if no label is specified for it */
- long numericLabel(String dimensionName) {
+ public String dimension(int i) {
+ return dimensionNames[i];
+ }
+
+ /** Returns the numeric label of this dimension, or -1 if no label is specified for it */
+ public long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return labels[i];
+ return asLong(labels[i]);
return -1;
}
+ /** Returns the label of this dimension, or null if no label is specified for it */
+ public String label(String dimensionName) {
+ for (int i = 0; i < dimensionNames.length; i++)
+ if (dimensionNames[i].equals(dimensionName))
+ return labels[i].toString();
+ return null;
+ }
+
+ /**
+ * Returns the label at position i
+ *
+ * @throws IllegalArgumentException if i is out of bounds
+ */
+ public String label(int i) {
+ if (i >= size())
+ throw new IllegalArgumentException("No label at position " + i + " in " + this);
+ return labels[i].toString();
+ }
+
+ public int size() { return dimensionNames.length; }
+
+ /** Returns this as an address in the given tensor type */
+ // We need the type here not just for validation but because this must map to the dimension order given by the type
+ public TensorAddress asAddress(TensorType type) {
+ if (type.rank() != size())
+ throw new IllegalArgumentException(type + " has a different rank than " + this);
+ if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) {
+ long[] numericLabels = new long[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ long label = numericLabel(type.dimensions().get(i).name());
+ if (label < 0)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ numericLabels[i] = label;
+ }
+ return TensorAddress.of(numericLabels);
+ }
+ else {
+ String[] stringLabels = new String[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ String label = label(type.dimensions().get(i).name());
+ if (label == null)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ stringLabels[i] = label;
+ }
+ return TensorAddress.of(stringLabels);
+ }
+ }
+
+ private long asLong(Object label) {
+ if (label instanceof Long) {
+ return (Long) label;
+ }
+ else {
+ try {
+ return Long.parseLong(label.toString());
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Label '" + label + "' is not numeric");
+ }
+ }
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder b = new StringBuilder("Partial address {");
+ for (int i = 0; i < dimensionNames.length; i++)
+ b.append(dimensionNames[i]).append(":").append(label(i)).append(", ");
+ if (size() > 0)
+ b.setLength(b.length() - 2);
+ return b.toString();
+ }
+
public static class Builder {
private String[] dimensionNames;
- private long[] labels;
+ private Object[] labels;
private int index = 0;
public Builder(int size) {
dimensionNames = new String[size];
- labels = new long[size];
+ labels = new Object[size];
}
public void add(String dimensionName, long label) {
@@ -50,6 +128,12 @@ public class PartialAddress {
index++;
}
+ public void add(String dimensionName, String label) {
+ dimensionNames[index] = dimensionName;
+ labels[index] = label;
+ index++;
+ }
+
public PartialAddress build() {
return new PartialAddress(this);
}