summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java79
1 files changed, 22 insertions, 57 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 5f2c04bbd56..6e587b05460 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -118,7 +118,7 @@ public abstract class IndexedTensor implements Tensor {
*/
public abstract double get(long valueIndex);
- private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
+ static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
@@ -132,7 +132,7 @@ public abstract class IndexedTensor implements Tensor {
return valueIndex;
}
- private static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
+ static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
long valueIndex = 0;
@@ -152,6 +152,12 @@ public abstract class IndexedTensor implements Tensor {
return product;
}
+ void throwOnIncompatibleType(TensorType type) {
+ if ( ! this.type().isRenamableTo(type))
+ throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
+ ": Types are not compatible");
+ }
+
@Override
public TensorType type() { return type; }
@@ -205,7 +211,7 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type) {
if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
- return new BoundBuilder(type);
+ return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
}
@@ -218,8 +224,8 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
- throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
- "for " + type);
+ throw new IllegalArgumentException(sizes.dimensions() +
+ " is the wrong number of dimensions for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
@@ -228,7 +234,13 @@ public abstract class IndexedTensor implements Tensor {
" but cannot be larger than " + size.get() + " in " + type);
}
- return new BoundBuilder(type, sizes);
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ // return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); TODO
+ else if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
}
public abstract Builder cell(double value, long ... indexes);
@@ -242,14 +254,9 @@ public abstract class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- public static class BoundBuilder extends Builder {
+ public static abstract class BoundBuilder extends Builder {
private DimensionSizes sizes;
- private double[] values;
-
- private BoundBuilder(TensorType type) {
- this(type, dimensionSizesOf(type));
- }
static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
@@ -258,58 +265,16 @@ public abstract class IndexedTensor implements Tensor {
return b.build();
}
- private BoundBuilder(TensorType type, DimensionSizes sizes) {
+ BoundBuilder(TensorType type, DimensionSizes sizes) {
super(type);
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[(int)sizes.totalSize()];
- }
-
- @Override
- public BoundBuilder cell(double value, long ... indexes) {
- values[(int)toValueIndex(indexes, sizes)] = value;
- return this;
- }
-
- @Override
- public CellBuilder cell() {
- return new CellBuilder(type, this);
- }
-
- @Override
- public Builder cell(TensorAddress address, double value) {
- values[(int)toValueIndex(address, sizes)] = value;
- return this;
}
- @Override
- public IndexedTensor build() {
- IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO
- // prevent further modification
- sizes = null;
- values = null;
- return tensor;
- }
+ DimensionSizes sizes() { return sizes; }
- @Override
- public Builder cell(Cell cell, double value) {
- long directIndex = cell.getDirectIndex();
- if (directIndex >= 0) // optimization
- values[(int)directIndex] = value;
- else
- super.cell(cell, value);
- return this;
- }
-
- /**
- * Set a cell value by the index in the internal layout of this cell.
- * This requires knowledge of the internal layout of cells in this implementation, and should therefore
- * probably not be used (but when it can be used it is fast).
- */
- public void cellByDirectIndex(long index, double value) {
- values[(int)index] = value;
- }
+ public abstract void cellByDirectIndex(long index, double value);
}