diff options
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 44 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 33 |
2 files changed, 66 insertions, 11 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 75690e45e15..e1b38264661 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -76,9 +76,12 @@ public class JsonFormat { } private static void decodeCells(Inspector cells, Tensor.Builder builder) { - if ( cells.type() != Type.ARRAY) - throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type()); - cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); + if ( cells.type() == Type.ARRAY) + cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); + else if (cells.type() == Type.OBJECT) + cells.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionCell(key, value, builder)); + else + throw new IllegalArgumentException("Excepted 'cells' to contain an array or obejct, not " + cells.type()); } private static void decodeCell(Inspector cell, Tensor.Builder builder) { @@ -91,6 +94,10 @@ public class JsonFormat { builder.cell(address, value.asDouble()); } + private static void decodeSingleDimensionCell(String key, Inspector value, Tensor.Builder builder) { + builder.cell(asAddress(key, builder.type()), decodeNumeric(value)); + } + private static void decodeValues(Inspector values, Tensor.Builder builder) { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + @@ -111,27 +118,36 @@ public class JsonFormat { if ( ! (builder instanceof MixedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - if (values.type() != Type.ARRAY) - throw new IllegalArgumentException("Excepted 'blocks' to contain an array, not " + values.type()); - MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; - values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); + if (values.type() == Type.ARRAY) + values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); + else if (values.type() == Type.OBJECT) + values.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionBlock(key, value, mixedBuilder)); + else + throw new IllegalArgumentException("Excepted 'blocks' to contain an array or object, not " + values.type()); } private static void decodeBlock(Inspector block, MixedTensor.BoundBuilder mixedBuilder) { if (block.type() != Type.OBJECT) throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + block.type()); + mixedBuilder.block(decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()), + decodeValues(block.field("values"), mixedBuilder)); + } - TensorAddress mappedAddress = decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()); + private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { + if (value.type() != Type.ARRAY) + throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + value.type()); + mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()), + decodeValues(value, mixedBuilder)); + } - Inspector valuesField = block.field("values"); + private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { if (valuesField.type() != Type.ARRAY) throw new IllegalArgumentException("Expected a block to contain a 'values' array"); double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); - - mixedBuilder.block(mappedAddress, values); + return values; } private static TensorAddress decodeAddress(Inspector addressField, TensorType type) { @@ -142,6 +158,12 @@ public class JsonFormat { return builder.build(); } + private static TensorAddress asAddress(String label, TensorType type) { + if (type.dimensions().size() != 1) + throw new IllegalArgumentException("Expected a tensor with a single dimension but got " + type); + return new TensorAddress.Builder(type).add(type.dimensions().get(0).name(), label).build(); + } + private static double decodeNumeric(Inspector numericField) { if (numericField.type() != Type.LONG && numericField.type() != Type.DOUBLE) throw new IllegalArgumentException("Excepted a number, not " + numericField.type()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 16f92289504..81de8a9db4c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -32,6 +32,21 @@ public class JsonFormatTestCase { } @Test + public void testSingleSparseDimensionShortForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")); + builder.cell().label("x", "a").value(2.0); + builder.cell().label("x", "c").value(3.0); + Tensor expected = builder.build(); + + String json= "{\"cells\":{" + + "\"a\":2.0," + + "\"c\":3.0" + + "}}"; + Tensor decoded = JsonFormat.decode(expected.type(), json.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test public void testDenseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[2])")); builder.cell().label("x", 0).label("y", 0).value(2.0); @@ -85,6 +100,24 @@ public class JsonFormatTestCase { } @Test + public void testMixedTensorInMixedFormWithSingleSparseDimensionShortForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + String mixedJson = "{\"blocks\":{" + + "\"0\":[2.0,3.0,4.0]," + + "\"1\":[5.0,6.0,7.0]" + + "}}"; + Tensor decoded = JsonFormat.decode(expected.type(), mixedJson.getBytes(StandardCharsets.UTF_8)); + assertEquals(expected, decoded); + } + + @Test public void testTooManyCells() { TensorType x2 = TensorType.fromSpec("tensor(x[2])"); String json = "{\"cells\":[" + |