summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java111
1 files changed, 101 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 19edfc0269e..a03131f3ec9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -236,29 +236,106 @@ public abstract class IndexedTensor implements Tensor {
}
/**
+ * Creates a builder initialized with the given values
+ *
+ * @param type the type of the tensor to build
+ * @param values the initial values of the tensor. This <b>transfers ownership</b> of the value array - it
+ * must not be further mutated by the caller
+ */
+ public static Builder of(TensorType type, float[] values) {
+ if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ return of(type, BoundBuilder.dimensionSizesOf(type), values);
+ else
+ return new UnboundBuilder(type);
+ }
+
+ /**
+ * Creates a builder initialized with the given values
+ *
+ * @param type the type of the tensor to build
+ * @param values the initial values of the tensor. This <b>transfers ownership</b> of the value array - it
+ * must not be further mutated by the caller
+ */
+ public static Builder of(TensorType type, double[] values) {
+ if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ return of(type, BoundBuilder.dimensionSizesOf(type), values);
+ else
+ return new UnboundBuilder(type);
+ }
+
+ /**
* Create a builder with dimension size information for this instance. Must be one size entry per dimension,
* and, agree with the type size information when specified in the type.
* If sizes are completely specified in the type this size information is redundant.
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
+ validate(type, sizes);
+
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
+ }
+
+ /**
+ * Creates a builder initialized with the given values
+ *
+ * @param type the type of the tensor to build
+ * @param values the initial values of the tensor. This <b>transfers ownership</b> of the value array - it
+ * must not be further mutated by the caller
+ */
+ public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
+ validate(type, sizes);
+ validateSizes(sizes, values.length);
+
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ }
+
+ /**
+ * Creates a builder initialized with the given values
+ *
+ * @param type the type of the tensor to build
+ * @param values the initial values of the tensor. This <b>transfers ownership</b> of the value array - it
+ * must not be further mutated by the caller
+ */
+ public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
+ validate(type, sizes);
+ validateSizes(sizes, values.length);
+
+ if (type.valueType() == TensorType.Value.FLOAT)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.DOUBLE)
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ else
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
+ }
+
+ private static void validateSizes(DimensionSizes sizes, int length) {
+ if (sizes.totalSize() != length) {
+ throw new IllegalArgumentException("Invalid size(" + length + ") of supplied value vector." +
+ " Type specifies that size should be " + sizes.totalSize());
+ }
+ }
+
+ private static void validate(TensorType type, DimensionSizes sizes) {
// validate
if (sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException(sizes.dimensions() +
- " is the wrong number of dimensions for " + type);
+ " is the wrong number of dimensions for " + type);
for (int i = 0; i < sizes.dimensions(); i++ ) {
Optional<Long> size = type.dimensions().get(i).size();
if (size.isPresent() && size.get() < sizes.size(i))
throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " +
- sizes.size(i) +
- " but cannot be larger than " + size.get() + " in " + type);
+ sizes.size(i) +
+ " but cannot be larger than " + size.get() + " in " + type);
}
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
}
public abstract Builder cell(double value, long ... indexes);
@@ -290,6 +367,20 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
}
+ BoundBuilder fill(float [] values) {
+ long index = 0;
+ for (float value : values) {
+ cellByDirectIndex(index++, value);
+ }
+ return this;
+ }
+ BoundBuilder fill(double [] values) {
+ long index = 0;
+ for (double value : values) {
+ cellByDirectIndex(index++, value);
+ }
+ return this;
+ }
DimensionSizes sizes() { return sizes; }