diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java | 98 |
1 files changed, 91 insertions, 7 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 9c41d5aad68..4eca9c47402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.Arrays; + /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors * dimensions. @@ -13,10 +15,10 @@ package com.yahoo.tensor; // - We can add support for string labels later without breaking the API public class PartialAddress { - // Two arrays which contains corresponding dimension=label pairs. + // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final long[] labels; + private final Object[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -25,23 +27,99 @@ public class PartialAddress { builder.labels = null; } - /** Returns the int label of this dimension, or -1 if no label is specified for it */ - long numericLabel(String dimensionName) { + public String dimension(int i) { + return dimensionNames[i]; + } + + /** Returns the numeric label of this dimension, or -1 if no label is specified for it */ + public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i]; + return asLong(labels[i]); return -1; } + /** Returns the label of this dimension, or null if no label is specified for it */ + public String label(String dimensionName) { + for (int i = 0; i < dimensionNames.length; i++) + if (dimensionNames[i].equals(dimensionName)) + return labels[i].toString(); + return null; + } + + /** + * Returns the label at position i + * + * @throws IllegalArgumentException if i is out of bounds + */ + public String label(int i) { + if (i >= size()) + throw new IllegalArgumentException("No label at position " + i + " in " + this); + return labels[i].toString(); + } + + public int size() { return dimensionNames.length; } + + /** Returns this as an address in the given tensor type */ + // We need the type here not just for validation but because this must map to the dimension order given by the type + public TensorAddress asAddress(TensorType type) { + if (type.rank() != size()) + throw new IllegalArgumentException(type + " has a different rank than " + this); + if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label < 0) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; + } + return TensorAddress.of(numericLabels); + } + else { + String[] stringLabels = new String[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + String label = label(type.dimensions().get(i).name()); + if (label == null) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + stringLabels[i] = label; + } + return TensorAddress.of(stringLabels); + } + } + + private long asLong(Object label) { + if (label instanceof Long) { + return (Long) label; + } + else { + try { + return Long.parseLong(label.toString()); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Label '" + label + "' is not numeric"); + } + } + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("Partial address {"); + for (int i = 0; i < dimensionNames.length; i++) + b.append(dimensionNames[i]).append(":").append(label(i)).append(", "); + if (size() > 0) + b.setLength(b.length() - 2); + return b.toString(); + } + public static class Builder { private String[] dimensionNames; - private long[] labels; + private Object[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new long[size]; + labels = new Object[size]; } public void add(String dimensionName, long label) { @@ -50,6 +128,12 @@ public class PartialAddress { index++; } + public void add(String dimensionName, String label) { + dimensionNames[index] = dimensionName; + labels[index] = label; + index++; + } + public PartialAddress build() { return new PartialAddress(this); } |