diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java | 77 |
1 files changed, 29 insertions, 48 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java index 2e70811a67c..9003d263699 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java @@ -5,7 +5,8 @@ package com.yahoo.tensor.impl; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; -import static com.yahoo.tensor.impl.Convert.safe2Int; +import java.util.Arrays; + import static com.yahoo.tensor.impl.Label.toNumber; import static com.yahoo.tensor.impl.Label.fromNumber; @@ -21,7 +22,7 @@ abstract public class TensorAddressAny extends TensorAddress { @Override public String label(int i) { - return fromNumber((int)numericLabel(i)); + return fromNumber(numericLabel(i)); } public static TensorAddress of() { @@ -45,11 +46,11 @@ abstract public class TensorAddressAny extends TensorAddress { } public static TensorAddress of(String[] labels) { - int[] labelsAsInt = new int[labels.length]; + long[] labelsAsLong = new long[labels.length]; for (int i = 0; i < labels.length; i++) { - labelsAsInt[i] = toNumber(labels[i]); + labelsAsLong[i] = toNumber(labels[i]); } - return ofUnsafe(labelsAsInt); + return ofUnsafe(labelsAsLong); } public static TensorAddress of(int label) { @@ -71,80 +72,60 @@ abstract public class TensorAddressAny extends TensorAddress { public static TensorAddress of(int ... labels) { return switch (labels.length) { case 0 -> of(); - case 1 -> new TensorAddressAny1(sanitize(labels[0])); - case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1])); - case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2])); - case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3])); + case 1 -> of(labels[0]); + case 2 -> of(labels[0], labels[1]); + case 3 -> of(labels[0], labels[1], labels[2]); + case 4 -> of(labels[0], labels[1], labels[2], labels[3]); default -> { + long[] copy = new long[labels.length]; for (int i = 0; i < labels.length; i++) { - sanitize(labels[i]); + copy[i] = sanitize(labels[i]); } - yield new TensorAddressAnyN(labels); + yield new TensorAddressAnyN(copy); } }; } - public static TensorAddress of(long label) { - return of(safe2Int(label)); - } - - public static TensorAddress of(long label0, long label1) { - return of(safe2Int(label0), safe2Int(label1)); - } - - public static TensorAddress of(long label0, long label1, long label2) { - return of(safe2Int(label0), safe2Int(label1), safe2Int(label2)); - } - - public static TensorAddress of(long label0, long label1, long label2, long label3) { - return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3)); - } - public static TensorAddress of(long ... labels) { return switch (labels.length) { case 0 -> of(); - case 1 -> ofUnsafe(safe2Int(labels[0])); - case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1])); - case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2])); - case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3])); - default -> { - int[] labelsAsInt = new int[labels.length]; - for (int i = 0; i < labels.length; i++) { - labelsAsInt[i] = safe2Int(labels[i]); - } - yield of(labelsAsInt); - } + case 1 -> of(labels[0]); + case 2 -> of(labels[0], labels[1]); + case 3 -> of(labels[0], labels[1], labels[2]); + case 4 -> of(labels[0], labels[1], labels[2], labels[3]); + default -> new TensorAddressAnyN(Arrays.copyOf(labels, labels.length)); + }; } - private static TensorAddress ofUnsafe(int label) { + private static TensorAddress of(long label) { return new TensorAddressAny1(label); } - private static TensorAddress ofUnsafe(int label0, int label1) { + private static TensorAddress of(long label0, long label1) { return new TensorAddressAny2(label0, label1); } - private static TensorAddress ofUnsafe(int label0, int label1, int label2) { + public static TensorAddress of(long label0, long label1, long label2) { return new TensorAddressAny3(label0, label1, label2); } - private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) { + public static TensorAddress of(long label0, long label1, long label2, long label3) { return new TensorAddressAny4(label0, label1, label2, label3); } - public static TensorAddress ofUnsafe(int ... labels) { + public static TensorAddress ofUnsafe(long ... labels) { return switch (labels.length) { case 0 -> of(); - case 1 -> ofUnsafe(labels[0]); - case 2 -> ofUnsafe(labels[0], labels[1]); - case 3 -> ofUnsafe(labels[0], labels[1], labels[2]); - case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]); + case 1 -> of(labels[0]); + case 2 -> of(labels[0], labels[1]); + case 3 -> of(labels[0], labels[1], labels[2]); + case 4 -> of(labels[0], labels[1], labels[2], labels[3]); default -> new TensorAddressAnyN(labels); }; } - private static int sanitize(int label) { + private static long sanitize(long label) { if (label < Tensor.invalidIndex) { throw new IndexOutOfBoundsException("cell label " + label + " must be positive"); } |