diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 79 |
1 files changed, 22 insertions, 57 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 5f2c04bbd56..6e587b05460 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -118,7 +118,7 @@ public abstract class IndexedTensor implements Tensor { */ public abstract double get(long valueIndex); - private static long toValueIndex(long[] indexes, DimensionSizes sizes) { + static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed @@ -132,7 +132,7 @@ public abstract class IndexedTensor implements Tensor { return valueIndex; } - private static long toValueIndex(TensorAddress address, DimensionSizes sizes) { + static long toValueIndex(TensorAddress address, DimensionSizes sizes) { if (address.isEmpty()) return 0; long valueIndex = 0; @@ -152,6 +152,12 @@ public abstract class IndexedTensor implements Tensor { return product; } + void throwOnIncompatibleType(TensorType type) { + if ( ! this.type().isRenamableTo(type)) + throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + + ": Types are not compatible"); + } + @Override public TensorType type() { return type; } @@ -205,7 +211,7 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type) { if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) - return new BoundBuilder(type); + return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); } @@ -218,8 +224,8 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) - throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + - "for " + type); + throw new IllegalArgumentException(sizes.dimensions() + + " is the wrong number of dimensions for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional<Long> size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) @@ -228,7 +234,13 @@ public abstract class IndexedTensor implements Tensor { " but cannot be larger than " + size.get() + " in " + type); } - return new BoundBuilder(type, sizes); + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + // return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); TODO + else if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default } public abstract Builder cell(double value, long ... indexes); @@ -242,14 +254,9 @@ public abstract class IndexedTensor implements Tensor { } /** A bound builder can create the double array directly */ - public static class BoundBuilder extends Builder { + public static abstract class BoundBuilder extends Builder { private DimensionSizes sizes; - private double[] values; - - private BoundBuilder(TensorType type) { - this(type, dimensionSizesOf(type)); - } static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); @@ -258,58 +265,16 @@ public abstract class IndexedTensor implements Tensor { return b.build(); } - private BoundBuilder(TensorType type, DimensionSizes sizes) { + BoundBuilder(TensorType type, DimensionSizes sizes) { super(type); if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = new double[(int)sizes.totalSize()]; - } - - @Override - public BoundBuilder cell(double value, long ... indexes) { - values[(int)toValueIndex(indexes, sizes)] = value; - return this; - } - - @Override - public CellBuilder cell() { - return new CellBuilder(type, this); - } - - @Override - public Builder cell(TensorAddress address, double value) { - values[(int)toValueIndex(address, sizes)] = value; - return this; } - @Override - public IndexedTensor build() { - IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO - // prevent further modification - sizes = null; - values = null; - return tensor; - } + DimensionSizes sizes() { return sizes; } - @Override - public Builder cell(Cell cell, double value) { - long directIndex = cell.getDirectIndex(); - if (directIndex >= 0) // optimization - values[(int)directIndex] = value; - else - super.cell(cell, value); - return this; - } - - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ - public void cellByDirectIndex(long index, double value) { - values[(int)index] = value; - } + public abstract void cellByDirectIndex(long index, double value); } |