diff options
9 files changed, 198 insertions, 86 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java index 54c46b34163..ead6ad53715 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java @@ -89,7 +89,7 @@ public class FieldCollapsingSearcher extends Searcher { if (collapseField == null) return execution.search(query); - int collapseSize = query.properties().getInteger(collapsesize,defaultCollapseSize); + int collapseSize = query.properties().getInteger(collapsesize, defaultCollapseSize); query.properties().set(collapse, "0"); int hitsToRequest = query.getHits() != 0 ? (int) Math.ceil((query.getOffset() + query.getHits() + 1) * extraFactor) : 0; diff --git a/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java b/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java index d33cc8078dd..7f6ead528fe 100644 --- a/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java +++ b/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java @@ -48,6 +48,7 @@ import java.util.Set; * @author Vegard Sjonfjell */ public class JsonSerializationHelper { + private final static Base64.Encoder base64Encoder = Base64.getEncoder(); // Important: _basic_ format static class JsonSerializationException extends RuntimeException { @@ -99,14 +100,6 @@ public class JsonSerializationHelper { }); } - private static void serializeTensorDimensions(JsonGenerator generator, Set<String> dimensions) throws IOException { - generator.writeArrayFieldStart(TensorReader.TENSOR_DIMENSIONS); - for (String dimension : dimensions) { - generator.writeString(dimension); - } - generator.writeEndArray(); - } - static void serializeTensorCells(JsonGenerator generator, Tensor tensor) throws IOException { generator.writeArrayFieldStart(TensorReader.TENSOR_CELLS); for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) { diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java index 6b2bdbd53d8..e1214920296 100644 --- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java +++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java @@ -18,13 +18,15 @@ import com.google.common.base.Preconditions; public class TokenBuffer { private final Deque<Token> buffer; + private int nesting = 0; + private Token previousToken; public TokenBuffer() { this(new ArrayDeque<>()); } - private TokenBuffer(Deque<Token> buffer) { + public TokenBuffer(Deque<Token> buffer) { this.buffer = buffer; if (buffer.size() > 0) { updateNesting(buffer.peekFirst().token); @@ -35,7 +37,7 @@ public class TokenBuffer { public boolean isEmpty() { return size() == 0; } public JsonToken next() { - buffer.removeFirst(); + previousToken = buffer.removeFirst(); Token t = buffer.peekFirst(); if (t == null) { return null; @@ -44,6 +46,16 @@ public class TokenBuffer { return t.token; } + /** Goes one token back. Repeated calls to this method will *not* go back further. */ + public JsonToken previous() { + if (previousToken == null) return null; + updateNestingGoingBackwards(currentToken()); + Token newCurrent = previousToken; + previousToken = null; + buffer.push(newCurrent); + return newCurrent.token; + } + /** Returns the current token without changing position, or null if none */ public JsonToken currentToken() { Token token = buffer.peekFirst(); @@ -130,6 +142,10 @@ public class TokenBuffer { nesting += nestingOffset(t); } + private void updateNestingGoingBackwards(JsonToken t) { + nesting -= nestingOffset(t); + } + public int nesting() { return nesting; } diff --git a/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java b/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java index 5e1c1eb6ac4..b63a39f51c5 100644 --- a/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java +++ b/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java @@ -76,18 +76,10 @@ public class DocumentParser { throw new IllegalArgumentException("Could not read document, no document?"); } switch (currentToken) { - case START_OBJECT: - indentLevel++; - break; - case END_OBJECT: - indentLevel--; - break; - case START_ARRAY: - indentLevel += 10000L; - break; - case END_ARRAY: - indentLevel -= 10000L; - break; + case START_OBJECT -> indentLevel++; + case END_OBJECT -> indentLevel--; + case START_ARRAY -> indentLevel += 10000L; + case END_ARRAY -> indentLevel -= 10000L; } } @@ -133,18 +125,12 @@ public class DocumentParser { } private static DocumentOperationType operationNameToOperationType(String operationName) { - switch (operationName) { - case PUT: - case ID: - return DocumentOperationType.PUT; - case REMOVE: - return DocumentOperationType.REMOVE; - case UPDATE: - return DocumentOperationType.UPDATE; - default: - throw new IllegalArgumentException( - "Got " + operationName + " as document operation, only \"put\", " + - "\"remove\" and \"update\" are supported."); - } + return switch (operationName) { + case PUT, ID -> DocumentOperationType.PUT; + case REMOVE -> DocumentOperationType.REMOVE; + case UPDATE -> DocumentOperationType.UPDATE; + default -> throw new IllegalArgumentException("Got " + operationName + " as document operation, only \"put\", " + + "\"remove\" and \"update\" are supported."); + }; } } diff --git a/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java b/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java index 1723df2bd54..594dfc5ab06 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java +++ b/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java @@ -5,6 +5,8 @@ package com.yahoo.document.json.readers; import com.fasterxml.jackson.core.JsonToken; import com.google.common.base.Preconditions; +import java.util.Arrays; + public class JsonParserHelpers { public static void expectArrayStart(JsonToken token) { @@ -61,4 +63,9 @@ public class JsonParserHelpers { } } + public static void expectOneOf(JsonToken token, JsonToken ... tokens) { + if (Arrays.stream(tokens).noneMatch(t -> t == token)) + throw new IllegalArgumentException("Expected one of " + tokens + " but got " + token); + } + } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 193c9491e86..4b2da912c37 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -4,12 +4,18 @@ package com.yahoo.document.json.readers; import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; +import com.yahoo.document.select.parser.Token; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.Type; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.ArrayDeque; +import java.util.Deque; + import static com.yahoo.document.json.readers.JsonParserHelpers.*; import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; @@ -23,7 +29,6 @@ import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; public class TensorReader { public static final String TENSOR_ADDRESS = "address"; - public static final String TENSOR_DIMENSIONS = "dimensions"; public static final String TENSOR_CELLS = "cells"; public static final String TENSOR_VALUES = "values"; public static final String TENSOR_BLOCKS = "blocks"; @@ -32,22 +37,39 @@ public class TensorReader { // MUST be kept in sync with com.yahoo.tensor.serialization.JsonFormat.decode in vespajlib static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); - expectObjectStart(buffer.currentToken()); + expectOneOf(buffer.currentToken(), JsonToken.START_OBJECT, JsonToken.START_ARRAY); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { - if (TENSOR_CELLS.equals(buffer.currentName())) + if (TENSOR_CELLS.equals(buffer.currentName()) && ! primitiveContent(new TokenBuffer(new ArrayDeque<>(buffer.rest())))) { readTensorCells(buffer, builder); - else if (TENSOR_VALUES.equals(buffer.currentName())) + } + else if (TENSOR_VALUES.equals(buffer.currentName()) && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) { readTensorValues(buffer, builder); - else if (TENSOR_BLOCKS.equals(buffer.currentName())) + } + else if (TENSOR_BLOCKS.equals(buffer.currentName())) { readTensorBlocks(buffer, builder); - else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks', but got: "+buffer.currentName()); + } + else { + buffer.previous(); // Back up to the start of the enclosing block + readDirectTensorValue(buffer, builder); + buffer.previous(); // ... and back up to the end of the enclosing block + } } - expectObjectEnd(buffer.currentToken()); + expectOneOf(buffer.currentToken(), JsonToken.END_OBJECT, JsonToken.END_ARRAY); tensorFieldValue.assign(builder.build()); } + static boolean primitiveContent(TokenBuffer buffer) { + JsonToken cellsValue = buffer.currentToken(); + if (cellsValue.isScalarValue()) return true; + if (cellsValue == JsonToken.START_ARRAY) { + JsonToken firstArrayValue = buffer.next(); + if (firstArrayValue == JsonToken.END_ARRAY) return false; + if (firstArrayValue.isScalarValue()) return true; + } + return false; + } + static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) { if (buffer.currentToken() == JsonToken.START_ARRAY) { int initNesting = buffer.nesting(); @@ -88,10 +110,9 @@ public class TensorReader { } private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (buffer.currentToken() == JsonToken.VALUE_STRING) { double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType()); if (decoded.length == 0) @@ -112,11 +133,9 @@ public class TensorReader { } static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) { - if ( ! (builder instanceof MixedTensor.BoundBuilder)) + if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - - MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; if (buffer.currentToken() == JsonToken.START_ARRAY) { int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) @@ -160,6 +179,19 @@ public class TensorReader { mixedBuilder.block(address, values); } + /** Reads a tensor value directly at the root, where the format is decided by the tensor type. */ + private static void readDirectTensorValue(TokenBuffer buffer, Tensor.Builder builder) { + boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + + if ( ! hasMapped) + readTensorValues(buffer, builder); + else if (hasMapped && hasIndexed) + readTensorBlocks(buffer, builder); + else + readTensorCells(buffer, builder); + } + private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) { expectObjectStart(buffer.currentToken()); TensorAddress.Builder builder = new TensorAddress.Builder(type); diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 41d607b0d8e..c19094ff231 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1403,12 +1403,6 @@ public class JsonReaderTestCase { } @Test - public void testParsingOfTensorWithEmptyDimensions() { - assertSparseTensorField("tensor(x{},y{}):{}", - createPutWithSparseTensor(inputJson("{ 'dimensions': [] }"))); - } - - @Test public void testParsingOfTensorWithEmptyCells() { assertSparseTensorField("tensor(x{},y{}):{}", createPutWithSparseTensor(inputJson("{ 'cells': [] }"))); @@ -1511,6 +1505,32 @@ public class JsonReaderTestCase { Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor"); } + /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */ + @Test + public void testDirectValue() { + assertTensorField("tensor(x{}):{a:2, b:3}", "sparse_single_dimension_tensor", "{'a':2.0, 'b':3.0}"); + assertTensorField("tensor(x[2],y[3]):[2, 3, 4, 5, 6, 7]]", "dense_tensor", "[2, 3, 4, 5, 6, 7]"); + assertTensorField("tensor(x{},y[3]):{a:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor", "{'a':[2, 3, 4], 'b':[4, 5, 6]}"); + assertTensorField("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", "sparse_tensor", + "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]"); + } + + @Test + public void testDirectValueReservedNameKeys() { + // Single-valued + assertTensorField("tensor(x{}):{cells:2, b:3}", "sparse_single_dimension_tensor", "{'cells':2.0, 'b':3.0}"); + assertTensorField("tensor(x{}):{values:2, b:3}", "sparse_single_dimension_tensor", "{'values':2.0, 'b':3.0}"); + assertTensorField("tensor(x{}):{block:2, b:3}", "sparse_single_dimension_tensor", "{'block':2.0, 'b':3.0}"); + + // Multi-valued + assertTensorField("tensor(x{},y[3]):{cells:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor", + "{'cells':[2, 3, 4], 'b':[4, 5, 6]}"); + assertTensorField("tensor(x{},y[3]):{values:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor", + "{'values':[2, 3, 4], 'b':[4, 5, 6]}"); + assertTensorField("tensor(x{},y[3]):{block:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor", + "{'block':[2, 3, 4], 'b':[4, 5, 6]}"); + } + @Test public void testParsingOfMixedTensorOnMixedForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); @@ -2070,6 +2090,9 @@ public class JsonReaderTestCase { private static Tensor assertSparseTensorField(String expectedTensor, DocumentPut put) { return assertTensorField(expectedTensor, put, "sparse_tensor"); } + private Tensor assertTensorField(String expectedTensor, String fieldName, String inputJson) { + return assertTensorField(expectedTensor, createPutWithTensor(inputJson, fieldName), fieldName); + } private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) { return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index a7afc1efc6d..0e8fbc30bb6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -60,25 +60,21 @@ public class JsonFormat { Cursor root = slime.setObject(); root.setString("type", tensor.type().toString()); - // Encode as nested lists if indexed tensor - if (tensor instanceof IndexedTensor) { - IndexedTensor denseTensor = (IndexedTensor) tensor; + if (tensor instanceof IndexedTensor denseTensor) { + // Encode as nested lists if indexed tensor encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0); } - - // Short form for a single mapped dimension else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) { + // Short form for a single mapped dimension encodeSingleDimensionCells((MappedTensor) tensor, root); } - - // Short form for a mixed tensor else if (tensor instanceof MixedTensor && tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) { + // Short form for a mixed tensor encodeBlocks((MixedTensor) tensor, root); } - - // No other short forms exist: default to standard cell address output else { + // No other short forms exist: default to standard cell address output encodeCells(tensor, root); } @@ -177,17 +173,25 @@ public class JsonFormat { Tensor.Builder builder = Tensor.Builder.of(type); Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); - if (root.field("cells").valid()) + if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid()) + else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) decodeValues(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); - else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'"); + else + decodeDirectValue(root, builder); return builder.build(); } + private static boolean primitiveContent(Inspector cellsValue) { + if (cellsValue.type() == Type.DOUBLE) return true; + if (cellsValue.type() == Type.LONG) return true; + if (cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 && + ( cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG)) return true; + return false; + } + private static void decodeCells(Inspector cells, Tensor.Builder builder) { if (cells.type() == Type.ARRAY) cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder)); @@ -212,10 +216,9 @@ public class JsonFormat { } private static void decodeValues(Inspector values, Tensor.Builder builder) { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) + if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); - IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; if (values.type() == Type.STRING) { double[] decoded = decodeHexString(values.asString(), builder.type().valueType()); if (decoded.length == 0) @@ -240,10 +243,9 @@ public class JsonFormat { } private static void decodeBlocks(Inspector values, Tensor.Builder builder) { - if ( ! (builder instanceof MixedTensor.BoundBuilder)) + if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); - MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; if (values.type() == Type.ARRAY) values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder)); @@ -260,6 +262,19 @@ public class JsonFormat { decodeValues(block.field("values"), mixedBuilder)); } + /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ + private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { + boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + + if ( ! hasMapped) + decodeValues(root, builder); + else if (hasMapped && hasIndexed) + decodeBlocks(root, builder); + else + decodeCells(root, builder); + } + private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) { if (value.type() != Type.ARRAY) throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type()); @@ -334,18 +349,12 @@ public class JsonFormat { } public static double[] decodeHexString(String input, TensorType.Value valueType) { - switch(valueType) { - case INT8: - return decodeHexStringAsBytes(input); - case BFLOAT16: - return decodeHexStringAsBFloat16s(input); - case FLOAT: - return decodeHexStringAsFloats(input); - case DOUBLE: - return decodeHexStringAsDoubles(input); - default: - throw new IllegalArgumentException("Cannot handle value type: "+valueType); - } + return switch (valueType) { + case INT8 -> decodeHexStringAsBytes(input); + case BFLOAT16 -> decodeHexStringAsBFloat16s(input); + case FLOAT -> decodeHexStringAsFloats(input); + case DOUBLE -> decodeHexStringAsDoubles(input); + }; } private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 1c884186879..f71a68ec5ed 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -1,7 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; -import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -16,6 +15,30 @@ import static org.junit.Assert.fail; */ public class JsonFormatTestCase { + /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */ + @Test + public void testDirectValue() { + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}"); + assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]"); + assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", + "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]"); + } + + @Test + public void testDirectValueReservedNameKeys() { + // Single-valued + assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}"); + assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}"); + + // Multi-valued + assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}"); + assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}"); + } + @Test public void testSparseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); @@ -33,6 +56,21 @@ public class JsonFormatTestCase { } @Test + public void testEmptySparseTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + Tensor tensor = builder.build(); + byte[] json = JsonFormat.encode(tensor); + assertEquals("{\"cells\":[]}", + new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + + json = "{}".getBytes(); // short form variant of the above + decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); + } + + @Test public void testSingleSparseDimensionShortForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")); builder.cell().label("x", "a").value(2.0); @@ -327,7 +365,7 @@ public class JsonFormatTestCase { "]}"; try { JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8)); - fail("Excpected exception"); + fail("Expected exception"); } catch (IllegalArgumentException e) { assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage()); @@ -354,6 +392,14 @@ public class JsonFormatTestCase { assertEquals(expected, new String(json, StandardCharsets.UTF_8)); } + private void assertDecoded(String expected, String jsonToDecode) { + assertDecoded(Tensor.from(expected), jsonToDecode); + } + + private void assertDecoded(Tensor expected, String jsonToDecode) { + assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8))); + } + private void assertDecodeFails(TensorType type, String format, String msg) { try { Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8)); |