diff options
Diffstat (limited to 'vespajlib/src')
6 files changed, 209 insertions, 86 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 1ebd6c4179d..c1a24abd878 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -2,6 +2,7 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; @@ -147,11 +148,10 @@ public class IndexedTensor implements Tensor { return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); - Indexes indexes = new Indexes(dimensionSizes, values.length); + Indexes indexes = Indexes.of(dimensionSizes, values.length); for (int i = 0; i < values.length; i++) { + indexes.next(); builder.put(indexes.toAddress(), values[i]); - if (i < values.length -1) - indexes.next(); } return builder.build(); } @@ -161,11 +161,11 @@ public class IndexedTensor implements Tensor { @Override public String toString() { return Tensor.toStandardString(this); } - + @Override - public boolean equals(Object o) { - if ( ! (o instanceof Tensor)) return false; - return Tensor.equals(this, (Tensor)o); + public boolean equals(Object other) { + if ( ! ( other instanceof Tensor)) return false; + return Tensor.equals(this, ((Tensor)other)); } public abstract static class Builder implements Tensor.Builder { @@ -401,7 +401,7 @@ public class IndexedTensor implements Tensor { private final class CellIterator implements Iterator<Map.Entry<TensorAddress, Double>> { private int count = 0; - private final Indexes indexes = new Indexes(dimensionSizes, values.length); + private final Indexes indexes = Indexes.of(dimensionSizes, values.length); @Override public boolean hasNext() { @@ -411,14 +411,9 @@ public class IndexedTensor implements Tensor { @Override public Map.Entry<TensorAddress, Double> next() { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); - - Map.Entry<TensorAddress, Double> current = new Cell(indexes.toAddress(), get(indexes)); - count++; - if (hasNext()) - indexes.next(); - - return current; + indexes.next(); + return new Cell(indexes.toAddress(), get(indexes)); } } @@ -444,6 +439,21 @@ public class IndexedTensor implements Tensor { throw new UnsupportedOperationException("A tensor cannot be modified"); } + @Override + public boolean equals(Object o) { + if (o == this) return true; + if ( ! ( o instanceof Map.Entry)) return false; + Map.Entry other = (Map.Entry)o; + if ( ! this.getValue().equals(other.getValue())) return false; + if ( ! this.getKey().equals(other.getKey())) return false; + return true; + } + + @Override + public int hashCode() { + return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec + } + } private final class ValueIterator implements Iterator<Double> { @@ -490,10 +500,10 @@ public class IndexedTensor implements Tensor { for (int i = 0; i < type.dimensions().size(); i++ ) { boolean superDimension = superdimensionNames.contains(type.dimensions().get(i).name()); superdimensionIndexes[i] = superDimension; - subdimensionIndexes[i] = ! superDimension; + subdimensionIndexes[i] = ! superDimension; } - superindexes = new Indexes(dimensionSizes, superdimensionIndexes); + superindexes = Indexes.of(dimensionSizes, superdimensionIndexes); } @Override @@ -504,11 +514,9 @@ public class IndexedTensor implements Tensor { @Override public SubspaceIterator next() { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes); - SubspaceIterator subspace = new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes); count++; - if (hasNext()) - superindexes.next(); - return subspace; + superindexes.next(); + return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes); } } @@ -529,7 +537,7 @@ public class IndexedTensor implements Tensor { * @param address the address of the first cell of this subspace. */ private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) { - this.indexes = new Indexes(dimensionSizes, dimensionIndexes, address); + this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address); } /** Returns the total number of cells in this subspace */ @@ -543,52 +551,55 @@ public class IndexedTensor implements Tensor { @Override public Map.Entry<TensorAddress, Double> next() { if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes); - - Map.Entry<TensorAddress, Double> current = new Cell(indexes.toAddress(), get(indexes)); - count++; - if (hasNext()) - indexes.next(); - - return current; + indexes.next(); + return new Cell(indexes.toAddress(), get(indexes)); } } - /** An array of indexes into this tensor which are able to find the next index in the value order */ - private static class Indexes { - - private final int size; - private final int[] indexes; - - private final int[] dimensionSizes; - - /** Only mutate (take next in) the dimension indexes which are true */ - private final boolean[] iteratingDimensions; + /** + * An array of indexes into this tensor which are able to find the next index in the value order. + * next() can be called once per element in the dimensions we iterate over. It must be called once + * before accessing the first position. + */ + public abstract static class Indexes { + + protected final int[] indexes; - private Indexes(int[] dimensionSizes, int size) { - this(dimensionSizes, trueArray(dimensionSizes.length), size); + public static Indexes of(int[] dimensionSizes) { + return of(dimensionSizes, trueArray(dimensionSizes.length)); } - private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions) { - this(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions)); + private static Indexes of(int[] dimensionSizes, int size) { + return of(dimensionSizes, trueArray(dimensionSizes.length), size); } - - private Indexes(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) { - this(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size); + + private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions) { + return of(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions)); } - private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) { - this(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions)); + private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) { + return of(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size); } - private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) { - this.dimensionSizes = dimensionSizes; - this.iteratingDimensions = iteratingDimensions; - this.indexes = initialIndexes; - this.size = size; + private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) { + return of(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions)); + } + + private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) { + if (size == 0) + return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available + else if (size == 1) + return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero + else + return new MultivalueIndexes(dimensionSizes, iteratingDimensions, initialIndexes, size); } + private Indexes(int[] indexes) { + this.indexes = indexes; + } + private static boolean[] trueArray(int size) { boolean[] array = new boolean[size]; Arrays.fill(array, true); @@ -602,19 +613,112 @@ public class IndexedTensor implements Tensor { size *= dimensionSizes[dimensionIndex]; return size; } + + /** Returns the address of the current position of these indexes */ + private TensorAddress toAddress() { + // TODO: We may avoid the array copy by issuing a one-time-use address? + return TensorAddress.of(indexes); + } + + public int[] indexesCopy() { + return Arrays.copyOf(indexes, indexes.length); + } + + /** Returns a copy of the indexes of this which must not be modified */ + public int[] indexesForReading() { return indexes; } + + /** Returns an immutable list containing a copy of the indexes in this */ + public List<Integer> toList() { + ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>(); + for (int index : indexes) + builder.add(index); + return builder.build(); + } + + @Override + public String toString() { + return "indexes " + Arrays.toString(indexes); + } - private static boolean anyTrue(boolean[] values) { - for (boolean value : values) - if (value) return true; + public abstract int size(); + + public abstract void next(); + + } + + private final static class EmptyIndexes extends Indexes { + + private EmptyIndexes(int[] indexes) { + super(indexes); + } + + @Override + public int size() { + return 0; + } + + @Override + public void next() {} + + } + + private final static class SingleValueIndexes extends Indexes { + + private SingleValueIndexes(int[] indexes) { + super(indexes); + } + + @Override + public int size() { + return 1; + } + + @Override + public void next() {} + + } + + private final static class MultivalueIndexes extends Indexes { + + private final int size; + + private final int[] dimensionSizes; + + /** Only mutate (take next in) the dimension indexes which are true */ + private final boolean[] iteratingDimensions; + + private static boolean haveIteratingDimensions(boolean[] iteratingDimensions) { + for (boolean iterating : iteratingDimensions) + if (iterating) + return true; return false; } + private MultivalueIndexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) { + super(initialIndexes); + this.dimensionSizes = dimensionSizes; + this.iteratingDimensions = iteratingDimensions; + this.size = size; + + // Initialize to the (virtual) position before the first cell + int currentDimension = indexes.length - 1; + while (! iteratingDimensions[currentDimension]) + currentDimension--; + indexes[currentDimension]--; + } + /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ + @Override public int size() { return size; } - private void next() { + /** + * Advances this to the next cell in the standard indexed tensor cell order. + * The first call to this will put it at the first position. + */ + @Override + public void next() { int currentDimension = indexes.length - 1; while ( ! iteratingDimensions[currentDimension] || indexes[currentDimension] + 1 == dimensionSizes[currentDimension]) { @@ -626,24 +730,6 @@ public class IndexedTensor implements Tensor { indexes[currentDimension]++; } - /** Returns the address of the current position of these indexes */ - private TensorAddress toAddress() { - // TODO: We may avoid the array copy by issuing a one-time-use address? - return TensorAddress.of(indexes); - } - - private int[] indexesCopy() { - return Arrays.copyOf(indexes, indexes.length); - } - - /** Returns a copy of the indexes of this which must not be modified */ - private int[] indexesForReading() { return indexes; } - - @Override - public String toString() { - return "indexes " + Arrays.toString(indexes); - } - } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 6e169b8347f..8d72e860473 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -50,9 +50,9 @@ public class MappedTensor implements Tensor { public String toString() { return Tensor.toStandardString(this); } @Override - public boolean equals(Object o) { - if ( ! (o instanceof Tensor)) return false; - return Tensor.equals(this, (Tensor)o); + public boolean equals(Object other) { + if ( ! ( other instanceof Tensor)) return false; + return Tensor.equals(this, ((Tensor)other)); } public static class Builder implements Tensor.Builder { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 6f655fd5860..808da3abad4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -57,11 +57,16 @@ public interface Tensor { /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); + /** Returns the cell of this in some undefined order */ Iterator<Map.Entry<TensorAddress, Double>> cellIterator(); + /** Returns the values of this in some undefined order */ Iterator<Double> valueIterator(); - /** Returns an immutable map of the cells of this. This may be expensive for some implementations - avoid when possible */ + /** + * Returns an immutable map of the cells of this in no particular order. + * This may be expensive for some implementations - avoid when possible + */ Map<TensorAddress, Double> cells(); /** @@ -203,15 +208,24 @@ public interface Tensor { // ----------------- equality /** - * Returns true if the given tensor is mathematically equal to this: - * Both are of type Tensor and have the same content. + * Returns whether this tensor and the given tensor is mathematically equal: + * That they have the same dimension *names* and the same content. */ - @Override boolean equals(Object o); - /** Returns true if the two given tensors are mathematically equivalent, that is whether both have the same content */ + /** + * Implement here to make this work across implementations. + * Implementations must override equals and call this because this is an interface and cannot override equals. + */ static boolean equals(Tensor a, Tensor b) { - return a == b || a.cells().equals(b.cells()); + if (a == b) return true; + if ( ! a.type().mathematicallyEquals(b.type())) return false; + if ( a.size() != b.size()) return false; + for (Iterator<Map.Entry<TensorAddress, Double>> aIterator = a.cellIterator(); aIterator.hasNext(); ) { + Map.Entry<TensorAddress, Double> aCell = aIterator.next(); + if ( ! aCell.getValue().equals(b.get(aCell.getKey()))) return false; + } + return true; } // ----------------- Factories diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index e829f4c909b..13ddf3c2e20 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -129,6 +129,14 @@ public class TensorType { return dimensions.equals(((TensorType)other).dimensions); } + /** Returns whether the given type has the same dimension names as this */ + public boolean mathematicallyEquals(TensorType other) { + if (dimensions().size() != other.dimensions().size()) return false; + for (int i = 0; i < dimensions().size(); i++) + if (!dimensions().get(i).name().equals(other.dimensions().get(i).name())) return false; + return true; + } + @Override public int hashCode() { return dimensions.hashCode(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 508e322c3a1..9c92ca00eac 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -1,6 +1,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; @@ -58,9 +59,22 @@ public class Generate extends PrimitiveTensorFunction { @Override public Tensor evaluate(EvaluationContext context) { - throw new UnsupportedOperationException("Not implemented"); // TODO + Tensor.Builder builder = Tensor.Builder.of(type); + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); + for (int i = 0; i < indexes.size(); i++) { + indexes.next(); + builder.cell(generator.apply(indexes.toList()), indexes.indexesForReading()); + } + return builder.build(); } - + + private int[] dimensionSizes(TensorType type) { + int dimensionSizes[] = new int[type.dimensions().size()]; + for (int i = 0; i < dimensionSizes.length; i++) + dimensionSizes[i] = type.dimensions().get(i).size().get(); + return dimensionSizes; + } + @Override public String toString(ToStringContext context) { return type + "(" + generator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 6128611302f..ebec5efa436 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -143,6 +143,7 @@ public class Join extends PrimitiveTensorFunction { subspaceInSuper, subspaceInSuper.size(), reversedArgumentOrder, builder); } + return builder.build(); } |