aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-12-01 21:06:23 +0100
committerGitHub <noreply@github.com>2023-12-01 21:06:23 +0100
commit4f211b1eaa0c75d879eb04a2abce8f07dac926d5 (patch)
tree6be19da02fa3ec330ecab29f27dcc7f32ea6b80a
parent044c2cd1401cc37009f40552340c372f2c4e0d25 (diff)
parent617726bb3ebcce78a8fc98fe5a7fa2fd7372650e (diff)
Merge pull request #29525 from vespa-engine/arnej/handle-multi-dim-blocksv8.267.29
handle "blocks" syntax for mixed tensors with multiple indexed dimensions
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java18
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java32
2 files changed, 46 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 0ac9ea18029..28f14c8d7ca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -400,13 +400,23 @@ public class JsonFormat {
};
}
+ private static void decodeMaybeNestedValuesInBlock(Inspector arrayField, double[] target, MutableInteger index) {
+ if (arrayField.entries() == 0) {
+ throw new IllegalArgumentException("The block value array does not contain any values");
+ }
+ arrayField.traverse((ArrayTraverser) (__, value) -> {
+ if (value.type() == Type.ARRAY) {
+ decodeMaybeNestedValuesInBlock(value, target, index);
+ } else {
+ target[index.next()] = decodeNumeric(value);
+ }
+ });
+ }
+
private static double[] decodeValuesInBlock(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
if (valuesField.type() == Type.ARRAY) {
- if (valuesField.entries() == 0) {
- throw new IllegalArgumentException("The block value array does not contain any values");
- }
- valuesField.traverse((ArrayTraverser) (index, value) -> values[index] = decodeNumeric(value));
+ decodeMaybeNestedValuesInBlock(valuesField, values, new MutableInteger(0));
} else if (valuesField.type() == Type.STRING) {
double[] decoded = decodeHexString(valuesField.asString(), mixedBuilder.type().valueType());
if (decoded.length == 0) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 363d08c1123..d95396aca50 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -613,6 +613,38 @@ public class JsonFormatTestCase {
}
}
+ @Test
+ public void testMultiMixedTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(cat{},x[2],y[2])"));
+ // a:
+ builder.cell().label("cat", "a").label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("cat", "a").label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("cat", "a").label("x", 1).label("y", 0).value(4.0);
+ builder.cell().label("cat", "a").label("x", 1).label("y", 1).value(5.0);
+ // b:
+ builder.cell().label("cat", "b").label("x", 0).label("y", 0).value(6.0);
+ builder.cell().label("cat", "b").label("x", 0).label("y", 1).value(7.0);
+ builder.cell().label("cat", "b").label("x", 1).label("y", 0).value(8.0);
+ builder.cell().label("cat", "b").label("x", 1).label("y", 1).value(9.0);
+ Tensor tensor = builder.build();
+ String shortJson = """
+ {
+ "type": "tensor(cat{},x[2],y[2])",
+ "blocks": {"a":[[2.0,3.0],[4.0,5.0]],"b":[[6.0,7.0],[8.0,9.0]]}
+ }
+ """;
+ byte[] shortEncoded = JsonFormat.encode(tensor, true, false);
+ assertEqualJson(shortJson, new String(shortEncoded, StandardCharsets.UTF_8));
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), shortEncoded));
+ String oldJson = """
+ {
+ "type": "tensor(cat{},x[2],y[2])",
+ "blocks": {"a":[2,3,4,5],"b":[6,7,8,9]}
+ }
+ """;
+ assertEquals(tensor, JsonFormat.decode(tensor.type(), oldJson.getBytes(StandardCharsets.UTF_8)));
+ }
+
private void assertEncodeDecode(Tensor tensor) {
Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encode(tensor, false, false));
assertEquals(tensor, decoded);