diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2023-01-13 10:49:21 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-13 10:49:21 +0100 |
commit | 504db6bf752b023b8051a7ddfb1a446152cee934 (patch) | |
tree | 8e1ef0d33235ecfb29070867e71b8f092a2f3985 /vespajlib | |
parent | f2c7e54941cb217586f099a2d47a14b42f58ba7d (diff) | |
parent | 9dd24f4376447165054f6c498e95a45aeb69549f (diff) |
Merge pull request #25549 from vespa-engine/bratseth/tensor-direct-values
Parse tensor JSON values at root
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 67 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 49 |
2 files changed, 86 insertions, 30 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index a7afc1efc6d..0e8fbc30bb6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -60,25 +60,21 @@ public class JsonFormat { Cursor root = slime.setObject(); root.setString("type", tensor.type().toString()); - // Encode as nested lists if indexed tensor - if (tensor instanceof IndexedTensor) { - IndexedTensor denseTensor = (IndexedTensor) tensor; + if (tensor instanceof IndexedTensor denseTensor) { + // Encode as nested lists if indexed tensor encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0); } - - // Short form for a single mapped dimension else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) { + // Short form for a single mapped dimension encodeSingleDimensionCells((MappedTensor) tensor, root); } - - // Short form for a mixed tensor else if (tensor instanceof MixedTensor && tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) { + // Short form for a mixed tensor encodeBlocks((MixedTensor) tensor, root); } - - // No other short forms exist: default to standard cell address output else { + // No other short forms exist: default to standard cell address output encodeCells(tensor, root); } @@ -177,17 +173,25 @@ public class JsonFormat { Tensor.Builder builder = Tensor.Builder.of(type); Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); - if (root.field("cells").valid()) + if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid()) + else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) decodeValues(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); - else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'"); + else + decodeDirectValue(root, builder); return builder.build(); } + private static boolean primitiveContent(Inspector cellsValue) { + if (cellsValue.type() == Type.DOUBLE) return true; + if (cellsValue.type() == Type.LONG) return true; + if (cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 && + ( cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG)) return true; + return false; + } + private static void decodeCells(Inspector cells, Tensor.Builder builder) { if (cells.type() == Type.ARRAY) cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); @@ -212,10 +216,9 @@ public class JsonFormat { } private static void decodeValues(Inspector values, Tensor.Builder builder) { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (values.type() == Type.STRING) { double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); if (decoded.length == 0) @@ -240,10 +243,9 @@ public class JsonFormat { } private static void decodeBlocks(Inspector values, Tensor.Builder builder) { - if ( ! (builder instanceof MixedTensor.BoundBuilder)) + if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; if (values.type() == Type.ARRAY) values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); @@ -260,6 +262,19 @@ public class JsonFormat { decodeValues(block.field("values"), mixedBuilder)); } + /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ + private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { + boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + + if ( ! hasMapped) + decodeValues(root, builder); + else if (hasMapped && hasIndexed) + decodeBlocks(root, builder); + else + decodeCells(root, builder); + } + private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { if (value.type() != Type.ARRAY) throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type()); @@ -334,18 +349,12 @@ public class JsonFormat { } public static double[] decodeHexString(String input, TensorType.Value valueType) { - switch(valueType) { - case INT8: - return decodeHexStringAsBytes(input); - case BFLOAT16: - return decodeHexStringAsBFloat16s(input); - case FLOAT: - return decodeHexStringAsFloats(input); - case DOUBLE: - return decodeHexStringAsDoubles(input); - default: - throw new IllegalArgumentException("Cannot handle value type: "+valueType); - } + return switch (valueType) { + case INT8 -> decodeHexStringAsBytes(input); + case BFLOAT16 -> decodeHexStringAsBFloat16s(input); + case FLOAT -> decodeHexStringAsFloats(input); + case DOUBLE -> decodeHexStringAsDoubles(input); + }; } private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index abbd1d34885..6a6bb3c6781 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -15,6 +15,30 @@ import static org.junit.Assert.fail; */ public class JsonFormatTestCase { + /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */ + @Test + public void testDirectValue() { + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]"); + assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", + "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]"); + } + + @Test + public void testDirectValueReservedNameKeys() { + // Single-valued + assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}"); + + // Multi-valued + assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}"); + } + @Test public void testSparseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); @@ -32,6 +56,21 @@ public class JsonFormatTestCase { } @Test + public void testEmptySparseTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + Tensor tensor = builder.build(); + byte[] json = JsonFormat.encode(tensor); + assertEquals("{\"cells\":[]}", + new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + + json = "{}".getBytes(); // short form variant of the above + decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + } + + @Test public void testSingleSparseDimensionShortForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")); builder.cell().label("x", "a").value(2.0); @@ -326,7 +365,7 @@ public class JsonFormatTestCase { "]}"; try { JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8)); - fail("Excpected exception"); + fail("Expected exception"); } catch (IllegalArgumentException e) { assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage()); @@ -366,6 +405,14 @@ public class JsonFormatTestCase { assertEquals(expected, new String(json, StandardCharsets.UTF_8)); } + private void assertDecoded(String expected, String jsonToDecode) { + assertDecoded(Tensor.from(expected), jsonToDecode); + } + + private void assertDecoded(Tensor expected, String jsonToDecode) { + assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8))); + } + private void assertDecodeFails(TensorType type, String format, String msg) { try { Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8)); |