diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-01-12 22:13:12 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2023-01-12 22:13:12 +0100 |
commit | 2229a4d2e3141010850fa23f5ad731c9038052a8 (patch) | |
tree | d35ed23f65ef1ee793367e28450eda483372f031 /vespajlib/src/test | |
parent | 844eeeeebfd8cdffb28ee7d64e05a803aa2f0e5a (diff) |
Parse tensor JSON values at root
Our current tensor JSON formats require a "blocks", "cells" or "values" key
at the root, containing values in various forms.
This adds support for skipping that extra level and adding content at the root,
where the permissible content format depends on the tensor type, and matches
the formats below "blocks", "cells" or "values" for the corresponding tensor
types.
Diffstat (limited to 'vespajlib/src/test')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 1c884186879..f71a68ec5ed 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -1,7 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; -import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -16,6 +15,30 @@ import static org.junit.Assert.fail; */ public class JsonFormatTestCase { + /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */ + @Test + public void testDirectValue() { + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]"); + assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", + "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]"); + } + + @Test + public void testDirectValueReservedNameKeys() { + // Single-valued + assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}"); + + // Multi-valued + assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}"); + } + @Test public void testSparseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); @@ -33,6 +56,21 @@ public class JsonFormatTestCase { } @Test + public void testEmptySparseTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + Tensor tensor = builder.build(); + byte[] json = JsonFormat.encode(tensor); + assertEquals("{\"cells\":[]}", + new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + + json = "{}".getBytes(); // short form variant of the above + decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + } + + @Test public void testSingleSparseDimensionShortForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")); builder.cell().label("x", "a").value(2.0); @@ -327,7 +365,7 @@ public class JsonFormatTestCase { "]}"; try { JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8)); - fail("Excpected exception"); + fail("Expected exception"); } catch (IllegalArgumentException e) { assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage()); @@ -354,6 +392,14 @@ public class JsonFormatTestCase { assertEquals(expected, new String(json, StandardCharsets.UTF_8)); } + private void assertDecoded(String expected, String jsonToDecode) { + assertDecoded(Tensor.from(expected), jsonToDecode); + } + + private void assertDecoded(Tensor expected, String jsonToDecode) { + assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8))); + } + private void assertDecodeFails(TensorType type, String format, String msg) { try { Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8)); |