summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-10-08 12:44:24 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-10-08 12:44:24 +0200
commit75570e061520dcfeedfcd70de8a392df644bbc19 (patch)
tree5af98986f63becb328541e1bf86f95376537034e /vespajlib
parentf47861f1f38e644ad17d6acfd2872af8bcb7d090 (diff)
Support mixed tensor short form JSON
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java42
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java36
4 files changed, 65 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 95f64cec0c1..1f3c373c1e8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -212,7 +212,6 @@ public class MixedTensor implements Tensor {
}
-
/**
* Builder for mixed tensors with bound indexed dimensions.
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bafec70be59..95cc70804e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -80,11 +80,20 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
+ private final TensorType mappedSubtype;
+
private TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = ImmutableList.copyOf(dimensionList);
+
+ if (dimensionList.stream().allMatch(d -> d.isIndexed()))
+ mappedSubtype = empty;
+ else if (dimensionList.stream().noneMatch(d -> d.isIndexed()))
+ mappedSubtype = this;
+ else
+ mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).collect(Collectors.toList()));
}
static public Value combinedValueType(TensorType ... types) {
@@ -116,6 +125,9 @@ public class TensorType {
/** Returns the numeric type of the cell values of this */
public Value valueType() { return valueType; }
+ /** The type representing the mapped subset of dimensions of this. */
+ public TensorType mappedSubtype() { return mappedSubtype; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index fa022e2bdd1..2233622db3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -10,6 +10,7 @@ import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -75,21 +76,48 @@ public class JsonFormat {
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if ( cells.type() != Type.ARRAY)
throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type());
- cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell()));
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCellOrCells(cell, builder));
}
- private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
- Inspector address = cell.field("address");
+ private static void decodeCellOrCells(Inspector cell, Tensor.Builder builder) {
+ Inspector value = cell.field("value");
+ if (value.type() == Type.LONG || value.type() == Type.DOUBLE) {
+ decodeCell(cell.field("address"), value, builder.cell());
+ }
+ else {
+ Inspector values = cell.field("values");
+ if (values.type() == Type.ARRAY)
+ decodeValueBlock(cell.field("address"), values, builder);
+ else
+ throw new IllegalArgumentException("Expected a cell to contain a numeric 'value' or an array 'values'");
+ }
+ }
+
+ private static void decodeCell(Inspector address, Inspector value, Tensor.Builder.CellBuilder cellBuilder) {
if ( address.type() != Type.OBJECT)
throw new IllegalArgumentException("Excepted a cell to contain an object called 'address'");
address.traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString()));
-
- Inspector value = cell.field("value");
- if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
- throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'");
cellBuilder.value(value.asDouble());
}
+ private static void decodeValueBlock(Inspector address, Inspector valuesBlock, Tensor.Builder builder) {
+ if ( ! (builder instanceof MixedTensor.BoundBuilder))
+ throw new IllegalArgumentException("Sending 'values' in 'cells' is only permissible with a mixed tensor " +
+ "type with bound indexed dimensions, but the type is " +
+ builder.type());
+ MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder)builder;
+
+ if (address.type() != Type.OBJECT)
+ throw new IllegalArgumentException("Expected a cell to contain an object called 'address'");
+ TensorAddress.Builder sparseAddress = new TensorAddress.Builder(mixedBuilder.type().mappedSubtype());
+ address.traverse((ObjectTraverser) (dimension, label) -> sparseAddress.add(dimension, label.asString()));
+
+ double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
+ valuesBlock.traverse((ArrayTraverser) (index, value) -> values[index] = value.asDouble());
+
+ mixedBuilder.block(sparseAddress.build(), values);
+ }
+
private static void decodeValues(Inspector values, Tensor.Builder builder) {
if ( ! (builder instanceof IndexedTensor.BoundBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 32d62903af5..2878c82b7db 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -52,24 +52,6 @@ public class JsonFormatTestCase {
}
@Test
- public void testMixedTensor() {
- Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])"));
- builder.cell().label("x", "a").label("y", "0").value(1.0);
- builder.cell().label("x", "a").label("y", "1").value(2.0);
- builder.cell().label("x", "b").label("y", "0").value(3.0);
- builder.cell().label("x", "b").label("y", "1").value(4.0);
- Tensor tensor = builder.build();
- byte[] json = JsonFormat.encode(tensor);
- assertEquals("{\"cells\":[" +
- "{\"address\":{\"x\":\"a\"},\"values\":[1.0,2.0]}," +
- "{\"address\":{\"x\":\"b\"},\"values\":[3.0,4.0]}" +
- "]}",
- new String(json, StandardCharsets.UTF_8));
- Tensor decoded = JsonFormat.decode(tensor.type(), json);
- assertEquals(tensor, decoded);
- }
-
- @Test
public void testDenseTensorInDenseForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);
@@ -85,6 +67,24 @@ public class JsonFormatTestCase {
}
@Test
+ public void testMixedTensorInMixedForm() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 0).label("y", 2).value(4.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(6.0);
+ builder.cell().label("x", 1).label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+ String mixedJson = "{\"cells\":[" +
+ "{\"address\":{\"x\":\"0\"},\"values\":[2.0,3.0,4.0]}," +
+ "{\"address\":{\"x\":\"1\"},\"values\":[5.0,6.0,7.0]}" +
+ "]}";
+ Tensor decoded = JsonFormat.decode(expected.type(), mixedJson.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expected, decoded);
+ }
+
+ @Test
public void testTooManyCells() {
TensorType x2 = TensorType.fromSpec("tensor(x[2])");
String json = "{\"cells\":[" +