aboutsummaryrefslogtreecommitdiffstats
path: root/document/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-12 22:13:12 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-12 22:13:12 +0100
commit2229a4d2e3141010850fa23f5ad731c9038052a8 (patch)
treed35ed23f65ef1ee793367e28450eda483372f031 /document/src/test/java/com/yahoo
parent844eeeeebfd8cdffb28ee7d64e05a803aa2f0e5a (diff)
Parse tensor JSON values at root
Our current tensor JSON formats require a "blocks", "cells" or "values" key at the root, containing values in various forms. This adds support for skipping that extra level and adding content at the root, where the permissible content format depends on the tensor type, and matches the formats below "blocks", "cells" or "values" for the corresponding tensor types.
Diffstat (limited to 'document/src/test/java/com/yahoo')
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java35
1 files changed, 29 insertions, 6 deletions
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 41d607b0d8e..c19094ff231 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1403,12 +1403,6 @@ public class JsonReaderTestCase {
}
@Test
- public void testParsingOfTensorWithEmptyDimensions() {
- assertSparseTensorField("tensor(x{},y{}):{}",
- createPutWithSparseTensor(inputJson("{ 'dimensions': [] }")));
- }
-
- @Test
public void testParsingOfTensorWithEmptyCells() {
assertSparseTensorField("tensor(x{},y{}):{}",
createPutWithSparseTensor(inputJson("{ 'cells': [] }")));
@@ -1511,6 +1505,32 @@ public class JsonReaderTestCase {
Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor");
}
+ /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */
+ @Test
+ public void testDirectValue() {
+ assertTensorField("tensor(x{}):{a:2, b:3}", "sparse_single_dimension_tensor", "{'a':2.0, 'b':3.0}");
+ assertTensorField("tensor(x[2],y[3]):[2, 3, 4, 5, 6, 7]]", "dense_tensor", "[2, 3, 4, 5, 6, 7]");
+ assertTensorField("tensor(x{},y[3]):{a:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor", "{'a':[2, 3, 4], 'b':[4, 5, 6]}");
+ assertTensorField("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", "sparse_tensor",
+ "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]");
+ }
+
+ @Test
+ public void testDirectValueReservedNameKeys() {
+ // Single-valued
+ assertTensorField("tensor(x{}):{cells:2, b:3}", "sparse_single_dimension_tensor", "{'cells':2.0, 'b':3.0}");
+ assertTensorField("tensor(x{}):{values:2, b:3}", "sparse_single_dimension_tensor", "{'values':2.0, 'b':3.0}");
+ assertTensorField("tensor(x{}):{block:2, b:3}", "sparse_single_dimension_tensor", "{'block':2.0, 'b':3.0}");
+
+ // Multi-valued
+ assertTensorField("tensor(x{},y[3]):{cells:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor",
+ "{'cells':[2, 3, 4], 'b':[4, 5, 6]}");
+ assertTensorField("tensor(x{},y[3]):{values:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor",
+ "{'values':[2, 3, 4], 'b':[4, 5, 6]}");
+ assertTensorField("tensor(x{},y[3]):{block:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor",
+ "{'block':[2, 3, 4], 'b':[4, 5, 6]}");
+ }
+
@Test
public void testParsingOfMixedTensorOnMixedForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])"));
@@ -2070,6 +2090,9 @@ public class JsonReaderTestCase {
private static Tensor assertSparseTensorField(String expectedTensor, DocumentPut put) {
return assertTensorField(expectedTensor, put, "sparse_tensor");
}
+ private Tensor assertTensorField(String expectedTensor, String fieldName, String inputJson) {
+ return assertTensorField(expectedTensor, createPutWithTensor(inputJson, fieldName), fieldName);
+ }
private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) {
return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName);
}