summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java236
1 files changed, 161 insertions, 75 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 1ebd6c4179d..c1a24abd878 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
@@ -147,11 +148,10 @@ public class IndexedTensor implements Tensor {
return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
- Indexes indexes = new Indexes(dimensionSizes, values.length);
+ Indexes indexes = Indexes.of(dimensionSizes, values.length);
for (int i = 0; i < values.length; i++) {
+ indexes.next();
builder.put(indexes.toAddress(), values[i]);
- if (i < values.length -1)
- indexes.next();
}
return builder.build();
}
@@ -161,11 +161,11 @@ public class IndexedTensor implements Tensor {
@Override
public String toString() { return Tensor.toStandardString(this); }
-
+
@Override
- public boolean equals(Object o) {
- if ( ! (o instanceof Tensor)) return false;
- return Tensor.equals(this, (Tensor)o);
+ public boolean equals(Object other) {
+ if ( ! ( other instanceof Tensor)) return false;
+ return Tensor.equals(this, ((Tensor)other));
}
public abstract static class Builder implements Tensor.Builder {
@@ -401,7 +401,7 @@ public class IndexedTensor implements Tensor {
private final class CellIterator implements Iterator<Map.Entry<TensorAddress, Double>> {
private int count = 0;
- private final Indexes indexes = new Indexes(dimensionSizes, values.length);
+ private final Indexes indexes = Indexes.of(dimensionSizes, values.length);
@Override
public boolean hasNext() {
@@ -411,14 +411,9 @@ public class IndexedTensor implements Tensor {
@Override
public Map.Entry<TensorAddress, Double> next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
-
- Map.Entry<TensorAddress, Double> current = new Cell(indexes.toAddress(), get(indexes));
-
count++;
- if (hasNext())
- indexes.next();
-
- return current;
+ indexes.next();
+ return new Cell(indexes.toAddress(), get(indexes));
}
}
@@ -444,6 +439,21 @@ public class IndexedTensor implements Tensor {
throw new UnsupportedOperationException("A tensor cannot be modified");
}
+ @Override
+ public boolean equals(Object o) {
+ if (o == this) return true;
+ if ( ! ( o instanceof Map.Entry)) return false;
+ Map.Entry other = (Map.Entry)o;
+ if ( ! this.getValue().equals(other.getValue())) return false;
+ if ( ! this.getKey().equals(other.getKey())) return false;
+ return true;
+ }
+
+ @Override
+ public int hashCode() {
+ return getKey().hashCode() ^ getValue().hashCode(); // by Map.Entry spec
+ }
+
}
private final class ValueIterator implements Iterator<Double> {
@@ -490,10 +500,10 @@ public class IndexedTensor implements Tensor {
for (int i = 0; i < type.dimensions().size(); i++ ) {
boolean superDimension = superdimensionNames.contains(type.dimensions().get(i).name());
superdimensionIndexes[i] = superDimension;
- subdimensionIndexes[i] = ! superDimension;
+ subdimensionIndexes[i] = ! superDimension;
}
- superindexes = new Indexes(dimensionSizes, superdimensionIndexes);
+ superindexes = Indexes.of(dimensionSizes, superdimensionIndexes);
}
@Override
@@ -504,11 +514,9 @@ public class IndexedTensor implements Tensor {
@Override
public SubspaceIterator next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + superindexes);
- SubspaceIterator subspace = new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes);
count++;
- if (hasNext())
- superindexes.next();
- return subspace;
+ superindexes.next();
+ return new SubspaceIterator(subdimensionIndexes, superindexes.indexesCopy(), dimensionSizes);
}
}
@@ -529,7 +537,7 @@ public class IndexedTensor implements Tensor {
* @param address the address of the first cell of this subspace.
*/
private SubspaceIterator(boolean[] dimensionIndexes, int[] address, int[] dimensionSizes) {
- this.indexes = new Indexes(dimensionSizes, dimensionIndexes, address);
+ this.indexes = Indexes.of(dimensionSizes, dimensionIndexes, address);
}
/** Returns the total number of cells in this subspace */
@@ -543,52 +551,55 @@ public class IndexedTensor implements Tensor {
@Override
public Map.Entry<TensorAddress, Double> next() {
if ( ! hasNext()) throw new NoSuchElementException("No cell at " + indexes);
-
- Map.Entry<TensorAddress, Double> current = new Cell(indexes.toAddress(), get(indexes));
-
count++;
- if (hasNext())
- indexes.next();
-
- return current;
+ indexes.next();
+ return new Cell(indexes.toAddress(), get(indexes));
}
}
- /** An array of indexes into this tensor which are able to find the next index in the value order */
- private static class Indexes {
-
- private final int size;
- private final int[] indexes;
-
- private final int[] dimensionSizes;
-
- /** Only mutate (take next in) the dimension indexes which are true */
- private final boolean[] iteratingDimensions;
+ /**
+ * An array of indexes into this tensor which are able to find the next index in the value order.
+ * next() can be called once per element in the dimensions we iterate over. It must be called once
+ * before accessing the first position.
+ */
+ public abstract static class Indexes {
+
+ protected final int[] indexes;
- private Indexes(int[] dimensionSizes, int size) {
- this(dimensionSizes, trueArray(dimensionSizes.length), size);
+ public static Indexes of(int[] dimensionSizes) {
+ return of(dimensionSizes, trueArray(dimensionSizes.length));
}
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions) {
- this(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions));
+ private static Indexes of(int[] dimensionSizes, int size) {
+ return of(dimensionSizes, trueArray(dimensionSizes.length), size);
}
-
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) {
- this(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size);
+
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions) {
+ return of(dimensionSizes, iteratingDimensions, computeSize(dimensionSizes, iteratingDimensions));
}
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) {
- this(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions));
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensionIndexes, int size) {
+ return of(dimensionSizes, iteratingDimensionIndexes, new int[dimensionSizes.length], size);
}
- private Indexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
- this.dimensionSizes = dimensionSizes;
- this.iteratingDimensions = iteratingDimensions;
- this.indexes = initialIndexes;
- this.size = size;
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes) {
+ return of(dimensionSizes, iteratingDimensions, initialIndexes, computeSize(dimensionSizes, iteratingDimensions));
+ }
+
+ private static Indexes of(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
+ if (size == 0)
+ return new EmptyIndexes(initialIndexes); // we're told explicitly there are truly no values available
+ else if (size == 1)
+ return new SingleValueIndexes(initialIndexes); // with no (iterating) dimensions, we still return one value, not zero
+ else
+ return new MultivalueIndexes(dimensionSizes, iteratingDimensions, initialIndexes, size);
}
+ private Indexes(int[] indexes) {
+ this.indexes = indexes;
+ }
+
private static boolean[] trueArray(int size) {
boolean[] array = new boolean[size];
Arrays.fill(array, true);
@@ -602,19 +613,112 @@ public class IndexedTensor implements Tensor {
size *= dimensionSizes[dimensionIndex];
return size;
}
+
+ /** Returns the address of the current position of these indexes */
+ private TensorAddress toAddress() {
+ // TODO: We may avoid the array copy by issuing a one-time-use address?
+ return TensorAddress.of(indexes);
+ }
+
+ public int[] indexesCopy() {
+ return Arrays.copyOf(indexes, indexes.length);
+ }
+
+ /** Returns a copy of the indexes of this which must not be modified */
+ public int[] indexesForReading() { return indexes; }
+
+ /** Returns an immutable list containing a copy of the indexes in this */
+ public List<Integer> toList() {
+ ImmutableList.Builder<Integer> builder = new ImmutableList.Builder<>();
+ for (int index : indexes)
+ builder.add(index);
+ return builder.build();
+ }
+
+ @Override
+ public String toString() {
+ return "indexes " + Arrays.toString(indexes);
+ }
- private static boolean anyTrue(boolean[] values) {
- for (boolean value : values)
- if (value) return true;
+ public abstract int size();
+
+ public abstract void next();
+
+ }
+
+ private final static class EmptyIndexes extends Indexes {
+
+ private EmptyIndexes(int[] indexes) {
+ super(indexes);
+ }
+
+ @Override
+ public int size() {
+ return 0;
+ }
+
+ @Override
+ public void next() {}
+
+ }
+
+ private final static class SingleValueIndexes extends Indexes {
+
+ private SingleValueIndexes(int[] indexes) {
+ super(indexes);
+ }
+
+ @Override
+ public int size() {
+ return 1;
+ }
+
+ @Override
+ public void next() {}
+
+ }
+
+ private final static class MultivalueIndexes extends Indexes {
+
+ private final int size;
+
+ private final int[] dimensionSizes;
+
+ /** Only mutate (take next in) the dimension indexes which are true */
+ private final boolean[] iteratingDimensions;
+
+ private static boolean haveIteratingDimensions(boolean[] iteratingDimensions) {
+ for (boolean iterating : iteratingDimensions)
+ if (iterating)
+ return true;
return false;
}
+ private MultivalueIndexes(int[] dimensionSizes, boolean[] iteratingDimensions, int[] initialIndexes, int size) {
+ super(initialIndexes);
+ this.dimensionSizes = dimensionSizes;
+ this.iteratingDimensions = iteratingDimensions;
+ this.size = size;
+
+ // Initialize to the (virtual) position before the first cell
+ int currentDimension = indexes.length - 1;
+ while (! iteratingDimensions[currentDimension])
+ currentDimension--;
+ indexes[currentDimension]--;
+ }
+
/** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */
+ @Override
public int size() {
return size;
}
- private void next() {
+ /**
+ * Advances this to the next cell in the standard indexed tensor cell order.
+ * The first call to this will put it at the first position.
+ */
+ @Override
+ public void next() {
int currentDimension = indexes.length - 1;
while ( ! iteratingDimensions[currentDimension] ||
indexes[currentDimension] + 1 == dimensionSizes[currentDimension]) {
@@ -626,24 +730,6 @@ public class IndexedTensor implements Tensor {
indexes[currentDimension]++;
}
- /** Returns the address of the current position of these indexes */
- private TensorAddress toAddress() {
- // TODO: We may avoid the array copy by issuing a one-time-use address?
- return TensorAddress.of(indexes);
- }
-
- private int[] indexesCopy() {
- return Arrays.copyOf(indexes, indexes.length);
- }
-
- /** Returns a copy of the indexes of this which must not be modified */
- private int[] indexesForReading() { return indexes; }
-
- @Override
- public String toString() {
- return "indexes " + Arrays.toString(indexes);
- }
-
}
}