From 443437a83cd1c3b4d55c732e8756d5c0b1595902 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Sat, 1 Jun 2019 15:53:41 +0200 Subject: Allow passing your own vector without copy to the IndexedTensor. --- .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 5 +- .../java/com/yahoo/tensor/IndexedFloatTensor.java | 9 ++- .../main/java/com/yahoo/tensor/IndexedTensor.java | 77 +++++++++++++++++++--- .../com/yahoo/tensor/IndexedTensorTestCase.java | 32 +++++++++ 4 files changed, 111 insertions(+), 12 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 285837a1bc6..e0cb3dca969 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -43,8 +43,11 @@ class IndexedDoubleTensor extends IndexedTensor { private double[] values; BoundDoubleBuilder(TensorType type, DimensionSizes sizes) { + this(type, sizes, new double[(int)sizes.totalSize()]); + } + BoundDoubleBuilder(TensorType type, DimensionSizes sizes, double [] values) { super(type, sizes); - values = new double[(int)sizes.totalSize()]; + this.values = values; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 8f8c24c8421..56cb22da7a5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -43,8 +43,15 @@ class IndexedFloatTensor extends IndexedTensor { private float[] values; BoundFloatBuilder(TensorType type, DimensionSizes sizes) { + this(type, sizes, new float[(int)sizes.totalSize()]); + } + BoundFloatBuilder(TensorType type, DimensionSizes sizes, float [] values) { super(type, sizes); - values = new float[(int)sizes.totalSize()]; + if (sizes.totalSize() != values.length) { + throw new IllegalArgumentException("Invalid size(" + values.length + ") of supplied value vector." + + " Type specifies that size should be " + sizes.totalSize()); + } + this.values = values; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 19edfc0269e..b43993be732 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -234,6 +234,18 @@ public abstract class IndexedTensor implements Tensor { else return new UnboundBuilder(type); } + public static Builder of(TensorType type, float [] values) { + if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + return of(type, BoundBuilder.dimensionSizesOf(type), values); + else + return new UnboundBuilder(type); + } + public static Builder of(TensorType type, double [] values) { + if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + return of(type, BoundBuilder.dimensionSizesOf(type), values); + else + return new UnboundBuilder(type); + } /** * Create a builder with dimension size information for this instance. Must be one size entry per dimension, @@ -241,24 +253,55 @@ public abstract class IndexedTensor implements Tensor { * If sizes are completely specified in the type this size information is redundant. */ public static Builder of(TensorType type, DimensionSizes sizes) { + validate(type, sizes); + + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.DOUBLE) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default + } + public static Builder of(TensorType type, DimensionSizes sizes, float [] values) { + validate(type, sizes); + validateSizes(sizes, values.length); + + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + else if (type.valueType() == TensorType.Value.DOUBLE) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + } + public static Builder of(TensorType type, DimensionSizes sizes, double [] values) { + validate(type, sizes); + validateSizes(sizes, values.length); + + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.DOUBLE) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default + } + private static void validateSizes(DimensionSizes sizes, int length) { + if (sizes.totalSize() != length) { + throw new IllegalArgumentException("Invalid size(" + length + ") of supplied value vector." + + " Type specifies that size should be " + sizes.totalSize()); + } + } + private static void validate(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException(sizes.dimensions() + - " is the wrong number of dimensions for " + type); + " is the wrong number of dimensions for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + - sizes.size(i) + - " but cannot be larger than " + size.get() + " in " + type); + sizes.size(i) + + " but cannot be larger than " + size.get() + " in " + type); } - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default } public abstract Builder cell(double value, long ... indexes); @@ -290,6 +333,20 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; } + BoundBuilder fill(float [] values) { + long index = 0; + for (float value : values) { + cellByDirectIndex(index++, value); + } + return this; + } + BoundBuilder fill(double [] values) { + long index = 0; + for (double value : values) { + cellByDirectIndex(index++, value); + } + return this; + } DimensionSizes sizes() { return sizes; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index a5fc3d5a5d8..4bfdb53e321 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -10,6 +10,7 @@ import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; /** * @author bratseth @@ -41,6 +42,37 @@ public class IndexedTensorTestCase { assertTrue(singleValueFromString instanceof IndexedTensor); assertEquals(singleValue, singleValueFromString); } + + private void verifyFloat(String spec) { + float [] floats = {1.0f, 2.0f, 3.0f}; + Tensor tensor = IndexedTensor.Builder.of(TensorType.fromSpec(spec), floats).build(); + int index = 0; + for (Double cell : tensor.cells().values()) { + assertEquals(cell, Double.valueOf(floats[index++])); + } + } + private void verifyDouble(String spec) { + double [] values = {1.0, 2.0, 3.0}; + Tensor tensor = IndexedTensor.Builder.of(TensorType.fromSpec(spec), values).build(); + int index = 0; + for (Double cell : tensor.cells().values()) { + assertEquals(cell, Double.valueOf(values[index++])); + } + } + + @Test + public void testBoundHandoverBuilding() { + verifyFloat("tensor(x[3])"); + verifyDouble("tensor(x[3])"); + verifyFloat("tensor(x[3])"); + verifyDouble("tensor(x[3])"); + try { + verifyDouble("tensor(x[4])"); + fail("Expect IllegalArgumentException"); + } catch (IllegalArgumentException e) { + assertEquals("Invalid size(3) of supplied value vector. Type specifies that size should be 4", e.getMessage()); + } + } @Test public void testBoundBuilding() { -- cgit v1.2.3