diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 86 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 2 |
2 files changed, 52 insertions, 36 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 2233622db3e..664f1ec23a4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -68,6 +68,8 @@ public class JsonFormat { decodeCells(root.field("cells"), builder); else if (root.field("values").valid()) decodeValues(root.field("values"), builder); + else if (root.field("blocks").valid()) + decodeBlocks(root.field("blocks"), builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'"); return builder.build(); @@ -76,52 +78,25 @@ public class JsonFormat { private static void decodeCells(Inspector cells, Tensor.Builder builder) { if ( cells.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'cells' to contain an array, not " + cells.type()); - cells.traverse((ArrayTraverser) (__, cell) -> decodeCellOrCells(cell, builder)); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder.cell())); } - private static void decodeCellOrCells(Inspector cell, Tensor.Builder builder) { - Inspector value = cell.field("value"); - if (value.type() == Type.LONG || value.type() == Type.DOUBLE) { - decodeCell(cell.field("address"), value, builder.cell()); - } - else { - Inspector values = cell.field("values"); - if (values.type() == Type.ARRAY) - decodeValueBlock(cell.field("address"), values, builder); - else - throw new IllegalArgumentException("Expected a cell to contain a numeric 'value' or an array 'values'"); - } - } - - private static void decodeCell(Inspector address, Inspector value, Tensor.Builder.CellBuilder cellBuilder) { + private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { + Inspector address = cell.field("address"); if ( address.type() != Type.OBJECT) throw new IllegalArgumentException("Excepted a cell to contain an object called 'address'"); address.traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); - cellBuilder.value(value.asDouble()); - } - private static void decodeValueBlock(Inspector address, Inspector valuesBlock, Tensor.Builder builder) { - if ( ! (builder instanceof MixedTensor.BoundBuilder)) - throw new IllegalArgumentException("Sending 'values' in 'cells' is only permissible with a mixed tensor " + - "type with bound indexed dimensions, but the type is " + - builder.type()); - MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder)builder; - - if (address.type() != Type.OBJECT) - throw new IllegalArgumentException("Expected a cell to contain an object called 'address'"); - TensorAddress.Builder sparseAddress = new TensorAddress.Builder(mixedBuilder.type().mappedSubtype()); - address.traverse((ObjectTraverser) (dimension, label) -> sparseAddress.add(dimension, label.asString())); - - double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; - valuesBlock.traverse((ArrayTraverser) (index, value) -> values[index] = value.asDouble()); - - mixedBuilder.block(sparseAddress.build(), values); + Inspector value = cell.field("value"); + if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'"); + cellBuilder.value(value.asDouble()); } private static void decodeValues(Inspector values, Tensor.Builder builder) { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + - "Use 'cells' instead"); + "Use 'cells' or 'blocks' instead"); if ( values.type() != Type.ARRAY) throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); @@ -134,4 +109,45 @@ public class JsonFormat { }); } + private static void decodeBlocks(Inspector values, Tensor.Builder builder) { + if ( ! (builder instanceof MixedTensor.BoundBuilder)) + throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + + "Use 'cells' or 'values' instead"); + if (values.type() != Type.ARRAY) + throw new IllegalArgumentException("Excepted 'blocks' to contain an array, not " + values.type()); + + MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; + + values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); + } + + private static void decodeBlock(Inspector block, MixedTensor.BoundBuilder mixedBuilder) { + if (block.type() != Type.OBJECT) + throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + block.type()); + + TensorAddress mappedAddress = decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()); + + Inspector valuesField = block.field("values"); + if (valuesField.type() != Type.ARRAY) + throw new IllegalArgumentException("Expected a block to contain a 'values' array"); + double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; + valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); + + mixedBuilder.block(mappedAddress, values); + } + + private static TensorAddress decodeAddress(Inspector addressField, TensorType type) { + if (addressField.type() != Type.OBJECT) + throw new IllegalArgumentException("Expected an 'address' object, not " + addressField.type()); + TensorAddress.Builder builder = new TensorAddress.Builder(type); + addressField.traverse((ObjectTraverser) (dimension, label) -> builder.add(dimension, label.asString())); + return builder.build(); + } + + private static double decodeNumeric(Inspector numericField) { + if (numericField.type() != Type.LONG && numericField.type() != Type.DOUBLE) + throw new IllegalArgumentException("Excepted a number, not " + numericField.type()); + return numericField.asDouble(); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 2878c82b7db..16f92289504 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -76,7 +76,7 @@ public class JsonFormatTestCase { builder.cell().label("x", 1).label("y", 1).value(6.0); builder.cell().label("x", 1).label("y", 2).value(7.0); Tensor expected = builder.build(); - String mixedJson = "{\"cells\":[" + + String mixedJson = "{\"blocks\":[" + "{\"address\":{\"x\":\"0\"},\"values\":[2.0,3.0,4.0]}," + "{\"address\":{\"x\":\"1\"},\"values\":[5.0,6.0,7.0]}" + "]}"; |