summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-16 12:14:24 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-16 12:14:24 +0100
commit218590ca3eaed89e96a05edbf8a84f69cf300b22 (patch)
tree4d94de5e584f0c095b74f8c7c78afb5c1e1e922b
parentd8d1f9173a6d25e16f687ad19c2c3ed920299fb0 (diff)
More uniform API
-rw-r--r--document/src/main/java/com/yahoo/document/json/JsonReader.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MappedTensorBuilderTestCase.java42
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java44
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java40
10 files changed, 74 insertions, 91 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/JsonReader.java b/document/src/main/java/com/yahoo/document/json/JsonReader.java
index 59e88b2cce0..b1c635d7641 100644
--- a/document/src/main/java/com/yahoo/document/json/JsonReader.java
+++ b/document/src/main/java/com/yahoo/document/json/JsonReader.java
@@ -35,6 +35,7 @@ import com.yahoo.document.update.FieldUpdate;
import com.yahoo.document.update.MapValueUpdate;
import com.yahoo.document.update.ValueUpdate;
import com.yahoo.tensor.MappedTensor;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.apache.commons.codec.binary.Base64;
@@ -584,7 +585,7 @@ public class JsonReader {
private void fillTensor(TensorFieldValue tensorFieldValue) {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
- MappedTensor.Builder tensorBuilder = null;
+ Tensor.Builder tensorBuilder = null;
// read tensor cell fields and ignore everything else
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
if (TENSOR_CELLS.equals(buffer.currentName()))
@@ -592,11 +593,11 @@ public class JsonReader {
}
expectObjectEnd(buffer.currentToken());
if (tensorBuilder == null) // no cells + no type: empty tensor type
- tensorBuilder = new MappedTensor.Builder(TensorType.empty);
+ tensorBuilder = Tensor.Builder.of(TensorType.empty);
tensorFieldValue.assign(tensorBuilder.build());
}
- private MappedTensor.Builder readTensorCells(MappedTensor.Builder tensorBuilder) {
+ private Tensor.Builder readTensorCells(Tensor.Builder tensorBuilder) {
expectArrayStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
@@ -606,11 +607,11 @@ public class JsonReader {
return tensorBuilder;
}
- private MappedTensor.Builder readTensorCell(MappedTensor.Builder tensorBuilder) {
+ private Tensor.Builder readTensorCell(Tensor.Builder tensorBuilder) {
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
double cellValue = 0.0;
- MappedTensor.Builder.CellBuilder cellBuilder = null;
+ Tensor.Builder.CellBuilder cellBuilder = null;
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String currentName = buffer.currentName();
if (TENSOR_ADDRESS.equals(currentName)) {
@@ -630,7 +631,7 @@ public class JsonReader {
TensorType.Builder typeBuilder = new TensorType.Builder();
for (Pair<String,String> entry : entries)
typeBuilder.mapped(entry.getFirst());
- tensorBuilder = new MappedTensor.Builder(typeBuilder.build());
+ tensorBuilder = Tensor.Builder.of(typeBuilder.build());
cellBuilder = tensorBuilder.cell();
for (Pair<String,String> entry : entries)
cellBuilder.label(entry.getFirst(), entry.getSecond());
@@ -642,7 +643,7 @@ public class JsonReader {
}
expectObjectEnd(buffer.currentToken());
if (tensorBuilder == null) { // no content TODO; This will go away with the above
- tensorBuilder = new MappedTensor.Builder(TensorType.empty);
+ tensorBuilder = Tensor.Builder.of(TensorType.empty);
cellBuilder = tensorBuilder.cell();
}
cellBuilder.value(cellValue);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 67ed2180201..c9ea45e59d7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -140,7 +140,6 @@ public class IndexedTensor implements Tensor {
this.type = type;
}
- // TODO: Let other tensor builders be created by this method as well (and update system tests)
public static Builder of(TensorType type) {
if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
return new BoundBuilder(type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 9e09a482070..243a8408f20 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -50,10 +50,12 @@ public class MappedTensor implements Tensor {
private final TensorType type;
private final ImmutableMap.Builder<TensorAddress, Double> cells = new ImmutableMap.Builder<>();
- public Builder(TensorType type) {
+ public static Builder of(TensorType type) { return new Builder(type); }
+
+ private Builder(TensorType type) {
this.type = type;
}
-
+
public MappedCellBuilder cell() {
return new MappedCellBuilder();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 34e36d9fb8b..1fab9939c1a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -246,7 +246,7 @@ public interface Tensor {
if (containsIndexed && containsMapped)
throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
if (containsMapped)
- return new MappedTensor.Builder(type);
+ return MappedTensor.Builder.of(type);
else // indexed or empty
return IndexedTensor.Builder.of(type);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 5990e633266..b86d6e0729a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -30,7 +30,7 @@ class TensorParser {
if (type.isPresent() && ! type.get().equals(TensorType.empty))
throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString +
"but type is not empty but " + type.get());
- return IndexedTensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build();
+ return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build();
}
}
catch (NumberFormatException e) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 22ddcc33c92..57b862534a1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -141,14 +141,14 @@ public class Reduce extends PrimitiveTensorFunction {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Double cellValue : argument.cells().values())
valueAggregator.aggregate(cellValue);
- return IndexedTensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
+ return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
private Tensor reduceIndexedVector(IndexedTensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (int i = 0; i < argument.length(0); i++)
valueAggregator.aggregate(argument.get(i));
- return IndexedTensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
+ return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
private static abstract class ValueAggregator {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index a6c2462f577..1dc35f20057 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -61,7 +61,7 @@ class SparseBinaryFormat implements BinaryFormat {
@Override
public Tensor decode(GrowableByteBuffer buffer) {
TensorType type = decodeDimensions(buffer);
- MappedTensor.Builder builder = new MappedTensor.Builder(type);
+ Tensor.Builder builder = Tensor.Builder.of(type);
decodeCells(buffer, builder, type);
return builder.build();
}
@@ -75,17 +75,16 @@ class SparseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private static void decodeCells(GrowableByteBuffer buffer, MappedTensor.Builder builder, TensorType type) {
+ private static void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) {
int numCells = buffer.getInt1_4Bytes();
for (int i = 0; i < numCells; ++i) {
- MappedTensor.Builder.CellBuilder cellBuilder = builder.cell();
+ Tensor.Builder.CellBuilder cellBuilder = builder.cell();
decodeAddress(buffer, cellBuilder, type);
cellBuilder.value(buffer.getDouble());
}
}
- private static void decodeAddress(GrowableByteBuffer buffer, MappedTensor.Builder.CellBuilder builder,
- TensorType type) {
+ private static void decodeAddress(GrowableByteBuffer buffer, Tensor.Builder.CellBuilder builder, TensorType type) {
for (TensorType.Dimension dimension : type.dimensions()) {
String label = decodeString(buffer);
if ( ! label.isEmpty()) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorBuilderTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorBuilderTestCase.java
deleted file mode 100644
index cbbccd71f56..00000000000
--- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorBuilderTestCase.java
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor;
-
-import com.google.common.collect.Sets;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertTrue;
-
-/**
- * @author geirst
- */
-public class MappedTensorBuilderTestCase {
-
- @Test
- public void requireThatEmptyTensorCanBeBuilt() {
- Tensor tensor = new MappedTensor.Builder(TensorType.empty).build();
- assertEquals(0, tensor.type().dimensions().size());
- assertEquals("{}", tensor.toString());
- }
-
- @Test
- public void requireThatOneDimensionalTensorCanBeBuilt() {
- TensorType type = new TensorType.Builder().mapped("x").build();
- Tensor tensor = new MappedTensor.Builder(type).
- cell().label("x", "0").value(1).
- cell().label("x", "1").value(2).build();
- assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames());
- assertEquals("{{x:0}:1.0,{x:1}:2.0}", tensor.toString());
- }
-
- @Test
- public void requireThatTwoDimensionalTensorCanBeBuilt() {
- TensorType type = new TensorType.Builder().mapped("x").mapped("y").build();
- Tensor tensor = new MappedTensor.Builder(type).
- cell().label("x", "0").label("y", "0").value(1).
- cell().label("x", "1").label("y", "0").value(2).build();
- assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames());
- assertEquals("{{x:0,y:0}:1.0,{x:1,y:0}:2.0}", tensor.toString());
- }
-
-}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
index b9bc1292e91..4c32a80dc11 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.google.common.collect.Sets;
import org.junit.Test;
import java.util.Set;
@@ -17,38 +18,23 @@ import static org.junit.Assert.fail;
public class MappedTensorTestCase {
@Test
- public void testStringForm() {
- assertEquals("{}", Tensor.from("{}").toString());
- assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
- assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
+ public void testOneDimensionalBuilding() {
+ TensorType type = new TensorType.Builder().mapped("x").build();
+ Tensor tensor = Tensor.Builder.of(type).
+ cell().label("x", "0").value(1).
+ cell().label("x", "1").value(2).build();
+ assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames());
+ assertEquals("{{x:0}:1.0,{x:1}:2.0}", tensor.toString());
}
@Test
- public void testParseError() {
- try {
- Tensor.from("--");
- fail("Expected parse error");
- }
- catch (IllegalArgumentException expected) {
- assertEquals("Excepted a number or a string starting by { or tensor(, got '--'", expected.getMessage());
- }
- }
-
- @Test
- public void testDimensions() {
- Set<String> dimensions1 = Tensor.from("{} ").type().dimensionNames();
- assertEquals(0, dimensions1.size());
-
- Set<String> dimensions2 = Tensor.from("{ {d1:l1, d2:l2}:5, {d1:l2, d2:l2}:6.0} ").type().dimensionNames();
- assertEquals(2, dimensions2.size());
- assertTrue(dimensions2.contains("d1"));
- assertTrue(dimensions2.contains("d2"));
-
- Set<String> dimensions3 = Tensor.from("{ {d1:l1, d2:l1, d3:l1}:5, {d1:l1, d2:l2, d3:l1}:6.0} ").type().dimensionNames();
- assertEquals(3, dimensions3.size());
- assertTrue(dimensions3.contains("d1"));
- assertTrue(dimensions3.contains("d2"));
- assertTrue(dimensions3.contains("d3"));
+ public void testTwoDimensionalBuilding() {
+ TensorType type = new TensorType.Builder().mapped("x").mapped("y").build();
+ Tensor tensor = Tensor.Builder.of(type).
+ cell().label("x", "0").label("y", "0").value(1).
+ cell().label("x", "1").label("y", "0").value(2).build();
+ assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames());
+ assertEquals("{{x:0,y:0}:1.0,{x:1,y:0}:2.0}", tensor.toString());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index da472d102ff..3f3e0b6e66b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -12,17 +12,55 @@ import org.junit.Test;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
+import java.util.Set;
import static org.junit.Assert.assertEquals;
import static com.yahoo.tensor.TensorType.Dimension.Type;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
/**
- * Tests functionality on Tensor
+ * Tests Tensor functionality
*
* @author bratseth
*/
public class TensorTestCase {
+ @Test
+ public void testStringForm() {
+ assertEquals("{}", Tensor.from("{}").toString());
+ assertEquals("{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
+ assertEquals("{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
+ }
+
+ @Test
+ public void testParseError() {
+ try {
+ Tensor.from("--");
+ fail("Expected parse error");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Excepted a number or a string starting by { or tensor(, got '--'", expected.getMessage());
+ }
+ }
+
+ @Test
+ public void testDimensions() {
+ Set<String> dimensions1 = Tensor.from("{} ").type().dimensionNames();
+ assertEquals(0, dimensions1.size());
+
+ Set<String> dimensions2 = Tensor.from("{ {d1:l1, d2:l2}:5, {d1:l2, d2:l2}:6.0} ").type().dimensionNames();
+ assertEquals(2, dimensions2.size());
+ assertTrue(dimensions2.contains("d1"));
+ assertTrue(dimensions2.contains("d2"));
+
+ Set<String> dimensions3 = Tensor.from("{ {d1:l1, d2:l1, d3:l1}:5, {d1:l1, d2:l2, d3:l1}:6.0} ").type().dimensionNames();
+ assertEquals(3, dimensions3.size());
+ assertTrue(dimensions3.contains("d1"));
+ assertTrue(dimensions3.contains("d2"));
+ assertTrue(dimensions3.contains("d3"));
+ }
+
/** All functions are more throughly tested in searchlib EvaluationTestCase */
@Test
public void testTensorComputation() {