summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-12 22:13:12 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-12 22:13:12 +0100
commit2229a4d2e3141010850fa23f5ad731c9038052a8 (patch)
treed35ed23f65ef1ee793367e28450eda483372f031 /vespajlib
parent844eeeeebfd8cdffb28ee7d64e05a803aa2f0e5a (diff)
Parse tensor JSON values at root
Our current tensor JSON formats require a "blocks", "cells" or "values" key at the root, containing values in various forms. This adds support for skipping that extra level and adding content at the root, where the permissible content format depends on the tensor type, and matches the formats below "blocks", "cells" or "values" for the corresponding tensor types.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java67
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java50
2 files changed, 86 insertions, 31 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index a7afc1efc6d..0e8fbc30bb6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -60,25 +60,21 @@ public class JsonFormat {
Cursor root = slime.setObject();
root.setString("type", tensor.type().toString());
- // Encode as nested lists if indexed tensor
- if (tensor instanceof IndexedTensor) {
- IndexedTensor denseTensor = (IndexedTensor) tensor;
+ if (tensor instanceof IndexedTensor denseTensor) {
+ // Encode as nested lists if indexed tensor
encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0);
}
-
- // Short form for a single mapped dimension
else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
+ // Short form for a single mapped dimension
encodeSingleDimensionCells((MappedTensor) tensor, root);
}
-
- // Short form for a mixed tensor
else if (tensor instanceof MixedTensor &&
tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) {
+ // Short form for a mixed tensor
encodeBlocks((MixedTensor) tensor, root);
}
-
- // No other short forms exist: default to standard cell address output
else {
+ // No other short forms exist: default to standard cell address output
encodeCells(tensor, root);
}
@@ -177,17 +173,25 @@ public class JsonFormat {
Tensor.Builder builder = Tensor.Builder.of(type);
Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
- if (root.field("cells").valid())
+ if (root.field("cells").valid() && ! primitiveContent(root.field("cells")))
decodeCells(root.field("cells"), builder);
- else if (root.field("values").valid())
+ else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed()))
decodeValues(root.field("values"), builder);
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
- else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'");
+ else
+ decodeDirectValue(root, builder);
return builder.build();
}
+ private static boolean primitiveContent(Inspector cellsValue) {
+ if (cellsValue.type() == Type.DOUBLE) return true;
+ if (cellsValue.type() == Type.LONG) return true;
+ if (cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 &&
+ ( cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG)) return true;
+ return false;
+ }
+
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if (cells.type() == Type.ARRAY)
cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder));
@@ -212,10 +216,9 @@ public class JsonFormat {
}
private static void decodeValues(Inspector values, Tensor.Builder builder) {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
- IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
if (values.type() == Type.STRING) {
double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
if (decoded.length == 0)
@@ -240,10 +243,9 @@ public class JsonFormat {
}
private static void decodeBlocks(Inspector values, Tensor.Builder builder) {
- if ( ! (builder instanceof MixedTensor.BoundBuilder))
+ if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder))
throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " +
"Use 'cells' or 'values' instead");
- MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder;
if (values.type() == Type.ARRAY)
values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder));
@@ -260,6 +262,19 @@ public class JsonFormat {
decodeValues(block.field("values"), mixedBuilder));
}
+ /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */
+ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) {
+ boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
+
+ if ( ! hasMapped)
+ decodeValues(root, builder);
+ else if (hasMapped && hasIndexed)
+ decodeBlocks(root, builder);
+ else
+ decodeCells(root, builder);
+ }
+
private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
if (value.type() != Type.ARRAY)
throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type());
@@ -334,18 +349,12 @@ public class JsonFormat {
}
public static double[] decodeHexString(String input, TensorType.Value valueType) {
- switch(valueType) {
- case INT8:
- return decodeHexStringAsBytes(input);
- case BFLOAT16:
- return decodeHexStringAsBFloat16s(input);
- case FLOAT:
- return decodeHexStringAsFloats(input);
- case DOUBLE:
- return decodeHexStringAsDoubles(input);
- default:
- throw new IllegalArgumentException("Cannot handle value type: "+valueType);
- }
+ return switch (valueType) {
+ case INT8 -> decodeHexStringAsBytes(input);
+ case BFLOAT16 -> decodeHexStringAsBFloat16s(input);
+ case FLOAT -> decodeHexStringAsFloats(input);
+ case DOUBLE -> decodeHexStringAsDoubles(input);
+ };
}
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 1c884186879..f71a68ec5ed 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -1,7 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
-import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -16,6 +15,30 @@ import static org.junit.Assert.fail;
*/
public class JsonFormatTestCase {
+ /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */
+ @Test
+ public void testDirectValue() {
+ assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}");
+ assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]");
+ assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}",
+ "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]");
+ }
+
+ @Test
+ public void testDirectValueReservedNameKeys() {
+ // Single-valued
+ assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}");
+
+ // Multi-valued
+ assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}");
+ }
+
@Test
public void testSparseTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
@@ -33,6 +56,21 @@ public class JsonFormatTestCase {
}
@Test
+ public void testEmptySparseTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ Tensor tensor = builder.build();
+ byte[] json = JsonFormat.encode(tensor);
+ assertEquals("{\"cells\":[]}",
+ new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
+
+ json = "{}".getBytes(); // short form variant of the above
+ decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
+ }
+
+ @Test
public void testSingleSparseDimensionShortForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})"));
builder.cell().label("x", "a").value(2.0);
@@ -327,7 +365,7 @@ public class JsonFormatTestCase {
"]}";
try {
JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8));
- fail("Excpected exception");
+ fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage());
@@ -354,6 +392,14 @@ public class JsonFormatTestCase {
assertEquals(expected, new String(json, StandardCharsets.UTF_8));
}
+ private void assertDecoded(String expected, String jsonToDecode) {
+ assertDecoded(Tensor.from(expected), jsonToDecode);
+ }
+
+ private void assertDecoded(Tensor expected, String jsonToDecode) {
+ assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8)));
+ }
+
private void assertDecodeFails(TensorType type, String format, String msg) {
try {
Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8));