diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-26 11:34:13 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-26 11:34:13 +0200 |
commit | 94b4b3ad837f9d3f9d43b158c4de8475ff2c2a2d (patch) | |
tree | 583eb30a7c463699b414a527fd230e0fabc32fd0 /vespajlib | |
parent | 3873424bb18acd179441cdd914070c32e41699ee (diff) |
Make float builder when appropriate
Diffstat (limited to 'vespajlib')
3 files changed, 59 insertions, 7 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 80350d9e5f5..5d5c2be4576 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -46,10 +46,6 @@ class IndexedDoubleTensor extends IndexedTensor { private double[] values; - BoundDoubleBuilder(TensorType type) { - this(type, dimensionSizesOf(type)); - } - BoundDoubleBuilder(TensorType type, DimensionSizes sizes) { super(type, sizes); values = new double[(int)sizes.totalSize()]; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 563d72137e7..1e2aed1f5b4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -41,4 +41,61 @@ class IndexedFloatTensor extends IndexedTensor { @Override public int hashCode() { return Arrays.hashCode(values); } + /** A bound builder can create the float array directly */ + public static class BoundFloatBuilder extends BoundBuilder { + + private float[] values; + + BoundFloatBuilder(TensorType type, DimensionSizes sizes) { + super(type, sizes); + values = new float[(int)sizes.totalSize()]; + } + + @Override + public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes())] = (float)value; + return this; + } + + @Override + public CellBuilder cell() { + return new CellBuilder(type, this); + } + + @Override + public Builder cell(TensorAddress address, double value) { + values[(int)toValueIndex(address, sizes())] = (float)value; + return this; + } + + @Override + public IndexedTensor build() { + IndexedTensor tensor = new IndexedFloatTensor(type, sizes(), values); + // prevent further modification + values = null; + return tensor; + } + + @Override + public Builder cell(Cell cell, double value) { + long directIndex = cell.getDirectIndex(); + if (directIndex >= 0) // optimization + values[(int)directIndex] = (float)value; + else + super.cell(cell, value); + return this; + } + + /** + * Set a cell value by the index in the internal layout of this cell. + * This requires knowledge of the internal layout of cells in this implementation, and should therefore + * probably not be used (but when it can be used it is fast). + */ + @Override + public void cellByDirectIndex(long index, double value) { + values[(int)index] = (float)value; + } + + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6e587b05460..8e2223def83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -235,8 +235,7 @@ public abstract class IndexedTensor implements Tensor { } if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - // return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); TODO + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); else if (type.valueType() == TensorType.Value.FLOAT) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else @@ -258,7 +257,7 @@ public abstract class IndexedTensor implements Tensor { private DimensionSizes sizes; - static DimensionSizes dimensionSizesOf(TensorType type) { + private static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < type.dimensions().size(); i++) b.set(i, type.dimensions().get(i).size().get()); |