From 4a720e1feb149158f5868b1739bc28821486a6e1 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 8 Oct 2019 14:12:00 +0200 Subject: Support mixed tensor short form JSON --- .../yahoo/document/json/readers/TensorReader.java | 69 +++++++++++++++++++++- 1 file changed, 66 insertions(+), 3 deletions(-) (limited to 'document/src/main') diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 6bdac611fdc..90714476ac2 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -8,7 +8,10 @@ import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Type; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MappedTensor; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; import static com.yahoo.document.json.readers.JsonParserHelpers.*; @@ -25,10 +28,11 @@ public class TensorReader { public static final String TENSOR_DIMENSIONS = "dimensions"; public static final String TENSOR_CELLS = "cells"; public static final String TENSOR_VALUES = "values"; + public static final String TENSOR_BLOCKS = "blocks"; public static final String TENSOR_VALUE = "value"; + // MUST be kept in sync with com.yahoo.tensor.serialization.JsonFormat.decode in vespajlib static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { - // TODO: Switch implementation to om.yahoo.tensor.serialization.JsonFormat.decode Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); @@ -37,8 +41,10 @@ public class TensorReader { readTensorCells(buffer, builder); else if (TENSOR_VALUES.equals(buffer.currentName())) readTensorValues(buffer, builder); + else if (TENSOR_BLOCKS.equals(buffer.currentName())) + readTensorBlocks(buffer, builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks'"); } expectObjectEnd(buffer.currentToken()); tensorFieldValue.assign(builder.build()); @@ -83,7 +89,7 @@ public class TensorReader { private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) { if ( ! (builder instanceof IndexedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + - "Use 'cells' instead"); + "Use 'cells' or 'blocks' instead"); expectArrayStart(buffer.currentToken()); IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; @@ -94,6 +100,63 @@ public class TensorReader { expectCompositeEnd(buffer.currentToken()); } + private static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) { + if ( ! (builder instanceof MixedTensor.BoundBuilder)) + throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + + "Use 'cells' or 'values' instead"); + expectArrayStart(buffer.currentToken()); + + MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder; + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + readTensorBlock(buffer, mixedBuilder); + expectCompositeEnd(buffer.currentToken()); + } + + private static void readTensorBlock(TokenBuffer buffer, MixedTensor.BoundBuilder mixedBuilder) { + expectObjectStart(buffer.currentToken()); + + TensorAddress address = null; + double[] values = null; + + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { + String currentName = buffer.currentName(); + if (TensorReader.TENSOR_ADDRESS.equals(currentName)) + address = readAddress(buffer, mixedBuilder.type().mappedSubtype()); + else if (TensorReader.TENSOR_VALUES.equals(currentName)) + values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize()); + } + expectObjectEnd(buffer.currentToken()); + if (address == null) + throw new IllegalArgumentException("Expected a 'blocks' array object to contain an object 'address'"); + if (values == null) + throw new IllegalArgumentException("Expected a 'blocks' array object to contain an array 'values'"); + mixedBuilder.block(address, values); + } + + private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) { + expectObjectStart(buffer.currentToken()); + int initNesting = buffer.nesting(); + TensorAddress.Builder builder = new TensorAddress.Builder(type); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + builder.add(buffer.currentName(), buffer.currentText()); + expectObjectEnd(buffer.currentToken()); + return builder.build(); + } + + private static double[] readValues(TokenBuffer buffer, int size) { + expectArrayStart(buffer.currentToken()); + + int index = 0; + int initNesting = buffer.nesting(); + double[] values = new double[size]; + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + values[index++] = readDouble(buffer); + expectCompositeEnd(buffer.currentToken()); + return values; + } + private static double readDouble(TokenBuffer buffer) { try { return Double.valueOf(buffer.currentText()); -- cgit v1.2.3