diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-25 16:40:00 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-25 16:40:00 +0200 |
commit | 7ef86b1fb25f2268d00fa3af87bc1e594de0b1b3 (patch) | |
tree | 5af4bc2b63e291b7e80d2ffc3ea85b5dfdf2b044 /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | |
parent | a8949c869c613d671886b87ab684b2dfef9d9ca5 (diff) |
Split values into IndexedDoubleTensor subclass
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 51 |
1 files changed, 17 insertions, 34 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 38d832d01c2..5f2c04bbd56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -20,7 +20,7 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class IndexedTensor implements Tensor { +public abstract class IndexedTensor implements Tensor { /** The prescribed and possibly abstract type this is an instance of */ private final TensorType type; @@ -28,17 +28,9 @@ public class IndexedTensor implements Tensor { /** The sizes of the dimensions of this in the order of the dimensions of the type */ private final DimensionSizes dimensionSizes; - private final double[] values; - - private IndexedTensor(TensorType type, DimensionSizes dimensionSizes, double[] values) { + IndexedTensor(TensorType type, DimensionSizes dimensionSizes) { this.type = type; this.dimensionSizes = dimensionSizes; - this.values = values; - } - - @Override - public long size() { - return values.length; } /** @@ -96,13 +88,13 @@ public class IndexedTensor implements Tensor { } /** - * Returns the value at the given indexes + * Returns the value at the given indexes as a double * * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(long ... indexes) { - return values[(int)toValueIndex(indexes, dimensionSizes)]; + return get((int)toValueIndex(indexes, dimensionSizes)); } /** Returns the value at this address, or NaN if there is no value at this address */ @@ -110,7 +102,7 @@ public class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return values[(int)toValueIndex(address, dimensionSizes)]; + return get((int)toValueIndex(address, dimensionSizes)); } catch (IndexOutOfBoundsException e) { return Double.NaN; @@ -124,7 +116,7 @@ public class IndexedTensor implements Tensor { * @param valueIndex the direct index into the underlying data. * @throws IndexOutOfBoundsException if index is out of bounds */ - public double get(long valueIndex) { return values[(int)valueIndex]; } + public abstract double get(long valueIndex); private static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed @@ -164,13 +156,7 @@ public class IndexedTensor implements Tensor { public TensorType type() { return type; } @Override - public IndexedTensor withType(TensorType type) { - if (!this.type.isRenamableTo(type)) { - throw new IllegalArgumentException("IndexedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); - } - return new IndexedTensor(type, dimensionSizes, values); - } + public abstract IndexedTensor withType(TensorType type); public DimensionSizes dimensionSizes() { return dimensionSizes; @@ -179,13 +165,13 @@ public class IndexedTensor implements Tensor { @Override public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) - return Collections.singletonMap(TensorAddress.of(), values[0]); + return Collections.singletonMap(TensorAddress.of(), get(0)); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); - Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); - for (long i = 0; i < values.length; i++) { + Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size()); + for (long i = 0; i < size(); i++) { indexes.next(); - builder.put(indexes.toAddress(), values[(int)i]); + builder.put(indexes.toAddress(), get(i)); } return builder.build(); } @@ -201,9 +187,6 @@ public class IndexedTensor implements Tensor { } @Override - public int hashCode() { return Arrays.hashCode(values); } - - @Override public String toString() { return Tensor.toStandardString(this); } @Override @@ -302,7 +285,7 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { - IndexedTensor tensor = new IndexedTensor(type, sizes, values); + IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO // prevent further modification sizes = null; values = null; @@ -348,12 +331,12 @@ public class IndexedTensor implements Tensor { if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values"); if (type.dimensions().isEmpty()) // single number - return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); + return new IndexedDoubleTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); double[] values = new double[(int)dimensionSizes.totalSize()]; fillValues(0, 0, firstDimension, dimensionSizes, values); - return new IndexedTensor(type, dimensionSizes, values); + return new IndexedDoubleTensor(type, dimensionSizes, values); } private DimensionSizes findDimensionSizes(List<Object> firstDimension) { @@ -460,7 +443,7 @@ public class IndexedTensor implements Tensor { private final class CellIterator implements Iterator<Cell> { private long count = 0; - private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); + private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, size()); private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN); @Override @@ -485,13 +468,13 @@ public class IndexedTensor implements Tensor { @Override public boolean hasNext() { - return count < values.length; + return count < size(); } @Override public Double next() { try { - return values[(int)count++]; + return get(count++); } catch (IndexOutOfBoundsException e) { throw new NoSuchElementException("No element at position " + count); |