diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-22 09:10:34 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-02-22 09:10:34 +0100 |
commit | 21a0951fb8906d60a6c2f565f6aac40087e986fe (patch) | |
tree | 30a5a80fa5a5d8046c12117be222ad91cbe10b10 /document | |
parent | f3e121c715d2cb60102b88494a4daccf1ec2ebc4 (diff) | |
parent | 21651a8420530f069d42f37ca4dd0381f043501a (diff) |
Merge pull request #8558 from vespa-engine/lesters/tensor-partial-update-mixed-tensors-java
Tensor partial update for mixed tensors - Java
Diffstat (limited to 'document')
13 files changed, 224 insertions, 120 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json index 61390af3523..d4db3026b27 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -5244,7 +5244,7 @@ ], "methods": [ "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)", - "public static com.yahoo.tensor.TensorType convertToCompatibleType(com.yahoo.tensor.TensorType)", + "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)", "public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()", "public com.yahoo.document.datatypes.TensorFieldValue getValue()", "public void setValue(com.yahoo.document.datatypes.TensorFieldValue)", @@ -5278,6 +5278,7 @@ "public boolean equals(java.lang.Object)", "public int hashCode()", "public java.lang.String toString()", + "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)", "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)", "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()" ], diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index ffbfe49347c..6310fa62d15 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -23,22 +23,23 @@ public class TensorAddUpdateReader { public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType tensorType = tensorDataType.getTensorType(); TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); fillTensor(buffer, tensorFieldValue); + expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); return new TensorAddUpdate(tensorFieldValue); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream() - .anyMatch(dim -> dim.isIndexed())) { - throw new IllegalArgumentException("An add update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("An add update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -48,5 +49,4 @@ public class TensorAddUpdateReader { } } - } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index a9bbba519bd..66588debbca 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -29,10 +29,8 @@ public class TensorModifyUpdateReader { private static final String MODIFY_MULTIPLY = "multiply"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { - expectFieldIsOfTypeTensor(field); expectTensorTypeHasNoneIndexedUnboundDimensions(field); - expectTensorTypeIsNotMixed(field); expectObjectStart(buffer.currentToken()); ModifyUpdateResult result = createModifyUpdateResult(buffer, field); @@ -58,16 +56,6 @@ public class TensorModifyUpdateReader { } } - private static void expectTensorTypeIsNotMixed(Field field) { - TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - long numMappedDimensions = tensorType.dimensions().stream().filter(dim -> dim.type().equals(TensorType.Dimension.Type.mapped)).count(); - long numIndexedDimensions = tensorType.dimensions().stream().filter(dim -> dim.isIndexed()).count(); - if (numMappedDimensions > 0 && numIndexedDimensions > 0) { - throw new IllegalArgumentException("A modify update cannot be applied to tensor types with mixed dimensions. " - + "Field '" + field.getName() + "' has mixed tensor type '" + tensorType + "'"); - } - } - private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) { if (operation == null) { throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation"); @@ -121,7 +109,7 @@ public class TensorModifyUpdateReader { private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); - TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType); + TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType); Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType); readTensorCells(buffer, tensorBuilder); @@ -129,25 +117,26 @@ public class TensorModifyUpdateReader { validateBounds(tensor, originalType); - TensorFieldValue result = new TensorFieldValue(convertedType); - result.assign(tensor); - return result; + return new TensorFieldValue(tensor); } - /** Only validate if original type is indexed bound */ - private static void validateBounds(Tensor convertedTensor, TensorType originalType) { - if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { + /** Only validate if original type has indexed bound dimensions */ + static void validateBounds(Tensor convertedTensor, TensorType originalType) { + if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { Tensor.Cell cell = iter.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { - long label = address.numericLabel(i); - long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound - if (label >= bound) { - throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + - "' has label '" + label + "' but type is " + originalType.toString()); + TensorType.Dimension dim = originalType.dimensions().get(i); + if (dim instanceof TensorType.IndexedBoundDimension) { + long label = address.numericLabel(i); + long bound = dim.size().get(); // size is non-optional for indexed bound + if (label >= bound) { + throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + + "' has label '" + label + "' but type is " + originalType.toString()); + } } } } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 210a6a80ee5..3bb4b2e262f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader { static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType tensorType = tensorDataType.getTensorType(); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType); + Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - Tensor tensor = readRemoveUpdateTensor(buffer, tensorType); expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) { - throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("A remove update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader { /** * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0 */ - private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) { + private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) { Tensor.Builder builder = Tensor.Builder.of(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); @@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader { expectArrayStart(buffer.currentToken()); int nesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) { - builder.cell(readTensorAddress(buffer, type), 1.0); + builder.cell(readTensorAddress(buffer, type, originalType), 1.0); } expectCompositeEnd(buffer.currentToken()); } @@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader { return builder.build(); } - private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) { + private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) { TensorAddress.Builder builder = new TensorAddress.Builder(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { String dimension = buffer.currentName(); + if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) { + throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update"); + } String label = buffer.currentText(); builder.add(dimension, label); } diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index 2f22def9aa1..a763db33e7a 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -5,6 +5,7 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentTypeManager; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.document.json.readers.TensorRemoveUpdateReader; import com.yahoo.document.update.TensorAddUpdate; import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.document.update.TensorRemoveUpdate; @@ -35,7 +36,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType())); + TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType); + + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorModifyUpdate(operation, tensor); } @@ -46,7 +50,8 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType()); + TensorType tensorType = tensorDataType.getTensorType(); + TensorFieldValue tensor = new TensorFieldValue(tensorType); tensor.deserialize(this); return new TensorAddUpdate(tensor); } @@ -58,10 +63,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - TensorFieldValue tensor = new TensorFieldValue(tensorType); + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorRemoveUpdate(tensor); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index cfc3ee0c742..f8d2464deb7 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -7,15 +7,11 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import java.util.Map; import java.util.Objects; /** - * An update used to add cells to a sparse tensor (has only mapped dimensions). - * - * The cells to add are contained in a sparse tensor as well. + * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). */ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @@ -50,22 +46,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> oldCells = oldTensor.cells(); - Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { - builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! oldCells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates + return new TensorFieldValue(result); } @Override diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 6111b51ca4e..2773f9d31da 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -37,7 +37,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { /** * Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions). */ - public static TensorType convertToCompatibleType(TensorType type) { + public static TensorType convertDimensionsToMapped(TensorType type) { TensorType.Builder builder = new TensorType.Builder(); type.dimensions().stream().forEach(dim -> builder.mapped(dim.name())); return builder.build(); diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index e9fb1e3efd5..335cda8e133 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -7,10 +7,8 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; -import java.util.Iterator; -import java.util.Map; import java.util.Objects; /** @@ -25,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); this.tensor = value; + verifyCompatibleType(); + } + + private void verifyCompatibleType() { + if ( ! tensor.getTensor().isPresent()) { + throw new IllegalArgumentException("Tensor must be present in remove update"); + } + TensorType tensorType = tensor.getTensor().get().type(); + TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType()); + if ( ! tensorType.equals(expectedType)) { + throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'"); + } } @Override @@ -51,17 +61,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells(); - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) { - Tensor.Cell cell = i.next(); - TensorAddress address = cell.getKey(); - if ( ! cellsToRemove.containsKey(address)) { - builder.cell(address, cell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.remove(update.cells().keySet()); + return new TensorFieldValue(result); } @Override @@ -93,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return super.toString() + " " + tensor; } + public static TensorType extractSparseDimensions(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); + return builder.build(); + } + + } diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java index e2736dabd2b..454ad72f344 100644 --- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java +++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java @@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest { final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build(); final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build(); + final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build(); final static DocumentTypeManager types = new DocumentTypeManager(); final static JsonFactory parserFactory = new JsonFactory(); final static DocumentType docType = new DocumentType("doctype"); @@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest { docType.addField(new Field("byte_field", DataType.BYTE)); docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType))); docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType))); + docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType))); docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777))); docType.addField(new Field("predicate_field", DataType.PREDICATE)); docType.addField(new Field("raw_field", DataType.RAW)); @@ -336,6 +338,26 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_modify_update_on_mixed_tensor() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'modify': {", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'c', 'y': '1' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_add_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -355,6 +377,29 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_add_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'add': {", + " 'cells': [", + " { 'address': { 'x': '1', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': '1', 'y': '1' }, 'value': 0.0 },", + " { 'address': { 'x': '1', 'y': '2' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '0' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '1' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '2' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_remove_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -374,6 +419,24 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_remove_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'remove': {", + " 'addresses': [", + " {'x':'0' }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void reference_field_id_can_be_update_assigned_non_empty_id() { roundtripSerializeJsonAndMatch(inputJson( "{", diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e58b26d500d..15d1e859f73 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1387,12 +1387,30 @@ public class JsonReaderTestCase { } @Test - public void tensor_modify_update_on_mixed_tensor_throws() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage("A modify update cannot be applied to tensor types with mixed dimensions. Field 'mixed_tensor' has mixed tensor type 'tensor(x{},y[3])'"); - createTensorModifyUpdate(inputJson("{", - " 'operation': 'replace',", - " 'cells': [] }"), "mixed_tensor"); + public void tensor_modify_update_with_replace_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_add_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor", + inputJson("{", + " 'operation': 'add',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_multiply_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor", + inputJson("{", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); } @Test @@ -1406,6 +1424,17 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])"); + createTensorModifyUpdate(inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + + @Test public void tensor_modify_update_with_unknown_operation_throws() { exception.expect(IllegalArgumentException.class); exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'"); @@ -1449,11 +1478,29 @@ public class JsonReaderTestCase { } @Test - public void tensor_add_update_on_non_sparse_tensor_throws() { + public void tensor_add_update_on_mixed_tensor() { + assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor", + inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}")); + } + + @Test + public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Index 3 out of bounds for length 3"); + createTensorAddUpdate(inputJson("{", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_add_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorAddUpdate(inputJson("{", - " 'cells': [] }"), "mixed_tensor"); + " 'cells': [] }"), "dense_tensor"); } @Test @@ -1470,6 +1517,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells"); createTensorAddUpdate(inputJson("{}"), "sparse_tensor"); + createTensorAddUpdate(inputJson("{}"), "mixed_tensor"); } @Test @@ -1482,11 +1530,30 @@ public class JsonReaderTestCase { } @Test - public void tensor_remove_update_on_non_sparse_tensor_throws() { + public void tensor_remove_update_on_mixed_tensor() { + assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", + inputJson("{", + " 'addresses': [", + " { 'x': '1' },", + " { 'x': '2' } ]}")); + } + + @Test + public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update"); + createTensorRemoveUpdate(inputJson("{", + " 'addresses': [", + " { 'x': '1', 'y': '0' },", + " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_remove_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorRemoveUpdate(inputJson("{", - " 'addresses': [] }"), "mixed_tensor"); + " 'addresses': [] }"), "dense_tensor"); } @Test @@ -1503,6 +1570,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses"); createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor"); + createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index eb4001e6415..6935c54ba2a 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -12,18 +12,14 @@ public class TensorAddUpdateTest { @Test public void apply_add_update_operations() { assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); } private void assertApplyTo(String init, String update, String expected) { String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); - TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue); - assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); + Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); + assertEquals(Tensor.from(spec, expected), updated); } } diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 6e9444de2be..b885e6ddca0 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -1,12 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.update; -import com.yahoo.document.Document; -import com.yahoo.document.DocumentId; -import com.yahoo.document.DocumentType; -import com.yahoo.document.DocumentTypeManager; -import com.yahoo.document.Field; -import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.update.TensorModifyUpdate.Operation; import com.yahoo.tensor.Tensor; @@ -28,10 +22,11 @@ public class TensorModifyUpdateTest { assertConvertToCompatible("tensor(x{})", "tensor(x[10])"); assertConvertToCompatible("tensor(x{})", "tensor(x{})"); assertConvertToCompatible("tensor(x{},y{},z{})", "tensor(x[],y[10],z{})"); + assertConvertToCompatible("tensor(x{},y{})", "tensor(x{},y[3])"); } private static void assertConvertToCompatible(String expectedType, String inputType) { - assertEquals(expectedType, TensorModifyUpdate.convertToCompatibleType(TensorType.fromSpec(inputType)).toString()); + assertEquals(expectedType, TensorModifyUpdate.convertDimensionsToMapped(TensorType.fromSpec(inputType)).toString()); } @Test @@ -46,15 +41,9 @@ public class TensorModifyUpdateTest { public void apply_modify_update_operations() { assertApplyTo("tensor(x{},y{})", Operation.REPLACE, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y{})", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); assertApplyTo("tensor(x[1],y[2])", Operation.ADD, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY, + assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY, "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); } diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 40ab00facdb..3a005e858c8 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -12,9 +12,6 @@ public class TensorRemoveUpdateTest { @Test public void apply_remove_update_operations() { assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); - assertApplyTo("{}", "{{x:0,y:0}:1}", "{}"); - assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); } private void assertApplyTo(String init, String update, String expected) { |