aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java136
1 files changed, 136 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
new file mode 100644
index 00000000000..31863c99a74
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
@@ -0,0 +1,136 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import static com.yahoo.tensor.impl.Convert.safe2Int;
+import static com.yahoo.tensor.impl.Label.toNumber;
+import static com.yahoo.tensor.impl.Label.fromNumber;
+
+/**
+ * Parent of tensor address family centered around each dimension as int.
+ * A positive number represents a numeric index usable as a direect addressing.
+ * - 1 is representing an invalid/null address
+ * Other negative numbers are an enumeration maintained in {@link Label}
+ *
+ * @author baldersheim
+ */
+abstract public class TensorAddressAny extends TensorAddress {
+ @Override
+ public String label(int i) {
+ return fromNumber((int)numericLabel(i));
+ }
+
+ public static TensorAddress of() {
+ return TensorAddressEmpty.empty;
+ }
+ public static TensorAddress of(String label) {
+ return new TensorAddressAny1(toNumber(label));
+ }
+ public static TensorAddress of(String label0, String label1) {
+ return new TensorAddressAny2(toNumber(label0), toNumber(label1));
+ }
+ public static TensorAddress of(String label0, String label1, String label2) {
+ return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2));
+ }
+ public static TensorAddress of(String label0, String label1, String label2, String label3) {
+ return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3));
+ }
+ public static TensorAddress of(String [] labels) {
+ int [] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = toNumber(labels[i]);
+ }
+ return ofUnsafe(labelsAsInt);
+ }
+ public static TensorAddress of(int label) {
+ return new TensorAddressAny1(sanitize(label));
+ }
+ public static TensorAddress of(int label0, int label1) {
+ return new TensorAddressAny2(sanitize(label0), sanitize(label1));
+ }
+ public static TensorAddress of(int label0, int label1, int label2) {
+ return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2));
+ }
+ public static TensorAddress of(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3));
+ }
+ public static TensorAddress of(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> new TensorAddressAny1(sanitize(labels[0]));
+ case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1]));
+ case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]));
+ case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3]));
+ default -> {
+ for (int i = 0; i < labels.length; i++) {
+ sanitize(labels[i]);
+ }
+ yield new TensorAddressAnyN(labels);
+ }
+ };
+ }
+ public static TensorAddress of(long label) {
+ return of(safe2Int(label));
+ }
+
+ public static TensorAddress of(long label0, long label1) {
+ return of(safe2Int(label0), safe2Int(label1));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2, long label3) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3));
+ }
+
+ public static TensorAddress of(long ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(safe2Int(labels[0]));
+ case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]));
+ case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]));
+ case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3]));
+ default -> {
+ int [] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = safe2Int(labels[i]);
+ }
+ yield of(labelsAsInt);
+ }
+ };
+ }
+
+ private static TensorAddress ofUnsafe(int label) {
+ return new TensorAddressAny1(label);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1) {
+ return new TensorAddressAny2(label0, label1);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2) {
+ return new TensorAddressAny3(label0, label1, label2);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(label0, label1, label2, label3);
+ }
+ public static TensorAddress ofUnsafe(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(labels[0]);
+ case 2 -> ofUnsafe(labels[0], labels[1]);
+ case 3 -> ofUnsafe(labels[0], labels[1], labels[2]);
+ case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]);
+ default -> new TensorAddressAnyN(labels);
+ };
+ }
+ private static int sanitize(int label) {
+ if (label < Tensor.INVALID_INDEX) {
+ throw new IndexOutOfBoundsException("cell label " + label + " must be positive");
+ }
+ return label;
+ }
+}