summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2019-06-01 15:53:41 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2019-06-01 15:53:41 +0200
commit443437a83cd1c3b4d55c732e8756d5c0b1595902 (patch)
treef8883230319e09c561059c89010e5d4f25a19063 /vespajlib
parent9538c19b84ffcea70e7254855bd05ada1402a56f (diff)
Allow passing your own vector without copy to the IndexedTensor.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java77
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java32
4 files changed, 111 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 285837a1bc6..e0cb3dca969 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -43,8 +43,11 @@ class IndexedDoubleTensor extends IndexedTensor {
private double[] values;
BoundDoubleBuilder(TensorType type, DimensionSizes sizes) {
+ this(type, sizes, new double[(int)sizes.totalSize()]);
+ }
+ BoundDoubleBuilder(TensorType type, DimensionSizes sizes, double [] values) {
super(type, sizes);
- values = new double[(int)sizes.totalSize()];
+ this.values = values;
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
index 8f8c24c8421..56cb22da7a5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
@@ -43,8 +43,15 @@ class IndexedFloatTensor extends IndexedTensor {
private float[] values;
BoundFloatBuilder(TensorType type, DimensionSizes sizes) {
+ this(type, sizes, new float[(int)sizes.totalSize()]);
+ }
+ BoundFloatBuilder(TensorType type, DimensionSizes sizes, float [] values) {
super(type, sizes);
- values = new float[(int)sizes.totalSize()];
+ if (sizes.totalSize() != values.length) {
+ throw new IllegalArgumentException("Invalid size(" + values.length + ") of supplied value vector." +
+ " Type specifies that size should be " + sizes.totalSize());
+ }
+ this.values = values;
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 19edfc0269e..b43993be732 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -234,6 +234,18 @@ public abstract class IndexedTensor implements Tensor {
else
return new UnboundBuilder(type);
}
+ public static Builder of(TensorType type, float [] values) {
+ if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ return of(type, BoundBuilder.dimensionSizesOf(type), values);
+ else
+ return new UnboundBuilder(type);
+ }
+ public static Builder of(TensorType type, double [] values) {
+ if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ return of(type, BoundBuilder.dimensionSizesOf(type), values);
+ else
+ return new UnboundBuilder(type);
+ }
/**
* Create a builder with dimension size information for this instance. Must be one size entry per dimension,
@@ -241,24 +253,55 @@ public abstract class IndexedTensor implements Tensor {
* If sizes are completely specified in the type this size information is redundant.
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
+ validate(type, sizes);
+
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
+ }
+ public static Builder of(TensorType type, DimensionSizes sizes, float [] values) {
+ validate(type, sizes);
+ validateSizes(sizes, values.length);
+
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ }
+ public static Builder of(TensorType type, DimensionSizes sizes, double [] values) {
+ validate(type, sizes);
+ validateSizes(sizes, values.length);
+
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
+ }
+ private static void validateSizes(DimensionSizes sizes, int length) {
+ if (sizes.totalSize() != length) {
+ throw new IllegalArgumentException("Invalid size(" + length + ") of supplied value vector." +
+ " Type specifies that size should be " + sizes.totalSize());
+ }
+ }
+ private static void validate(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException(sizes.dimensions() +
- " is the wrong number of dimensions for " + type);
+ " is the wrong number of dimensions for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
- sizes.size(i) +
- " but cannot be larger than " + size.get() + " in " + type);
+ sizes.size(i) +
+ " but cannot be larger than " + size.get() + " in " + type);
}
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
}
public abstract Builder cell(double value, long ... indexes);
@@ -290,6 +333,20 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
}
+ BoundBuilder fill(float [] values) {
+ long index = 0;
+ for (float value : values) {
+ cellByDirectIndex(index++, value);
+ }
+ return this;
+ }
+ BoundBuilder fill(double [] values) {
+ long index = 0;
+ for (double value : values) {
+ cellByDirectIndex(index++, value);
+ }
+ return this;
+ }
DimensionSizes sizes() { return sizes; }
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index a5fc3d5a5d8..4bfdb53e321 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -10,6 +10,7 @@ import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
/**
* @author bratseth
@@ -41,6 +42,37 @@ public class IndexedTensorTestCase {
assertTrue(singleValueFromString instanceof IndexedTensor);
assertEquals(singleValue, singleValueFromString);
}
+
+ private void verifyFloat(String spec) {
+ float [] floats = {1.0f, 2.0f, 3.0f};
+ Tensor tensor = IndexedTensor.Builder.of(TensorType.fromSpec(spec), floats).build();
+ int index = 0;
+ for (Double cell : tensor.cells().values()) {
+ assertEquals(cell, Double.valueOf(floats[index++]));
+ }
+ }
+ private void verifyDouble(String spec) {
+ double [] values = {1.0, 2.0, 3.0};
+ Tensor tensor = IndexedTensor.Builder.of(TensorType.fromSpec(spec), values).build();
+ int index = 0;
+ for (Double cell : tensor.cells().values()) {
+ assertEquals(cell, Double.valueOf(values[index++]));
+ }
+ }
+
+ @Test
+ public void testBoundHandoverBuilding() {
+ verifyFloat("tensor<float>(x[3])");
+ verifyDouble("tensor<float>(x[3])");
+ verifyFloat("tensor<double>(x[3])");
+ verifyDouble("tensor<double>(x[3])");
+ try {
+ verifyDouble("tensor<double>(x[4])");
+ fail("Expect IllegalArgumentException");
+ } catch (IllegalArgumentException e) {
+ assertEquals("Invalid size(3) of supplied value vector. Type specifies that size should be 4", e.getMessage());
+ }
+ }
@Test
public void testBoundBuilding() {