diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-10 15:55:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-10 15:55:53 +0100 |
commit | 451e7cf03729b7a09c8e4f9457edf9ae1007ba8a (patch) | |
tree | 5c62016b68eeecf06cbb205cc349712ef36a93c5 /vespajlib/src/main/java | |
parent | 14a0470694ea7f24b8ef007783432a6f532e42ba (diff) |
Use MappedTensor to represent tensor with no dimensions or values
Diffstat (limited to 'vespajlib/src/main/java')
5 files changed, 24 insertions, 27 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 4654f53647f..deee4aa02b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -103,7 +103,6 @@ public class IndexedTensor implements Tensor { * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(int ... indexes) { - if (values.length == 0) return Double.NaN; return values[toValueIndex(indexes, dimensionSizes)]; } @@ -157,7 +156,7 @@ public class IndexedTensor implements Tensor { @Override public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) - return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); + return Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); @@ -221,7 +220,7 @@ public class IndexedTensor implements Tensor { public TensorType type() { return type; } @Override - public abstract IndexedTensor build(); + public abstract Tensor build(); } @@ -269,11 +268,14 @@ public class IndexedTensor implements Tensor { } @Override - public IndexedTensor build() { + public Tensor build() { // Note that we do not check for no NaN's here for performance reasons. // NaN's don't get lost so leaving them in place should be quite benign - if (values.length == 1 && Double.isNaN(values[0])) - values = new double[0]; + + // An empty tensor with no dimensions is mapped + if (values.length == 1 && Double.isNaN(values[0]) && type.dimensions().isEmpty()) + return MappedTensor.Builder.of(type).build(); + IndexedTensor tensor = new IndexedTensor(type, sizes, values); // prevent further modification sizes = null; @@ -316,24 +318,28 @@ public class IndexedTensor implements Tensor { } @Override - public IndexedTensor build() { - if (firstDimension == null) // empty - return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {}); + public Tensor build() { + if (firstDimension == null && type.dimensions().isEmpty()) // empty + return MappedTensor.Builder.of(type).build(); if (type.dimensions().isEmpty()) // single number return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); double[] values = new double[dimensionSizes.totalSize()]; - fillValues(0, 0, firstDimension, dimensionSizes, values); + if (firstDimension != null) + fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } private DimensionSizes findDimensionSizes(List<Object> firstDimension) { List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); - findDimensionSizes(0, dimensionSizeList, firstDimension); + if (firstDimension != null) + findDimensionSizes(0, dimensionSizeList, firstDimension); DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct - for (int i = 0; i < b.dimensions(); i++) - b.set(i, dimensionSizeList.get(i)); + for (int i = 0; i < b.dimensions(); i++) { + if (i < dimensionSizeList.size()) + b.set(i, dimensionSizeList.get(i)); + } return b.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 51d40a89f3b..29c508ce12f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -213,10 +213,9 @@ public interface Tensor { static String contentToString(Tensor tensor) { List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); - if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number + if (tensor.type().dimensions().isEmpty()) { if (cellEntries.isEmpty()) return "{}"; - double value = cellEntries.get(0).getValue(); - return value == 0.0 ? "{}" : "{" + value +"}"; + return "{" + cellEntries.get(0).getValue() +"}"; } Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 82f36972a47..fbc469c1829 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,9 +53,6 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } - /** Returns true if all dimensions of this are indexed */ - public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); } - /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ceade39ce42..f295e129a0f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction { /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { - if (subspace.type().isIndexed() && superspace.type().isIndexed()) + if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); else return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index c3284131be0..0a97576d5b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -41,13 +41,8 @@ public class DenseBinaryFormat implements BinaryFormat { private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { Iterator<Double> i = tensor.valueIterator(); - if ( ! i.hasNext()) { // no values: Encode as NaN, as 0 dimensions may also mean 1 value - buffer.putDouble(Double.NaN); - } - else { - while (i.hasNext()) - buffer.putDouble(i.next()); - } + while (i.hasNext()) + buffer.putDouble(i.next()); } @Override |