diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-13 14:38:24 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-13 14:38:24 +0100 |
commit | fdcf0682eb4ed0471431adaf4a6be70628b9c84d (patch) | |
tree | 929006dbc7398704f1ee496c3e9df020ef23c21d | |
parent | 7fad0f3d7b5dcd171655d101c05cf51f758bfc83 (diff) |
Convert tensor update to sparse
7 files changed, 55 insertions, 17 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index 5fd1c7bbab7..b8937d8b739 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -132,20 +132,27 @@ public class TensorModifyUpdateReader { Tensor.Builder tensorBuilder = Tensor.Builder.of(type); readTensorBlocks(buffer, tensorBuilder); - Tensor tensor = tensorBuilder.build(); - + Tensor tensor = convertToSparse(tensorBuilder.build()); validateBounds(tensor, type); return new TensorFieldValue(tensor); } + private static Tensor convertToSparse(Tensor tensor) { + if (tensor.type().dimensions().stream().noneMatch(dimension -> dimension.isIndexed())) return tensor; + Tensor.Builder b = Tensor.Builder.of(TensorModifyUpdate.convertDimensionsToMapped(tensor.type())); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) + b.cell(i.next()); + return b.build(); + } + /** Only validate if original type has indexed bound dimensions */ static void validateBounds(Tensor convertedTensor, TensorType originalType) { if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } - for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { - Tensor.Cell cell = iter.next(); + for (Iterator<Tensor.Cell> cellIterator = convertedTensor.cellIterator(); cellIterator.hasNext(); ) { + Tensor.Cell cell = cellIterator.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { TensorType.Dimension dim = originalType.dimensions().get(i); diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index e5699d0e6b1..769e31818e6 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -111,12 +111,15 @@ public class TensorReader { } else if (buffer.currentToken() == JsonToken.START_OBJECT) { int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - mixedBuilder.block(asAddress(buffer.currentName(), builder.type().mappedSubtype()), - readValues(buffer, (int)mixedBuilder.denseSubspaceSize())); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { + TensorAddress mappedAddress = asAddress(buffer.currentName(), builder.type().mappedSubtype()); + mixedBuilder.block(mappedAddress, + readValues(buffer, (int) mixedBuilder.denseSubspaceSize(), mappedAddress, mixedBuilder.type())); + } } else { - throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + buffer.currentToken()); + throw new IllegalArgumentException("Expected 'blocks' to contain an array or an object, but got " + + buffer.currentToken()); } expectCompositeEnd(buffer.currentToken()); @@ -134,7 +137,7 @@ public class TensorReader { if (TensorReader.TENSOR_ADDRESS.equals(currentName)) address = readAddress(buffer, mixedBuilder.type().mappedSubtype()); else if (TensorReader.TENSOR_VALUES.equals(currentName)) - values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize()); + values = readValues(buffer, (int)mixedBuilder.denseSubspaceSize(), address, mixedBuilder.type()); } expectObjectEnd(buffer.currentToken()); if (address == null) @@ -154,7 +157,16 @@ public class TensorReader { return builder.build(); } - private static double[] readValues(TokenBuffer buffer, int size) { + /** + * Reads values for a tensor subspace block + * + * @param buffer the buffer containing the values + * @param size the expected number of values + * @param address the address for the block for error reporting, or null if not known + * @param type the type of the tensor we are reading + * @return the values read + */ + private static double[] readValues(TokenBuffer buffer, int size, TensorAddress address, TensorType type) { expectArrayStart(buffer.currentToken()); int index = 0; @@ -162,6 +174,9 @@ public class TensorReader { double[] values = new double[size]; for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) values[index++] = readDouble(buffer); + if (index != size) + throw new IllegalArgumentException((address != null ? "At " + address.toString(type) + ": " : "") + + "Expected " + size + " values, but got " + index); expectCompositeEnd(buffer.currentToken()); return values; } diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 0015b59e9a9..b6664464e0b 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -30,7 +30,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { } private void verifyCompatibleType(TensorType type) { - if (type.rank() > 0 && type.dimensions().stream().noneMatch(dim -> dim.isMapped()) ) { + if (type.dimensions().stream().anyMatch(dim -> dim.isIndexed()) ) { throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it has no mapped dimensions"); } } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 5867ca5596c..54ae3d6d373 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1498,7 +1498,7 @@ public class JsonReaderTestCase { @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_array() { - assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", inputJson("{", " 'operation': 'replace',", " 'blocks': [", @@ -1506,8 +1506,18 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_replace_operation_mixed_block_short_form_must_specify_full_subspace() { + illegalTensorModifyUpdate("Error in 'mixed_tensor': At {x:a}: Expected 3 values, but got 2", + "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'blocks': {", + " 'a': [2,3] } }")); + } + + @Test public void tensor_modify_update_with_replace_operation_mixed_block_short_form_map() { - assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + assertTensorModifyUpdate("{{x:a,y:0}:1,{x:a,y:1}:2,{x:a,y:2}:3}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", inputJson("{", " 'operation': 'replace',", " 'blocks': {", diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index f631b3e1c58..66eb4b1f4e6 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1107,6 +1107,7 @@ "public varargs abstract com.yahoo.tensor.Tensor$Builder cell(float, long[])", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, double)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell, float)", + "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.Tensor$Cell)", "public abstract com.yahoo.tensor.Tensor build()" ], "fields": [] diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 67c6930ce35..2b393d8a637 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -253,8 +253,12 @@ public class MixedTensor implements Tensor { } public Tensor.Builder block(TensorAddress sparsePart, double[] values) { + int denseSubspaceSize = (int)denseSubspaceSize(); + if (values.length < denseSubspaceSize) + throw new IllegalArgumentException("Block should have " + denseSubspaceSize + + " values, but has only " + values.length); double[] denseSubspace = denseSubspace(sparsePart); - System.arraycopy(values, 0, denseSubspace, 0, (int)denseSubspaceSize()); + System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize); return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 08d4f1c08b7..71bdee36c27 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -516,9 +516,10 @@ public interface Tensor { default Builder cell(Cell cell, double value) { return cell(cell.getKey(), value); } - default Builder cell(Cell cell, float value) { - return cell(cell.getKey(), value); - } + default Builder cell(Cell cell, float value) { return cell(cell.getKey(), value); } + + /** Adds the given cell to this tensor */ + default Builder cell(Cell cell) { return cell(cell.getKey(), cell.getValue()); } Tensor build(); |