diff options
44 files changed, 1128 insertions, 179 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json index 61390af3523..d4db3026b27 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -5244,7 +5244,7 @@ ], "methods": [ "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)", - "public static com.yahoo.tensor.TensorType convertToCompatibleType(com.yahoo.tensor.TensorType)", + "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)", "public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()", "public com.yahoo.document.datatypes.TensorFieldValue getValue()", "public void setValue(com.yahoo.document.datatypes.TensorFieldValue)", @@ -5278,6 +5278,7 @@ "public boolean equals(java.lang.Object)", "public int hashCode()", "public java.lang.String toString()", + "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)", "public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)", "public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()" ], diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java index ffbfe49347c..6310fa62d15 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java @@ -23,22 +23,23 @@ public class TensorAddUpdateReader { public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType tensorType = tensorDataType.getTensorType(); TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType); fillTensor(buffer, tensorFieldValue); + expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get()); return new TensorAddUpdate(tensorFieldValue); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream() - .anyMatch(dim -> dim.isIndexed())) { - throw new IllegalArgumentException("An add update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("An add update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -48,5 +49,4 @@ public class TensorAddUpdateReader { } } - } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index a9bbba519bd..66588debbca 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -29,10 +29,8 @@ public class TensorModifyUpdateReader { private static final String MODIFY_MULTIPLY = "multiply"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { - expectFieldIsOfTypeTensor(field); expectTensorTypeHasNoneIndexedUnboundDimensions(field); - expectTensorTypeIsNotMixed(field); expectObjectStart(buffer.currentToken()); ModifyUpdateResult result = createModifyUpdateResult(buffer, field); @@ -58,16 +56,6 @@ public class TensorModifyUpdateReader { } } - private static void expectTensorTypeIsNotMixed(Field field) { - TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - long numMappedDimensions = tensorType.dimensions().stream().filter(dim -> dim.type().equals(TensorType.Dimension.Type.mapped)).count(); - long numIndexedDimensions = tensorType.dimensions().stream().filter(dim -> dim.isIndexed()).count(); - if (numMappedDimensions > 0 && numIndexedDimensions > 0) { - throw new IllegalArgumentException("A modify update cannot be applied to tensor types with mixed dimensions. " - + "Field '" + field.getName() + "' has mixed tensor type '" + tensorType + "'"); - } - } - private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) { if (operation == null) { throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation"); @@ -121,7 +109,7 @@ public class TensorModifyUpdateReader { private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); - TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType); + TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType); Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType); readTensorCells(buffer, tensorBuilder); @@ -129,25 +117,26 @@ public class TensorModifyUpdateReader { validateBounds(tensor, originalType); - TensorFieldValue result = new TensorFieldValue(convertedType); - result.assign(tensor); - return result; + return new TensorFieldValue(tensor); } - /** Only validate if original type is indexed bound */ - private static void validateBounds(Tensor convertedTensor, TensorType originalType) { - if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { + /** Only validate if original type has indexed bound dimensions */ + static void validateBounds(Tensor convertedTensor, TensorType originalType) { + if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) { return; } for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) { Tensor.Cell cell = iter.next(); TensorAddress address = cell.getKey(); for (int i = 0; i < address.size(); ++i) { - long label = address.numericLabel(i); - long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound - if (label >= bound) { - throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + - "' has label '" + label + "' but type is " + originalType.toString()); + TensorType.Dimension dim = originalType.dimensions().get(i); + if (dim instanceof TensorType.IndexedBoundDimension) { + long label = address.numericLabel(i); + long bound = dim.size().get(); // size is non-optional for indexed bound + if (label >= bound) { + throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() + + "' has label '" + label + "' but type is " + originalType.toString()); + } } } } diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java index 210a6a80ee5..3bb4b2e262f 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java @@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader { static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) { expectObjectStart(buffer.currentToken()); - expectTensorTypeIsSparse(field); + expectTensorTypeHasSparseDimensions(field); TensorDataType tensorDataType = (TensorDataType)field.getDataType(); - TensorType tensorType = tensorDataType.getTensorType(); + TensorType originalType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType); + Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - Tensor tensor = readRemoveUpdateTensor(buffer, tensorType); expectAddressesAreNonEmpty(field, tensor); return new TensorRemoveUpdate(new TensorFieldValue(tensor)); } - private static void expectTensorTypeIsSparse(Field field) { + private static void expectTensorTypeHasSparseDimensions(Field field) { TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType(); - if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) { - throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. " - + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'"); + if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) { + throw new IllegalArgumentException("A remove update can only be applied to tensors " + + "with at least one sparse dimension. Field '" + field.getName() + + "' has unsupported tensor type '" + tensorType + "'"); } } @@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader { /** * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0 */ - private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) { + private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) { Tensor.Builder builder = Tensor.Builder.of(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); @@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader { expectArrayStart(buffer.currentToken()); int nesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) { - builder.cell(readTensorAddress(buffer, type), 1.0); + builder.cell(readTensorAddress(buffer, type, originalType), 1.0); } expectCompositeEnd(buffer.currentToken()); } @@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader { return builder.build(); } - private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) { + private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) { TensorAddress.Builder builder = new TensorAddress.Builder(type); expectObjectStart(buffer.currentToken()); int initNesting = buffer.nesting(); for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { String dimension = buffer.currentName(); + if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) { + throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update"); + } String label = buffer.currentText(); builder.add(dimension, label); } diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index 2f22def9aa1..a763db33e7a 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -5,6 +5,7 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentTypeManager; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.document.json.readers.TensorRemoveUpdateReader; import com.yahoo.document.update.TensorAddUpdate; import com.yahoo.document.update.TensorModifyUpdate; import com.yahoo.document.update.TensorRemoveUpdate; @@ -35,7 +36,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType())); + TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType); + + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorModifyUpdate(operation, tensor); } @@ -46,7 +50,8 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { throw new DeserializationException("Expected tensor data type, got " + type); } TensorDataType tensorDataType = (TensorDataType)type; - TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType()); + TensorType tensorType = tensorDataType.getTensorType(); + TensorFieldValue tensor = new TensorFieldValue(tensorType); tensor.deserialize(this); return new TensorAddUpdate(tensor); } @@ -58,10 +63,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); + TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType); - // TODO: for mixed case extract a new tensor type based only on mapped dimensions - - TensorFieldValue tensor = new TensorFieldValue(tensorType); + TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); return new TensorRemoveUpdate(tensor); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index cfc3ee0c742..f8d2464deb7 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -7,15 +7,11 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import java.util.Map; import java.util.Objects; /** - * An update used to add cells to a sparse tensor (has only mapped dimensions). - * - * The cells to add are contained in a sparse tensor as well. + * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension). */ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @@ -50,22 +46,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> oldCells = oldTensor.cells(); - Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); - - // currently, underlying implementation disallows multiple entries with the same key - - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { - builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); - } - for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { - if ( ! oldCells.containsKey(addCell.getKey())) { - builder.cell(addCell.getKey(), addCell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates + return new TensorFieldValue(result); } @Override diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 6111b51ca4e..2773f9d31da 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -37,7 +37,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { /** * Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions). */ - public static TensorType convertToCompatibleType(TensorType type) { + public static TensorType convertDimensionsToMapped(TensorType type) { TensorType.Builder builder = new TensorType.Builder(); type.dimensions().stream().forEach(dim -> builder.mapped(dim.name())); return builder.build(); diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index e9fb1e3efd5..335cda8e133 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -7,10 +7,8 @@ import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; -import java.util.Iterator; -import java.util.Map; import java.util.Objects; /** @@ -25,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { public TensorRemoveUpdate(TensorFieldValue value) { super(ValueUpdateClassID.TENSORREMOVE); this.tensor = value; + verifyCompatibleType(); + } + + private void verifyCompatibleType() { + if ( ! tensor.getTensor().isPresent()) { + throw new IllegalArgumentException("Tensor must be present in remove update"); + } + TensorType tensorType = tensor.getTensor().get().type(); + TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType()); + if ( ! tensorType.equals(expectedType)) { + throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'"); + } } @Override @@ -51,17 +61,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } - Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); - Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells(); - Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); - for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) { - Tensor.Cell cell = i.next(); - TensorAddress address = cell.getKey(); - if ( ! cellsToRemove.containsKey(address)) { - builder.cell(address, cell.getValue()); - } - } - return new TensorFieldValue(builder.build()); + Tensor old = ((TensorFieldValue) oldValue).getTensor().get(); + Tensor update = tensor.getTensor().get(); + Tensor result = old.remove(update.cells().keySet()); + return new TensorFieldValue(result); } @Override @@ -93,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { return super.toString() + " " + tensor; } + public static TensorType extractSparseDimensions(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); + return builder.build(); + } + + } diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java index e2736dabd2b..454ad72f344 100644 --- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java +++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java @@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest { final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build(); final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build(); + final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build(); final static DocumentTypeManager types = new DocumentTypeManager(); final static JsonFactory parserFactory = new JsonFactory(); final static DocumentType docType = new DocumentType("doctype"); @@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest { docType.addField(new Field("byte_field", DataType.BYTE)); docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType))); docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType))); + docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType))); docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777))); docType.addField(new Field("predicate_field", DataType.PREDICATE)); docType.addField(new Field("raw_field", DataType.RAW)); @@ -336,6 +338,26 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_modify_update_on_mixed_tensor() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'modify': {", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'c', 'y': '1' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_add_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -355,6 +377,29 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_add_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'add': {", + " 'cells': [", + " { 'address': { 'x': '1', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': '1', 'y': '1' }, 'value': 0.0 },", + " { 'address': { 'x': '1', 'y': '2' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '0' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '1' }, 'value': 0.0 },", + " { 'address': { 'x': '0', 'y': '2' }, 'value': 3.0 }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void test_tensor_remove_update() { roundtripSerializeJsonAndMatch(inputJson( "{", @@ -374,6 +419,24 @@ public class DocumentUpdateJsonSerializerTest { } @Test + public void test_tensor_remove_update_mixed() { + roundtripSerializeJsonAndMatch(inputJson( + "{", + " 'update': 'DOCUMENT_ID',", + " 'fields': {", + " 'mixed_tensor': {", + " 'remove': {", + " 'addresses': [", + " {'x':'0' }", + " ]", + " }", + " }", + " }", + "}" + )); + } + + @Test public void reference_field_id_can_be_update_assigned_non_empty_id() { roundtripSerializeJsonAndMatch(inputJson( "{", diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e58b26d500d..15d1e859f73 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1387,12 +1387,30 @@ public class JsonReaderTestCase { } @Test - public void tensor_modify_update_on_mixed_tensor_throws() { - exception.expect(IllegalArgumentException.class); - exception.expectMessage("A modify update cannot be applied to tensor types with mixed dimensions. Field 'mixed_tensor' has mixed tensor type 'tensor(x{},y[3])'"); - createTensorModifyUpdate(inputJson("{", - " 'operation': 'replace',", - " 'cells': [] }"), "mixed_tensor"); + public void tensor_modify_update_with_replace_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor", + inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_add_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor", + inputJson("{", + " 'operation': 'add',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_multiply_operation_mixed() { + assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor", + inputJson("{", + " 'operation': 'multiply',", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}")); } @Test @@ -1406,6 +1424,17 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])"); + createTensorModifyUpdate(inputJson("{", + " 'operation': 'replace',", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + + @Test public void tensor_modify_update_with_unknown_operation_throws() { exception.expect(IllegalArgumentException.class); exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'"); @@ -1449,11 +1478,29 @@ public class JsonReaderTestCase { } @Test - public void tensor_add_update_on_non_sparse_tensor_throws() { + public void tensor_add_update_on_mixed_tensor() { + assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor", + inputJson("{", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },", + " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}")); + } + + @Test + public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() { + exception.expect(IndexOutOfBoundsException.class); + exception.expectMessage("Index 3 out of bounds for length 3"); + createTensorAddUpdate(inputJson("{", + " 'cells': [", + " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_add_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorAddUpdate(inputJson("{", - " 'cells': [] }"), "mixed_tensor"); + " 'cells': [] }"), "dense_tensor"); } @Test @@ -1470,6 +1517,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells"); createTensorAddUpdate(inputJson("{}"), "sparse_tensor"); + createTensorAddUpdate(inputJson("{}"), "mixed_tensor"); } @Test @@ -1482,11 +1530,30 @@ public class JsonReaderTestCase { } @Test - public void tensor_remove_update_on_non_sparse_tensor_throws() { + public void tensor_remove_update_on_mixed_tensor() { + assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor", + inputJson("{", + " 'addresses': [", + " { 'x': '1' },", + " { 'x': '2' } ]}")); + } + + @Test + public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() { + exception.expect(IllegalArgumentException.class); + exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update"); + createTensorRemoveUpdate(inputJson("{", + " 'addresses': [", + " { 'x': '1', 'y': '0' },", + " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor"); + } + + @Test + public void tensor_remove_update_on_dense_tensor_throws() { exception.expect(IllegalArgumentException.class); - exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'"); + exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'"); createTensorRemoveUpdate(inputJson("{", - " 'addresses': [] }"), "mixed_tensor"); + " 'addresses': [] }"), "dense_tensor"); } @Test @@ -1503,6 +1570,7 @@ public class JsonReaderTestCase { exception.expect(IllegalArgumentException.class); exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses"); createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor"); + createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor"); } @Test diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java index eb4001e6415..6935c54ba2a 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -12,18 +12,14 @@ public class TensorAddUpdateTest { @Test public void apply_add_update_operations() { assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); - assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); } private void assertApplyTo(String init, String update, String expected) { String spec = "tensor(x{},y{})"; TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))); - TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue); - assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); + Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get(); + assertEquals(Tensor.from(spec, expected), updated); } } diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 6e9444de2be..b885e6ddca0 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -1,12 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.document.update; -import com.yahoo.document.Document; -import com.yahoo.document.DocumentId; -import com.yahoo.document.DocumentType; -import com.yahoo.document.DocumentTypeManager; -import com.yahoo.document.Field; -import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.update.TensorModifyUpdate.Operation; import com.yahoo.tensor.Tensor; @@ -28,10 +22,11 @@ public class TensorModifyUpdateTest { assertConvertToCompatible("tensor(x{})", "tensor(x[10])"); assertConvertToCompatible("tensor(x{})", "tensor(x{})"); assertConvertToCompatible("tensor(x{},y{},z{})", "tensor(x[],y[10],z{})"); + assertConvertToCompatible("tensor(x{},y{})", "tensor(x{},y[3])"); } private static void assertConvertToCompatible(String expectedType, String inputType) { - assertEquals(expectedType, TensorModifyUpdate.convertToCompatibleType(TensorType.fromSpec(inputType)).toString()); + assertEquals(expectedType, TensorModifyUpdate.convertDimensionsToMapped(TensorType.fromSpec(inputType)).toString()); } @Test @@ -46,15 +41,9 @@ public class TensorModifyUpdateTest { public void apply_modify_update_operations() { assertApplyTo("tensor(x{},y{})", Operation.REPLACE, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); - assertApplyTo("tensor(x{},y{})", Operation.ADD, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY, - "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); - assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE, - "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}"); assertApplyTo("tensor(x[1],y[2])", Operation.ADD, "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}"); - assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY, + assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY, "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); } diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java index 40ab00facdb..3a005e858c8 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java @@ -12,9 +12,6 @@ public class TensorRemoveUpdateTest { @Test public void apply_remove_update_operations() { assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}"); - assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}"); - assertApplyTo("{}", "{{x:0,y:0}:1}", "{}"); - assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}"); } private void assertApplyTo(String init, String update, String expected) { diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index c74c211756f..017d83893f0 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -922,6 +922,17 @@ TEST(DocumentUpdateTest, tensor_add_update_can_be_applied) .add({{"x", "c"}}, 7)); } +TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied) +{ + TensorUpdateFixture f; + f.assertApplyUpdate(f.spec().add({{"x", "a"}}, 2) + .add({{"x", "b"}}, 3), + + TensorRemoveUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 1))), + + f.spec().add({{"x", "a"}}, 2)); +} + TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied) { TensorUpdateFixture f; diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 3e2bb86c66b..671bf260629 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -6,6 +6,8 @@ #include <vespa/document/fieldvalue/document.h> #include <vespa/document/fieldvalue/tensorfieldvalue.h> #include <vespa/document/serialization/vespadocumentdeserializer.h> +#include <vespa/eval/tensor/cell_values.h> +#include <vespa/eval/tensor/sparse/sparse_tensor.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/xmlstream.h> @@ -77,17 +79,35 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const std::unique_ptr<Tensor> TensorRemoveUpdate::applyTo(const Tensor &tensor) const { - // TODO: implement - (void) tensor; + auto &addressTensor = _tensor->getAsTensorPtr(); + if (addressTensor) { + if (const auto *sparseTensor = dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) { + vespalib::tensor::CellValues cellAddresses(*sparseTensor); + return tensor.remove(cellAddresses); + } else { + throw IllegalArgumentException(make_string("Expected address tensor to be sparse, but has type '%s'", + addressTensor->type().to_spec().c_str())); + } + } return std::unique_ptr<Tensor>(); } bool TensorRemoveUpdate::applyTo(FieldValue &value) const { - // TODO: implement - (void) value; - return false; + if (value.inherits(TensorFieldValue::classId)) { + TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value); + auto &oldTensor = tensorFieldValue.getAsTensorPtr(); + auto newTensor = applyTo(*oldTensor); + if (newTensor) { + tensorFieldValue = std::move(newTensor); + } + } else { + std::string err = make_string("Unable to perform a tensor remove update on a '%s' field value.", + value.getClass().name()); + throw IllegalStateException(err, VESPA_STRLOC); + } + return true; } void diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h index 0e15943f8e4..6939d10ce2c 100644 --- a/document/src/vespa/document/update/valueupdate.h +++ b/document/src/vespa/document/update/valueupdate.h @@ -55,7 +55,8 @@ public: Map = IDENTIFIABLE_CLASSID(MapValueUpdate), Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate), TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate), - TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate) + TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate), + TensorRemoveUpdate = IDENTIFIABLE_CLASSID(TensorRemoveUpdate) }; ValueUpdate() diff --git a/documentapi/CMakeLists.txt b/documentapi/CMakeLists.txt index b03dd66c817..86d29732399 100644 --- a/documentapi/CMakeLists.txt +++ b/documentapi/CMakeLists.txt @@ -14,7 +14,6 @@ vespa_define_module( vdslib LIBS - src/vespa/binref src/vespa/documentapi src/vespa/documentapi/loadtypes src/vespa/documentapi/messagebus diff --git a/documentapi/src/vespa/binref/.gitignore b/documentapi/src/vespa/binref/.gitignore deleted file mode 100644 index cfb0e619824..00000000000 --- a/documentapi/src/vespa/binref/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -.depend -Makefile -testrun.sh diff --git a/documentapi/src/vespa/binref/CMakeLists.txt b/documentapi/src/vespa/binref/CMakeLists.txt deleted file mode 100644 index adece6dd711..00000000000 --- a/documentapi/src/vespa/binref/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. diff --git a/jrt_test/src/binref/testrun.sh b/jrt_test/src/binref/testrun.sh deleted file mode 120000 index 56c3c1186d8..00000000000 --- a/jrt_test/src/binref/testrun.sh +++ /dev/null @@ -1 +0,0 @@ -../../../vespalib/src/vespa/vespalib/testkit/testrun.sh
\ No newline at end of file diff --git a/lowercasing_test/src/binref/testrun.sh b/lowercasing_test/src/binref/testrun.sh deleted file mode 120000 index 56c3c1186d8..00000000000 --- a/lowercasing_test/src/binref/testrun.sh +++ /dev/null @@ -1 +0,0 @@ -../../../vespalib/src/vespa/vespalib/testkit/testrun.sh
\ No newline at end of file diff --git a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp index afbb1c30f17..78cd9ce44b9 100644 --- a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp +++ b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp @@ -20,6 +20,7 @@ #include <vespa/document/update/removevalueupdate.h> #include <vespa/document/update/tensor_add_update.h> #include <vespa/document/update/tensor_modify_update.h> +#include <vespa/document/update/tensor_remove_update.h> #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/searchcore/proton/common/attribute_updater.h> @@ -28,8 +29,8 @@ #include <vespa/searchlib/attribute/reference_attribute.h> #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/generic_tensor_attribute.h> -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/stllike/hash_map.hpp> +#include <vespa/vespalib/testkit/testapp.h> #include <vespa/log/log.h> LOG_SETUP("attribute_updater_test"); @@ -76,7 +77,8 @@ makeDocumentTypeRepo() .addField("wsfloat", Wset(DataType::T_FLOAT)) .addField("wsstring", Wset(DataType::T_STRING)) .addField("ref", 333) - .addField("dense_tensor", DataType::T_TENSOR), + .addField("dense_tensor", DataType::T_TENSOR) + .addField("sparse_tensor", DataType::T_TENSOR), Struct("testdoc.body")) .referenceType(333, 222); return std::make_unique<DocumentTypeRepo>(builder.config()); @@ -416,35 +418,54 @@ makeTensorFieldValue(const TensorSpec &spec) return result; } -void -setTensor(TensorAttribute &attribute, uint32_t lid, const TensorSpec &spec) -{ - auto tensor = makeTensor(spec); - attribute.setTensor(lid, *tensor); - attribute.commit(); -} +template <typename TensorAttributeType> +struct TensorFixture : public Fixture { + vespalib::string type; + std::unique_ptr<TensorAttributeType> attribute; -TEST_F("require that tensor modify update is applied", Fixture) -{ - vespalib::string type = "tensor(x[2])"; - auto attribute = makeTensorAttribute<DenseTensorAttribute>("dense_tensor", type); - setTensor(*attribute, 1, TensorSpec(type).add({{"x", 0}}, 3).add({{"x", 1}}, 5)); + TensorFixture(const vespalib::string &type_, const vespalib::string &name) + : type(type_), + attribute(makeTensorAttribute<TensorAttributeType>(name, type)) + { + } - f.applyValueUpdate(*attribute, 1, + void setTensor(const TensorSpec &spec) { + auto tensor = makeTensor(spec); + attribute->setTensor(1, *tensor); + attribute->commit(); + } + + void assertTensor(const TensorSpec &expSpec) { + EXPECT_EQUAL(expSpec, attribute->getTensor(1)->toSpec()); + } +}; + +TEST_F("require that tensor modify update is applied", + TensorFixture<DenseTensorAttribute>("tensor(x[2])", "dense_tensor")) +{ + f.setTensor(TensorSpec(f.type).add({{"x", 0}}, 3).add({{"x", 1}}, 5)); + f.applyValueUpdate(*f.attribute, 1, TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, makeTensorFieldValue(TensorSpec("tensor(x{})").add({{"x", 0}}, 7)))); - EXPECT_EQUAL(TensorSpec(type).add({{"x", 0}}, 7).add({{"x", 1}}, 5), attribute->getTensor(1)->toSpec()); + f.assertTensor(TensorSpec(f.type).add({{"x", 0}}, 7).add({{"x", 1}}, 5)); } -TEST_F("require that tensor add update is applied", Fixture) +TEST_F("require that tensor add update is applied", + TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor")) { - vespalib::string type = "tensor(x{})"; - auto attribute = makeTensorAttribute<GenericTensorAttribute>("dense_tensor", type); - setTensor(*attribute, 1, TensorSpec(type).add({{"x", "a"}}, 2)); + f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2)); + f.applyValueUpdate(*f.attribute, 1, + TensorAddUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "a"}}, 3)))); + f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 3)); +} - f.applyValueUpdate(*attribute, 1, - TensorAddUpdate(makeTensorFieldValue(TensorSpec(type).add({{"x", "a"}}, 3)))); - EXPECT_EQUAL(TensorSpec(type).add({{"x", "a"}}, 3), attribute->getTensor(1)->toSpec()); +TEST_F("require that tensor remove update is applied", + TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor")) +{ + f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2).add({{"x", "b"}}, 3)); + f.applyValueUpdate(*f.attribute, 1, + TensorRemoveUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "b"}}, 1)))); + f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 2)); } } diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp index 933857cffed..fcca1c2a737 100644 --- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp @@ -16,6 +16,7 @@ #include <vespa/document/update/removevalueupdate.h> #include <vespa/document/update/tensor_add_update.h> #include <vespa/document/update/tensor_modify_update.h> +#include <vespa/document/update/tensor_remove_update.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/searchlib/attribute/attributevector.hpp> #include <vespa/searchlib/attribute/changevector.hpp> @@ -238,6 +239,8 @@ AttributeUpdater::handleUpdate(TensorAttribute &vec, uint32_t lid, const ValueUp applyTensorUpdate(vec, lid, static_cast<const TensorModifyUpdate &>(upd)); } else if (op == ValueUpdate::TensorAddUpdate) { applyTensorUpdate(vec, lid, static_cast<const TensorAddUpdate &>(upd)); + } else if (op == ValueUpdate::TensorRemoveUpdate) { + applyTensorUpdate(vec, lid, static_cast<const TensorRemoveUpdate &>(upd)); } else if (op == ValueUpdate::Clear) { vec.clearDoc(lid); } else { diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java new file mode 100644 index 00000000000..2911b77707a --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java @@ -0,0 +1,101 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.io.IOException; +import java.net.Authenticator; +import java.net.CookieHandler; +import java.net.ProxySelector; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executor; + +/** + * A {@link HttpClient} that uses either http or https based on the global Vespa TLS configuration. + * + * @author bjorncs + */ +class TlsAwareHttpClient extends HttpClient { + + private final HttpClient wrappedClient; + private final String userAgent; + + TlsAwareHttpClient(HttpClient wrappedClient, String userAgent) { + this.wrappedClient = wrappedClient; + this.userAgent = userAgent; + } + + @Override + public Optional<CookieHandler> cookieHandler() { + return wrappedClient.cookieHandler(); + } + + @Override + public Optional<Duration> connectTimeout() { + return wrappedClient.connectTimeout(); + } + + @Override + public Redirect followRedirects() { + return wrappedClient.followRedirects(); + } + + @Override + public Optional<ProxySelector> proxy() { + return wrappedClient.proxy(); + } + + @Override + public SSLContext sslContext() { + return wrappedClient.sslContext(); + } + + @Override + public SSLParameters sslParameters() { + return wrappedClient.sslParameters(); + } + + @Override + public Optional<Authenticator> authenticator() { + return wrappedClient.authenticator(); + } + + @Override + public Version version() { + return wrappedClient.version(); + } + + @Override + public Optional<Executor> executor() { + return wrappedClient.executor(); + } + + @Override + public <T> HttpResponse<T> send(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) throws IOException, InterruptedException { + return wrappedClient.send(wrapRequest(request), responseBodyHandler); + } + + @Override + public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) { + return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler); + } + + @Override + public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler, HttpResponse.PushPromiseHandler<T> pushPromiseHandler) { + return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler, pushPromiseHandler); + } + + @Override + public String toString() { + return wrappedClient.toString(); + } + + private HttpRequest wrapRequest(HttpRequest request) { + return new TlsAwareHttpRequest(request, userAgent); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java new file mode 100644 index 00000000000..7eca2463ba7 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java @@ -0,0 +1,97 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import com.yahoo.security.tls.TlsContext; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.net.Authenticator; +import java.net.CookieHandler; +import java.net.ProxySelector; +import java.net.http.HttpClient; +import java.time.Duration; +import java.util.concurrent.Executor; + +/** + * A client builder for {@link HttpClient} which uses {@link TlsContext} for TLS configuration. + * Intended for internal Vespa communication only. + * + * @author bjorncs + */ +public class TlsAwareHttpClientBuilder implements HttpClient.Builder { + + private final HttpClient.Builder wrappedBuilder; + private final String userAgent; + + public TlsAwareHttpClientBuilder(TlsContext tlsContext) { + this(tlsContext, "vespa-tls-aware-client"); + } + + public TlsAwareHttpClientBuilder(TlsContext tlsContext, String userAgent) { + this.wrappedBuilder = HttpClient.newBuilder() + .sslContext(tlsContext.context()) + .sslParameters(tlsContext.parameters()); + this.userAgent = userAgent; + } + + @Override + public HttpClient.Builder cookieHandler(CookieHandler cookieHandler) { + throw new UnsupportedOperationException(); + } + + @Override + public HttpClient.Builder connectTimeout(Duration duration) { + wrappedBuilder.connectTimeout(duration); + return this; + } + + @Override + public HttpClient.Builder sslContext(SSLContext sslContext) { + throw new UnsupportedOperationException("SSLContext is given from tls context"); + } + + @Override + public HttpClient.Builder sslParameters(SSLParameters sslParameters) { + throw new UnsupportedOperationException("SSLParameters is given from tls context"); + } + + @Override + public HttpClient.Builder executor(Executor executor) { + wrappedBuilder.executor(executor); + return this; + } + + @Override + public HttpClient.Builder followRedirects(HttpClient.Redirect policy) { + wrappedBuilder.followRedirects(policy); + return this; + } + + @Override + public HttpClient.Builder version(HttpClient.Version version) { + wrappedBuilder.version(version); + return this; + } + + @Override + public HttpClient.Builder priority(int priority) { + wrappedBuilder.priority(priority); + return this; + } + + @Override + public HttpClient.Builder proxy(ProxySelector proxySelector) { + throw new UnsupportedOperationException(); + } + + @Override + public HttpClient.Builder authenticator(Authenticator authenticator) { + throw new UnsupportedOperationException(); + } + + @Override + public HttpClient build() { + // TODO Stop wrapping the client once TLS is mandatory + return new TlsAwareHttpClient(wrappedBuilder.build(), userAgent); + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java new file mode 100644 index 00000000000..bbdd8af791f --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java @@ -0,0 +1,103 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.security.tls.https; + +import com.yahoo.security.tls.MixedMode; +import com.yahoo.security.tls.TransportSecurityUtils; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpHeaders; +import java.net.http.HttpRequest; +import java.time.Duration; +import java.util.HashMap; +import java.util.List; +import java.util.Optional; + +/** + * A {@link HttpRequest} where the scheme is either http or https based on the global Vespa TLS configuration. + * + * @author bjorncs + */ +class TlsAwareHttpRequest extends HttpRequest { + + private final URI rewrittenUri; + private final HttpRequest wrappedRequest; + private final HttpHeaders rewrittenHeaders; + + TlsAwareHttpRequest(HttpRequest wrappedRequest, String userAgent) { + this.wrappedRequest = wrappedRequest; + this.rewrittenUri = rewriteUri(wrappedRequest.uri()); + this.rewrittenHeaders = rewriteHeaders(wrappedRequest, userAgent); + } + + @Override + public Optional<BodyPublisher> bodyPublisher() { + return wrappedRequest.bodyPublisher(); + } + + @Override + public String method() { + return wrappedRequest.method(); + } + + @Override + public Optional<Duration> timeout() { + return wrappedRequest.timeout(); + } + + @Override + public boolean expectContinue() { + return wrappedRequest.expectContinue(); + } + + @Override + public URI uri() { + return rewrittenUri; + } + + @Override + public Optional<HttpClient.Version> version() { + return wrappedRequest.version(); + } + + @Override + public HttpHeaders headers() { + return rewrittenHeaders; + } + + private static URI rewriteUri(URI uri) { + if (!uri.getScheme().equals("http")) { + return uri; + } + String rewrittenScheme = + TransportSecurityUtils.getConfigFile().isPresent() && TransportSecurityUtils.getInsecureMixedMode() != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER ? + "https" : + "http"; + int port = uri.getPort(); + int rewrittenPort = port != -1 ? port : (rewrittenScheme.equals("http") ? 80 : 443); + try { + return new URI(rewrittenScheme, uri.getUserInfo(), uri.getHost(), rewrittenPort, uri.getPath(), uri.getQuery(), uri.getFragment()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private static HttpHeaders rewriteHeaders(HttpRequest request, String userAgent) { + HttpHeaders headers = request.headers(); + if (headers.firstValue("User-Agent").isPresent()) { + return headers; + } + HashMap<String, List<String>> rewrittenHeaders = new HashMap<>(headers.map()); + rewrittenHeaders.put("User-Agent", List.of(userAgent)); + return HttpHeaders.of(rewrittenHeaders, (ignored1, ignored2) -> true); + } + + @Override + public String toString() { + return "TlsAwareHttpRequest{" + + "rewrittenUri=" + rewrittenUri + + ", wrappedRequest=" + wrappedRequest + + '}'; + } +} diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java new file mode 100644 index 00000000000..43067705fa3 --- /dev/null +++ b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.security.tls.https; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/storage/src/tests/bucketdb/bucketmanagertest.cpp b/storage/src/tests/bucketdb/bucketmanagertest.cpp index 54b3bf4b8d0..09fe310e97e 100644 --- a/storage/src/tests/bucketdb/bucketmanagertest.cpp +++ b/storage/src/tests/bucketdb/bucketmanagertest.cpp @@ -8,6 +8,7 @@ #include <vespa/document/update/documentupdate.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/storage/bucketdb/bucketmanager.h> +#include <vespa/storage/common/global_bucket_space_distribution_converter.h> #include <vespa/storage/persistence/filestorage/filestormanager.h> #include <vespa/storageapi/message/persistence.h> #include <vespa/storageapi/message/state.h> @@ -84,6 +85,7 @@ public: CPPUNIT_TEST(testConflictSetOnlyClearedAfterAllBucketRequestsDone); CPPUNIT_TEST(testRejectRequestWithMismatchingDistributionHash); CPPUNIT_TEST(testDbNotIteratedWhenAllRequestsRejected); + CPPUNIT_TEST(fall_back_to_legacy_global_distribution_hash_on_mismatch); // FIXME(vekterli): test is not deterministic and enjoys failing // sporadically when running under Valgrind. See bug 5932891. @@ -154,6 +156,7 @@ public: void testConflictSetOnlyClearedAfterAllBucketRequestsDone(); void testRejectRequestWithMismatchingDistributionHash(); void testDbNotIteratedWhenAllRequestsRejected(); + void fall_back_to_legacy_global_distribution_hash_on_mismatch(); public: static constexpr uint32_t DIR_SPREAD = 3; @@ -785,6 +788,10 @@ public: return std::make_shared<api::RequestBucketInfoCommand>(makeBucketSpace(), 0, _state, hash); } + auto createFullFetchCommandWithHash(document::BucketSpace space, vespalib::stringref hash) const { + return std::make_shared<api::RequestBucketInfoCommand>(space, 0, _state, hash); + } + auto acquireBucketLockAndSendInfoRequest(const document::BucketId& bucket) { auto guard = acquireBucketLock(bucket); // Send down processing command which will block. @@ -850,6 +857,45 @@ public: _self._top->getRepliesOnce(); } + // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 + std::unique_ptr<lib::Distribution> default_grouped_distribution() { + return std::make_unique<lib::Distribution>( + GlobalBucketSpaceDistributionConverter::string_to_config(vespalib::string( +R"(redundancy 2 +group[3] +group[0].name "invalid" +group[0].index "invalid" +group[0].partitions 1|* +group[0].nodes[0] +group[1].name rack0 +group[1].index 0 +group[1].nodes[3] +group[1].nodes[0].index 0 +group[1].nodes[1].index 1 +group[1].nodes[2].index 2 +group[2].name rack1 +group[2].index 1 +group[2].nodes[3] +group[2].nodes[0].index 3 +group[2].nodes[1].index 4 +group[2].nodes[2].index 5 +)"))); + } + + std::shared_ptr<lib::Distribution> derived_global_grouped_distribution(bool use_legacy) { + auto default_distr = default_grouped_distribution(); + return GlobalBucketSpaceDistributionConverter::convert_to_global(*default_distr, use_legacy); + } + + void set_grouped_distribution_configs() { + auto default_distr = default_grouped_distribution(); + _self._node->getComponentRegister().getBucketSpaceRepo() + .get(document::FixedBucketSpaces::default_space()).setDistribution(std::move(default_distr)); + auto global_distr = derived_global_grouped_distribution(false); + _self._node->getComponentRegister().getBucketSpaceRepo() + .get(document::FixedBucketSpaces::global_space()).setDistribution(std::move(global_distr)); + } + private: BucketManagerTest& _self; lib::ClusterState _state; @@ -1358,4 +1404,19 @@ BucketManagerTest::testDbNotIteratedWhenAllRequestsRejected() auto replies = fixture.awaitAndGetReplies(1); } +// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 +void BucketManagerTest::fall_back_to_legacy_global_distribution_hash_on_mismatch() { + ConcurrentOperationFixture f(*this); + + f.set_grouped_distribution_configs(); + + auto legacy_hash = f.derived_global_grouped_distribution(true)->getNodeGraph().getDistributionConfigHash(); + + auto infoCmd = f.createFullFetchCommandWithHash(document::FixedBucketSpaces::global_space(), legacy_hash); + _top->sendDown(infoCmd); + auto replies = f.awaitAndGetReplies(1); + auto& reply = dynamic_cast<api::RequestBucketInfoReply&>(*replies[0]); + CPPUNIT_ASSERT_EQUAL(api::ReturnCode::OK, reply.getResult().getResult()); // _not_ REJECTED +} + } // storage diff --git a/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp b/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp index 5afea9cd3cd..d75f2ac6459 100644 --- a/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp +++ b/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp @@ -17,6 +17,7 @@ struct GlobalBucketSpaceDistributionConverterTest : public CppUnit::TestFixture CPPUNIT_TEST(config_retired_state_is_propagated); CPPUNIT_TEST(group_capacities_are_propagated); CPPUNIT_TEST(global_distribution_has_same_owner_distributors_as_default); + CPPUNIT_TEST(can_generate_config_with_legacy_partition_spec); CPPUNIT_TEST_SUITE_END(); void can_transform_flat_cluster_config(); @@ -27,6 +28,7 @@ struct GlobalBucketSpaceDistributionConverterTest : public CppUnit::TestFixture void config_retired_state_is_propagated(); void group_capacities_are_propagated(); void global_distribution_has_same_owner_distributors_as_default(); + void can_generate_config_with_legacy_partition_spec(); }; CPPUNIT_TEST_SUITE_REGISTRATION(GlobalBucketSpaceDistributionConverterTest); @@ -35,9 +37,9 @@ using DistributionConfig = vespa::config::content::StorDistributionConfig; namespace { -vespalib::string default_to_global_config(const vespalib::string& default_config) { +vespalib::string default_to_global_config(const vespalib::string& default_config, bool legacy_mode = false) { auto default_cfg = GlobalBucketSpaceDistributionConverter::string_to_config(default_config); - auto as_global = GlobalBucketSpaceDistributionConverter::convert_to_global(*default_cfg); + auto as_global = GlobalBucketSpaceDistributionConverter::convert_to_global(*default_cfg, legacy_mode); return GlobalBucketSpaceDistributionConverter::config_to_string(*as_global); } @@ -377,4 +379,64 @@ group[2].nodes[1].index 2 } } +// By "legacy" read "broken", but we need to be able to generate it to support rolling upgrades properly. +// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 +void GlobalBucketSpaceDistributionConverterTest::can_generate_config_with_legacy_partition_spec() { + vespalib::string default_config( +R"(redundancy 2 +group[3] +group[0].name "invalid" +group[0].index "invalid" +group[0].partitions 1|* +group[0].nodes[0] +group[1].name rack0 +group[1].index 0 +group[1].nodes[3] +group[1].nodes[0].index 0 +group[1].nodes[1].index 1 +group[1].nodes[2].index 2 +group[2].name rack1 +group[2].index 1 +group[2].nodes[3] +group[2].nodes[0].index 3 +group[2].nodes[1].index 4 +group[2].nodes[2].index 5 +)"); + + vespalib::string expected_global_config( +R"(redundancy 6 +initial_redundancy 0 +ensure_primary_persisted true +ready_copies 6 +active_per_leaf_group true +distributor_auto_ownership_transfer_on_whole_group_down true +group[0].index "invalid" +group[0].name "invalid" +group[0].capacity 1 +group[0].partitions "3|3|*" +group[1].index "0" +group[1].name "rack0" +group[1].capacity 1 +group[1].partitions "" +group[1].nodes[0].index 0 +group[1].nodes[0].retired false +group[1].nodes[1].index 1 +group[1].nodes[1].retired false +group[1].nodes[2].index 2 +group[1].nodes[2].retired false +group[2].index "1" +group[2].name "rack1" +group[2].capacity 1 +group[2].partitions "" +group[2].nodes[0].index 3 +group[2].nodes[0].retired false +group[2].nodes[1].index 4 +group[2].nodes[1].retired false +group[2].nodes[2].index 5 +group[2].nodes[2].retired false +disk_distribution MODULO_BID +)"); + CPPUNIT_ASSERT_EQUAL(expected_global_config, default_to_global_config(default_config, true)); +} + }
\ No newline at end of file diff --git a/storage/src/tests/distributor/bucketdbupdatertest.cpp b/storage/src/tests/distributor/bucketdbupdatertest.cpp index 53f80854bef..b2d554c1e42 100644 --- a/storage/src/tests/distributor/bucketdbupdatertest.cpp +++ b/storage/src/tests/distributor/bucketdbupdatertest.cpp @@ -111,6 +111,7 @@ class BucketDBUpdaterTest : public CppUnit::TestFixture, CPPUNIT_TEST(identity_update_of_diverging_untrusted_replicas_does_not_mark_any_as_trusted); CPPUNIT_TEST(adding_diverging_replica_to_existing_trusted_does_not_remove_trusted); CPPUNIT_TEST(batch_update_from_distributor_change_does_not_mark_diverging_replicas_as_trusted); + CPPUNIT_TEST(global_distribution_hash_falls_back_to_legacy_format_upon_request_rejection); CPPUNIT_TEST_SUITE_END(); public: @@ -175,6 +176,7 @@ protected: void identity_update_of_diverging_untrusted_replicas_does_not_mark_any_as_trusted(); void adding_diverging_replica_to_existing_trusted_does_not_remove_trusted(); void batch_update_from_distributor_change_does_not_mark_diverging_replicas_as_trusted(); + void global_distribution_hash_falls_back_to_legacy_format_upon_request_rejection(); auto &defaultDistributorBucketSpace() { return getBucketSpaceRepo().get(makeBucketSpace()); } @@ -505,7 +507,7 @@ public: std::make_shared<lib::Distribution>(distConfig)); } - std::string getDistConfig6Nodes3Groups() const { + std::string getDistConfig6Nodes2Groups() const { return ("redundancy 2\n" "group[3]\n" "group[0].name \"invalid\"\n" @@ -692,7 +694,7 @@ BucketDBUpdaterTest::testDistributorChange() void BucketDBUpdaterTest::testDistributorChangeWithGrouping() { - std::string distConfig(getDistConfig6Nodes3Groups()); + std::string distConfig(getDistConfig6Nodes2Groups()); setDistribution(distConfig); int numBuckets = 100; @@ -2073,7 +2075,7 @@ BucketDBUpdaterTest::testClusterStateAlwaysSendsFullFetchWhenDistributionChangeP setAndEnableClusterState(stateBefore, expectedMsgs, dummyBucketsToReturn); } _sender.clear(); - std::string distConfig(getDistConfig6Nodes3Groups()); + std::string distConfig(getDistConfig6Nodes2Groups()); setDistribution(distConfig); sortSentMessagesByIndex(_sender); CPPUNIT_ASSERT_EQUAL(messageCount(6), _sender.commands.size()); @@ -2549,4 +2551,52 @@ void BucketDBUpdaterTest::batch_update_from_distributor_change_does_not_mark_div "0:5/1/2/3|1:5/7/8/9", true)); } +// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 +void BucketDBUpdaterTest::global_distribution_hash_falls_back_to_legacy_format_upon_request_rejection() { + std::string distConfig(getDistConfig6Nodes2Groups()); + setDistribution(distConfig); + + const vespalib::string current_hash = "(0d*|*(0;0;1;2)(1;3;4;5))"; + const vespalib::string legacy_hash = "(0d3|3|*(0;0;1;2)(1;3;4;5))"; + + setSystemState(lib::ClusterState("distributor:6 storage:6")); + CPPUNIT_ASSERT_EQUAL(messageCount(6), _sender.commands.size()); + + api::RequestBucketInfoCommand* global_req = nullptr; + for (auto& cmd : _sender.commands) { + auto& req_cmd = dynamic_cast<api::RequestBucketInfoCommand&>(*cmd); + if (req_cmd.getBucketSpace() == document::FixedBucketSpaces::global_space()) { + global_req = &req_cmd; + break; + } + } + CPPUNIT_ASSERT(global_req != nullptr); + CPPUNIT_ASSERT_EQUAL(current_hash, global_req->getDistributionHash()); + + auto reply = std::make_shared<api::RequestBucketInfoReply>(*global_req); + reply->setResult(api::ReturnCode::REJECTED); + getBucketDBUpdater().onRequestBucketInfoReply(reply); + + getClock().addSecondsToTime(10); + getBucketDBUpdater().resendDelayedMessages(); + + // Should now be a resent request with legacy distribution hash + CPPUNIT_ASSERT_EQUAL(messageCount(6) + 1, _sender.commands.size()); + auto& legacy_req = dynamic_cast<api::RequestBucketInfoCommand&>(*_sender.commands.back()); + CPPUNIT_ASSERT_EQUAL(legacy_hash, legacy_req.getDistributionHash()); + + // Now if we reject it _again_ we should cycle back to the current hash + // in case it wasn't a hash-based rejection after all. And the circle of life continues. + reply = std::make_shared<api::RequestBucketInfoReply>(legacy_req); + reply->setResult(api::ReturnCode::REJECTED); + getBucketDBUpdater().onRequestBucketInfoReply(reply); + + getClock().addSecondsToTime(10); + getBucketDBUpdater().resendDelayedMessages(); + + CPPUNIT_ASSERT_EQUAL(messageCount(6) + 2, _sender.commands.size()); + auto& new_current_req = dynamic_cast<api::RequestBucketInfoCommand&>(*_sender.commands.back()); + CPPUNIT_ASSERT_EQUAL(current_hash, new_current_req.getDistributionHash()); +} + } diff --git a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp index 41de215d877..a1c1742edb5 100644 --- a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp +++ b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp @@ -6,6 +6,7 @@ #include <iomanip> #include <vespa/storage/common/content_bucket_space_repo.h> #include <vespa/storage/common/nodestateupdater.h> +#include <vespa/storage/common/global_bucket_space_distribution_converter.h> #include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/storageutil/distributorstatecache.h> #include <vespa/storageframework/generic/status/htmlstatusreporter.h> @@ -577,7 +578,21 @@ BucketManager::processRequestBucketInfoCommands(document::BucketSpace bucketSpac << " differs from this state."; } else if (!their_hash.empty() && their_hash != our_hash) { // Empty hash indicates request from 4.2 protocol or earlier - error << "Distribution config has changed since request."; + // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 + bool matches_legacy_hash = false; + if (bucketSpace == document::FixedBucketSpaces::global_space()) { + const auto default_distr =_component.getBucketSpaceRepo() + .get(document::FixedBucketSpaces::default_space()).getDistribution(); + // Convert in legacy distribution mode, which will accept old 'hash' structure. + const auto legacy_global_distr = GlobalBucketSpaceDistributionConverter::convert_to_global( + *default_distr, true/*use legacy mode*/); + const auto legacy_hash = legacy_global_distr->getNodeGraph().getDistributionConfigHash(); + LOG(debug, "Falling back to comparing against legacy distribution hash: %s", legacy_hash.c_str()); + matches_legacy_hash = (their_hash == legacy_hash); + } + if (!matches_legacy_hash) { + error << "Distribution config has changed since request."; + } } if (error.str().empty()) { std::pair<std::set<uint16_t>::iterator, bool> result( diff --git a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp index 534644458bc..cbcaeef8fdf 100644 --- a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp +++ b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp @@ -59,6 +59,21 @@ vespalib::string sub_groups_to_partition_spec(const Group& parent) { return spec.str(); } +// Allow generating legacy (broken) partition specs that may be used transiently +// during rolling upgrades from a pre-fix version to a post-fix version. +// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 +vespalib::string sub_groups_to_legacy_partition_spec(const Group& parent) { + vespalib::asciistream partitions; + // In case of a flat cluster config, this ends up with a partition spec of '*', + // which is fine. It basically means "put all replicas in this group", which + // happens to be exactly what we want. + for (auto& child : parent.sub_groups) { + partitions << child.second->nested_leaf_count << '|'; + } + partitions << '*'; + return partitions.str(); +} + bool is_leaf_group(const DistributionConfigBuilder::Group& g) noexcept { return !g.nodes.empty(); } @@ -87,19 +102,31 @@ void insert_new_group_into_tree( void build_transformed_root_group(DistributionConfigBuilder& builder, const DistributionConfigBuilder::Group& config_source_root, - const Group& parsed_root) { + const Group& parsed_root, + bool legacy_mode) { DistributionConfigBuilder::Group new_root(config_source_root); - new_root.partitions = sub_groups_to_partition_spec(parsed_root); + if (!legacy_mode) { + new_root.partitions = sub_groups_to_partition_spec(parsed_root); + } else { + // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 + new_root.partitions = sub_groups_to_legacy_partition_spec(parsed_root); + } builder.group.emplace_back(std::move(new_root)); } void build_transformed_non_root_group(DistributionConfigBuilder& builder, const DistributionConfigBuilder::Group& config_source_group, - const Group& parsed_root) { + const Group& parsed_root, + bool legacy_mode) { DistributionConfigBuilder::Group new_group(config_source_group); if (!is_leaf_group(config_source_group)) { // Partition specs only apply to inner nodes const auto& g = find_non_root_group_by_index(config_source_group.index, parsed_root); - new_group.partitions = sub_groups_to_partition_spec(g); + if (!legacy_mode) { + new_group.partitions = sub_groups_to_partition_spec(g); + } else { + // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 + new_group.partitions = sub_groups_to_legacy_partition_spec(g); + } } builder.group.emplace_back(std::move(new_group)); } @@ -135,16 +162,16 @@ std::unique_ptr<Group> create_group_tree_from_config(const DistributionConfig& s * transitively, its parents again etc) have already been processed. This directly * implies that the root group is always the first group present in the config. */ -void build_global_groups(DistributionConfigBuilder& builder, const DistributionConfig& source) { +void build_global_groups(DistributionConfigBuilder& builder, const DistributionConfig& source, bool legacy_mode) { assert(!source.group.empty()); // TODO gracefully handle empty config? auto root = create_group_tree_from_config(source); auto g_iter = source.group.begin(); const auto g_end = source.group.end(); - build_transformed_root_group(builder, *g_iter, *root); + build_transformed_root_group(builder, *g_iter, *root, legacy_mode); ++g_iter; for (; g_iter != g_end; ++g_iter) { - build_transformed_non_root_group(builder, *g_iter, *root); + build_transformed_non_root_group(builder, *g_iter, *root, legacy_mode); } builder.redundancy = root->nested_leaf_count; @@ -154,17 +181,17 @@ void build_global_groups(DistributionConfigBuilder& builder, const DistributionC } // anon ns std::shared_ptr<DistributionConfig> -GlobalBucketSpaceDistributionConverter::convert_to_global(const DistributionConfig& source) { +GlobalBucketSpaceDistributionConverter::convert_to_global(const DistributionConfig& source, bool legacy_mode) { DistributionConfigBuilder builder; set_distribution_invariant_config_fields(builder, source); - build_global_groups(builder, source); + build_global_groups(builder, source, legacy_mode); return std::make_shared<DistributionConfig>(builder); } std::shared_ptr<lib::Distribution> -GlobalBucketSpaceDistributionConverter::convert_to_global(const lib::Distribution& distr) { +GlobalBucketSpaceDistributionConverter::convert_to_global(const lib::Distribution& distr, bool legacy_mode) { const auto src_config = distr.serialize(); - auto global_config = convert_to_global(*string_to_config(src_config)); + auto global_config = convert_to_global(*string_to_config(src_config), legacy_mode); return std::make_shared<lib::Distribution>(*global_config); } diff --git a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h index d135f56a5c1..b2be65dad42 100644 --- a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h +++ b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h @@ -10,8 +10,9 @@ namespace storage { struct GlobalBucketSpaceDistributionConverter { using DistributionConfig = vespa::config::content::StorDistributionConfig; - static std::shared_ptr<DistributionConfig> convert_to_global(const DistributionConfig&); - static std::shared_ptr<lib::Distribution> convert_to_global(const lib::Distribution&); + // TODO remove legacy_mode flags on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 + static std::shared_ptr<DistributionConfig> convert_to_global(const DistributionConfig&, bool legacy_mode = false); + static std::shared_ptr<lib::Distribution> convert_to_global(const lib::Distribution&, bool legacy_mode = false); // Helper functions which may be of use outside this class static std::unique_ptr<DistributionConfig> string_to_config(const vespalib::string&); diff --git a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp index 2071558628e..c295be19a0b 100644 --- a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp +++ b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp @@ -35,7 +35,8 @@ PendingBucketSpaceDbTransition::PendingBucketSpaceDbTransition(const PendingClus _pendingClusterState(pendingClusterState), _distributorBucketSpace(distributorBucketSpace), _distributorIndex(_clusterInfo->getDistributorIndex()), - _bucketOwnershipTransfer(distributionChanged) + _bucketOwnershipTransfer(distributionChanged), + _rejectedRequests() { if (distributorChanged()) { _bucketOwnershipTransfer = true; diff --git a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h index 903f9b762fb..7eb2974eb52 100644 --- a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h +++ b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h @@ -4,6 +4,7 @@ #include "pending_bucket_space_db_transition_entry.h" #include "outdated_nodes.h" #include <vespa/storage/bucketdb/bucketdatabase.h> +#include <unordered_map> namespace storage::api { class RequestBucketInfoReply; } namespace storage::lib { class ClusterState; class State; } @@ -48,6 +49,7 @@ private: DistributorBucketSpace &_distributorBucketSpace; uint16_t _distributorIndex; bool _bucketOwnershipTransfer; + std::unordered_map<uint16_t, size_t> _rejectedRequests; // BucketDataBase::MutableEntryProcessor API bool process(BucketDatabase::Entry& e) override; @@ -111,6 +113,14 @@ public: // Methods used by unit tests. const EntryList& results() const { return _entries; } void addNodeInfo(const document::BucketId& id, const BucketCopy& copy); + + void incrementRequestRejections(uint16_t node) { + _rejectedRequests[node]++; + } + size_t rejectedRequests(uint16_t node) const { + auto iter = _rejectedRequests.find(node); + return ((iter != _rejectedRequests.end()) ? iter->second : 0); + } }; } diff --git a/storage/src/vespa/storage/distributor/pendingclusterstate.cpp b/storage/src/vespa/storage/distributor/pendingclusterstate.cpp index 1996ae9d2af..5f74a82c28a 100644 --- a/storage/src/vespa/storage/distributor/pendingclusterstate.cpp +++ b/storage/src/vespa/storage/distributor/pendingclusterstate.cpp @@ -7,6 +7,7 @@ #include "distributor_bucket_space.h" #include <vespa/storageframework/defaultimplementation/clock/realclock.h> #include <vespa/storage/common/bucketoperationlogger.h> +#include <vespa/storage/common/global_bucket_space_distribution_converter.h> #include <vespa/document/bucket/fixed_bucket_spaces.h> #include <vespa/vespalib/util/xmlstream.hpp> #include <climits> @@ -185,7 +186,30 @@ PendingClusterState::requestNode(BucketSpaceAndNode bucketSpaceAndNode) { const auto &distributorBucketSpace(_bucketSpaceRepo.get(bucketSpaceAndNode.bucketSpace)); const auto &distribution(distributorBucketSpace.getDistribution()); - vespalib::string distributionHash(distribution.getNodeGraph().getDistributionConfigHash()); + vespalib::string distributionHash; + // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475 + bool sendLegacyHash = false; + if (bucketSpaceAndNode.bucketSpace == document::FixedBucketSpaces::global_space()) { + auto transitionIter = _pendingTransitions.find(bucketSpaceAndNode.bucketSpace); + assert(transitionIter != _pendingTransitions.end()); + // First request cannot have been rejected yet and will thus be sent with non-legacy hash. + // Subsequent requests will be sent 50-50. This is because a request may be rejected due to + // other reasons than just the hash mismatching, so if we don't cycle back to the non-legacy + // hash we risk never converging. + sendLegacyHash = ((transitionIter->second->rejectedRequests(bucketSpaceAndNode.node) % 2) == 1); + } + if (!sendLegacyHash) { + distributionHash = distribution.getNodeGraph().getDistributionConfigHash(); + } else { + const auto& defaultSpace = _bucketSpaceRepo.get(document::FixedBucketSpaces::default_space()); + // Generate legacy distribution hash explicitly. + auto legacyGlobalDistr = GlobalBucketSpaceDistributionConverter::convert_to_global( + defaultSpace.getDistribution(), true/*use legacy mode*/); + distributionHash = legacyGlobalDistr->getNodeGraph().getDistributionConfigHash(); + LOG(debug, "Falling back to sending legacy hash to node %u: %s", + bucketSpaceAndNode.node, distributionHash.c_str()); + } + LOG(debug, "Requesting bucket info for bucket space %" PRIu64 " node %d with cluster state '%s' " "and distribution hash '%s'", @@ -238,6 +262,11 @@ PendingClusterState::onRequestBucketInfoReply(const std::shared_ptr<api::Request resendTime += framework::MilliSecTime(100); _delayedRequests.emplace_back(resendTime, bucketSpaceAndNode); _sentMessages.erase(iter); + if (result.getResult() == api::ReturnCode::REJECTED) { + auto transitionIter = _pendingTransitions.find(bucketSpaceAndNode.bucketSpace); + assert(transitionIter != _pendingTransitions.end()); + transitionIter->second->incrementRequestRejections(bucketSpaceAndNode.node); + } return true; } diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml index f30aed1af5f..0f23eaed964 100644 --- a/vespa-athenz/pom.xml +++ b/vespa-athenz/pom.xml @@ -117,7 +117,21 @@ <dependency> <groupId>com.amazonaws</groupId> <artifactId>aws-java-sdk-core</artifactId> - </dependency> + <exclusions> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + </exclusion> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + </exclusion> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-annotations</artifactId> + </exclusion> + </exclusions> + </dependency> </dependencies> <build> diff --git a/vespa-hadoop/abi-spec.json b/vespa-hadoop/abi-spec.json index 5bbac15f0e5..e3f4dcf272a 100644 --- a/vespa-hadoop/abi-spec.json +++ b/vespa-hadoop/abi-spec.json @@ -1201,6 +1201,8 @@ "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1245,6 +1247,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)" @@ -1330,6 +1334,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1432,6 +1438,8 @@ "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 932513f8a57..c3fe8c5c7ad 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -808,6 +808,8 @@ "public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -852,6 +854,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)" @@ -937,6 +941,8 @@ "public java.util.Iterator valueIterator()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1039,6 +1045,8 @@ "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index fb55b2d5014..38d832d01c2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -13,6 +13,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; +import java.util.function.DoubleBinaryOperator; /** * An indexed (dense) tensor backed by a double array. @@ -190,6 +191,16 @@ public class IndexedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { + throw new IllegalArgumentException("Merge is not supported for indexed tensors"); + } + + @Override + public Tensor remove(Set<TensorAddress> addresses) { + throw new IllegalArgumentException("Remove is not supported for indexed tensors"); + } + + @Override public int hashCode() { return Arrays.hashCode(values); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index ec3020a1a4e..22ceed22d3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -5,6 +5,8 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; +import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -51,6 +53,38 @@ public class MappedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + + // currently, underlying implementation disallows multiple entries with the same key + + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + if ( ! cells.containsKey(addCell.getKey())) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + } + return builder.build(); + } + + @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + TensorAddress address = cell.getKey(); + if ( ! addresses.contains(address)) { + builder.cell(address, cell.getValue()); + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 17e33c58a13..08878edeb83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -9,6 +9,8 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; /** @@ -70,13 +72,17 @@ public class MixedTensor implements Tensor { return cells.iterator(); } + private Iterable<Cell> cellIterable() { + return this::cellIterator; + } + /** * Returns an iterator over the values of this tensor. * The iteration order is the same as for cellIterator. */ @Override public Iterator<Double> valueIterator() { - return new Iterator<Double>() { + return new Iterator<>() { Iterator<Cell> cellIterator = cellIterator(); @Override public boolean hasNext() { @@ -108,6 +114,38 @@ public class MixedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Cell cell : cellIterable()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + return builder.build(); + } + + @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + + // iterate through all sparse addresses referencing a dense subspace + for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) { + TensorAddress sparsePartialAddress = entry.getKey(); + if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part + long offset = entry.getValue(); + for (int i = 0; i < index.denseSubspaceSize; ++i) { + Cell cell = cells.get((int)offset + i); + builder.cell(cell.getKey(), cell.getValue()); + } + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8002990e5c6..eb16801c306 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; @@ -113,6 +114,29 @@ public interface Tensor { return builder.build(); } + /** + * Returns a new tensor where existing cells in this tensor have been + * modified according to the given operation and cells in the given map. + * In contrast to {@link #modify}, previously non-existing cells are added + * to this tensor. Only valid for sparse or mixed tensors. + * + * @param op how to update overlapping cells + * @param cells cells to merge with this tensor + * @return a new tensor where this tensor is merged with the other + */ + Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); + + /** + * Returns a new tensor where existing cells in this tensor have been + * removed according to the given set of addresses. Only valid for sparse + * or mixed tensors. For mixed tensors, addresses are assumed to only + * contain the sparse dimensions, as the entire dense subspace is removed. + * + * @param addresses list of addresses to remove + * @return a new tensor where cells have been removed + */ + Tensor remove(Set<TensorAddress> addresses); + // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 2c9eefbd130..02d16e6f3e4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -151,12 +151,106 @@ public class TensorTestCase { Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), Tensor.from("tensor(x[1],y[3])", "{}"), Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}")); + assertTensorModify((left, right) -> left * right, + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}")); + } + + @Test + public void testTensorMerge() { + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0 + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + } + + @Test + public void testTensorRemove() { + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}")); } private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) { assertEquals(expected, init.modify(op, update.cells())); } + private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) { + DoubleBinaryOperator op = (left, right) -> right; + assertEquals(expected, init.merge(op, update.cells())); + } + + private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) { + assertEquals(expected, init.remove(update.cells().keySet())); + } + + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), |