diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java | 130 |
1 files changed, 88 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 1b88a5d1b2f..59a5e2a49b1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,13 +1,11 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.tensor.impl.NumericTensorAddress; -import com.yahoo.tensor.impl.StringTensorAddress; -import net.jpountz.xxhash.XXHash32; -import net.jpountz.xxhash.XXHashFactory; +import com.yahoo.tensor.impl.Label; +import com.yahoo.tensor.impl.TensorAddressAny; -import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.List; import java.util.Objects; /** @@ -18,23 +16,25 @@ import java.util.Objects; */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32(); - public static TensorAddress of(String[] labels) { - return StringTensorAddress.of(labels); + return TensorAddressAny.of(labels); } - public static TensorAddress ofLabels(String ... labels) { - return StringTensorAddress.of(labels); + public static TensorAddress ofLabels(String... labels) { + return TensorAddressAny.of(labels); } - public static TensorAddress of(long ... labels) { - return NumericTensorAddress.of(labels); + public static TensorAddress of(long... labels) { + return TensorAddressAny.of(labels); } - private int cached_hash = 0; + public static TensorAddress of(int... labels) { + return TensorAddressAny.of(labels); + } - /** Returns the number of labels in this */ + /** + * Returns the number of labels in this + */ public abstract int size(); /** @@ -67,32 +67,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { } @Override - public int hashCode() { - if (cached_hash != 0) return cached_hash; - - int hash = 0; - for (int i = 0; i < size(); i++) { - String label = label(i); - if (label != null) { - byte [] buf = label.getBytes(StandardCharsets.UTF_8); - hash = hasher.hash(buf, 0, buf.length, hash); + public String toString() { + StringBuilder sb = new StringBuilder("cell address ("); + int sz = size(); + if (sz > 0) { + sb.append(label(0)); + for (int i = 1; i < sz; i++) { + sb.append(',').append(label(i)); } } - return cached_hash = hash; - } - @Override - public boolean equals(Object o) { - if (o == this) return true; - if ( ! (o instanceof TensorAddress other)) return false; - if (other.size() != this.size()) return false; - for (int i = 0; i < this.size(); i++) - if ( ! Objects.equals(this.label(i), other.label(i))) - return false; - return true; + return sb.append(')').toString(); } - /** Returns this as a string on the appropriate form given the type */ + /** + * Returns this as a string on the appropriate form given the type + */ public final String toString(TensorType type) { StringBuilder b = new StringBuilder("{"); for (int i = 0; i < size(); i++) { @@ -105,24 +95,72 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } - /** Returns a label as a string with appropriate quoting/escaping when necessary */ + /** + * Returns a label as a string with appropriate quoting/escaping when necessary + */ public static String labelToString(String label) { if (TensorType.labelMatcher.matches(label)) return label; // no quoting if (label.contains("'")) return "\"" + label + "\""; return "'" + label + "'"; } + /** Returns an address with only some of the dimension */ + public TensorAddress partialCopy(int[] indexMap) { + int[] labels = new int[indexMap.length]; + for (int i = 0; i < labels.length; ++i) { + labels[i] = (int)numericLabel(indexMap[i]); + } + return TensorAddressAny.ofUnsafe(labels); + } + + /** Creates a complete address by taking the sparse dimmensions from this and the indexed from the densePart */ + public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int [] densePart) { + int [] labels = new int[dimensions.size()]; + int mappedIndex = 0; + int indexedIndex = 0; + for (int i = 0; i < labels.length; i++) { + TensorType.Dimension d = dimensions.get(i); + if (d.isIndexed()) { + labels[i] = densePart[indexedIndex]; + indexedIndex++; + } else { + labels[i] = (int)numericLabel(mappedIndex); + mappedIndex++; + } + } + return TensorAddressAny.ofUnsafe(labels); + } + + /** Extracts the sparse(non-indexed) dimensions of the address */ + public TensorAddress sparsePartialAddress(TensorType sparseType, List<TensorType.Dimension> dimensions) { + if (dimensions.size() != size()) + throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this); + TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); + for (int i = 0; i < dimensions.size(); ++i) { + TensorType.Dimension dimension = dimensions.get(i); + if ( ! dimension.isIndexed()) + builder.add(dimension.name(), (int)numericLabel(i)); + } + return builder.build(); + } + /** Builder of a tensor address */ public static class Builder { final TensorType type; - final String[] labels; + final int[] labels; + + private static int [] createEmptyLabels(int size) { + int [] labels = new int[size]; + Arrays.fill(labels, Tensor.INVALID_INDEX); + return labels; + } public Builder(TensorType type) { - this(type, new String[type.dimensions().size()]); + this(type, createEmptyLabels(type.dimensions().size())); } - private Builder(TensorType type, String[] labels) { + private Builder(TensorType type, int[] labels) { this.type = type; this.labels = labels; } @@ -152,6 +190,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { int labelIndex = type.indexOfDimensionAsInt(dimension); if ( labelIndex < 0) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); + labels[labelIndex] = Label.toNumber(label); + return this; + } + public Builder add(String dimension, int label) { + Objects.requireNonNull(dimension, "dimension cannot be null"); + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) + throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); labels[labelIndex] = label; return this; } @@ -166,14 +212,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { void validate() { for (int i = 0; i < labels.length; i++) - if (labels[i] == null) + if (labels[i] == Tensor.INVALID_INDEX) throw new IllegalArgumentException("Missing a label for dimension '" + type.dimensions().get(i).name() + "' for " + type); } public TensorAddress build() { validate(); - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } } @@ -185,7 +231,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { super(type); } - private PartialBuilder(TensorType type, String[] labels) { + private PartialBuilder(TensorType type, int[] labels) { super(type, labels); } |