diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 69 |
1 files changed, 41 insertions, 28 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 68997c82d3e..b7e6e67ce73 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -46,13 +46,13 @@ public class JsonFormat { */ public static byte[] encode(Tensor tensor, boolean shortForm, boolean directValues) { Slime slime = new Slime(); - if (shortForm) { - Cursor root = null; - if ( ! directValues) { - root = slime.setObject(); - root.setString("type", tensor.type().toString()); - } + Cursor root = null; + if ( ! directValues) { + root = slime.setObject(); + root.setString("type", tensor.type().toString()); + } + if (shortForm) { if (tensor instanceof IndexedTensor denseTensor) { // Encode as nested lists if indexed tensor Cursor parent = root == null ? slime.setArray() : root.setArray("values"); @@ -77,9 +77,8 @@ public class JsonFormat { return com.yahoo.slime.JsonFormat.toJsonBytes(slime); } else { - Cursor root = slime.setObject(); - root.setString("type", tensor.type().toString()); - encodeCells(tensor, root.setArray("cells")); + Cursor parent = root == null ? slime.setArray() : root.setArray("cells"); + encodeCells(tensor, parent); } return com.yahoo.slime.JsonFormat.toJsonBytes(slime); } @@ -241,48 +240,52 @@ public class JsonFormat { } private static void decodeValues(Inspector values, Tensor.Builder builder) { + decodeValues(values, builder, new MutableInteger(0)); + } + + private static void decodeValues(Inspector values, Tensor.Builder builder, MutableInteger index) { if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder)) - throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + - "Use 'cells' or 'blocks' instead"); + throw new IllegalArgumentException("An array of values can only be used with a dense tensor. Use a map instead"); if (values.type() == Type.STRING) { double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); if (decoded.length == 0) - throw new IllegalArgumentException("The 'values' string does not contain any values"); + throw new IllegalArgumentException("The values string does not contain any values"); for (int i = 0; i < decoded.length; i++) { indexedBuilder.cellByDirectIndex(i, decoded[i]); } return; } if (values.type() != Type.ARRAY) - throw new IllegalArgumentException("Excepted 'values' to contain an array, not " + values.type()); + throw new IllegalArgumentException("Excepted values to be an array, not " + values.type()); if (values.entries() == 0) - throw new IllegalArgumentException("The 'values' array does not contain any values"); + throw new IllegalArgumentException("The values array does not contain any values"); - MutableInteger index = new MutableInteger(0); values.traverse((ArrayTraverser) (__, value) -> { - if (value.type() != Type.LONG && value.type() != Type.DOUBLE) { - throw new IllegalArgumentException("Excepted the values array to contain numbers, not " + value.type()); - } - indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); + if (value.type() == Type.ARRAY) + decodeValues(value, builder, index); + else if (value.type() == Type.LONG || value.type() == Type.DOUBLE) + indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); + else + throw new IllegalArgumentException("Excepted the values array to contain numbers or nested arrays, not " + value.type()); }); } private static void decodeBlocks(Inspector values, Tensor.Builder builder) { if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder)) - throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + - "Use 'cells' or 'values' instead"); + throw new IllegalArgumentException("Blocks of values can only be used with mixed (sparse and dense) tensors." + + "Use an array of cell values instead."); if (values.type() == Type.ARRAY) values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); else if (values.type() == Type.OBJECT) values.traverse((ObjectTraverser) (key, value) -> decodeSingleDimensionBlock(key, value, mixedBuilder)); else - throw new IllegalArgumentException("Excepted 'blocks' to contain an array or object, not " + values.type()); + throw new IllegalArgumentException("Excepted the block to contain an array or object, not " + values.type()); } private static void decodeBlock(Inspector block, MixedTensor.BoundBuilder mixedBuilder) { if (block.type() != Type.OBJECT) - throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an object, not " + block.type()); + throw new IllegalArgumentException("Expected an item in a blocks array to be an object, not " + block.type()); mixedBuilder.block(decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()), decodeValues(block.field("values"), mixedBuilder)); } @@ -292,7 +295,9 @@ public class JsonFormat { boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - if ( ! hasMapped) + if (isArrayOfObjects(root)) + decodeCells(root, builder); + else if ( ! hasMapped) decodeValues(root, builder); else if (hasMapped && hasIndexed) decodeBlocks(root, builder); @@ -300,9 +305,17 @@ public class JsonFormat { decodeCells(root, builder); } + private static boolean isArrayOfObjects(Inspector inspector) { + if (inspector.type() != Type.ARRAY) return false; + if (inspector.entries() == 0) return false; + Inspector firstItem = inspector.entry(0); + if (firstItem.type() == Type.ARRAY) return isArrayOfObjects(firstItem); + return firstItem.type() == Type.OBJECT; + } + private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { if (value.type() != Type.ARRAY) - throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type()); + throw new IllegalArgumentException("Expected an item in a blocks array to be an array, not " + value.type()); mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()), decodeValues(value, mixedBuilder)); } @@ -386,19 +399,19 @@ public class JsonFormat { double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; if (valuesField.type() == Type.ARRAY) { if (valuesField.entries() == 0) { - throw new IllegalArgumentException("The 'block' value array does not contain any values"); + throw new IllegalArgumentException("The block value array does not contain any values"); } valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); } else if (valuesField.type() == Type.STRING) { double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType()); if (decoded.length == 0) { - throw new IllegalArgumentException("The 'block' value string does not contain any values"); + throw new IllegalArgumentException("The block value string does not contain any values"); } for (int i = 0; i < decoded.length; i++) { values[i] = decoded[i]; } } else { - throw new IllegalArgumentException("Expected a block to contain a 'values' array"); + throw new IllegalArgumentException("Expected a block to contain an array of values"); } return values; } |