diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-10 15:55:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-10 15:55:53 +0100 |
commit | 451e7cf03729b7a09c8e4f9457edf9ae1007ba8a (patch) | |
tree | 5c62016b68eeecf06cbb205cc349712ef36a93c5 /vespajlib | |
parent | 14a0470694ea7f24b8ef007783432a6f532e42ba (diff) |
Use MappedTensor to represent tensor with no dimensions or values
Diffstat (limited to 'vespajlib')
10 files changed, 44 insertions, 43 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 4654f53647f..deee4aa02b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -103,7 +103,6 @@ public class IndexedTensor implements Tensor { * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(int ... indexes) { - if (values.length == 0) return Double.NaN; return values[toValueIndex(indexes, dimensionSizes)]; } @@ -157,7 +156,7 @@ public class IndexedTensor implements Tensor { @Override public Map<TensorAddress, Double> cells() { if (dimensionSizes.dimensions() == 0) - return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]); + return Collections.singletonMap(TensorAddress.empty, values[0]); ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); @@ -221,7 +220,7 @@ public class IndexedTensor implements Tensor { public TensorType type() { return type; } @Override - public abstract IndexedTensor build(); + public abstract Tensor build(); } @@ -269,11 +268,14 @@ public class IndexedTensor implements Tensor { } @Override - public IndexedTensor build() { + public Tensor build() { // Note that we do not check for no NaN's here for performance reasons. // NaN's don't get lost so leaving them in place should be quite benign - if (values.length == 1 && Double.isNaN(values[0])) - values = new double[0]; + + // An empty tensor with no dimensions is mapped + if (values.length == 1 && Double.isNaN(values[0]) && type.dimensions().isEmpty()) + return MappedTensor.Builder.of(type).build(); + IndexedTensor tensor = new IndexedTensor(type, sizes, values); // prevent further modification sizes = null; @@ -316,24 +318,28 @@ public class IndexedTensor implements Tensor { } @Override - public IndexedTensor build() { - if (firstDimension == null) // empty - return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {}); + public Tensor build() { + if (firstDimension == null && type.dimensions().isEmpty()) // empty + return MappedTensor.Builder.of(type).build(); if (type.dimensions().isEmpty()) // single number return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); double[] values = new double[dimensionSizes.totalSize()]; - fillValues(0, 0, firstDimension, dimensionSizes, values); + if (firstDimension != null) + fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } private DimensionSizes findDimensionSizes(List<Object> firstDimension) { List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size()); - findDimensionSizes(0, dimensionSizeList, firstDimension); + if (firstDimension != null) + findDimensionSizes(0, dimensionSizeList, firstDimension); DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct - for (int i = 0; i < b.dimensions(); i++) - b.set(i, dimensionSizeList.get(i)); + for (int i = 0; i < b.dimensions(); i++) { + if (i < dimensionSizeList.size()) + b.set(i, dimensionSizeList.get(i)); + } return b.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 51d40a89f3b..29c508ce12f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -213,10 +213,9 @@ public interface Tensor { static String contentToString(Tensor tensor) { List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); - if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number + if (tensor.type().dimensions().isEmpty()) { if (cellEntries.isEmpty()) return "{}"; - double value = cellEntries.get(0).getValue(); - return value == 0.0 ? "{}" : "{" + value +"}"; + return "{" + cellEntries.get(0).getValue() +"}"; } Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 82f36972a47..fbc469c1829 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -53,9 +53,6 @@ public class TensorType { return TensorTypeParser.fromSpec(specString); } - /** Returns true if all dimensions of this are indexed */ - public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); } - /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ceade39ce42..f295e129a0f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction { /** Join a tensor into a superspace */ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { - if (subspace.type().isIndexed() && superspace.type().isIndexed()) + if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor) return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder); else return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index c3284131be0..0a97576d5b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -41,13 +41,8 @@ public class DenseBinaryFormat implements BinaryFormat { private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { Iterator<Double> i = tensor.valueIterator(); - if ( ! i.hasNext()) { // no values: Encode as NaN, as 0 dimensions may also mean 1 value - buffer.putDouble(Double.NaN); - } - else { - while (i.hasNext()) - buffer.putDouble(i.next()); - } + while (i.hasNext()) + buffer.putDouble(i.next()); } @Override diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index 3f7f02c6c00..01d1e6fc602 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -21,19 +21,6 @@ public class IndexedTensorTestCase { private final int zSize = 5; @Test - public void testEmpty() { - Tensor empty = Tensor.Builder.of(TensorType.empty).build(); - assertTrue(empty instanceof IndexedTensor); - assertTrue(empty.isEmpty()); - assertEquals("{}", empty.toString()); - Tensor emptyFromString = Tensor.from(TensorType.empty, "{}"); - assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString()); - assertTrue(emptyFromString.isEmpty()); - assertTrue(emptyFromString instanceof IndexedTensor); - assertEquals(empty, emptyFromString); - } - - @Test public void testSingleValue() { Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build(); assertTrue(singleValue instanceof IndexedTensor); @@ -91,7 +78,7 @@ public class IndexedTensorTestCase { for (int z = 0; z < zSize; z++) builder.cell(value(v, w, x, y, z), v, w, x, y, z); - IndexedTensor tensor = builder.build(); + IndexedTensor tensor = (IndexedTensor)builder.build(); // Lookup by index arguments for (int v = 0; v < vSize; v++) diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java index 4c32a80dc11..a2df146c8e1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.tensor; import com.google.common.collect.Sets; +import junit.framework.TestCase; import org.junit.Test; import java.util.Set; @@ -18,6 +19,19 @@ import static org.junit.Assert.fail; public class MappedTensorTestCase { @Test + public void testEmpty() { + Tensor empty = Tensor.Builder.of(TensorType.empty).build(); + TestCase.assertTrue(empty instanceof MappedTensor); + TestCase.assertTrue(empty.isEmpty()); + assertEquals("{}", empty.toString()); + Tensor emptyFromString = Tensor.from(TensorType.empty, "{}"); + assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString()); + TestCase.assertTrue(emptyFromString.isEmpty()); + TestCase.assertTrue(emptyFromString instanceof MappedTensor); + assertEquals(empty, emptyFromString); + } + + @Test public void testOneDimensionalBuilding() { TensorType type = new TensorType.Builder().mapped("x").build(); Tensor tensor = Tensor.Builder.of(type). diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index feeba1a7a10..e649d3cde2a 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -21,7 +21,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; /** - * Tests Tensor functionality + * Tests tensor functionality * * @author bratseth */ @@ -30,6 +30,9 @@ public class TensorTestCase { @Test public void testStringForm() { assertEquals("{}", Tensor.from("{}").toString()); + assertTrue(Tensor.from("{}") instanceof MappedTensor); + assertEquals("{5.7}", Tensor.from("{5.7}").toString()); + assertTrue(Tensor.from("{5.7}") instanceof IndexedTensor); assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java index 697eb2a7329..d2b2044f3ed 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java @@ -20,7 +20,6 @@ public class DenseBinaryFormatTestCase { @Test public void testSerialization() { - assertSerialization("{}"); assertSerialization("{-5.37}"); assertSerialization("tensor(x[]):{{x:0}:2.0}"); assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0}"); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index b314fe06f08..283aa90cf65 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -19,6 +19,7 @@ public class SparseBinaryFormatTestCase { @Test public void testSerialization() { + assertSerialization("tensor(x{}):{}"); assertSerialization("tensor(x{}):{{x:0}:2.0}"); assertSerialization("tensor(dimX{},dimY{}):{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}"); assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}"); |