summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java45
1 files changed, 25 insertions, 20 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index bee93ddb4e0..9315922f57a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -103,6 +103,7 @@ public class IndexedTensor implements Tensor {
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(int ... indexes) {
+ if (values.length == 0) return Double.NaN;
return values[toValueIndex(indexes, dimensionSizes)];
}
@@ -156,7 +157,7 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return Collections.singletonMap(TensorAddress.empty, values[0]);
+ return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
@@ -216,6 +217,13 @@ public class IndexedTensor implements Tensor {
public abstract Builder cell(double value, int ... indexes);
+ protected double[] arrayFor(DimensionSizes sizes) {
+ int productSize = 1;
+ for (int i = 0; i < sizes.dimensions(); i++ )
+ productSize *= sizes.size(i);
+ return new double[productSize];
+ }
+
@Override
public TensorType type() { return type; }
@@ -225,7 +233,7 @@ public class IndexedTensor implements Tensor {
}
/** A bound builder can create the double array directly */
- public static class BoundBuilder extends Builder {
+ private static class BoundBuilder extends Builder {
private DimensionSizes sizes;
private double[] values;
@@ -234,7 +242,7 @@ public class IndexedTensor implements Tensor {
this(type, dimensionSizesOf(type));
}
- static DimensionSizes dimensionSizesOf(TensorType type) {
+ public static DimensionSizes dimensionSizesOf(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < type.dimensions().size(); i++)
b.set(i, type.dimensions().get(i).size().get());
@@ -246,7 +254,8 @@ public class IndexedTensor implements Tensor {
if ( sizes.dimensions() != type.dimensions().size())
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
- values = new double[sizes.totalSize()];
+ values = arrayFor(sizes);
+ Arrays.fill(values, Double.NaN);
}
@Override
@@ -268,6 +277,10 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
+ // Note that we do not check for no NaN's here for performance reasons.
+ // NaN's don't get lost so leaving them in place should be quite benign
+ if (values.length == 1 && Double.isNaN(values[0]))
+ values = new double[0];
IndexedTensor tensor = new IndexedTensor(type, sizes, values);
// prevent further modification
sizes = null;
@@ -277,6 +290,9 @@ public class IndexedTensor implements Tensor {
@Override
public Builder cell(Cell cell, double value) {
+ // TODO: Use internal index if applicable
+ // values[internalIndex] = value;
+ // return this;
int directIndex = cell.getDirectIndex();
if (directIndex >= 0) // optimization
values[directIndex] = value;
@@ -285,15 +301,6 @@ public class IndexedTensor implements Tensor {
return this;
}
- /**
- * Set a cell value by the index in the internal layout of this cell.
- * This requires knowledge of the internal layout of cells in this implementation, and should therefore
- * probably not be used (but when it can be used it is fast).
- */
- public void cellByDirectIndex(int index, double value) {
- values[index] = value;
- }
-
}
/**
@@ -311,13 +318,13 @@ public class IndexedTensor implements Tensor {
@Override
public IndexedTensor build() {
- if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values");
-
+ if (firstDimension == null) // empty
+ return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {});
if (type.dimensions().isEmpty()) // single number
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
- double[] values = new double[dimensionSizes.totalSize()];
+ double[] values = arrayFor(dimensionSizes);
fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
@@ -326,10 +333,8 @@ public class IndexedTensor implements Tensor {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
- for (int i = 0; i < b.dimensions(); i++) {
- if (i < dimensionSizeList.size())
- b.set(i, dimensionSizeList.get(i));
- }
+ for (int i = 0; i < b.dimensions(); i++)
+ b.set(i, dimensionSizeList.get(i));
return b.build();
}