summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-21 14:25:01 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-21 14:25:01 +0100
commited8ec5305f6838e31de94ef87ddd3a75390b59ed (patch)
tree6266387837bafdc29713b1a9605919b59fd86079 /vespajlib/src/main/java/com/yahoo/tensor
parentb56911f909e6ca68fa0a02cf5932d422a61a9f49 (diff)
- Tensor generate implementation
- Cross tensor implementation equals - Better iteration
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java236
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java1
6 files changed, 209 insertions, 86 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 1ebd6c4179d..c1a24abd878 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
@@ -147,11 +148,10 @@ public class IndexedTensor implements Tensor {
return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
- Indexes indexes = new Indexes(dimensionSizes, values.length);
+ Indexes indexes = Indexes.of(dimensionSizes, values.length);
for (int i = 0; i < values.length; i++) {
+ indexes.next();
builder.put(indexes.toAddress(), values[i]);
- if (i < values.length -1)
- indexes.next();
}
return builder.build();
}
@@ -161,11 +161,11 @@ public class IndexedTensor implements Tensor {
@Override
public String toString() { return Tensor.toStandardString(this); }
-
+
@Override
- public boolean equals(Object o) {
- if ( ! (o instanceof Tensor)) return false;
- return Tensor.equals(this, (Tensor)o);
+ public boolean equals(Object other) {
+ if ( ! ( other instanceof Tensor)) return false;
+ return Tensor.equals(this, ((Tensor)other));
}
public abstract static class Builder implements Tensor.Builder {
@@ -401,7 +401,7 @@ public class IndexedTensor implements Tensor {
private final class CellIterator implements Iterator<Map.Entry<TensorAddress, Double>> {
private int count = 0;
- private final Indexes indexes = new Indexes(dimensionSizes, values.length);
+ private final Indexes indexes = Indexes.of(dimensionSizes, values.length);
@Override
public boolean hasNext() {
@@ -411,14 +411,9 @@ public class IndexedTensor implements Tensor {
@Override
public Map.Entry<TensorAddress, Double> next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
-
- Map.Entry<TensorAddress, Double> current = new Cell(indexes.toAddress(), get(indexes));
-
count++;
- if (hasNext())
- indexes.next();
-
- return current;
+ indexes.next();
+ return new Cell(indexes.toAddress(), get(indexes));
}
}
@@ -444,6 +439,21 @@ public class IndexedTensor implements Tensor {
throw new UnsupportedOperationException("A tensor cannot be modified");
}
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if ( ! ( o instanceof Map.Entry)) return false;
+ Map.Entry other = (Map.Entry)o;
+ if ( ! this.getValue().equals(other.getValue())) return false;
+ if ( ! this.getKey().equals(other.getKey())) return false;
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec
+ }
+
}
private final class ValueIterator implements Iterator<Double> {
@@ -490,10 +500,10 @@ public class IndexedTensor implements Tensor {
for (int i = 0; i < type.dimensions().size(); i++ ) {
boolean superDimension = superdimensionNames.contains(type.dimensions().get(i).name());
superdimensionIndexes[i] = superDimension;
- subdimensionIndexes[i] = ! superDimension;
+ subdimensionIndexes[i] = ! superDimension;
}
- superindexes = new Indexes(dimensionSizes, superdimensionIndexes);
+ superindexes = Indexes.of(dimensionSizes, superdimensionIndexes);
}
@Override
@@ -504,11 +514,9 @@ public class IndexedTensor implements Tensor {
@Override
public SubspaceIterator next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes);
- SubspaceIterator subspace = new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes);
count++;
- if (hasNext())
- superindexes.next();
- return subspace;
+ superindexes.next();
+ return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes);
}
}
@@ -529,7 +537,7 @@ public class IndexedTensor implements Tensor {
* @param address the address of the first cell of this subspace.
*/
private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) {
- this.indexes = new Indexes(dimensionSizes, dimensionIndexes, address);
+ this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address);
}
/** Returns the total number of cells in this subspace */
@@ -543,52 +551,55 @@ public class IndexedTensor implements Tensor {
@Override
public Map.Entry<TensorAddress, Double> next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
-
- Map.Entry<TensorAddress, Double> current = new Cell(indexes.toAddress(), get(indexes));
-
count++;
- if (hasNext())
- indexes.next();
-
- return current;
+ indexes.next();
+ return new Cell(indexes.toAddress(), get(indexes));
}
}
- /** An array of indexes into this tensor which are able to find the next index in the value order */
- private static class Indexes {
-
- private final int size;
- private final int[] indexes;
-
- private final int[] dimensionSizes;
-
- /** Only mutate (take next in) the dimension indexes which are true */
- private final boolean[] iteratingDimensions;
+ /**
+ * An array of indexes into this tensor which are able to find the next index in the value order.
+ * next() can be called once per element in the dimensions we iterate over. It must be called once
+ * before accessing the first position.
+ */
+ public abstract static class Indexes {
+
+ protected final int[] indexes;
- private Indexes(int[] dimensionSizes, int size) {
- this(dimensionSizes, trueArray(dimensionSizes.length), size);
+ public static Indexes of(int[] dimensionSizes) {
+ return of(dimensionSizes, trueArray(dimensionSizes.length));
}
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions) {
- this(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions));
+ private static Indexes of(int[] dimensionSizes, int size) {
+ return of(dimensionSizes, trueArray(dimensionSizes.length), size);
}
-
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) {
- this(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size);
+
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions) {
+ return of(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions));
}
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) {
- this(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions));
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) {
+ return of(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size);
}
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
- this.dimensionSizes = dimensionSizes;
- this.iteratingDimensions = iteratingDimensions;
- this.indexes = initialIndexes;
- this.size = size;
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) {
+ return of(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions));
+ }
+
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
+ if (size == 0)
+ return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available
+ else if (size == 1)
+ return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
+ else
+ return new MultivalueIndexes(dimensionSizes, iteratingDimensions, initialIndexes, size);
}
+ private Indexes(int[] indexes) {
+ this.indexes = indexes;
+ }
+
private static boolean[] trueArray(int size) {
boolean[] array = new boolean[size];
Arrays.fill(array, true);
@@ -602,19 +613,112 @@ public class IndexedTensor implements Tensor {
size *= dimensionSizes[dimensionIndex];
return size;
}
+
+ /** Returns the address of the current position of these indexes */
+ private TensorAddress toAddress() {
+ // TODO: We may avoid the array copy by issuing a one-time-use address?
+ return TensorAddress.of(indexes);
+ }
+
+ public int[] indexesCopy() {
+ return Arrays.copyOf(indexes, indexes.length);
+ }
+
+ /** Returns a copy of the indexes of this which must not be modified */
+ public int[] indexesForReading() { return indexes; }
+
+ /** Returns an immutable list containing a copy of the indexes in this */
+ public List<Integer> toList() {
+ ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>();
+ for (int index : indexes)
+ builder.add(index);
+ return builder.build();
+ }
+
+ @Override
+ public String toString() {
+ return "indexes " + Arrays.toString(indexes);
+ }
- private static boolean anyTrue(boolean[] values) {
- for (boolean value : values)
- if (value) return true;
+ public abstract int size();
+
+ public abstract void next();
+
+ }
+
+ private final static class EmptyIndexes extends Indexes {
+
+ private EmptyIndexes(int[] indexes) {
+ super(indexes);
+ }
+
+ @Override
+ public int size() {
+ return 0;
+ }
+
+ @Override
+ public void next() {}
+
+ }
+
+ private final static class SingleValueIndexes extends Indexes {
+
+ private SingleValueIndexes(int[] indexes) {
+ super(indexes);
+ }
+
+ @Override
+ public int size() {
+ return 1;
+ }
+
+ @Override
+ public void next() {}
+
+ }
+
+ private final static class MultivalueIndexes extends Indexes {
+
+ private final int size;
+
+ private final int[] dimensionSizes;
+
+ /** Only mutate (take next in) the dimension indexes which are true */
+ private final boolean[] iteratingDimensions;
+
+ private static boolean haveIteratingDimensions(boolean[] iteratingDimensions) {
+ for (boolean iterating : iteratingDimensions)
+ if (iterating)
+ return true;
return false;
}
+ private MultivalueIndexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
+ super(initialIndexes);
+ this.dimensionSizes = dimensionSizes;
+ this.iteratingDimensions = iteratingDimensions;
+ this.size = size;
+
+ // Initialize to the (virtual) position before the first cell
+ int currentDimension = indexes.length - 1;
+ while (! iteratingDimensions[currentDimension])
+ currentDimension--;
+ indexes[currentDimension]--;
+ }
+
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
+ @Override
public int size() {
return size;
}
- private void next() {
+ /**
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
+ */
+ @Override
+ public void next() {
int currentDimension = indexes.length - 1;
while ( ! iteratingDimensions[currentDimension] ||
indexes[currentDimension] + 1 == dimensionSizes[currentDimension]) {
@@ -626,24 +730,6 @@ public class IndexedTensor implements Tensor {
indexes[currentDimension]++;
}
- /** Returns the address of the current position of these indexes */
- private TensorAddress toAddress() {
- // TODO: We may avoid the array copy by issuing a one-time-use address?
- return TensorAddress.of(indexes);
- }
-
- private int[] indexesCopy() {
- return Arrays.copyOf(indexes, indexes.length);
- }
-
- /** Returns a copy of the indexes of this which must not be modified */
- private int[] indexesForReading() { return indexes; }
-
- @Override
- public String toString() {
- return "indexes " + Arrays.toString(indexes);
- }
-
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 6e169b8347f..8d72e860473 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -50,9 +50,9 @@ public class MappedTensor implements Tensor {
public String toString() { return Tensor.toStandardString(this); }
@Override
- public boolean equals(Object o) {
- if ( ! (o instanceof Tensor)) return false;
- return Tensor.equals(this, (Tensor)o);
+ public boolean equals(Object other) {
+ if ( ! ( other instanceof Tensor)) return false;
+ return Tensor.equals(this, ((Tensor)other));
}
public static class Builder implements Tensor.Builder {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 6f655fd5860..808da3abad4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -57,11 +57,16 @@ public interface Tensor {
/** Returns the value of a cell, or NaN if this cell does not exist/have no value */
double get(TensorAddress address);
+ /** Returns the cell of this in some undefined order */
Iterator<Map.Entry<TensorAddress, Double>> cellIterator();
+ /** Returns the values of this in some undefined order */
Iterator<Double> valueIterator();
- /** Returns an immutable map of the cells of this. This may be expensive for some implementations - avoid when possible */
+ /**
+ * Returns an immutable map of the cells of this in no particular order.
+ * This may be expensive for some implementations - avoid when possible
+ */
Map<TensorAddress, Double> cells();
/**
@@ -203,15 +208,24 @@ public interface Tensor {
// ----------------- equality
/**
- * Returns true if the given tensor is mathematically equal to this:
- * Both are of type Tensor and have the same content.
+ * Returns whether this tensor and the given tensor is mathematically equal:
+ * That they have the same dimension *names* and the same content.
*/
- @Override
boolean equals(Object o);
- /** Returns true if the two given tensors are mathematically equivalent, that is whether both have the same content */
+ /**
+ * Implement here to make this work across implementations.
+ * Implementations must override equals and call this because this is an interface and cannot override equals.
+ */
static boolean equals(Tensor a, Tensor b) {
- return a == b || a.cells().equals(b.cells());
+ if (a == b) return true;
+ if ( ! a.type().mathematicallyEquals(b.type())) return false;
+ if ( a.size() != b.size()) return false;
+ for (Iterator<Map.Entry<TensorAddress, Double>> aIterator = a.cellIterator(); aIterator.hasNext(); ) {
+ Map.Entry<TensorAddress, Double> aCell = aIterator.next();
+ if ( ! aCell.getValue().equals(b.get(aCell.getKey()))) return false;
+ }
+ return true;
}
// ----------------- Factories
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index e829f4c909b..13ddf3c2e20 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -129,6 +129,14 @@ public class TensorType {
return dimensions.equals(((TensorType)other).dimensions);
}
+ /** Returns whether the given type has the same dimension names as this */
+ public boolean mathematicallyEquals(TensorType other) {
+ if (dimensions().size() != other.dimensions().size()) return false;
+ for (int i = 0; i < dimensions().size(); i++)
+ if (!dimensions().get(i).name().equals(other.dimensions().get(i).name())) return false;
+ return true;
+ }
+
@Override
public int hashCode() {
return dimensions.hashCode();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 508e322c3a1..9c92ca00eac 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -1,6 +1,7 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
@@ -58,9 +59,22 @@ public class Generate extends PrimitiveTensorFunction {
@Override
public Tensor evaluate(EvaluationContext context) {
- throw new UnsupportedOperationException("Not implemented"); // TODO
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
+ for (int i = 0; i < indexes.size(); i++) {
+ indexes.next();
+ builder.cell(generator.apply(indexes.toList()), indexes.indexesForReading());
+ }
+ return builder.build();
}
-
+
+ private int[] dimensionSizes(TensorType type) {
+ int dimensionSizes[] = new int[type.dimensions().size()];
+ for (int i = 0; i < dimensionSizes.length; i++)
+ dimensionSizes[i] = type.dimensions().get(i).size().get();
+ return dimensionSizes;
+ }
+
@Override
public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 6128611302f..ebec5efa436 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -143,6 +143,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
+
return builder.build();
}