diff options
4 files changed, 65 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 95f64cec0c1..1f3c373c1e8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -212,7 +212,6 @@ public class MixedTensor implements Tensor { } - /** * Builder for mixed tensors with bound indexed dimensions. */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bafec70be59..95cc70804e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -80,11 +80,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final ImmutableList<Dimension> dimensions; + private final TensorType mappedSubtype; + private TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = ImmutableList.copyOf(dimensionList); + + if (dimensionList.stream().allMatch(d -> d.isIndexed())) + mappedSubtype = empty; + else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) + mappedSubtype = this; + else + mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).collect(Collectors.toList())); } static public Value combinedValueType(TensorType ... types) { @@ -116,6 +125,9 @@ public class TensorType { /** Returns the numeric type of the cell values of this */ public Value valueType() { return valueType; } + /** The type representing the mapped subset of dimensions of this. */ + public TensorType mappedSubtype() { return mappedSubtype; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index fa022e2bdd1..2233622db3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -10,6 +10,7 @@ import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.slime.Type; import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -75,21 +76,48 @@ public class JsonFormat { private static void decodeCells(Inspector cells, Tensor.Builder builder) { if ( cells.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type()); - cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell())); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCellOrCells(cell, builder)); } - private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { - Inspector address = cell.field("address"); + private static void decodeCellOrCells(Inspector cell, Tensor.Builder builder) { + Inspector value = cell.field("value"); + if (value.type() == Type.LONG || value.type() == Type.DOUBLE) { + decodeCell(cell.field("address"), value, builder.cell()); + } + else { + Inspector values = cell.field("values"); + if (values.type() == Type.ARRAY) + decodeValueBlock(cell.field("address"), values, builder); + else + throw new IllegalArgumentException("Expected a cell to contain a numeric 'value' or an array 'values'"); + } + } + + private static void decodeCell(Inspector address, Inspector value, Tensor.Builder.CellBuilder cellBuilder) { if ( address.type() != Type.OBJECT) throw new IllegalArgumentException("Excepted a cell to contain an object called 'address'"); address.traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); - - Inspector value = cell.field("value"); - if (value.type() != Type.LONG && value.type() != Type.DOUBLE) - throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'"); cellBuilder.value(value.asDouble()); } + private static void decodeValueBlock(Inspector address, Inspector valuesBlock, Tensor.Builder builder) { + if ( ! (builder instanceof MixedTensor.BoundBuilder)) + throw new IllegalArgumentException("Sending 'values' in 'cells' is only permissible with a mixed tensor " + + "type with bound indexed dimensions, but the type is " + + builder.type()); + MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder)builder; + + if (address.type() != Type.OBJECT) + throw new IllegalArgumentException("Expected a cell to contain an object called 'address'"); + TensorAddress.Builder sparseAddress = new TensorAddress.Builder(mixedBuilder.type().mappedSubtype()); + address.traverse((ObjectTraverser) (dimension, label) -> sparseAddress.add(dimension, label.asString())); + + double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; + valuesBlock.traverse((ArrayTraverser) (index, value) -> values[index] = value.asDouble()); + + mixedBuilder.block(sparseAddress.build(), values); + } + private static void decodeValues(Inspector values, Tensor.Builder builder) { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 32d62903af5..2878c82b7db 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -52,24 +52,6 @@ public class JsonFormatTestCase { } @Test - public void testMixedTensor() { - Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[2])")); - builder.cell().label("x", "a").label("y", "0").value(1.0); - builder.cell().label("x", "a").label("y", "1").value(2.0); - builder.cell().label("x", "b").label("y", "0").value(3.0); - builder.cell().label("x", "b").label("y", "1").value(4.0); - Tensor tensor = builder.build(); - byte[] json = JsonFormat.encode(tensor); - assertEquals("{\"cells\":[" + - "{\"address\":{\"x\":\"a\"},\"values\":[1.0,2.0]}," + - "{\"address\":{\"x\":\"b\"},\"values\":[3.0,4.0]}" + - "]}", - new String(json, StandardCharsets.UTF_8)); - Tensor decoded = JsonFormat.decode(tensor.type(), json); - assertEquals(tensor, decoded); - } - - @Test public void testDenseTensorInDenseForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); @@ -85,6 +67,24 @@ public class JsonFormatTestCase { } @Test + public void testMixedTensorInMixedForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + String mixedJson = "{\"cells\":[" + + "{\"address\":{\"x\":\"0\"},\"values\":[2.0,3.0,4.0]}," + + "{\"address\":{\"x\":\"1\"},\"values\":[5.0,6.0,7.0]}" + + "]}"; + Tensor decoded = JsonFormat.decode(expected.type(), mixedJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test public void testTooManyCells() { TensorType x2 = TensorType.fromSpec("tensor(x[2])"); String json = "{\"cells\":[" + |