summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-10 15:55:53 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-10 15:55:53 +0100
commit451e7cf03729b7a09c8e4f9457edf9ae1007ba8a (patch)
tree5c62016b68eeecf06cbb205cc349712ef36a93c5 /vespajlib
parent14a0470694ea7f24b8ef007783432a6f532e42ba (diff)
Use MappedTensor to represent tensor with no dimensions or values
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java14
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java5
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java1
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java1
10 files changed, 44 insertions, 43 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 4654f53647f..deee4aa02b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -103,7 +103,6 @@ public class IndexedTensor implements Tensor {
* @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(int ... indexes) {
- if (values.length == 0) return Double.NaN;
return values[toValueIndex(indexes, dimensionSizes)];
}
@@ -157,7 +156,7 @@ public class IndexedTensor implements Tensor {
@Override
public Map<TensorAddress, Double> cells() {
if (dimensionSizes.dimensions() == 0)
- return values.length == 0 ? Collections.emptyMap() : Collections.singletonMap(TensorAddress.empty, values[0]);
+ return Collections.singletonMap(TensorAddress.empty, values[0]);
ImmutableMap.Builder<TensorAddress, Double> builder = new ImmutableMap.Builder<>();
Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length);
@@ -221,7 +220,7 @@ public class IndexedTensor implements Tensor {
public TensorType type() { return type; }
@Override
- public abstract IndexedTensor build();
+ public abstract Tensor build();
}
@@ -269,11 +268,14 @@ public class IndexedTensor implements Tensor {
}
@Override
- public IndexedTensor build() {
+ public Tensor build() {
// Note that we do not check for no NaN's here for performance reasons.
// NaN's don't get lost so leaving them in place should be quite benign
- if (values.length == 1 && Double.isNaN(values[0]))
- values = new double[0];
+
+ // An empty tensor with no dimensions is mapped
+ if (values.length == 1 && Double.isNaN(values[0]) && type.dimensions().isEmpty())
+ return MappedTensor.Builder.of(type).build();
+
IndexedTensor tensor = new IndexedTensor(type, sizes, values);
// prevent further modification
sizes = null;
@@ -316,24 +318,28 @@ public class IndexedTensor implements Tensor {
}
@Override
- public IndexedTensor build() {
- if (firstDimension == null) // empty
- return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {});
+ public Tensor build() {
+ if (firstDimension == null && type.dimensions().isEmpty()) // empty
+ return MappedTensor.Builder.of(type).build();
if (type.dimensions().isEmpty()) // single number
return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) });
DimensionSizes dimensionSizes = findDimensionSizes(firstDimension);
double[] values = new double[dimensionSizes.totalSize()];
- fillValues(0, 0, firstDimension, dimensionSizes, values);
+ if (firstDimension != null)
+ fillValues(0, 0, firstDimension, dimensionSizes, values);
return new IndexedTensor(type, dimensionSizes, values);
}
private DimensionSizes findDimensionSizes(List<Object> firstDimension) {
List<Integer> dimensionSizeList = new ArrayList<>(type.dimensions().size());
- findDimensionSizes(0, dimensionSizeList, firstDimension);
+ if (firstDimension != null)
+ findDimensionSizes(0, dimensionSizeList, firstDimension);
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct
- for (int i = 0; i < b.dimensions(); i++)
- b.set(i, dimensionSizeList.get(i));
+ for (int i = 0; i < b.dimensions(); i++) {
+ if (i < dimensionSizeList.size())
+ b.set(i, dimensionSizeList.get(i));
+ }
return b.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 51d40a89f3b..29c508ce12f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -213,10 +213,9 @@ public interface Tensor {
static String contentToString(Tensor tensor) {
List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
- if (tensor.type().dimensions().isEmpty()) { // TODO: Decide on one way to represent degeneration to number
+ if (tensor.type().dimensions().isEmpty()) {
if (cellEntries.isEmpty()) return "{}";
- double value = cellEntries.get(0).getValue();
- return value == 0.0 ? "{}" : "{" + value +"}";
+ return "{" + cellEntries.get(0).getValue() +"}";
}
Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 82f36972a47..fbc469c1829 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -53,9 +53,6 @@ public class TensorType {
return TensorTypeParser.fromSpec(specString);
}
- /** Returns true if all dimensions of this are indexed */
- public boolean isIndexed() { return dimensions().stream().allMatch(d -> d.isIndexed()); }
-
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ceade39ce42..f295e129a0f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -113,7 +113,7 @@ public class Join extends PrimitiveTensorFunction {
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
- if (subspace.type().isIndexed() && superspace.type().isIndexed())
+ if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
else
return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index c3284131be0..0a97576d5b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -41,13 +41,8 @@ public class DenseBinaryFormat implements BinaryFormat {
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
Iterator<Double> i = tensor.valueIterator();
- if ( ! i.hasNext()) { // no values: Encode as NaN, as 0 dimensions may also mean 1 value
- buffer.putDouble(Double.NaN);
- }
- else {
- while (i.hasNext())
- buffer.putDouble(i.next());
- }
+ while (i.hasNext())
+ buffer.putDouble(i.next());
}
@Override
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index 3f7f02c6c00..01d1e6fc602 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -21,19 +21,6 @@ public class IndexedTensorTestCase {
private final int zSize = 5;
@Test
- public void testEmpty() {
- Tensor empty = Tensor.Builder.of(TensorType.empty).build();
- assertTrue(empty instanceof IndexedTensor);
- assertTrue(empty.isEmpty());
- assertEquals("{}", empty.toString());
- Tensor emptyFromString = Tensor.from(TensorType.empty, "{}");
- assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString());
- assertTrue(emptyFromString.isEmpty());
- assertTrue(emptyFromString instanceof IndexedTensor);
- assertEquals(empty, emptyFromString);
- }
-
- @Test
public void testSingleValue() {
Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build();
assertTrue(singleValue instanceof IndexedTensor);
@@ -91,7 +78,7 @@ public class IndexedTensorTestCase {
for (int z = 0; z < zSize; z++)
builder.cell(value(v, w, x, y, z), v, w, x, y, z);
- IndexedTensor tensor = builder.build();
+ IndexedTensor tensor = (IndexedTensor)builder.build();
// Lookup by index arguments
for (int v = 0; v < vSize; v++)
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
index 4c32a80dc11..a2df146c8e1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.google.common.collect.Sets;
+import junit.framework.TestCase;
import org.junit.Test;
import java.util.Set;
@@ -18,6 +19,19 @@ import static org.junit.Assert.fail;
public class MappedTensorTestCase {
@Test
+ public void testEmpty() {
+ Tensor empty = Tensor.Builder.of(TensorType.empty).build();
+ TestCase.assertTrue(empty instanceof MappedTensor);
+ TestCase.assertTrue(empty.isEmpty());
+ assertEquals("{}", empty.toString());
+ Tensor emptyFromString = Tensor.from(TensorType.empty, "{}");
+ assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString());
+ TestCase.assertTrue(emptyFromString.isEmpty());
+ TestCase.assertTrue(emptyFromString instanceof MappedTensor);
+ assertEquals(empty, emptyFromString);
+ }
+
+ @Test
public void testOneDimensionalBuilding() {
TensorType type = new TensorType.Builder().mapped("x").build();
Tensor tensor = Tensor.Builder.of(type).
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index feeba1a7a10..e649d3cde2a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -21,7 +21,7 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
- * Tests Tensor functionality
+ * Tests tensor functionality
*
* @author bratseth
*/
@@ -30,6 +30,9 @@ public class TensorTestCase {
@Test
public void testStringForm() {
assertEquals("{}", Tensor.from("{}").toString());
+ assertTrue(Tensor.from("{}") instanceof MappedTensor);
+ assertEquals("{5.7}", Tensor.from("{5.7}").toString());
+ assertTrue(Tensor.from("{5.7}") instanceof IndexedTensor);
assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 697eb2a7329..d2b2044f3ed 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -20,7 +20,6 @@ public class DenseBinaryFormatTestCase {
@Test
public void testSerialization() {
- assertSerialization("{}");
assertSerialization("{-5.37}");
assertSerialization("tensor(x[]):{{x:0}:2.0}");
assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0}");
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index b314fe06f08..283aa90cf65 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -19,6 +19,7 @@ public class SparseBinaryFormatTestCase {
@Test
public void testSerialization() {
+ assertSerialization("tensor(x{}):{}");
assertSerialization("tensor(x{}):{{x:0}:2.0}");
assertSerialization("tensor(dimX{},dimY{}):{{dimX:labelA,dimY:labelB}:2.0,{dimY:labelC,dimX:labelD}:3.0}");
assertSerialization("tensor(x{},y{}):{{x:0,y:1}:2.0}");