From 617726bb3ebcce78a8fc98fe5a7fa2fd7372650e Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Fri, 1 Dec 2023 12:20:29 +0000 Subject: handle "blocks" syntax for mixed tensors with multiple indexed dimensions --- .../com/yahoo/tensor/serialization/JsonFormat.java | 18 +++++++++--- .../tensor/serialization/JsonFormatTestCase.java | 32 ++++++++++++++++++++++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 0ac9ea18029..28f14c8d7ca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -400,13 +400,23 @@ public class JsonFormat { }; } + private static void decodeMaybeNestedValuesInBlock(Inspector arrayField, double[] target, MutableInteger index) { + if (arrayField.entries() == 0) { + throw new IllegalArgumentException("The block value array does not contain any values"); + } + arrayField.traverse((ArrayTraverser) (__, value) -> { + if (value.type() == Type.ARRAY) { + decodeMaybeNestedValuesInBlock(value, target, index); + } else { + target[index.next()] = decodeNumeric(value); + } + }); + } + private static double[] decodeValuesInBlock(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; if (valuesField.type() == Type.ARRAY) { - if (valuesField.entries() == 0) { - throw new IllegalArgumentException("The block value array does not contain any values"); - } - valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value)); + decodeMaybeNestedValuesInBlock(valuesField, values, new MutableInteger(0)); } else if (valuesField.type() == Type.STRING) { double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType()); if (decoded.length == 0) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 363d08c1123..d95396aca50 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -613,6 +613,38 @@ public class JsonFormatTestCase { } } + @Test + public void testMultiMixedTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(cat{},x[2],y[2])")); + // a: + builder.cell().label("cat", "a").label("x", 0).label("y", 0).value(2.0); + builder.cell().label("cat", "a").label("x", 0).label("y", 1).value(3.0); + builder.cell().label("cat", "a").label("x", 1).label("y", 0).value(4.0); + builder.cell().label("cat", "a").label("x", 1).label("y", 1).value(5.0); + // b: + builder.cell().label("cat", "b").label("x", 0).label("y", 0).value(6.0); + builder.cell().label("cat", "b").label("x", 0).label("y", 1).value(7.0); + builder.cell().label("cat", "b").label("x", 1).label("y", 0).value(8.0); + builder.cell().label("cat", "b").label("x", 1).label("y", 1).value(9.0); + Tensor tensor = builder.build(); + String shortJson = """ + { + "type": "tensor(cat{},x[2],y[2])", + "blocks": {"a":[[2.0,3.0],[4.0,5.0]],"b":[[6.0,7.0],[8.0,9.0]]} + } + """; + byte[] shortEncoded = JsonFormat.encode(tensor, true, false); + assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8)); + assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded)); + String oldJson = """ + { + "type": "tensor(cat{},x[2],y[2])", + "blocks": {"a":[2,3,4,5],"b":[6,7,8,9]} + } + """; + assertEquals(tensor, JsonFormat.decode(tensor.type(), oldJson.getBytes(StandardCharsets.UTF_8))); + } + private void assertEncodeDecode(Tensor tensor) { Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encode(tensor, false, false)); assertEquals(tensor, decoded); -- cgit v1.2.3