diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java | 60 |
1 files changed, 16 insertions, 44 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 3e41e6d94eb..da643d8c173 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,9 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.tensor.impl.StringTensorAddress; - -import java.util.Arrays; +import com.yahoo.tensor.impl.Label; /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors @@ -20,7 +18,7 @@ public class PartialAddress { // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final Object[] labels; + private final long[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -37,15 +35,15 @@ public class PartialAddress { public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return asLong(labels[i]); - return -1; + return labels[i]; + return Tensor.INVALID_INDEX; } /** Returns the label of this dimension, or null if no label is specified for it */ public String label(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i].toString(); + return Label.fromNumber(labels[i]); return null; } @@ -57,7 +55,7 @@ public class PartialAddress { public String label(int i) { if (i >= size()) throw new IllegalArgumentException("No label at position " + i + " in " + this); - return labels[i].toString(); + return Label.fromNumber(labels[i]); } public int size() { return dimensionNames.length; } @@ -67,40 +65,14 @@ public class PartialAddress { public TensorAddress asAddress(TensorType type) { if (type.rank() != size()) throw new IllegalArgumentException(type + " has a different rank than " + this); - if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { - long[] numericLabels = new long[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - long label = numericLabel(type.dimensions().get(i).name()); - if (label < 0) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - numericLabels[i] = label; - } - return TensorAddress.of(numericLabels); - } - else { - String[] stringLabels = new String[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - String label = label(type.dimensions().get(i).name()); - if (label == null) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - stringLabels[i] = label; - } - return StringTensorAddress.unsafeOf(stringLabels); - } - } - - private long asLong(Object label) { - if (label instanceof Long) { - return (Long) label; - } - else { - try { - return Long.parseLong(label.toString()); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Label '" + label + "' is not numeric"); - } + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label == Tensor.INVALID_INDEX) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; } + return TensorAddress.of(numericLabels); } @Override @@ -116,12 +88,12 @@ public class PartialAddress { public static class Builder { private String[] dimensionNames; - private Object[] labels; + private long[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new Object[size]; + labels = new long[size]; } public Builder add(String dimensionName, long label) { @@ -133,7 +105,7 @@ public class PartialAddress { public Builder add(String dimensionName, String label) { dimensionNames[index] = dimensionName; - labels[index] = label; + labels[index] = Label.toNumber(label); index++; return this; } |