diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 45 |
1 files changed, 25 insertions, 20 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index bee93ddb4e0..9315922f57a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -103,6 +103,7 @@ public class IndexedTensor implements Tensor { * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(int ... indexes) { + if (values.length == 0) return Double.NaN; return values[toValueIndex(indexes, dimensionSizes)]; } @@ -156,7 +157,7 @@ public class IndexedTensor implements Tensor { @Override public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) - return Collections.singletonMap(TensorAddress.empty, values[0]); + return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); @@ -216,6 +217,13 @@ public class IndexedTensor implements Tensor { public abstract Builder cell(double value, int ... indexes); + protected double[] arrayFor(DimensionSizes sizes) { + int productSize = 1; + for (int i = 0; i < sizes.dimensions(); i++ ) + productSize *= sizes.size(i); + return new double[productSize]; + } + @Override public TensorType type() { return type; } @@ -225,7 +233,7 @@ public class IndexedTensor implements Tensor { } /** A bound builder can create the double array directly */ - public static class BoundBuilder extends Builder { + private static class BoundBuilder extends Builder { private DimensionSizes sizes; private double[] values; @@ -234,7 +242,7 @@ public class IndexedTensor implements Tensor { this(type, dimensionSizesOf(type)); } - static DimensionSizes dimensionSizesOf(TensorType type) { + public static DimensionSizes dimensionSizesOf(TensorType type) { DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); for (int i = 0; i < type.dimensions().size(); i++) b.set(i, type.dimensions().get(i).size().get()); @@ -246,7 +254,8 @@ public class IndexedTensor implements Tensor { if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = new double[sizes.totalSize()]; + values = arrayFor(sizes); + Arrays.fill(values, Double.NaN); } @Override @@ -268,6 +277,10 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { + // Note that we do not check for no NaN's here for performance reasons. + // NaN's don't get lost so leaving them in place should be quite benign + if (values.length == 1 && Double.isNaN(values[0])) + values = new double[0]; IndexedTensor tensor = new IndexedTensor(type, sizes, values); // prevent further modification sizes = null; @@ -277,6 +290,9 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(Cell cell, double value) { + // TODO: Use internal index if applicable + // values[internalIndex] = value; + // return this; int directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization values[directIndex] = value; @@ -285,15 +301,6 @@ public class IndexedTensor implements Tensor { return this; } - /** - * Set a cell value by the index in the internal layout of this cell. - * This requires knowledge of the internal layout of cells in this implementation, and should therefore - * probably not be used (but when it can be used it is fast). - */ - public void cellByDirectIndex(int index, double value) { - values[index] = value; - } - } /** @@ -311,13 +318,13 @@ public class IndexedTensor implements Tensor { @Override public IndexedTensor build() { - if (firstDimension == null) throw new IllegalArgumentException("Tensor of type " + type() + " has no values"); - + if (firstDimension == null) // empty + return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {}); if (type.dimensions().isEmpty()) // single number return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); - double[] values = new double[dimensionSizes.totalSize()]; + double[] values = arrayFor(dimensionSizes); fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } @@ -326,10 +333,8 @@ public class IndexedTensor implements Tensor { List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct - for (int i = 0; i < b.dimensions(); i++) { - if (i < dimensionSizeList.size()) - b.set(i, dimensionSizeList.get(i)); - } + for (int i = 0; i < b.dimensions(); i++) + b.set(i, dimensionSizeList.get(i)); return b.build(); } |