diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-01-12 22:13:12 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2023-01-12 22:13:12 +0100 |
commit | 2229a4d2e3141010850fa23f5ad731c9038052a8 (patch) | |
tree | d35ed23f65ef1ee793367e28450eda483372f031 /vespajlib | |
parent | 844eeeeebfd8cdffb28ee7d64e05a803aa2f0e5a (diff) |
Parse tensor JSON values at root
Our current tensor JSON formats require a "blocks", "cells" or "values" key
at the root, containing values in various forms.
This adds support for skipping that extra level and adding content at the root,
where the permissible content format depends on the tensor type, and matches
the formats below "blocks", "cells" or "values" for the corresponding tensor
types.
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 67 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 50 |
2 files changed, 86 insertions, 31 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index a7afc1efc6d..0e8fbc30bb6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -60,25 +60,21 @@ public class JsonFormat { Cursor root = slime.setObject(); root.setString("type", tensor.type().toString()); - // Encode as nested lists if indexed tensor - if (tensor instanceof IndexedTensor) { - IndexedTensor denseTensor = (IndexedTensor) tensor; + if (tensor instanceof IndexedTensor denseTensor) { + // Encode as nested lists if indexed tensor encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0); } - - // Short form for a single mapped dimension else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) { + // Short form for a single mapped dimension encodeSingleDimensionCells((MappedTensor) tensor, root); } - - // Short form for a mixed tensor else if (tensor instanceof MixedTensor && tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) { + // Short form for a mixed tensor encodeBlocks((MixedTensor) tensor, root); } - - // No other short forms exist: default to standard cell address output else { + // No other short forms exist: default to standard cell address output encodeCells(tensor, root); } @@ -177,17 +173,25 @@ public class JsonFormat { Tensor.Builder builder = Tensor.Builder.of(type); Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); - if (root.field("cells").valid()) + if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid()) + else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) decodeValues(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); - else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'"); + else + decodeDirectValue(root, builder); return builder.build(); } + private static boolean primitiveContent(Inspector cellsValue) { + if (cellsValue.type() == Type.DOUBLE) return true; + if (cellsValue.type() == Type.LONG) return true; + if (cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 && + ( cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG)) return true; + return false; + } + private static void decodeCells(Inspector cells, Tensor.Builder builder) { if (cells.type() == Type.ARRAY) cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); @@ -212,10 +216,9 @@ public class JsonFormat { } private static void decodeValues(Inspector values, Tensor.Builder builder) { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (values.type() == Type.STRING) { double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); if (decoded.length == 0) @@ -240,10 +243,9 @@ public class JsonFormat { } private static void decodeBlocks(Inspector values, Tensor.Builder builder) { - if ( ! (builder instanceof MixedTensor.BoundBuilder)) + if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; if (values.type() == Type.ARRAY) values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); @@ -260,6 +262,19 @@ public class JsonFormat { decodeValues(block.field("values"), mixedBuilder)); } + /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ + private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { + boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + + if ( ! hasMapped) + decodeValues(root, builder); + else if (hasMapped && hasIndexed) + decodeBlocks(root, builder); + else + decodeCells(root, builder); + } + private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { if (value.type() != Type.ARRAY) throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type()); @@ -334,18 +349,12 @@ public class JsonFormat { } public static double[] decodeHexString(String input, TensorType.Value valueType) { - switch(valueType) { - case INT8: - return decodeHexStringAsBytes(input); - case BFLOAT16: - return decodeHexStringAsBFloat16s(input); - case FLOAT: - return decodeHexStringAsFloats(input); - case DOUBLE: - return decodeHexStringAsDoubles(input); - default: - throw new IllegalArgumentException("Cannot handle value type: "+valueType); - } + return switch (valueType) { + case INT8 -> decodeHexStringAsBytes(input); + case BFLOAT16 -> decodeHexStringAsBFloat16s(input); + case FLOAT -> decodeHexStringAsFloats(input); + case DOUBLE -> decodeHexStringAsDoubles(input); + }; } private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 1c884186879..f71a68ec5ed 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -1,7 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; -import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -16,6 +15,30 @@ import static org.junit.Assert.fail; */ public class JsonFormatTestCase { + /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */ + @Test + public void testDirectValue() { + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]"); + assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", + "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]"); + } + + @Test + public void testDirectValueReservedNameKeys() { + // Single-valued + assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}"); + + // Multi-valued + assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}"); + } + @Test public void testSparseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); @@ -33,6 +56,21 @@ public class JsonFormatTestCase { } @Test + public void testEmptySparseTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + Tensor tensor = builder.build(); + byte[] json = JsonFormat.encode(tensor); + assertEquals("{\"cells\":[]}", + new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + + json = "{}".getBytes(); // short form variant of the above + decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + } + + @Test public void testSingleSparseDimensionShortForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")); builder.cell().label("x", "a").value(2.0); @@ -327,7 +365,7 @@ public class JsonFormatTestCase { "]}"; try { JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8)); - fail("Excpected exception"); + fail("Expected exception"); } catch (IllegalArgumentException e) { assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage()); @@ -354,6 +392,14 @@ public class JsonFormatTestCase { assertEquals(expected, new String(json, StandardCharsets.UTF_8)); } + private void assertDecoded(String expected, String jsonToDecode) { + assertDecoded(Tensor.from(expected), jsonToDecode); + } + + private void assertDecoded(Tensor expected, String jsonToDecode) { + assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8))); + } + private void assertDecodeFails(TensorType type, String format, String msg) { try { Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8)); |