diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-08 14:12:00 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-08 14:12:00 +0200 |
commit | 4a720e1feb149158f5868b1739bc28821486a6e1 (patch) | |
tree | f303098d52047047e184c52ca54cad731af65ae8 /document | |
parent | ee1bf523a2cfb5e8d86a602323c337f6e55c202b (diff) |
Support mixed tensor short form JSON
Diffstat (limited to 'document')
-rw-r--r-- | document/src/main/java/com/yahoo/document/json/readers/TensorReader.java | 69 | ||||
-rw-r--r-- | document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java | 21 |
2 files changed, 87 insertions, 3 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 6bdac611fdc..90714476ac2 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -8,7 +8,10 @@ import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Type; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MappedTensor; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; import static com.yahoo.document.json.readers.JsonParserHelpers.*; @@ -25,10 +28,11 @@ public class TensorReader { public static final String TENSOR_DIMENSIONS = "dimensions"; public static final String TENSOR_CELLS = "cells"; public static final String TENSOR_VALUES = "values"; + public static final String TENSOR_BLOCKS = "blocks"; public static final String TENSOR_VALUE = "value"; + // MUST be kept in sync with com.yahoo.tensor.serialization.JsonFormat.decode in vespajlib static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { - // TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); @@ -37,8 +41,10 @@ public class TensorReader { readTensorCells(buffer, builder); else if (TENSOR_VALUES.equals(buffer.currentName())) readTensorValues(buffer, builder); + else if (TENSOR_BLOCKS.equals(buffer.currentName())) + readTensorBlocks(buffer, builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks'"); } expectObjectEnd(buffer.currentToken()); tensorFieldValue.assign(builder.build()); @@ -83,7 +89,7 @@ public class TensorReader { private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + - "Use 'cells' instead"); + "Use 'cells' or 'blocks' instead"); expectArrayStart(buffer.currentToken()); IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; @@ -94,6 +100,63 @@ public class TensorReader { expectCompositeEnd(buffer.currentToken()); } + private static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) { + if ( ! (builder instanceof MixedTensor.BoundBuilder)) + throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + + "Use 'cells' or 'values' instead"); + expectArrayStart(buffer.currentToken()); + + MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + readTensorBlock(buffer, mixedBuilder); + expectCompositeEnd(buffer.currentToken()); + } + + private static void readTensorBlock(TokenBuffer buffer, MixedTensor.BoundBuilder mixedBuilder) { + expectObjectStart(buffer.currentToken()); + + TensorAddress address = null; + double[] values = null; + + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { + String currentName = buffer.currentName(); + if (TensorReader.TENSOR_ADDRESS.equals(currentName)) + address = readAddress(buffer, mixedBuilder.type().mappedSubtype()); + else if (TensorReader.TENSOR_VALUES.equals(currentName)) + values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize()); + } + expectObjectEnd(buffer.currentToken()); + if (address == null) + throw new IllegalArgumentException("Expected a 'blocks' array object to contain an object 'address'"); + if (values == null) + throw new IllegalArgumentException("Expected a 'blocks' array object to contain an array 'values'"); + mixedBuilder.block(address, values); + } + + private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) { + expectObjectStart(buffer.currentToken()); + int initNesting = buffer.nesting(); + TensorAddress.Builder builder = new TensorAddress.Builder(type); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + builder.add(buffer.currentName(), buffer.currentText()); + expectObjectEnd(buffer.currentToken()); + return builder.build(); + } + + private static double[] readValues(TokenBuffer buffer, int size) { + expectArrayStart(buffer.currentToken()); + + int index = 0; + int initNesting = buffer.nesting(); + double[] values = new double[size]; + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + values[index++] = readDouble(buffer); + expectCompositeEnd(buffer.currentToken()); + return values; + } + private static double readDouble(TokenBuffer buffer) { try { return Double.valueOf(buffer.currentText()); diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 6103dc5947f..91998dedbb8 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -50,6 +50,7 @@ import com.yahoo.document.update.ValueUpdate; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MappedTensor; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; @@ -1314,6 +1315,26 @@ public class JsonReaderTestCase { } @Test + public void testParsingOfMixedTensorOnMixedForm() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + + String mixedJson = "{\"blocks\":[" + + "{\"address\":{\"x\":\"0\"},\"values\":[2.0,3.0,4.0]}," + + "{\"address\":{\"x\":\"1\"},\"values\":[5.0,6.0,7.0]}" + + "]}"; + Tensor tensor = assertTensorField(expected, + createPutWithTensor(inputJson(mixedJson), "mixed_tensor"), "mixed_tensor"); + assertTrue(tensor instanceof MixedTensor); // this matters for performance + } + + @Test public void testParsingOfTensorWithSingleCellInDifferentJsonOrder() { assertSparseTensorField("{{x:a,y:b}:2.0}", createPutWithSparseTensor(inputJson("{", |