From fdcf0682eb4ed0471431adaf4a6be70628b9c84d Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Mon, 13 Jan 2020 14:38:24 +0100 Subject: Convert tensor update to sparse --- .../json/readers/TensorModifyUpdateReader.java | 15 ++++++++---- .../yahoo/document/json/readers/TensorReader.java | 27 +++++++++++++++++----- .../yahoo/document/update/TensorModifyUpdate.java | 2 +- .../yahoo/document/json/JsonReaderTestCase.java | 14 +++++++++-- 4 files changed, 45 insertions(+), 13 deletions(-) (limited to 'document') diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index 5fd1c7bbab7..b8937d8b739 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -132,20 +132,27 @@ public class TensorModifyUpdateReader { Tensor.Builder tensorBuilder = Tensor.Builder.of(type); readTensorBlocks(buffer, tensorBuilder); - Tensor tensor = tensorBuilder.build(); - + Tensor tensor = convertToSparse(tensorBuilder.build()); validateBounds(tensor, type); return new TensorFieldValue(tensor); } + private static Tensor convertToSparse(Tensor tensor) { + if (tensor.type().dimensions().stream().noneMatch(dimension -> dimension.isIndexed())) return tensor; + Tensor.Builder b = Tensor.Builder.of(TensorModifyUpdate.convertDimensionsToMapped(tensor.type())); + for (Iterator i = tensor.cellIterator(); i.hasNext(); ) + b.cell(i.next()); + return b.build(); + } + /** Only validate if original type has indexed bound dimensions */ static void validateBounds(Tensor convertedTensor, TensorType originalType) { if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } - for (Iterator iter = convertedTensor.cellIterator(); iter.hasNext(); ) { - Tensor.Cell cell = iter.next(); + for (Iterator cellIterator = convertedTensor.cellIterator(); cellIterator.hasNext(); ) { + Tensor.Cell cell = cellIterator.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { TensorType.Dimension dim = originalType.dimensions().get(i); diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index e5699d0e6b1..769e31818e6 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -111,12 +111,15 @@ public class TensorReader { } else if (buffer.currentToken() == JsonToken.START_OBJECT) { int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - mixedBuilder.block(asAddress(buffer.currentName(), builder.type().mappedSubtype()), - readValues(buffer, (int)mixedBuilder.denseSubspaceSize())); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { + TensorAddress mappedAddress = asAddress(buffer.currentName(), builder.type().mappedSubtype()); + mixedBuilder.block(mappedAddress, + readValues(buffer, (int) mixedBuilder.denseSubspaceSize(), mappedAddress, mixedBuilder.type())); + } } else { - throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + buffer.currentToken()); + throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + + buffer.currentToken()); } expectCompositeEnd(buffer.currentToken()); @@ -134,7 +137,7 @@ public class TensorReader { if (TensorReader.TENSOR_ADDRESS.equals(currentName)) address = readAddress(buffer, mixedBuilder.type().mappedSubtype()); else if (TensorReader.TENSOR_VALUES.equals(currentName)) - values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize()); + values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize(), address, mixedBuilder.type()); } expectObjectEnd(buffer.currentToken()); if (address == null) @@ -154,7 +157,16 @@ public class TensorReader { return builder.build(); } - private static double[] readValues(TokenBuffer buffer, int size) { + /** + * Reads values for a tensor subspace block + * + * @param buffer the buffer containing the values + * @param size the expected number of values + * @param address the address for the block for error reporting, or null if not known + * @param type the type of the tensor we are reading + * @return the values read + */ + private static double[] readValues(TokenBuffer buffer, int size, TensorAddress address, TensorType type) { expectArrayStart(buffer.currentToken()); int index = 0; @@ -162,6 +174,9 @@ public class TensorReader { double[] values = new double[size]; for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) values[index++] = readDouble(buffer); + if (index != size) + throw new IllegalArgumentException((address != null ? "At " + address.toString(type) + ": " : "") + + "Expected " + size + " values, but got " + index); expectCompositeEnd(buffer.currentToken()); return values; } diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 0015b59e9a9..b6664464e0b 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -30,7 +30,7 @@ public class TensorModifyUpdate extends ValueUpdate { } private void verifyCompatibleType(TensorType type) { - if (type.rank() > 0 && type.dimensions().stream().noneMatch(dim -> dim.isMapped()) ) { + if (type.dimensions().stream().anyMatch(dim -> dim.isIndexed()) ) { throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it has no mapped dimensions"); } } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 5867ca5596c..54ae3d6d373 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1498,16 +1498,26 @@ public class JsonReaderTestCase { @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_array() { - assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", inputJson("{", " 'operation': 'replace',", " 'blocks': [", " { 'address': { 'x': 'a' }, 'values': [1,2,3] } ]}")); } + @Test + public void tensor_modify_update_with_replace_operation_mixed_block_short_form_must_specify_full_subspace() { + illegalTensorModifyUpdate("Error in 'mixed_tensor': At {x:a}: Expected 3 values, but got 2", + "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'blocks': {", + " 'a': [2,3] } }")); + } + @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_map() { - assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", inputJson("{", " 'operation': 'replace',", " 'blocks': {", -- cgit v1.2.3