summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java130
1 files changed, 88 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 1b88a5d1b2f..59a5e2a49b1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,13 +1,11 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.tensor.impl.NumericTensorAddress;
-import com.yahoo.tensor.impl.StringTensorAddress;
-import net.jpountz.xxhash.XXHash32;
-import net.jpountz.xxhash.XXHashFactory;
+import com.yahoo.tensor.impl.Label;
+import com.yahoo.tensor.impl.TensorAddressAny;
-import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import java.util.List;
import java.util.Objects;
/**
@@ -18,23 +16,25 @@ import java.util.Objects;
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32();
-
public static TensorAddress of(String[] labels) {
- return StringTensorAddress.of(labels);
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress ofLabels(String ... labels) {
- return StringTensorAddress.of(labels);
+ public static TensorAddress ofLabels(String... labels) {
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress of(long ... labels) {
- return NumericTensorAddress.of(labels);
+ public static TensorAddress of(long... labels) {
+ return TensorAddressAny.of(labels);
}
- private int cached_hash = 0;
+ public static TensorAddress of(int... labels) {
+ return TensorAddressAny.of(labels);
+ }
- /** Returns the number of labels in this */
+ /**
+ * Returns the number of labels in this
+ */
public abstract int size();
/**
@@ -67,32 +67,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
}
@Override
- public int hashCode() {
- if (cached_hash != 0) return cached_hash;
-
- int hash = 0;
- for (int i = 0; i < size(); i++) {
- String label = label(i);
- if (label != null) {
- byte [] buf = label.getBytes(StandardCharsets.UTF_8);
- hash = hasher.hash(buf, 0, buf.length, hash);
+ public String toString() {
+ StringBuilder sb = new StringBuilder("cell address (");
+ int sz = size();
+ if (sz > 0) {
+ sb.append(label(0));
+ for (int i = 1; i < sz; i++) {
+ sb.append(',').append(label(i));
}
}
- return cached_hash = hash;
- }
- @Override
- public boolean equals(Object o) {
- if (o == this) return true;
- if ( ! (o instanceof TensorAddress other)) return false;
- if (other.size() != this.size()) return false;
- for (int i = 0; i < this.size(); i++)
- if ( ! Objects.equals(this.label(i), other.label(i)))
- return false;
- return true;
+ return sb.append(')').toString();
}
- /** Returns this as a string on the appropriate form given the type */
+ /**
+ * Returns this as a string on the appropriate form given the type
+ */
public final String toString(TensorType type) {
StringBuilder b = new StringBuilder("{");
for (int i = 0; i < size(); i++) {
@@ -105,24 +95,72 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- /** Returns a label as a string with appropriate quoting/escaping when necessary */
+ /**
+ * Returns a label as a string with appropriate quoting/escaping when necessary
+ */
public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
return "'" + label + "'";
}
+ /** Returns an address with only some of the dimension */
+ public TensorAddress partialCopy(int[] indexMap) {
+ int[] labels = new int[indexMap.length];
+ for (int i = 0; i < labels.length; ++i) {
+ labels[i] = (int)numericLabel(indexMap[i]);
+ }
+ return TensorAddressAny.ofUnsafe(labels);
+ }
+
+ /** Creates a complete address by taking the sparse dimmensions from this and the indexed from the densePart */
+ public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int [] densePart) {
+ int [] labels = new int[dimensions.size()];
+ int mappedIndex = 0;
+ int indexedIndex = 0;
+ for (int i = 0; i < labels.length; i++) {
+ TensorType.Dimension d = dimensions.get(i);
+ if (d.isIndexed()) {
+ labels[i] = densePart[indexedIndex];
+ indexedIndex++;
+ } else {
+ labels[i] = (int)numericLabel(mappedIndex);
+ mappedIndex++;
+ }
+ }
+ return TensorAddressAny.ofUnsafe(labels);
+ }
+
+ /** Extracts the sparse(non-indexed) dimensions of the address */
+ public TensorAddress sparsePartialAddress(TensorType sparseType, List<TensorType.Dimension> dimensions) {
+ if (dimensions.size() != size())
+ throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this);
+ TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
+ for (int i = 0; i < dimensions.size(); ++i) {
+ TensorType.Dimension dimension = dimensions.get(i);
+ if ( ! dimension.isIndexed())
+ builder.add(dimension.name(), (int)numericLabel(i));
+ }
+ return builder.build();
+ }
+
/** Builder of a tensor address */
public static class Builder {
final TensorType type;
- final String[] labels;
+ final int[] labels;
+
+ private static int [] createEmptyLabels(int size) {
+ int [] labels = new int[size];
+ Arrays.fill(labels, Tensor.INVALID_INDEX);
+ return labels;
+ }
public Builder(TensorType type) {
- this(type, new String[type.dimensions().size()]);
+ this(type, createEmptyLabels(type.dimensions().size()));
}
- private Builder(TensorType type, String[] labels) {
+ private Builder(TensorType type, int[] labels) {
this.type = type;
this.labels = labels;
}
@@ -152,6 +190,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
int labelIndex = type.indexOfDimensionAsInt(dimension);
if ( labelIndex < 0)
throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
+ labels[labelIndex] = Label.toNumber(label);
+ return this;
+ }
+ public Builder add(String dimension, int label) {
+ Objects.requireNonNull(dimension, "dimension cannot be null");
+ int labelIndex = type.indexOfDimensionAsInt(dimension);
+ if ( labelIndex < 0)
+ throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
labels[labelIndex] = label;
return this;
}
@@ -166,14 +212,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
void validate() {
for (int i = 0; i < labels.length; i++)
- if (labels[i] == null)
+ if (labels[i] == Tensor.INVALID_INDEX)
throw new IllegalArgumentException("Missing a label for dimension '" +
type.dimensions().get(i).name() + "' for " + type);
}
public TensorAddress build() {
validate();
- return TensorAddress.of(labels);
+ return TensorAddressAny.ofUnsafe(labels);
}
}
@@ -185,7 +231,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
super(type);
}
- private PartialBuilder(TensorType type, String[] labels) {
+ private PartialBuilder(TensorType type, int[] labels) {
super(type, labels);
}