diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 77 |
1 files changed, 67 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 19edfc0269e..b43993be732 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -234,6 +234,18 @@ public abstract class IndexedTensor implements Tensor { else return new UnboundBuilder(type); } + public static Builder of(TensorType type, float [] values) { + if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + return of(type, BoundBuilder.dimensionSizesOf(type), values); + else + return new UnboundBuilder(type); + } + public static Builder of(TensorType type, double [] values) { + if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + return of(type, BoundBuilder.dimensionSizesOf(type), values); + else + return new UnboundBuilder(type); + } /** * Create a builder with dimension size information for this instance. Must be one size entry per dimension, @@ -241,24 +253,55 @@ public abstract class IndexedTensor implements Tensor { * If sizes are completely specified in the type this size information is redundant. */ public static Builder of(TensorType type, DimensionSizes sizes) { + validate(type, sizes); + + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.DOUBLE) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default + } + public static Builder of(TensorType type, DimensionSizes sizes, float [] values) { + validate(type, sizes); + validateSizes(sizes, values.length); + + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + else if (type.valueType() == TensorType.Value.DOUBLE) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + } + public static Builder of(TensorType type, DimensionSizes sizes, double [] values) { + validate(type, sizes); + validateSizes(sizes, values.length); + + if (type.valueType() == TensorType.Value.FLOAT) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.DOUBLE) + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + else + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default + } + private static void validateSizes(DimensionSizes sizes, int length) { + if (sizes.totalSize() != length) { + throw new IllegalArgumentException("Invalid size(" + length + ") of supplied value vector." + + " Type specifies that size should be " + sizes.totalSize()); + } + } + private static void validate(TensorType type, DimensionSizes sizes) { // validate if (sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException(sizes.dimensions() + - " is the wrong number of dimensions for " + type); + " is the wrong number of dimensions for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { Optional<Long> size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + - sizes.size(i) + - " but cannot be larger than " + size.get() + " in " + type); + sizes.size(i) + + " but cannot be larger than " + size.get() + " in " + type); } - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default } public abstract Builder cell(double value, long ... indexes); @@ -290,6 +333,20 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; } + BoundBuilder fill(float [] values) { + long index = 0; + for (float value : values) { + cellByDirectIndex(index++, value); + } + return this; + } + BoundBuilder fill(double [] values) { + long index = 0; + for (double value : values) { + cellByDirectIndex(index++, value); + } + return this; + } DimensionSizes sizes() { return sizes; } |