summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-12 22:13:12 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-12 22:13:12 +0100
commit2229a4d2e3141010850fa23f5ad731c9038052a8 (patch)
treed35ed23f65ef1ee793367e28450eda483372f031 /vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
parent844eeeeebfd8cdffb28ee7d64e05a803aa2f0e5a (diff)
Parse tensor JSON values at root
Our current tensor JSON formats require a "blocks", "cells" or "values" key at the root, containing values in various forms. This adds support for skipping that extra level and adding content at the root, where the permissible content format depends on the tensor type, and matches the formats below "blocks", "cells" or "values" for the corresponding tensor types.
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java50
1 files changed, 48 insertions, 2 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 1c884186879..f71a68ec5ed 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -1,7 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
-import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -16,6 +15,30 @@ import static org.junit.Assert.fail;
*/
public class JsonFormatTestCase {
+ /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */
+ @Test
+ public void testDirectValue() {
+ assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}");
+ assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]");
+ assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}",
+ "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]");
+ }
+
+ @Test
+ public void testDirectValueReservedNameKeys() {
+ // Single-valued
+ assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}");
+
+ // Multi-valued
+ assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}");
+ }
+
@Test
public void testSparseTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
@@ -33,6 +56,21 @@ public class JsonFormatTestCase {
}
@Test
+ public void testEmptySparseTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ Tensor tensor = builder.build();
+ byte[] json = JsonFormat.encode(tensor);
+ assertEquals("{\"cells\":[]}",
+ new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
+
+ json = "{}".getBytes(); // short form variant of the above
+ decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
+ }
+
+ @Test
public void testSingleSparseDimensionShortForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})"));
builder.cell().label("x", "a").value(2.0);
@@ -327,7 +365,7 @@ public class JsonFormatTestCase {
"]}";
try {
JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8));
- fail("Excpected exception");
+ fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage());
@@ -354,6 +392,14 @@ public class JsonFormatTestCase {
assertEquals(expected, new String(json, StandardCharsets.UTF_8));
}
+ private void assertDecoded(String expected, String jsonToDecode) {
+ assertDecoded(Tensor.from(expected), jsonToDecode);
+ }
+
+ private void assertDecoded(Tensor expected, String jsonToDecode) {
+ assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8)));
+ }
+
private void assertDecodeFails(TensorType type, String format, String msg) {
try {
Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8));