diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-13 11:19:24 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2020-01-13 11:19:24 +0100 |
commit | a3ec97de26572acdbfe4b1801744061decb84d38 (patch) | |
tree | c11e4a00f7595da100ea6a03f5eebc9771164800 /document | |
parent | 4976b922193b1071db4711328caf31bc54e1a0d1 (diff) |
Support modify of mixed tensors
Diffstat (limited to 'document')
7 files changed, 91 insertions, 30 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index 6310fa62d15..e98a262b661 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -13,7 +13,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStar import static com.yahoo.document.json.readers.TensorReader.fillTensor; /** - * Class used to read an add update for a tensor field. + * Reader of an "add" update of a tensor field. */ public class TensorAddUpdateReader { @@ -38,8 +38,8 @@ public class TensorAddUpdateReader { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { throw new IllegalArgumentException("An add update can only be applied to tensors " + - "with at least one sparse dimension. Field '" + field.getName() + - "' has unsupported tensor type '" + tensorType + "'"); + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index 66588debbca..5fd1c7bbab7 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -14,11 +14,13 @@ import com.yahoo.tensor.TensorType; import java.util.Iterator; import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart; +import static com.yahoo.document.json.readers.TensorReader.TENSOR_BLOCKS; import static com.yahoo.document.json.readers.TensorReader.TENSOR_CELLS; +import static com.yahoo.document.json.readers.TensorReader.readTensorBlocks; import static com.yahoo.document.json.readers.TensorReader.readTensorCells; /** - * Class used to read a modify update for a tensor field. + * Reader of a "modify" update of a tensor field. */ public class TensorModifyUpdateReader { @@ -30,7 +32,7 @@ public class TensorModifyUpdateReader { public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { expectFieldIsOfTypeTensor(field); - expectTensorTypeHasNoneIndexedUnboundDimensions(field); + expectTensorTypeHasNoIndexedUnboundDimensions(field); expectObjectStart(buffer.currentToken()); ModifyUpdateResult result = createModifyUpdateResult(buffer, field); @@ -41,18 +43,19 @@ public class TensorModifyUpdateReader { } private static void expectFieldIsOfTypeTensor(Field field) { - if (!(field.getDataType() instanceof TensorDataType)) { + if ( ! (field.getDataType() instanceof TensorDataType)) { throw new IllegalArgumentException("A modify update can only be applied to tensor fields. " + - "Field '" + field.getName() + "' is of type '" + field.getDataType().getName() + "'"); + "Field '" + field.getName() + "' is of type '" + + field.getDataType().getName() + "'"); } } - private static void expectTensorTypeHasNoneIndexedUnboundDimensions(Field field) { + private static void expectTensorTypeHasNoIndexedUnboundDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); if (tensorType.dimensions().stream() .anyMatch(dim -> dim.type().equals(TensorType.Dimension.Type.indexedUnbound))) { - throw new IllegalArgumentException("A modify update cannot be applied to tensor types with indexed unbound dimensions. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + throw new IllegalArgumentException("A modify update cannot be applied to tensor types with indexed unbound dimensions. " + + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -83,7 +86,10 @@ public class TensorModifyUpdateReader { result.operation = createOperation(buffer, field.getName()); break; case TENSOR_CELLS: - result.tensor = createTensor(buffer, field); + result.tensor = createTensorFromCells(buffer, field); + break; + case TENSOR_BLOCKS: + result.tensor = createTensorFromBlocks(buffer, field); break; default: throw new IllegalArgumentException("Unknown JSON string '" + buffer.currentName() + "' in modify update for field '" + field.getName() + "'"); @@ -106,7 +112,7 @@ public class TensorModifyUpdateReader { } } - private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) { + private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType); @@ -120,6 +126,19 @@ public class TensorModifyUpdateReader { return new TensorFieldValue(tensor); } + private static TensorFieldValue createTensorFromBlocks(TokenBuffer buffer, Field field) { + TensorDataType tensorDataType = (TensorDataType)field.getDataType(); + TensorType type = tensorDataType.getTensorType(); + + Tensor.Builder tensorBuilder = Tensor.Builder.of(type); + readTensorBlocks(buffer, tensorBuilder); + Tensor tensor = tensorBuilder.build(); + + validateBounds(tensor, type); + + return new TensorFieldValue(tensor); + } + /** Only validate if original type has indexed bound dimensions */ static void validateBounds(Tensor convertedTensor, TensorType originalType) { if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { @@ -135,7 +154,7 @@ public class TensorModifyUpdateReader { long bound = dim.size().get(); // size is non-optional for indexed bound if (label >= bound) { throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + - "' has label '" + label + "' but type is " + originalType.toString()); + "' has label '" + label + "' but type is " + originalType.toString()); } } } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 5516e9523a1..e5699d0e6b1 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -98,7 +98,7 @@ public class TensorReader { expectCompositeEnd(buffer.currentToken()); } - private static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) { + static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) { if ( ! (builder instanceof MixedTensor.BoundBuilder)) throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " + "Use 'cells' or 'values' instead"); diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 3bb4b2e262f..91c275b6da0 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -16,7 +16,7 @@ import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd; import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart; /** - * Class used to read a remove update for a tensor field. + * Reader of a "remove" update of a tensor field. */ public class TensorRemoveUpdateReader { @@ -39,14 +39,15 @@ public class TensorRemoveUpdateReader { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { throw new IllegalArgumentException("A remove update can only be applied to tensors " + - "with at least one sparse dimension. Field '" + field.getName() + - "' has unsupported tensor type '" + tensorType + "'"); + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } private static void expectAddressesAreNonEmpty(Field field, Tensor tensor) { if (tensor.isEmpty()) { - throw new IllegalArgumentException("Remove update for field '" + field.getName() + "' does not contain tensor addresses"); + throw new IllegalArgumentException("Remove update for field '" + field.getName() + + "' does not contain tensor addresses"); } } @@ -77,8 +78,9 @@ public class TensorRemoveUpdateReader { int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { String dimension = buffer.currentName(); - if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) { - throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update"); + if ( type.dimension(dimension).isEmpty() && originalType.dimension(dimension).isPresent()) { + throw new IllegalArgumentException("Indexed dimension address '" + dimension + + "' should not be specified in remove update"); } String label = buffer.currentText(); builder.add(dimension, label); diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 435c8fcdc65..cc59ff65f1f 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -18,6 +18,7 @@ import java.util.function.DoubleBinaryOperator; * The cells to update are contained in a sparse tensor (has only mapped dimensions). */ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { + protected Operation operation; protected TensorFieldValue tensor; @@ -29,8 +30,8 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { } private void verifyCompatibleType(TensorType type) { - if (type.dimensions().stream().anyMatch(dim -> dim.isIndexed()) ) { - throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it contains some indexed dimensions"); + if (type.rank() > 0 && type.dimensions().stream().noneMatch(dim -> dim.isMapped()) ) { + throw new IllegalArgumentException("Tensor type '" + type + "' is not compatible as it has no mapped dimensions"); } } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 511ad081c8c..5867ca5596c 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1378,13 +1378,13 @@ public class JsonReaderTestCase { @Test public void testAssignUpdateOfEmptySparseTensor() { - assertTensorAssignUpdate("tensor(x{},y{}):{}", createAssignUpdateWithSparseTensor("{}")); + assertTensorAssignUpdateSparseField("tensor(x{},y{}):{}", createAssignUpdateWithSparseTensor("{}")); } @Test public void testAssignUpdateOfEmptyDenseTensor() { try { - assertTensorAssignUpdate("tensor(x{},y{}):{}", createAssignUpdateWithTensor("{}", "dense_unbound_tensor")); + assertTensorAssignUpdateSparseField("tensor(x{},y{}):{}", createAssignUpdateWithTensor("{}", "dense_unbound_tensor")); } catch (IllegalArgumentException e) { assertEquals("An indexed tensor must have a value", @@ -1402,8 +1402,8 @@ public class JsonReaderTestCase { @Test public void testAssignUpdateOfTensorWithCells() { - assertTensorAssignUpdate("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", - createAssignUpdateWithSparseTensor(inputJson("{", + assertTensorAssignUpdateSparseField("{{x:a,y:b}:2.0,{x:c,y:b}:3.0}}", + createAssignUpdateWithSparseTensor(inputJson("{", " 'cells': [", " { 'address': { 'x': 'a', 'y': 'b' },", " 'value': 2.0 },", @@ -1414,6 +1414,15 @@ public class JsonReaderTestCase { } @Test + public void testAssignUpdateOfTensorDenseShortForm() { + assertTensorAssignUpdateDenseField("tensor(x[2],y[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]", + createAssignUpdateWithTensor(inputJson("{", + " 'values': [1,2,3,4,5,6]", + "}"), + "dense_tensor")); + } + + @Test public void tensor_modify_update_with_replace_operation() { assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.REPLACE, "sparse_tensor", inputJson("{", @@ -1488,6 +1497,24 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_replace_operation_mixed_block_short_form_array() { + assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'blocks': [", + " { 'address': { 'x': 'a' }, 'values': [1,2,3] } ]}")); + } + + @Test + public void tensor_modify_update_with_replace_operation_mixed_block_short_form_map() { + assertTensorModifyUpdate("tensor(x{},y[3]):{a:[1,2,3]}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'blocks': {", + " 'a': [1,2,3] } }")); + } + + @Test public void tensor_modify_update_with_add_operation_mixed() { assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor", inputJson("{", @@ -1830,10 +1857,18 @@ public class JsonReaderTestCase { assertEquals(1, update.getFieldUpdate(tensorFieldName).size()); } - private static void assertTensorAssignUpdate(String expectedTensor, DocumentUpdate update) { + private static void assertTensorAssignUpdateSparseField(String expectedTensor, DocumentUpdate update) { assertEquals("testtensor", update.getId().getDocType()); assertEquals(TENSOR_DOC_ID, update.getId().toString()); - AssignValueUpdate assignUpdate = (AssignValueUpdate) getTensorField(update).getValueUpdate(0); + AssignValueUpdate assignUpdate = (AssignValueUpdate) getTensorField(update, "sparse_tensor").getValueUpdate(0); + TensorFieldValue fieldValue = (TensorFieldValue) assignUpdate.getValue(); + assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get()); + } + + private static void assertTensorAssignUpdateDenseField(String expectedTensor, DocumentUpdate update) { + assertEquals("testtensor", update.getId().getDocType()); + assertEquals(TENSOR_DOC_ID, update.getId().toString()); + AssignValueUpdate assignUpdate = (AssignValueUpdate) getTensorField(update, "dense_tensor").getValueUpdate(0); TensorFieldValue fieldValue = (TensorFieldValue) assignUpdate.getValue(); assertEquals(Tensor.from(expectedTensor), fieldValue.getTensor().get()); } @@ -1895,7 +1930,11 @@ public class JsonReaderTestCase { } private static FieldUpdate getTensorField(DocumentUpdate update) { - FieldUpdate fieldUpdate = update.getFieldUpdate("sparse_tensor"); + return getTensorField(update, "sparse_tensor"); + } + + private static FieldUpdate getTensorField(DocumentUpdate update, String fieldName) { + FieldUpdate fieldUpdate = update.getFieldUpdate(fieldName); assertEquals(1, fieldUpdate.size()); return fieldUpdate; } diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index b885e6ddca0..4c8d2e69855 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -32,7 +32,7 @@ public class TensorModifyUpdateTest { @Test public void use_of_incompatible_tensor_type_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("Tensor type 'tensor(x[3])' is not compatible as it contains some indexed dimensions"); + exception.expectMessage("Tensor type 'tensor(x[3])' is not compatible as it has no mapped dimensions"); new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, new TensorFieldValue(Tensor.from("tensor(x[3])", "{{x:1}:3}"))); } |