aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-26 11:26:04 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-26 11:26:04 +0200
commit3873424bb18acd179441cdd914070c32e41699ee (patch)
treebd843dc77a84a7a30f65405f17334dd3907defee /vespajlib
parent7ef86b1fb25f2268d00fa3af87bc1e594de0b1b3 (diff)
Move bound builder double array into double subclass
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java65
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java44
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java79
3 files changed, 128 insertions, 60 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 27cecdab80c..80350d9e5f5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -34,13 +34,72 @@ class IndexedDoubleTensor extends IndexedTensor {
@Override
public IndexedTensor withType(TensorType type) {
- if ( ! this.type().isRenamableTo(type))
- throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
- ": Types are not compatible");
+ throwOnIncompatibleType(type);
return new IndexedDoubleTensor(type, dimensionSizes(), values);
}
@Override
public int hashCode() { return Arrays.hashCode(values); }
+ /** A bound builder can create the double array directly */
+ public static class BoundDoubleBuilder extends BoundBuilder {
+
+ private double[] values;
+
+ BoundDoubleBuilder(TensorType type) {
+ this(type, dimensionSizesOf(type));
+ }
+
+ BoundDoubleBuilder(TensorType type, DimensionSizes sizes) {
+ super(type, sizes);
+ values = new double[(int)sizes.totalSize()];
+ }
+
+ @Override
+ public IndexedTensor.BoundBuilder cell(double value, long ... indexes) {
+ values[(int)toValueIndex(indexes, sizes())] = value;
+ return this;
+ }
+
+ @Override
+ public CellBuilder cell() {
+ return new CellBuilder(type, this);
+ }
+
+ @Override
+ public Builder cell(TensorAddress address, double value) {
+ values[(int)toValueIndex(address, sizes())] = value;
+ return this;
+ }
+
+ @Override
+ public IndexedTensor build() {
+ IndexedTensor tensor = new IndexedDoubleTensor(type, sizes(), values);
+ // prevent further modification
+ values = null;
+ return tensor;
+ }
+
+ @Override
+ public Builder cell(Cell cell, double value) {
+ long directIndex = cell.getDirectIndex();
+ if (directIndex >= 0) // optimization
+ values[(int)directIndex] = value;
+ else
+ super.cell(cell, value);
+ return this;
+ }
+
+ /**
+ * Set a cell value by the index in the internal layout of this cell.
+ * This requires knowledge of the internal layout of cells in this implementation, and should therefore
+ * probably not be used (but when it can be used it is fast).
+ */
+ @Override
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
+ }
+
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
new file mode 100644
index 00000000000..563d72137e7
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
@@ -0,0 +1,44 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.Arrays;
+
+/**
+ * An indexed tensor implementation holding values as floats
+ *
+ * @author bratseth
+ */
+class IndexedFloatTensor extends IndexedTensor {
+
+ private final float[] values;
+
+ IndexedFloatTensor(TensorType type, DimensionSizes dimensionSizes, float[] values) {
+ super(type, dimensionSizes);
+ this.values = values;
+ }
+
+ @Override
+ public long size() {
+ return values.length;
+ }
+
+ /**
+ * Returns the value at the given index by direct lookup. Only use
+ * if you know the underlying data layout.
+ *
+ * @param valueIndex the direct index into the underlying data.
+ * @throws IndexOutOfBoundsException if index is out of bounds
+ */
+ @Override
+ public double get(long valueIndex) { return values[(int)valueIndex]; }
+
+ @Override
+ public IndexedTensor withType(TensorType type) {
+ throwOnIncompatibleType(type);
+ return new IndexedFloatTensor(type, dimensionSizes(), values);
+ }
+
+ @Override
+ public int hashCode() { return Arrays.hashCode(values); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 5f2c04bbd56..6e587b05460 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -118,7 +118,7 @@ public abstract class IndexedTensor implements Tensor {
*/
public abstract double get(long valueIndex);
- private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
+ static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
if (indexes.length == 0) return 0; // for speed
@@ -132,7 +132,7 @@ public abstract class IndexedTensor implements Tensor {
return valueIndex;
}
- private static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
+ static long toValueIndex(TensorAddress address, DimensionSizes sizes) {
if (address.isEmpty()) return 0;
long valueIndex = 0;
@@ -152,6 +152,12 @@ public abstract class IndexedTensor implements Tensor {
return product;
}
+ void throwOnIncompatibleType(TensorType type) {
+ if ( ! this.type().isRenamableTo(type))
+ throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
+ ": Types are not compatible");
+ }
+
@Override
public TensorType type() { return type; }
@@ -205,7 +211,7 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type) {
if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
- return new BoundBuilder(type);
+ return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
}
@@ -218,8 +224,8 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
- throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " +
- "for " + type);
+ throw new IllegalArgumentException(sizes.dimensions() +
+ " is the wrong number of dimensions for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
@@ -228,7 +234,13 @@ public abstract class IndexedTensor implements Tensor {
" but cannot be larger than " + size.get() + " in " + type);
}
- return new BoundBuilder(type, sizes);
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ // return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); TODO
+ else if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
}
public abstract Builder cell(double value, long ... indexes);
@@ -242,14 +254,9 @@ public abstract class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- public static class BoundBuilder extends Builder {
+ public static abstract class BoundBuilder extends Builder {
private DimensionSizes sizes;
- private double[] values;
-
- private BoundBuilder(TensorType type) {
- this(type, dimensionSizesOf(type));
- }
static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
@@ -258,58 +265,16 @@ public abstract class IndexedTensor implements Tensor {
return b.build();
}
- private BoundBuilder(TensorType type, DimensionSizes sizes) {
+ BoundBuilder(TensorType type, DimensionSizes sizes) {
super(type);
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[(int)sizes.totalSize()];
- }
-
- @Override
- public BoundBuilder cell(double value, long ... indexes) {
- values[(int)toValueIndex(indexes, sizes)] = value;
- return this;
- }
-
- @Override
- public CellBuilder cell() {
- return new CellBuilder(type, this);
- }
-
- @Override
- public Builder cell(TensorAddress address, double value) {
- values[(int)toValueIndex(address, sizes)] = value;
- return this;
}
- @Override
- public IndexedTensor build() {
- IndexedTensor tensor = new IndexedDoubleTensor(type, sizes, values); // TODO
- // prevent further modification
- sizes = null;
- values = null;
- return tensor;
- }
+ DimensionSizes sizes() { return sizes; }
- @Override
- public Builder cell(Cell cell, double value) {
- long directIndex = cell.getDirectIndex();
- if (directIndex >= 0) // optimization
- values[(int)directIndex] = value;
- else
- super.cell(cell, value);
- return this;
- }
-
- /**
- * Set a cell value by the index in the internal layout of this cell.
- * This requires knowledge of the internal layout of cells in this implementation, and should therefore
- * probably not be used (but when it can be used it is fast).
- */
- public void cellByDirectIndex(long index, double value) {
- values[(int)index] = value;
- }
+ public abstract void cellByDirectIndex(long index, double value);
}