summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/abi-spec.json3
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java14
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java37
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java27
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java26
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java2
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java38
-rw-r--r--document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java63
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java92
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java8
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java17
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java3
-rw-r--r--document/src/tests/documentupdatetestcase.cpp11
-rw-r--r--document/src/vespa/document/update/tensor_remove_update.cpp30
-rw-r--r--document/src/vespa/document/update/valueupdate.h3
-rw-r--r--documentapi/CMakeLists.txt1
-rw-r--r--documentapi/src/vespa/binref/.gitignore3
-rw-r--r--documentapi/src/vespa/binref/CMakeLists.txt1
l---------jrt_test/src/binref/testrun.sh1
l---------lowercasing_test/src/binref/testrun.sh1
-rw-r--r--searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp67
-rw-r--r--searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp3
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java101
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java97
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java103
-rw-r--r--security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java8
-rw-r--r--storage/src/tests/bucketdb/bucketmanagertest.cpp61
-rw-r--r--storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp66
-rw-r--r--storage/src/tests/distributor/bucketdbupdatertest.cpp56
-rw-r--r--storage/src/vespa/storage/bucketdb/bucketmanager.cpp17
-rw-r--r--storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp49
-rw-r--r--storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h5
-rw-r--r--storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp3
-rw-r--r--storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h10
-rw-r--r--storage/src/vespa/storage/distributor/pendingclusterstate.cpp31
-rw-r--r--vespa-athenz/pom.xml16
-rw-r--r--vespa-hadoop/abi-spec.json8
-rw-r--r--vespajlib/abi-spec.json8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java34
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java40
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java24
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java94
44 files changed, 1128 insertions, 179 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index 61390af3523..d4db3026b27 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -5244,7 +5244,7 @@
],
"methods": [
"public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)",
- "public static com.yahoo.tensor.TensorType convertToCompatibleType(com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)",
"public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()",
"public com.yahoo.document.datatypes.TensorFieldValue getValue()",
"public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
@@ -5278,6 +5278,7 @@
"public boolean equals(java.lang.Object)",
"public int hashCode()",
"public java.lang.String toString()",
+ "public static com.yahoo.tensor.TensorType extractSparseDimensions(com.yahoo.tensor.TensorType)",
"public bridge synthetic void setValue(com.yahoo.document.datatypes.FieldValue)",
"public bridge synthetic com.yahoo.document.datatypes.FieldValue getValue()"
],
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
index ffbfe49347c..6310fa62d15 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorAddUpdateReader.java
@@ -23,22 +23,23 @@ public class TensorAddUpdateReader {
public static TensorAddUpdate createTensorAddUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType tensorType = tensorDataType.getTensorType();
TensorFieldValue tensorFieldValue = new TensorFieldValue(tensorType);
fillTensor(buffer, tensorFieldValue);
+
expectTensorIsNonEmpty(field, tensorFieldValue.getTensor().get());
return new TensorAddUpdate(tensorFieldValue);
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream()
- .anyMatch(dim -> dim.isIndexed())) {
- throw new IllegalArgumentException("An add update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("An add update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -48,5 +49,4 @@ public class TensorAddUpdateReader {
}
}
-
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index a9bbba519bd..66588debbca 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -29,10 +29,8 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_MULTIPLY = "multiply";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
-
expectFieldIsOfTypeTensor(field);
expectTensorTypeHasNoneIndexedUnboundDimensions(field);
- expectTensorTypeIsNotMixed(field);
expectObjectStart(buffer.currentToken());
ModifyUpdateResult result = createModifyUpdateResult(buffer, field);
@@ -58,16 +56,6 @@ public class TensorModifyUpdateReader {
}
}
- private static void expectTensorTypeIsNotMixed(Field field) {
- TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- long numMappedDimensions = tensorType.dimensions().stream().filter(dim -> dim.type().equals(TensorType.Dimension.Type.mapped)).count();
- long numIndexedDimensions = tensorType.dimensions().stream().filter(dim -> dim.isIndexed()).count();
- if (numMappedDimensions > 0 && numIndexedDimensions > 0) {
- throw new IllegalArgumentException("A modify update cannot be applied to tensor types with mixed dimensions. "
- + "Field '" + field.getName() + "' has mixed tensor type '" + tensorType + "'");
- }
- }
-
private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) {
if (operation == null) {
throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation");
@@ -121,7 +109,7 @@ public class TensorModifyUpdateReader {
private static TensorFieldValue createTensor(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
- TensorType convertedType = TensorModifyUpdate.convertToCompatibleType(originalType);
+ TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType);
Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType);
readTensorCells(buffer, tensorBuilder);
@@ -129,25 +117,26 @@ public class TensorModifyUpdateReader {
validateBounds(tensor, originalType);
- TensorFieldValue result = new TensorFieldValue(convertedType);
- result.assign(tensor);
- return result;
+ return new TensorFieldValue(tensor);
}
- /** Only validate if original type is indexed bound */
- private static void validateBounds(Tensor convertedTensor, TensorType originalType) {
- if ( ! originalType.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
+ /** Only validate if original type has indexed bound dimensions */
+ static void validateBounds(Tensor convertedTensor, TensorType originalType) {
+ if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
return;
}
for (Iterator<Tensor.Cell> iter = convertedTensor.cellIterator(); iter.hasNext(); ) {
Tensor.Cell cell = iter.next();
TensorAddress address = cell.getKey();
for (int i = 0; i < address.size(); ++i) {
- long label = address.numericLabel(i);
- long bound = originalType.dimensions().get(i).size().get(); // size is non-optional for indexed bound
- if (label >= bound) {
- throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
- "' has label '" + label + "' but type is " + originalType.toString());
+ TensorType.Dimension dim = originalType.dimensions().get(i);
+ if (dim instanceof TensorType.IndexedBoundDimension) {
+ long label = address.numericLabel(i);
+ long bound = dim.size().get(); // size is non-optional for indexed bound
+ if (label >= bound) {
+ throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
+ "' has label '" + label + "' but type is " + originalType.toString());
+ }
}
}
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
index 210a6a80ee5..3bb4b2e262f 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
@@ -24,23 +24,23 @@ public class TensorRemoveUpdateReader {
static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
- expectTensorTypeIsSparse(field);
+ expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
- TensorType tensorType = tensorDataType.getTensorType();
+ TensorType originalType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(originalType);
+ Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- Tensor tensor = readRemoveUpdateTensor(buffer, tensorType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
- private static void expectTensorTypeIsSparse(Field field) {
+ private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
- if (tensorType.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed)) {
- throw new IllegalArgumentException("A remove update can only be applied to sparse tensors. "
- + "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
+ if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
+ throw new IllegalArgumentException("A remove update can only be applied to tensors " +
+ "with at least one sparse dimension. Field '" + field.getName() +
+ "' has unsupported tensor type '" + tensorType + "'");
}
}
@@ -53,7 +53,7 @@ public class TensorRemoveUpdateReader {
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
- private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type) {
+ private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
Tensor.Builder builder = Tensor.Builder.of(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
@@ -62,7 +62,7 @@ public class TensorRemoveUpdateReader {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
- builder.cell(readTensorAddress(buffer, type), 1.0);
+ builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
}
expectCompositeEnd(buffer.currentToken());
}
@@ -71,12 +71,15 @@ public class TensorRemoveUpdateReader {
return builder.build();
}
- private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type) {
+ private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
TensorAddress.Builder builder = new TensorAddress.Builder(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String dimension = buffer.currentName();
+ if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
+ throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
+ }
String label = buffer.currentText();
builder.add(dimension, label);
}
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index 2f22def9aa1..a763db33e7a 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -5,6 +5,7 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.document.json.readers.TensorRemoveUpdateReader;
import com.yahoo.document.update.TensorAddUpdate;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.document.update.TensorRemoveUpdate;
@@ -35,7 +36,10 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(TensorModifyUpdate.convertToCompatibleType(tensorDataType.getTensorType()));
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
+
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorModifyUpdate(operation, tensor);
}
@@ -46,7 +50,8 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
throw new DeserializationException("Expected tensor data type, got " + type);
}
TensorDataType tensorDataType = (TensorDataType)type;
- TensorFieldValue tensor = new TensorFieldValue(tensorDataType.getTensorType());
+ TensorType tensorType = tensorDataType.getTensorType();
+ TensorFieldValue tensor = new TensorFieldValue(tensorType);
tensor.deserialize(this);
return new TensorAddUpdate(tensor);
}
@@ -58,10 +63,9 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
}
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
+ TensorType convertedType = TensorRemoveUpdate.extractSparseDimensions(tensorType);
- // TODO: for mixed case extract a new tensor type based only on mapped dimensions
-
- TensorFieldValue tensor = new TensorFieldValue(tensorType);
+ TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
return new TensorRemoveUpdate(tensor);
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
index cfc3ee0c742..f8d2464deb7 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -7,15 +7,11 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import java.util.Map;
import java.util.Objects;
/**
- * An update used to add cells to a sparse tensor (has only mapped dimensions).
- *
- * The cells to add are contained in a sparse tensor as well.
+ * An update used to add cells to a sparse or mixed tensor (has at least one mapped dimension).
*/
public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
@@ -50,22 +46,10 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> oldCells = oldTensor.cells();
- Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells();
-
- // currently, underlying implementation disallows multiple entries with the same key
-
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) {
- builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue()));
- }
- for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
- if ( ! oldCells.containsKey(addCell.getKey())) {
- builder.cell(addCell.getKey(), addCell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.merge((left, right) -> right, update.cells()); // note this might be slow for large mixed tensor updates
+ return new TensorFieldValue(result);
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 6111b51ca4e..2773f9d31da 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -37,7 +37,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
/**
* Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions).
*/
- public static TensorType convertToCompatibleType(TensorType type) {
+ public static TensorType convertDimensionsToMapped(TensorType type) {
TensorType.Builder builder = new TensorType.Builder();
type.dimensions().stream().forEach(dim -> builder.mapped(dim.name()));
return builder.build();
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index e9fb1e3efd5..335cda8e133 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -7,10 +7,8 @@ import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
-import java.util.Iterator;
-import java.util.Map;
import java.util.Objects;
/**
@@ -25,6 +23,18 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
public TensorRemoveUpdate(TensorFieldValue value) {
super(ValueUpdateClassID.TENSORREMOVE);
this.tensor = value;
+ verifyCompatibleType();
+ }
+
+ private void verifyCompatibleType() {
+ if ( ! tensor.getTensor().isPresent()) {
+ throw new IllegalArgumentException("Tensor must be present in remove update");
+ }
+ TensorType tensorType = tensor.getTensor().get().type();
+ TensorType expectedType = extractSparseDimensions(tensor.getDataType().getTensorType());
+ if ( ! tensorType.equals(expectedType)) {
+ throw new IllegalArgumentException("Unexpected type '" + tensorType + "' in remove update. Expected is '" + expectedType + "'");
+ }
}
@Override
@@ -51,17 +61,10 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
- Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
- Map<TensorAddress, Double> cellsToRemove = tensor.getTensor().get().cells();
- Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
- for (Iterator<Tensor.Cell> i = oldTensor.cellIterator(); i.hasNext(); ) {
- Tensor.Cell cell = i.next();
- TensorAddress address = cell.getKey();
- if ( ! cellsToRemove.containsKey(address)) {
- builder.cell(address, cell.getValue());
- }
- }
- return new TensorFieldValue(builder.build());
+ Tensor old = ((TensorFieldValue) oldValue).getTensor().get();
+ Tensor update = tensor.getTensor().get();
+ Tensor result = old.remove(update.cells().keySet());
+ return new TensorFieldValue(result);
}
@Override
@@ -93,4 +96,11 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
return super.toString() + " " + tensor;
}
+ public static TensorType extractSparseDimensions(TensorType type) {
+ TensorType.Builder builder = new TensorType.Builder();
+ type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
+ return builder.build();
+ }
+
+
}
diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
index e2736dabd2b..454ad72f344 100644
--- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
+++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java
@@ -40,6 +40,7 @@ public class DocumentUpdateJsonSerializerTest {
final static TensorType sparseTensorType = new TensorType.Builder().mapped("x").mapped("y").build();
final static TensorType denseTensorType = new TensorType.Builder().indexed("x", 2).indexed("y", 3).build();
+ final static TensorType mixedTensorType = new TensorType.Builder().mapped("x").indexed("y", 3).build();
final static DocumentTypeManager types = new DocumentTypeManager();
final static JsonFactory parserFactory = new JsonFactory();
final static DocumentType docType = new DocumentType("doctype");
@@ -60,6 +61,7 @@ public class DocumentUpdateJsonSerializerTest {
docType.addField(new Field("byte_field", DataType.BYTE));
docType.addField(new Field("sparse_tensor", new TensorDataType(sparseTensorType)));
docType.addField(new Field("dense_tensor", new TensorDataType(denseTensorType)));
+ docType.addField(new Field("mixed_tensor", new TensorDataType(mixedTensorType)));
docType.addField(new Field("reference_field", new ReferenceDataType(refTargetDocType, 777)));
docType.addField(new Field("predicate_field", DataType.PREDICATE));
docType.addField(new Field("raw_field", DataType.RAW));
@@ -336,6 +338,26 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_modify_update_on_mixed_tensor() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'modify': {",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'c', 'y': '1' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_add_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -355,6 +377,29 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_add_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'add': {",
+ " 'cells': [",
+ " { 'address': { 'x': '1', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': '1', 'y': '1' }, 'value': 0.0 },",
+ " { 'address': { 'x': '1', 'y': '2' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '0' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '1' }, 'value': 0.0 },",
+ " { 'address': { 'x': '0', 'y': '2' }, 'value': 3.0 }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void test_tensor_remove_update() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
@@ -374,6 +419,24 @@ public class DocumentUpdateJsonSerializerTest {
}
@Test
+ public void test_tensor_remove_update_mixed() {
+ roundtripSerializeJsonAndMatch(inputJson(
+ "{",
+ " 'update': 'DOCUMENT_ID',",
+ " 'fields': {",
+ " 'mixed_tensor': {",
+ " 'remove': {",
+ " 'addresses': [",
+ " {'x':'0' }",
+ " ]",
+ " }",
+ " }",
+ " }",
+ "}"
+ ));
+ }
+
+ @Test
public void reference_field_id_can_be_update_assigned_non_empty_id() {
roundtripSerializeJsonAndMatch(inputJson(
"{",
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index e58b26d500d..15d1e859f73 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1387,12 +1387,30 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_modify_update_on_mixed_tensor_throws() {
- exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A modify update cannot be applied to tensor types with mixed dimensions. Field 'mixed_tensor' has mixed tensor type 'tensor(x{},y[3])'");
- createTensorModifyUpdate(inputJson("{",
- " 'operation': 'replace',",
- " 'cells': [] }"), "mixed_tensor");
+ public void tensor_modify_update_with_replace_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.REPLACE, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_add_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.ADD, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_multiply_operation_mixed() {
+ assertTensorModifyUpdate("{{x:a,y:0}:2.0}", TensorModifyUpdate.Operation.MULTIPLY, "mixed_tensor",
+ inputJson("{",
+ " 'operation': 'multiply',",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 } ]}"));
}
@Test
@@ -1406,6 +1424,17 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_out_of_bound_cells_throws_mixed() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Dimension 'y' has label '3' but type is tensor(x{},y[3])");
+ createTensorModifyUpdate(inputJson("{",
+ " 'operation': 'replace',",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+
+ @Test
public void tensor_modify_update_with_unknown_operation_throws() {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Unknown operation 'unknown' in modify update for field 'sparse_tensor'");
@@ -1449,11 +1478,29 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_add_update_on_non_sparse_tensor_throws() {
+ public void tensor_add_update_on_mixed_tensor() {
+ assertTensorAddUpdate("{{x:a,y:0}:2.0, {x:a,y:1}:3.0, {x:a,y:2}:0.0}", "mixed_tensor",
+ inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': '0' }, 'value': 2.0 },",
+ " { 'address': { 'x': 'a', 'y': '1' }, 'value': 3.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_add_update_on_mixed_with_out_of_bound_dense_cells_throws() {
+ exception.expect(IndexOutOfBoundsException.class);
+ exception.expectMessage("Index 3 out of bounds for length 3");
+ createTensorAddUpdate(inputJson("{",
+ " 'cells': [",
+ " { 'address': { 'x': '0', 'y': '3' }, 'value': 2.0 } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_add_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("An add update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("An add update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorAddUpdate(inputJson("{",
- " 'cells': [] }"), "mixed_tensor");
+ " 'cells': [] }"), "dense_tensor");
}
@Test
@@ -1470,6 +1517,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Add update for field 'sparse_tensor' does not contain tensor cells");
createTensorAddUpdate(inputJson("{}"), "sparse_tensor");
+ createTensorAddUpdate(inputJson("{}"), "mixed_tensor");
}
@Test
@@ -1482,11 +1530,30 @@ public class JsonReaderTestCase {
}
@Test
- public void tensor_remove_update_on_non_sparse_tensor_throws() {
+ public void tensor_remove_update_on_mixed_tensor() {
+ assertTensorRemoveUpdate("{{x:1}:1.0,{x:2}:1.0}", "mixed_tensor",
+ inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1' },",
+ " { 'x': '2' } ]}"));
+ }
+
+ @Test
+ public void tensor_remove_update_on_mixed_tensor_with_dense_addresses_throws() {
+ exception.expect(IllegalArgumentException.class);
+ exception.expectMessage("Indexed dimension address 'y' should not be specified in remove update");
+ createTensorRemoveUpdate(inputJson("{",
+ " 'addresses': [",
+ " { 'x': '1', 'y': '0' },",
+ " { 'x': '2', 'y': '0' } ]}"), "mixed_tensor");
+ }
+
+ @Test
+ public void tensor_remove_update_on_dense_tensor_throws() {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("A remove update can only be applied to sparse tensors. Field 'mixed_tensor' has unsupported tensor type 'tensor(x{},y[3])'");
+ exception.expectMessage("A remove update can only be applied to tensors with at least one sparse dimension. Field 'dense_tensor' has unsupported tensor type 'tensor(x[2],y[3])'");
createTensorRemoveUpdate(inputJson("{",
- " 'addresses': [] }"), "mixed_tensor");
+ " 'addresses': [] }"), "dense_tensor");
}
@Test
@@ -1503,6 +1570,7 @@ public class JsonReaderTestCase {
exception.expect(IllegalArgumentException.class);
exception.expectMessage("Remove update for field 'sparse_tensor' does not contain tensor addresses");
createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "sparse_tensor");
+ createTensorRemoveUpdate(inputJson("{'addresses': [] }"), "mixed_tensor");
}
@Test
diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
index eb4001e6415..6935c54ba2a 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
@@ -12,18 +12,14 @@ public class TensorAddUpdateTest {
@Test
public void apply_add_update_operations() {
assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
- assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
}
private void assertApplyTo(String init, String update, String expected) {
String spec = "tensor(x{},y{})";
TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
TensorAddUpdate addUpdate = new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update)));
- TensorFieldValue updatedFieldValue = (TensorFieldValue) addUpdate.applyTo(initialFieldValue);
- assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
+ Tensor updated = ((TensorFieldValue) addUpdate.applyTo(initialFieldValue)).getTensor().get();
+ assertEquals(Tensor.from(spec, expected), updated);
}
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 6e9444de2be..b885e6ddca0 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -1,12 +1,6 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.update;
-import com.yahoo.document.Document;
-import com.yahoo.document.DocumentId;
-import com.yahoo.document.DocumentType;
-import com.yahoo.document.DocumentTypeManager;
-import com.yahoo.document.Field;
-import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.update.TensorModifyUpdate.Operation;
import com.yahoo.tensor.Tensor;
@@ -28,10 +22,11 @@ public class TensorModifyUpdateTest {
assertConvertToCompatible("tensor(x{})", "tensor(x[10])");
assertConvertToCompatible("tensor(x{})", "tensor(x{})");
assertConvertToCompatible("tensor(x{},y{},z{})", "tensor(x[],y[10],z{})");
+ assertConvertToCompatible("tensor(x{},y{})", "tensor(x{},y[3])");
}
private static void assertConvertToCompatible(String expectedType, String inputType) {
- assertEquals(expectedType, TensorModifyUpdate.convertToCompatibleType(TensorType.fromSpec(inputType)).toString());
+ assertEquals(expectedType, TensorModifyUpdate.convertDimensionsToMapped(TensorType.fromSpec(inputType)).toString());
}
@Test
@@ -46,15 +41,9 @@ public class TensorModifyUpdateTest {
public void apply_modify_update_operations() {
assertApplyTo("tensor(x{},y{})", Operation.REPLACE,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
- assertApplyTo("tensor(x{},y{})", Operation.ADD,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x{},y{})", Operation.MULTIPLY,
- "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
- assertApplyTo("tensor(x[1],y[2])", Operation.REPLACE,
- "{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:0}", "{{x:0,y:0}:1,{x:0,y:1}:0}");
assertApplyTo("tensor(x[1],y[2])", Operation.ADD,
"{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:5}");
- assertApplyTo("tensor(x[1],y[2])", Operation.MULTIPLY,
+ assertApplyTo("tensor(x{},y[2])", Operation.MULTIPLY,
"{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
}
diff --git a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
index 40ab00facdb..3a005e858c8 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorRemoveUpdateTest.java
@@ -12,9 +12,6 @@ public class TensorRemoveUpdateTest {
@Test
public void apply_remove_update_operations() {
assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{{x:0,y:1}:1}", "{{x:0,y:0}:2}");
- assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:0}:1,{x:0,y:1}:1}", "{}");
- assertApplyTo("{}", "{{x:0,y:0}:1}", "{}");
- assertApplyTo("{{x:0,y:0}:2, {x:0,y:1}:3}", "{}", "{{x:0,y:0}:2, {x:0,y:1}:3}");
}
private void assertApplyTo(String init, String update, String expected) {
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index c74c211756f..017d83893f0 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -922,6 +922,17 @@ TEST(DocumentUpdateTest, tensor_add_update_can_be_applied)
.add({{"x", "c"}}, 7));
}
+TEST(DocumentUpdateTest, tensor_remove_update_can_be_applied)
+{
+ TensorUpdateFixture f;
+ f.assertApplyUpdate(f.spec().add({{"x", "a"}}, 2)
+ .add({{"x", "b"}}, 3),
+
+ TensorRemoveUpdate(f.makeTensor(f.spec().add({{"x", "b"}}, 1))),
+
+ f.spec().add({{"x", "a"}}, 2));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied)
{
TensorUpdateFixture f;
diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp
index 3e2bb86c66b..671bf260629 100644
--- a/document/src/vespa/document/update/tensor_remove_update.cpp
+++ b/document/src/vespa/document/update/tensor_remove_update.cpp
@@ -6,6 +6,8 @@
#include <vespa/document/fieldvalue/document.h>
#include <vespa/document/fieldvalue/tensorfieldvalue.h>
#include <vespa/document/serialization/vespadocumentdeserializer.h>
+#include <vespa/eval/tensor/cell_values.h>
+#include <vespa/eval/tensor/sparse/sparse_tensor.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/xmlstream.h>
@@ -77,17 +79,35 @@ TensorRemoveUpdate::checkCompatibility(const Field &field) const
std::unique_ptr<Tensor>
TensorRemoveUpdate::applyTo(const Tensor &tensor) const
{
- // TODO: implement
- (void) tensor;
+ auto &addressTensor = _tensor->getAsTensorPtr();
+ if (addressTensor) {
+ if (const auto *sparseTensor = dynamic_cast<const vespalib::tensor::SparseTensor *>(addressTensor.get())) {
+ vespalib::tensor::CellValues cellAddresses(*sparseTensor);
+ return tensor.remove(cellAddresses);
+ } else {
+ throw IllegalArgumentException(make_string("Expected address tensor to be sparse, but has type '%s'",
+ addressTensor->type().to_spec().c_str()));
+ }
+ }
return std::unique_ptr<Tensor>();
}
bool
TensorRemoveUpdate::applyTo(FieldValue &value) const
{
- // TODO: implement
- (void) value;
- return false;
+ if (value.inherits(TensorFieldValue::classId)) {
+ TensorFieldValue &tensorFieldValue = static_cast<TensorFieldValue &>(value);
+ auto &oldTensor = tensorFieldValue.getAsTensorPtr();
+ auto newTensor = applyTo(*oldTensor);
+ if (newTensor) {
+ tensorFieldValue = std::move(newTensor);
+ }
+ } else {
+ std::string err = make_string("Unable to perform a tensor remove update on a '%s' field value.",
+ value.getClass().name());
+ throw IllegalStateException(err, VESPA_STRLOC);
+ }
+ return true;
}
void
diff --git a/document/src/vespa/document/update/valueupdate.h b/document/src/vespa/document/update/valueupdate.h
index 0e15943f8e4..6939d10ce2c 100644
--- a/document/src/vespa/document/update/valueupdate.h
+++ b/document/src/vespa/document/update/valueupdate.h
@@ -55,7 +55,8 @@ public:
Map = IDENTIFIABLE_CLASSID(MapValueUpdate),
Remove = IDENTIFIABLE_CLASSID(RemoveValueUpdate),
TensorModifyUpdate = IDENTIFIABLE_CLASSID(TensorModifyUpdate),
- TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate)
+ TensorAddUpdate = IDENTIFIABLE_CLASSID(TensorAddUpdate),
+ TensorRemoveUpdate = IDENTIFIABLE_CLASSID(TensorRemoveUpdate)
};
ValueUpdate()
diff --git a/documentapi/CMakeLists.txt b/documentapi/CMakeLists.txt
index b03dd66c817..86d29732399 100644
--- a/documentapi/CMakeLists.txt
+++ b/documentapi/CMakeLists.txt
@@ -14,7 +14,6 @@ vespa_define_module(
vdslib
LIBS
- src/vespa/binref
src/vespa/documentapi
src/vespa/documentapi/loadtypes
src/vespa/documentapi/messagebus
diff --git a/documentapi/src/vespa/binref/.gitignore b/documentapi/src/vespa/binref/.gitignore
deleted file mode 100644
index cfb0e619824..00000000000
--- a/documentapi/src/vespa/binref/.gitignore
+++ /dev/null
@@ -1,3 +0,0 @@
-.depend
-Makefile
-testrun.sh
diff --git a/documentapi/src/vespa/binref/CMakeLists.txt b/documentapi/src/vespa/binref/CMakeLists.txt
deleted file mode 100644
index adece6dd711..00000000000
--- a/documentapi/src/vespa/binref/CMakeLists.txt
+++ /dev/null
@@ -1 +0,0 @@
-# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
diff --git a/jrt_test/src/binref/testrun.sh b/jrt_test/src/binref/testrun.sh
deleted file mode 120000
index 56c3c1186d8..00000000000
--- a/jrt_test/src/binref/testrun.sh
+++ /dev/null
@@ -1 +0,0 @@
-../../../vespalib/src/vespa/vespalib/testkit/testrun.sh \ No newline at end of file
diff --git a/lowercasing_test/src/binref/testrun.sh b/lowercasing_test/src/binref/testrun.sh
deleted file mode 120000
index 56c3c1186d8..00000000000
--- a/lowercasing_test/src/binref/testrun.sh
+++ /dev/null
@@ -1 +0,0 @@
-../../../vespalib/src/vespa/vespalib/testkit/testrun.sh \ No newline at end of file
diff --git a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
index afbb1c30f17..78cd9ce44b9 100644
--- a/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
+++ b/searchcore/src/tests/proton/common/attribute_updater/attribute_updater_test.cpp
@@ -20,6 +20,7 @@
#include <vespa/document/update/removevalueupdate.h>
#include <vespa/document/update/tensor_add_update.h>
#include <vespa/document/update/tensor_modify_update.h>
+#include <vespa/document/update/tensor_remove_update.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/searchcore/proton/common/attribute_updater.h>
@@ -28,8 +29,8 @@
#include <vespa/searchlib/attribute/reference_attribute.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/generic_tensor_attribute.h>
-#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
+#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/log/log.h>
LOG_SETUP("attribute_updater_test");
@@ -76,7 +77,8 @@ makeDocumentTypeRepo()
.addField("wsfloat", Wset(DataType::T_FLOAT))
.addField("wsstring", Wset(DataType::T_STRING))
.addField("ref", 333)
- .addField("dense_tensor", DataType::T_TENSOR),
+ .addField("dense_tensor", DataType::T_TENSOR)
+ .addField("sparse_tensor", DataType::T_TENSOR),
Struct("testdoc.body"))
.referenceType(333, 222);
return std::make_unique<DocumentTypeRepo>(builder.config());
@@ -416,35 +418,54 @@ makeTensorFieldValue(const TensorSpec &spec)
return result;
}
-void
-setTensor(TensorAttribute &attribute, uint32_t lid, const TensorSpec &spec)
-{
- auto tensor = makeTensor(spec);
- attribute.setTensor(lid, *tensor);
- attribute.commit();
-}
+template <typename TensorAttributeType>
+struct TensorFixture : public Fixture {
+ vespalib::string type;
+ std::unique_ptr<TensorAttributeType> attribute;
-TEST_F("require that tensor modify update is applied", Fixture)
-{
- vespalib::string type = "tensor(x[2])";
- auto attribute = makeTensorAttribute<DenseTensorAttribute>("dense_tensor", type);
- setTensor(*attribute, 1, TensorSpec(type).add({{"x", 0}}, 3).add({{"x", 1}}, 5));
+ TensorFixture(const vespalib::string &type_, const vespalib::string &name)
+ : type(type_),
+ attribute(makeTensorAttribute<TensorAttributeType>(name, type))
+ {
+ }
- f.applyValueUpdate(*attribute, 1,
+ void setTensor(const TensorSpec &spec) {
+ auto tensor = makeTensor(spec);
+ attribute->setTensor(1, *tensor);
+ attribute->commit();
+ }
+
+ void assertTensor(const TensorSpec &expSpec) {
+ EXPECT_EQUAL(expSpec, attribute->getTensor(1)->toSpec());
+ }
+};
+
+TEST_F("require that tensor modify update is applied",
+ TensorFixture<DenseTensorAttribute>("tensor(x[2])", "dense_tensor"))
+{
+ f.setTensor(TensorSpec(f.type).add({{"x", 0}}, 3).add({{"x", 1}}, 5));
+ f.applyValueUpdate(*f.attribute, 1,
TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE,
makeTensorFieldValue(TensorSpec("tensor(x{})").add({{"x", 0}}, 7))));
- EXPECT_EQUAL(TensorSpec(type).add({{"x", 0}}, 7).add({{"x", 1}}, 5), attribute->getTensor(1)->toSpec());
+ f.assertTensor(TensorSpec(f.type).add({{"x", 0}}, 7).add({{"x", 1}}, 5));
}
-TEST_F("require that tensor add update is applied", Fixture)
+TEST_F("require that tensor add update is applied",
+ TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor"))
{
- vespalib::string type = "tensor(x{})";
- auto attribute = makeTensorAttribute<GenericTensorAttribute>("dense_tensor", type);
- setTensor(*attribute, 1, TensorSpec(type).add({{"x", "a"}}, 2));
+ f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2));
+ f.applyValueUpdate(*f.attribute, 1,
+ TensorAddUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "a"}}, 3))));
+ f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 3));
+}
- f.applyValueUpdate(*attribute, 1,
- TensorAddUpdate(makeTensorFieldValue(TensorSpec(type).add({{"x", "a"}}, 3))));
- EXPECT_EQUAL(TensorSpec(type).add({{"x", "a"}}, 3), attribute->getTensor(1)->toSpec());
+TEST_F("require that tensor remove update is applied",
+ TensorFixture<GenericTensorAttribute>("tensor(x{})", "sparse_tensor"))
+{
+ f.setTensor(TensorSpec(f.type).add({{"x", "a"}}, 2).add({{"x", "b"}}, 3));
+ f.applyValueUpdate(*f.attribute, 1,
+ TensorRemoveUpdate(makeTensorFieldValue(TensorSpec(f.type).add({{"x", "b"}}, 1))));
+ f.assertTensor(TensorSpec(f.type).add({{"x", "a"}}, 2));
}
}
diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
index 933857cffed..fcca1c2a737 100644
--- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
+++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp
@@ -16,6 +16,7 @@
#include <vespa/document/update/removevalueupdate.h>
#include <vespa/document/update/tensor_add_update.h>
#include <vespa/document/update/tensor_modify_update.h>
+#include <vespa/document/update/tensor_remove_update.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/searchlib/attribute/attributevector.hpp>
#include <vespa/searchlib/attribute/changevector.hpp>
@@ -238,6 +239,8 @@ AttributeUpdater::handleUpdate(TensorAttribute &vec, uint32_t lid, const ValueUp
applyTensorUpdate(vec, lid, static_cast<const TensorModifyUpdate &>(upd));
} else if (op == ValueUpdate::TensorAddUpdate) {
applyTensorUpdate(vec, lid, static_cast<const TensorAddUpdate &>(upd));
+ } else if (op == ValueUpdate::TensorRemoveUpdate) {
+ applyTensorUpdate(vec, lid, static_cast<const TensorRemoveUpdate &>(upd));
} else if (op == ValueUpdate::Clear) {
vec.clearDoc(lid);
} else {
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java
new file mode 100644
index 00000000000..2911b77707a
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClient.java
@@ -0,0 +1,101 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLParameters;
+import java.io.IOException;
+import java.net.Authenticator;
+import java.net.CookieHandler;
+import java.net.ProxySelector;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.time.Duration;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executor;
+
+/**
+ * A {@link HttpClient} that uses either http or https based on the global Vespa TLS configuration.
+ *
+ * @author bjorncs
+ */
+class TlsAwareHttpClient extends HttpClient {
+
+ private final HttpClient wrappedClient;
+ private final String userAgent;
+
+ TlsAwareHttpClient(HttpClient wrappedClient, String userAgent) {
+ this.wrappedClient = wrappedClient;
+ this.userAgent = userAgent;
+ }
+
+ @Override
+ public Optional<CookieHandler> cookieHandler() {
+ return wrappedClient.cookieHandler();
+ }
+
+ @Override
+ public Optional<Duration> connectTimeout() {
+ return wrappedClient.connectTimeout();
+ }
+
+ @Override
+ public Redirect followRedirects() {
+ return wrappedClient.followRedirects();
+ }
+
+ @Override
+ public Optional<ProxySelector> proxy() {
+ return wrappedClient.proxy();
+ }
+
+ @Override
+ public SSLContext sslContext() {
+ return wrappedClient.sslContext();
+ }
+
+ @Override
+ public SSLParameters sslParameters() {
+ return wrappedClient.sslParameters();
+ }
+
+ @Override
+ public Optional<Authenticator> authenticator() {
+ return wrappedClient.authenticator();
+ }
+
+ @Override
+ public Version version() {
+ return wrappedClient.version();
+ }
+
+ @Override
+ public Optional<Executor> executor() {
+ return wrappedClient.executor();
+ }
+
+ @Override
+ public <T> HttpResponse<T> send(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) throws IOException, InterruptedException {
+ return wrappedClient.send(wrapRequest(request), responseBodyHandler);
+ }
+
+ @Override
+ public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler) {
+ return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler);
+ }
+
+ @Override
+ public <T> CompletableFuture<HttpResponse<T>> sendAsync(HttpRequest request, HttpResponse.BodyHandler<T> responseBodyHandler, HttpResponse.PushPromiseHandler<T> pushPromiseHandler) {
+ return wrappedClient.sendAsync(wrapRequest(request), responseBodyHandler, pushPromiseHandler);
+ }
+
+ @Override
+ public String toString() {
+ return wrappedClient.toString();
+ }
+
+ private HttpRequest wrapRequest(HttpRequest request) {
+ return new TlsAwareHttpRequest(request, userAgent);
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java
new file mode 100644
index 00000000000..7eca2463ba7
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpClientBuilder.java
@@ -0,0 +1,97 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import com.yahoo.security.tls.TlsContext;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLParameters;
+import java.net.Authenticator;
+import java.net.CookieHandler;
+import java.net.ProxySelector;
+import java.net.http.HttpClient;
+import java.time.Duration;
+import java.util.concurrent.Executor;
+
+/**
+ * A client builder for {@link HttpClient} which uses {@link TlsContext} for TLS configuration.
+ * Intended for internal Vespa communication only.
+ *
+ * @author bjorncs
+ */
+public class TlsAwareHttpClientBuilder implements HttpClient.Builder {
+
+ private final HttpClient.Builder wrappedBuilder;
+ private final String userAgent;
+
+ public TlsAwareHttpClientBuilder(TlsContext tlsContext) {
+ this(tlsContext, "vespa-tls-aware-client");
+ }
+
+ public TlsAwareHttpClientBuilder(TlsContext tlsContext, String userAgent) {
+ this.wrappedBuilder = HttpClient.newBuilder()
+ .sslContext(tlsContext.context())
+ .sslParameters(tlsContext.parameters());
+ this.userAgent = userAgent;
+ }
+
+ @Override
+ public HttpClient.Builder cookieHandler(CookieHandler cookieHandler) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public HttpClient.Builder connectTimeout(Duration duration) {
+ wrappedBuilder.connectTimeout(duration);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder sslContext(SSLContext sslContext) {
+ throw new UnsupportedOperationException("SSLContext is given from tls context");
+ }
+
+ @Override
+ public HttpClient.Builder sslParameters(SSLParameters sslParameters) {
+ throw new UnsupportedOperationException("SSLParameters is given from tls context");
+ }
+
+ @Override
+ public HttpClient.Builder executor(Executor executor) {
+ wrappedBuilder.executor(executor);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder followRedirects(HttpClient.Redirect policy) {
+ wrappedBuilder.followRedirects(policy);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder version(HttpClient.Version version) {
+ wrappedBuilder.version(version);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder priority(int priority) {
+ wrappedBuilder.priority(priority);
+ return this;
+ }
+
+ @Override
+ public HttpClient.Builder proxy(ProxySelector proxySelector) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public HttpClient.Builder authenticator(Authenticator authenticator) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public HttpClient build() {
+ // TODO Stop wrapping the client once TLS is mandatory
+ return new TlsAwareHttpClient(wrappedBuilder.build(), userAgent);
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java
new file mode 100644
index 00000000000..bbdd8af791f
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/TlsAwareHttpRequest.java
@@ -0,0 +1,103 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.security.tls.https;
+
+import com.yahoo.security.tls.MixedMode;
+import com.yahoo.security.tls.TransportSecurityUtils;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.http.HttpClient;
+import java.net.http.HttpHeaders;
+import java.net.http.HttpRequest;
+import java.time.Duration;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * A {@link HttpRequest} where the scheme is either http or https based on the global Vespa TLS configuration.
+ *
+ * @author bjorncs
+ */
+class TlsAwareHttpRequest extends HttpRequest {
+
+ private final URI rewrittenUri;
+ private final HttpRequest wrappedRequest;
+ private final HttpHeaders rewrittenHeaders;
+
+ TlsAwareHttpRequest(HttpRequest wrappedRequest, String userAgent) {
+ this.wrappedRequest = wrappedRequest;
+ this.rewrittenUri = rewriteUri(wrappedRequest.uri());
+ this.rewrittenHeaders = rewriteHeaders(wrappedRequest, userAgent);
+ }
+
+ @Override
+ public Optional<BodyPublisher> bodyPublisher() {
+ return wrappedRequest.bodyPublisher();
+ }
+
+ @Override
+ public String method() {
+ return wrappedRequest.method();
+ }
+
+ @Override
+ public Optional<Duration> timeout() {
+ return wrappedRequest.timeout();
+ }
+
+ @Override
+ public boolean expectContinue() {
+ return wrappedRequest.expectContinue();
+ }
+
+ @Override
+ public URI uri() {
+ return rewrittenUri;
+ }
+
+ @Override
+ public Optional<HttpClient.Version> version() {
+ return wrappedRequest.version();
+ }
+
+ @Override
+ public HttpHeaders headers() {
+ return rewrittenHeaders;
+ }
+
+ private static URI rewriteUri(URI uri) {
+ if (!uri.getScheme().equals("http")) {
+ return uri;
+ }
+ String rewrittenScheme =
+ TransportSecurityUtils.getConfigFile().isPresent() && TransportSecurityUtils.getInsecureMixedMode() != MixedMode.PLAINTEXT_CLIENT_MIXED_SERVER ?
+ "https" :
+ "http";
+ int port = uri.getPort();
+ int rewrittenPort = port != -1 ? port : (rewrittenScheme.equals("http") ? 80 : 443);
+ try {
+ return new URI(rewrittenScheme, uri.getUserInfo(), uri.getHost(), rewrittenPort, uri.getPath(), uri.getQuery(), uri.getFragment());
+ } catch (URISyntaxException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static HttpHeaders rewriteHeaders(HttpRequest request, String userAgent) {
+ HttpHeaders headers = request.headers();
+ if (headers.firstValue("User-Agent").isPresent()) {
+ return headers;
+ }
+ HashMap<String, List<String>> rewrittenHeaders = new HashMap<>(headers.map());
+ rewrittenHeaders.put("User-Agent", List.of(userAgent));
+ return HttpHeaders.of(rewrittenHeaders, (ignored1, ignored2) -> true);
+ }
+
+ @Override
+ public String toString() {
+ return "TlsAwareHttpRequest{" +
+ "rewrittenUri=" + rewrittenUri +
+ ", wrappedRequest=" + wrappedRequest +
+ '}';
+ }
+}
diff --git a/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java
new file mode 100644
index 00000000000..43067705fa3
--- /dev/null
+++ b/security-utils/src/main/java/com/yahoo/security/tls/https/package-info.java
@@ -0,0 +1,8 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+/**
+ * @author bjorncs
+ */
+@ExportPackage
+package com.yahoo.security.tls.https;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/storage/src/tests/bucketdb/bucketmanagertest.cpp b/storage/src/tests/bucketdb/bucketmanagertest.cpp
index 54b3bf4b8d0..09fe310e97e 100644
--- a/storage/src/tests/bucketdb/bucketmanagertest.cpp
+++ b/storage/src/tests/bucketdb/bucketmanagertest.cpp
@@ -8,6 +8,7 @@
#include <vespa/document/update/documentupdate.h>
#include <vespa/document/repo/documenttyperepo.h>
#include <vespa/storage/bucketdb/bucketmanager.h>
+#include <vespa/storage/common/global_bucket_space_distribution_converter.h>
#include <vespa/storage/persistence/filestorage/filestormanager.h>
#include <vespa/storageapi/message/persistence.h>
#include <vespa/storageapi/message/state.h>
@@ -84,6 +85,7 @@ public:
CPPUNIT_TEST(testConflictSetOnlyClearedAfterAllBucketRequestsDone);
CPPUNIT_TEST(testRejectRequestWithMismatchingDistributionHash);
CPPUNIT_TEST(testDbNotIteratedWhenAllRequestsRejected);
+ CPPUNIT_TEST(fall_back_to_legacy_global_distribution_hash_on_mismatch);
// FIXME(vekterli): test is not deterministic and enjoys failing
// sporadically when running under Valgrind. See bug 5932891.
@@ -154,6 +156,7 @@ public:
void testConflictSetOnlyClearedAfterAllBucketRequestsDone();
void testRejectRequestWithMismatchingDistributionHash();
void testDbNotIteratedWhenAllRequestsRejected();
+ void fall_back_to_legacy_global_distribution_hash_on_mismatch();
public:
static constexpr uint32_t DIR_SPREAD = 3;
@@ -785,6 +788,10 @@ public:
return std::make_shared<api::RequestBucketInfoCommand>(makeBucketSpace(), 0, _state, hash);
}
+ auto createFullFetchCommandWithHash(document::BucketSpace space, vespalib::stringref hash) const {
+ return std::make_shared<api::RequestBucketInfoCommand>(space, 0, _state, hash);
+ }
+
auto acquireBucketLockAndSendInfoRequest(const document::BucketId& bucket) {
auto guard = acquireBucketLock(bucket);
// Send down processing command which will block.
@@ -850,6 +857,45 @@ public:
_self._top->getRepliesOnce();
}
+ // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+ std::unique_ptr<lib::Distribution> default_grouped_distribution() {
+ return std::make_unique<lib::Distribution>(
+ GlobalBucketSpaceDistributionConverter::string_to_config(vespalib::string(
+R"(redundancy 2
+group[3]
+group[0].name "invalid"
+group[0].index "invalid"
+group[0].partitions 1|*
+group[0].nodes[0]
+group[1].name rack0
+group[1].index 0
+group[1].nodes[3]
+group[1].nodes[0].index 0
+group[1].nodes[1].index 1
+group[1].nodes[2].index 2
+group[2].name rack1
+group[2].index 1
+group[2].nodes[3]
+group[2].nodes[0].index 3
+group[2].nodes[1].index 4
+group[2].nodes[2].index 5
+)")));
+ }
+
+ std::shared_ptr<lib::Distribution> derived_global_grouped_distribution(bool use_legacy) {
+ auto default_distr = default_grouped_distribution();
+ return GlobalBucketSpaceDistributionConverter::convert_to_global(*default_distr, use_legacy);
+ }
+
+ void set_grouped_distribution_configs() {
+ auto default_distr = default_grouped_distribution();
+ _self._node->getComponentRegister().getBucketSpaceRepo()
+ .get(document::FixedBucketSpaces::default_space()).setDistribution(std::move(default_distr));
+ auto global_distr = derived_global_grouped_distribution(false);
+ _self._node->getComponentRegister().getBucketSpaceRepo()
+ .get(document::FixedBucketSpaces::global_space()).setDistribution(std::move(global_distr));
+ }
+
private:
BucketManagerTest& _self;
lib::ClusterState _state;
@@ -1358,4 +1404,19 @@ BucketManagerTest::testDbNotIteratedWhenAllRequestsRejected()
auto replies = fixture.awaitAndGetReplies(1);
}
+// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+void BucketManagerTest::fall_back_to_legacy_global_distribution_hash_on_mismatch() {
+ ConcurrentOperationFixture f(*this);
+
+ f.set_grouped_distribution_configs();
+
+ auto legacy_hash = f.derived_global_grouped_distribution(true)->getNodeGraph().getDistributionConfigHash();
+
+ auto infoCmd = f.createFullFetchCommandWithHash(document::FixedBucketSpaces::global_space(), legacy_hash);
+ _top->sendDown(infoCmd);
+ auto replies = f.awaitAndGetReplies(1);
+ auto& reply = dynamic_cast<api::RequestBucketInfoReply&>(*replies[0]);
+ CPPUNIT_ASSERT_EQUAL(api::ReturnCode::OK, reply.getResult().getResult()); // _not_ REJECTED
+}
+
} // storage
diff --git a/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp b/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp
index 5afea9cd3cd..d75f2ac6459 100644
--- a/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp
+++ b/storage/src/tests/common/global_bucket_space_distribution_converter_test.cpp
@@ -17,6 +17,7 @@ struct GlobalBucketSpaceDistributionConverterTest : public CppUnit::TestFixture
CPPUNIT_TEST(config_retired_state_is_propagated);
CPPUNIT_TEST(group_capacities_are_propagated);
CPPUNIT_TEST(global_distribution_has_same_owner_distributors_as_default);
+ CPPUNIT_TEST(can_generate_config_with_legacy_partition_spec);
CPPUNIT_TEST_SUITE_END();
void can_transform_flat_cluster_config();
@@ -27,6 +28,7 @@ struct GlobalBucketSpaceDistributionConverterTest : public CppUnit::TestFixture
void config_retired_state_is_propagated();
void group_capacities_are_propagated();
void global_distribution_has_same_owner_distributors_as_default();
+ void can_generate_config_with_legacy_partition_spec();
};
CPPUNIT_TEST_SUITE_REGISTRATION(GlobalBucketSpaceDistributionConverterTest);
@@ -35,9 +37,9 @@ using DistributionConfig = vespa::config::content::StorDistributionConfig;
namespace {
-vespalib::string default_to_global_config(const vespalib::string& default_config) {
+vespalib::string default_to_global_config(const vespalib::string& default_config, bool legacy_mode = false) {
auto default_cfg = GlobalBucketSpaceDistributionConverter::string_to_config(default_config);
- auto as_global = GlobalBucketSpaceDistributionConverter::convert_to_global(*default_cfg);
+ auto as_global = GlobalBucketSpaceDistributionConverter::convert_to_global(*default_cfg, legacy_mode);
return GlobalBucketSpaceDistributionConverter::config_to_string(*as_global);
}
@@ -377,4 +379,64 @@ group[2].nodes[1].index 2
}
}
+// By "legacy" read "broken", but we need to be able to generate it to support rolling upgrades properly.
+// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+void GlobalBucketSpaceDistributionConverterTest::can_generate_config_with_legacy_partition_spec() {
+ vespalib::string default_config(
+R"(redundancy 2
+group[3]
+group[0].name "invalid"
+group[0].index "invalid"
+group[0].partitions 1|*
+group[0].nodes[0]
+group[1].name rack0
+group[1].index 0
+group[1].nodes[3]
+group[1].nodes[0].index 0
+group[1].nodes[1].index 1
+group[1].nodes[2].index 2
+group[2].name rack1
+group[2].index 1
+group[2].nodes[3]
+group[2].nodes[0].index 3
+group[2].nodes[1].index 4
+group[2].nodes[2].index 5
+)");
+
+ vespalib::string expected_global_config(
+R"(redundancy 6
+initial_redundancy 0
+ensure_primary_persisted true
+ready_copies 6
+active_per_leaf_group true
+distributor_auto_ownership_transfer_on_whole_group_down true
+group[0].index "invalid"
+group[0].name "invalid"
+group[0].capacity 1
+group[0].partitions "3|3|*"
+group[1].index "0"
+group[1].name "rack0"
+group[1].capacity 1
+group[1].partitions ""
+group[1].nodes[0].index 0
+group[1].nodes[0].retired false
+group[1].nodes[1].index 1
+group[1].nodes[1].retired false
+group[1].nodes[2].index 2
+group[1].nodes[2].retired false
+group[2].index "1"
+group[2].name "rack1"
+group[2].capacity 1
+group[2].partitions ""
+group[2].nodes[0].index 3
+group[2].nodes[0].retired false
+group[2].nodes[1].index 4
+group[2].nodes[1].retired false
+group[2].nodes[2].index 5
+group[2].nodes[2].retired false
+disk_distribution MODULO_BID
+)");
+ CPPUNIT_ASSERT_EQUAL(expected_global_config, default_to_global_config(default_config, true));
+}
+
} \ No newline at end of file
diff --git a/storage/src/tests/distributor/bucketdbupdatertest.cpp b/storage/src/tests/distributor/bucketdbupdatertest.cpp
index 53f80854bef..b2d554c1e42 100644
--- a/storage/src/tests/distributor/bucketdbupdatertest.cpp
+++ b/storage/src/tests/distributor/bucketdbupdatertest.cpp
@@ -111,6 +111,7 @@ class BucketDBUpdaterTest : public CppUnit::TestFixture,
CPPUNIT_TEST(identity_update_of_diverging_untrusted_replicas_does_not_mark_any_as_trusted);
CPPUNIT_TEST(adding_diverging_replica_to_existing_trusted_does_not_remove_trusted);
CPPUNIT_TEST(batch_update_from_distributor_change_does_not_mark_diverging_replicas_as_trusted);
+ CPPUNIT_TEST(global_distribution_hash_falls_back_to_legacy_format_upon_request_rejection);
CPPUNIT_TEST_SUITE_END();
public:
@@ -175,6 +176,7 @@ protected:
void identity_update_of_diverging_untrusted_replicas_does_not_mark_any_as_trusted();
void adding_diverging_replica_to_existing_trusted_does_not_remove_trusted();
void batch_update_from_distributor_change_does_not_mark_diverging_replicas_as_trusted();
+ void global_distribution_hash_falls_back_to_legacy_format_upon_request_rejection();
auto &defaultDistributorBucketSpace() { return getBucketSpaceRepo().get(makeBucketSpace()); }
@@ -505,7 +507,7 @@ public:
std::make_shared<lib::Distribution>(distConfig));
}
- std::string getDistConfig6Nodes3Groups() const {
+ std::string getDistConfig6Nodes2Groups() const {
return ("redundancy 2\n"
"group[3]\n"
"group[0].name \"invalid\"\n"
@@ -692,7 +694,7 @@ BucketDBUpdaterTest::testDistributorChange()
void
BucketDBUpdaterTest::testDistributorChangeWithGrouping()
{
- std::string distConfig(getDistConfig6Nodes3Groups());
+ std::string distConfig(getDistConfig6Nodes2Groups());
setDistribution(distConfig);
int numBuckets = 100;
@@ -2073,7 +2075,7 @@ BucketDBUpdaterTest::testClusterStateAlwaysSendsFullFetchWhenDistributionChangeP
setAndEnableClusterState(stateBefore, expectedMsgs, dummyBucketsToReturn);
}
_sender.clear();
- std::string distConfig(getDistConfig6Nodes3Groups());
+ std::string distConfig(getDistConfig6Nodes2Groups());
setDistribution(distConfig);
sortSentMessagesByIndex(_sender);
CPPUNIT_ASSERT_EQUAL(messageCount(6), _sender.commands.size());
@@ -2549,4 +2551,52 @@ void BucketDBUpdaterTest::batch_update_from_distributor_change_does_not_mark_div
"0:5/1/2/3|1:5/7/8/9", true));
}
+// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+void BucketDBUpdaterTest::global_distribution_hash_falls_back_to_legacy_format_upon_request_rejection() {
+ std::string distConfig(getDistConfig6Nodes2Groups());
+ setDistribution(distConfig);
+
+ const vespalib::string current_hash = "(0d*|*(0;0;1;2)(1;3;4;5))";
+ const vespalib::string legacy_hash = "(0d3|3|*(0;0;1;2)(1;3;4;5))";
+
+ setSystemState(lib::ClusterState("distributor:6 storage:6"));
+ CPPUNIT_ASSERT_EQUAL(messageCount(6), _sender.commands.size());
+
+ api::RequestBucketInfoCommand* global_req = nullptr;
+ for (auto& cmd : _sender.commands) {
+ auto& req_cmd = dynamic_cast<api::RequestBucketInfoCommand&>(*cmd);
+ if (req_cmd.getBucketSpace() == document::FixedBucketSpaces::global_space()) {
+ global_req = &req_cmd;
+ break;
+ }
+ }
+ CPPUNIT_ASSERT(global_req != nullptr);
+ CPPUNIT_ASSERT_EQUAL(current_hash, global_req->getDistributionHash());
+
+ auto reply = std::make_shared<api::RequestBucketInfoReply>(*global_req);
+ reply->setResult(api::ReturnCode::REJECTED);
+ getBucketDBUpdater().onRequestBucketInfoReply(reply);
+
+ getClock().addSecondsToTime(10);
+ getBucketDBUpdater().resendDelayedMessages();
+
+ // Should now be a resent request with legacy distribution hash
+ CPPUNIT_ASSERT_EQUAL(messageCount(6) + 1, _sender.commands.size());
+ auto& legacy_req = dynamic_cast<api::RequestBucketInfoCommand&>(*_sender.commands.back());
+ CPPUNIT_ASSERT_EQUAL(legacy_hash, legacy_req.getDistributionHash());
+
+ // Now if we reject it _again_ we should cycle back to the current hash
+ // in case it wasn't a hash-based rejection after all. And the circle of life continues.
+ reply = std::make_shared<api::RequestBucketInfoReply>(legacy_req);
+ reply->setResult(api::ReturnCode::REJECTED);
+ getBucketDBUpdater().onRequestBucketInfoReply(reply);
+
+ getClock().addSecondsToTime(10);
+ getBucketDBUpdater().resendDelayedMessages();
+
+ CPPUNIT_ASSERT_EQUAL(messageCount(6) + 2, _sender.commands.size());
+ auto& new_current_req = dynamic_cast<api::RequestBucketInfoCommand&>(*_sender.commands.back());
+ CPPUNIT_ASSERT_EQUAL(current_hash, new_current_req.getDistributionHash());
+}
+
}
diff --git a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp
index 41de215d877..a1c1742edb5 100644
--- a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp
+++ b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp
@@ -6,6 +6,7 @@
#include <iomanip>
#include <vespa/storage/common/content_bucket_space_repo.h>
#include <vespa/storage/common/nodestateupdater.h>
+#include <vespa/storage/common/global_bucket_space_distribution_converter.h>
#include <vespa/vdslib/state/cluster_state_bundle.h>
#include <vespa/storage/storageutil/distributorstatecache.h>
#include <vespa/storageframework/generic/status/htmlstatusreporter.h>
@@ -577,7 +578,21 @@ BucketManager::processRequestBucketInfoCommands(document::BucketSpace bucketSpac
<< " differs from this state.";
} else if (!their_hash.empty() && their_hash != our_hash) {
// Empty hash indicates request from 4.2 protocol or earlier
- error << "Distribution config has changed since request.";
+ // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+ bool matches_legacy_hash = false;
+ if (bucketSpace == document::FixedBucketSpaces::global_space()) {
+ const auto default_distr =_component.getBucketSpaceRepo()
+ .get(document::FixedBucketSpaces::default_space()).getDistribution();
+ // Convert in legacy distribution mode, which will accept old 'hash' structure.
+ const auto legacy_global_distr = GlobalBucketSpaceDistributionConverter::convert_to_global(
+ *default_distr, true/*use legacy mode*/);
+ const auto legacy_hash = legacy_global_distr->getNodeGraph().getDistributionConfigHash();
+ LOG(debug, "Falling back to comparing against legacy distribution hash: %s", legacy_hash.c_str());
+ matches_legacy_hash = (their_hash == legacy_hash);
+ }
+ if (!matches_legacy_hash) {
+ error << "Distribution config has changed since request.";
+ }
}
if (error.str().empty()) {
std::pair<std::set<uint16_t>::iterator, bool> result(
diff --git a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp
index 534644458bc..cbcaeef8fdf 100644
--- a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp
+++ b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.cpp
@@ -59,6 +59,21 @@ vespalib::string sub_groups_to_partition_spec(const Group& parent) {
return spec.str();
}
+// Allow generating legacy (broken) partition specs that may be used transiently
+// during rolling upgrades from a pre-fix version to a post-fix version.
+// TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+vespalib::string sub_groups_to_legacy_partition_spec(const Group& parent) {
+ vespalib::asciistream partitions;
+ // In case of a flat cluster config, this ends up with a partition spec of '*',
+ // which is fine. It basically means "put all replicas in this group", which
+ // happens to be exactly what we want.
+ for (auto& child : parent.sub_groups) {
+ partitions << child.second->nested_leaf_count << '|';
+ }
+ partitions << '*';
+ return partitions.str();
+}
+
bool is_leaf_group(const DistributionConfigBuilder::Group& g) noexcept {
return !g.nodes.empty();
}
@@ -87,19 +102,31 @@ void insert_new_group_into_tree(
void build_transformed_root_group(DistributionConfigBuilder& builder,
const DistributionConfigBuilder::Group& config_source_root,
- const Group& parsed_root) {
+ const Group& parsed_root,
+ bool legacy_mode) {
DistributionConfigBuilder::Group new_root(config_source_root);
- new_root.partitions = sub_groups_to_partition_spec(parsed_root);
+ if (!legacy_mode) {
+ new_root.partitions = sub_groups_to_partition_spec(parsed_root);
+ } else {
+ // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+ new_root.partitions = sub_groups_to_legacy_partition_spec(parsed_root);
+ }
builder.group.emplace_back(std::move(new_root));
}
void build_transformed_non_root_group(DistributionConfigBuilder& builder,
const DistributionConfigBuilder::Group& config_source_group,
- const Group& parsed_root) {
+ const Group& parsed_root,
+ bool legacy_mode) {
DistributionConfigBuilder::Group new_group(config_source_group);
if (!is_leaf_group(config_source_group)) { // Partition specs only apply to inner nodes
const auto& g = find_non_root_group_by_index(config_source_group.index, parsed_root);
- new_group.partitions = sub_groups_to_partition_spec(g);
+ if (!legacy_mode) {
+ new_group.partitions = sub_groups_to_partition_spec(g);
+ } else {
+ // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+ new_group.partitions = sub_groups_to_legacy_partition_spec(g);
+ }
}
builder.group.emplace_back(std::move(new_group));
}
@@ -135,16 +162,16 @@ std::unique_ptr<Group> create_group_tree_from_config(const DistributionConfig& s
* transitively, its parents again etc) have already been processed. This directly
* implies that the root group is always the first group present in the config.
*/
-void build_global_groups(DistributionConfigBuilder& builder, const DistributionConfig& source) {
+void build_global_groups(DistributionConfigBuilder& builder, const DistributionConfig& source, bool legacy_mode) {
assert(!source.group.empty()); // TODO gracefully handle empty config?
auto root = create_group_tree_from_config(source);
auto g_iter = source.group.begin();
const auto g_end = source.group.end();
- build_transformed_root_group(builder, *g_iter, *root);
+ build_transformed_root_group(builder, *g_iter, *root, legacy_mode);
++g_iter;
for (; g_iter != g_end; ++g_iter) {
- build_transformed_non_root_group(builder, *g_iter, *root);
+ build_transformed_non_root_group(builder, *g_iter, *root, legacy_mode);
}
builder.redundancy = root->nested_leaf_count;
@@ -154,17 +181,17 @@ void build_global_groups(DistributionConfigBuilder& builder, const DistributionC
} // anon ns
std::shared_ptr<DistributionConfig>
-GlobalBucketSpaceDistributionConverter::convert_to_global(const DistributionConfig& source) {
+GlobalBucketSpaceDistributionConverter::convert_to_global(const DistributionConfig& source, bool legacy_mode) {
DistributionConfigBuilder builder;
set_distribution_invariant_config_fields(builder, source);
- build_global_groups(builder, source);
+ build_global_groups(builder, source, legacy_mode);
return std::make_shared<DistributionConfig>(builder);
}
std::shared_ptr<lib::Distribution>
-GlobalBucketSpaceDistributionConverter::convert_to_global(const lib::Distribution& distr) {
+GlobalBucketSpaceDistributionConverter::convert_to_global(const lib::Distribution& distr, bool legacy_mode) {
const auto src_config = distr.serialize();
- auto global_config = convert_to_global(*string_to_config(src_config));
+ auto global_config = convert_to_global(*string_to_config(src_config), legacy_mode);
return std::make_shared<lib::Distribution>(*global_config);
}
diff --git a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h
index d135f56a5c1..b2be65dad42 100644
--- a/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h
+++ b/storage/src/vespa/storage/common/global_bucket_space_distribution_converter.h
@@ -10,8 +10,9 @@ namespace storage {
struct GlobalBucketSpaceDistributionConverter {
using DistributionConfig = vespa::config::content::StorDistributionConfig;
- static std::shared_ptr<DistributionConfig> convert_to_global(const DistributionConfig&);
- static std::shared_ptr<lib::Distribution> convert_to_global(const lib::Distribution&);
+ // TODO remove legacy_mode flags on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+ static std::shared_ptr<DistributionConfig> convert_to_global(const DistributionConfig&, bool legacy_mode = false);
+ static std::shared_ptr<lib::Distribution> convert_to_global(const lib::Distribution&, bool legacy_mode = false);
// Helper functions which may be of use outside this class
static std::unique_ptr<DistributionConfig> string_to_config(const vespalib::string&);
diff --git a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp
index 2071558628e..c295be19a0b 100644
--- a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp
+++ b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.cpp
@@ -35,7 +35,8 @@ PendingBucketSpaceDbTransition::PendingBucketSpaceDbTransition(const PendingClus
_pendingClusterState(pendingClusterState),
_distributorBucketSpace(distributorBucketSpace),
_distributorIndex(_clusterInfo->getDistributorIndex()),
- _bucketOwnershipTransfer(distributionChanged)
+ _bucketOwnershipTransfer(distributionChanged),
+ _rejectedRequests()
{
if (distributorChanged()) {
_bucketOwnershipTransfer = true;
diff --git a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h
index 903f9b762fb..7eb2974eb52 100644
--- a/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h
+++ b/storage/src/vespa/storage/distributor/pending_bucket_space_db_transition.h
@@ -4,6 +4,7 @@
#include "pending_bucket_space_db_transition_entry.h"
#include "outdated_nodes.h"
#include <vespa/storage/bucketdb/bucketdatabase.h>
+#include <unordered_map>
namespace storage::api { class RequestBucketInfoReply; }
namespace storage::lib { class ClusterState; class State; }
@@ -48,6 +49,7 @@ private:
DistributorBucketSpace &_distributorBucketSpace;
uint16_t _distributorIndex;
bool _bucketOwnershipTransfer;
+ std::unordered_map<uint16_t, size_t> _rejectedRequests;
// BucketDataBase::MutableEntryProcessor API
bool process(BucketDatabase::Entry& e) override;
@@ -111,6 +113,14 @@ public:
// Methods used by unit tests.
const EntryList& results() const { return _entries; }
void addNodeInfo(const document::BucketId& id, const BucketCopy& copy);
+
+ void incrementRequestRejections(uint16_t node) {
+ _rejectedRequests[node]++;
+ }
+ size_t rejectedRequests(uint16_t node) const {
+ auto iter = _rejectedRequests.find(node);
+ return ((iter != _rejectedRequests.end()) ? iter->second : 0);
+ }
};
}
diff --git a/storage/src/vespa/storage/distributor/pendingclusterstate.cpp b/storage/src/vespa/storage/distributor/pendingclusterstate.cpp
index 1996ae9d2af..5f74a82c28a 100644
--- a/storage/src/vespa/storage/distributor/pendingclusterstate.cpp
+++ b/storage/src/vespa/storage/distributor/pendingclusterstate.cpp
@@ -7,6 +7,7 @@
#include "distributor_bucket_space.h"
#include <vespa/storageframework/defaultimplementation/clock/realclock.h>
#include <vespa/storage/common/bucketoperationlogger.h>
+#include <vespa/storage/common/global_bucket_space_distribution_converter.h>
#include <vespa/document/bucket/fixed_bucket_spaces.h>
#include <vespa/vespalib/util/xmlstream.hpp>
#include <climits>
@@ -185,7 +186,30 @@ PendingClusterState::requestNode(BucketSpaceAndNode bucketSpaceAndNode)
{
const auto &distributorBucketSpace(_bucketSpaceRepo.get(bucketSpaceAndNode.bucketSpace));
const auto &distribution(distributorBucketSpace.getDistribution());
- vespalib::string distributionHash(distribution.getNodeGraph().getDistributionConfigHash());
+ vespalib::string distributionHash;
+ // TODO remove on Vespa 8 - this is a workaround for https://github.com/vespa-engine/vespa/issues/8475
+ bool sendLegacyHash = false;
+ if (bucketSpaceAndNode.bucketSpace == document::FixedBucketSpaces::global_space()) {
+ auto transitionIter = _pendingTransitions.find(bucketSpaceAndNode.bucketSpace);
+ assert(transitionIter != _pendingTransitions.end());
+ // First request cannot have been rejected yet and will thus be sent with non-legacy hash.
+ // Subsequent requests will be sent 50-50. This is because a request may be rejected due to
+ // other reasons than just the hash mismatching, so if we don't cycle back to the non-legacy
+ // hash we risk never converging.
+ sendLegacyHash = ((transitionIter->second->rejectedRequests(bucketSpaceAndNode.node) % 2) == 1);
+ }
+ if (!sendLegacyHash) {
+ distributionHash = distribution.getNodeGraph().getDistributionConfigHash();
+ } else {
+ const auto& defaultSpace = _bucketSpaceRepo.get(document::FixedBucketSpaces::default_space());
+ // Generate legacy distribution hash explicitly.
+ auto legacyGlobalDistr = GlobalBucketSpaceDistributionConverter::convert_to_global(
+ defaultSpace.getDistribution(), true/*use legacy mode*/);
+ distributionHash = legacyGlobalDistr->getNodeGraph().getDistributionConfigHash();
+ LOG(debug, "Falling back to sending legacy hash to node %u: %s",
+ bucketSpaceAndNode.node, distributionHash.c_str());
+ }
+
LOG(debug,
"Requesting bucket info for bucket space %" PRIu64 " node %d with cluster state '%s' "
"and distribution hash '%s'",
@@ -238,6 +262,11 @@ PendingClusterState::onRequestBucketInfoReply(const std::shared_ptr<api::Request
resendTime += framework::MilliSecTime(100);
_delayedRequests.emplace_back(resendTime, bucketSpaceAndNode);
_sentMessages.erase(iter);
+ if (result.getResult() == api::ReturnCode::REJECTED) {
+ auto transitionIter = _pendingTransitions.find(bucketSpaceAndNode.bucketSpace);
+ assert(transitionIter != _pendingTransitions.end());
+ transitionIter->second->incrementRequestRejections(bucketSpaceAndNode.node);
+ }
return true;
}
diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml
index f30aed1af5f..0f23eaed964 100644
--- a/vespa-athenz/pom.xml
+++ b/vespa-athenz/pom.xml
@@ -117,7 +117,21 @@
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk-core</artifactId>
- </dependency>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-annotations</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
</dependencies>
<build>
diff --git a/vespa-hadoop/abi-spec.json b/vespa-hadoop/abi-spec.json
index 5bbac15f0e5..e3f4dcf272a 100644
--- a/vespa-hadoop/abi-spec.json
+++ b/vespa-hadoop/abi-spec.json
@@ -1201,6 +1201,8 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1245,6 +1247,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -1330,6 +1334,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1432,6 +1438,8 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 932513f8a57..c3fe8c5c7ad 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -808,6 +808,8 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -852,6 +854,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -937,6 +941,8 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public com.yahoo.tensor.Tensor remove(java.util.Set)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1039,6 +1045,8 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index fb55b2d5014..38d832d01c2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -13,6 +13,7 @@ import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
/**
* An indexed (dense) tensor backed by a double array.
@@ -190,6 +191,16 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) {
+ throw new IllegalArgumentException("Merge is not supported for indexed tensors");
+ }
+
+ @Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ throw new IllegalArgumentException("Remove is not supported for indexed tensors");
+ }
+
+ @Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index ec3020a1a4e..22ceed22d3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -5,6 +5,8 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
/**
* A sparse implementation of a tensor backed by a Map of cells to values.
@@ -51,6 +53,38 @@ public class MappedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+
+ // currently, underlying implementation disallows multiple entries with the same key
+
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ if ( ! cells.containsKey(addCell.getKey())) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ TensorAddress address = cell.getKey();
+ if ( ! addresses.contains(address)) {
+ builder.cell(address, cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 17e33c58a13..08878edeb83 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -9,6 +9,8 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
/**
@@ -70,13 +72,17 @@ public class MixedTensor implements Tensor {
return cells.iterator();
}
+ private Iterable<Cell> cellIterable() {
+ return this::cellIterator;
+ }
+
/**
* Returns an iterator over the values of this tensor.
* The iteration order is the same as for cellIterator.
*/
@Override
public Iterator<Double> valueIterator() {
- return new Iterator<Double>() {
+ return new Iterator<>() {
Iterator<Cell> cellIterator = cellIterator();
@Override
public boolean hasNext() {
@@ -108,6 +114,38 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Cell cell : cellIterable()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ return builder.build();
+ }
+
+ @Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+
+ // iterate through all sparse addresses referencing a dense subspace
+ for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) {
+ TensorAddress sparsePartialAddress = entry.getKey();
+ if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part
+ long offset = entry.getValue();
+ for (int i = 0; i < index.denseSubspaceSize; ++i) {
+ Cell cell = cells.get((int)offset + i);
+ builder.cell(cell.getKey(), cell.getValue());
+ }
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8002990e5c6..eb16801c306 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -25,6 +25,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
@@ -113,6 +114,29 @@ public interface Tensor {
return builder.build();
}
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * modified according to the given operation and cells in the given map.
+ * In contrast to {@link #modify}, previously non-existing cells are added
+ * to this tensor. Only valid for sparse or mixed tensors.
+ *
+ * @param op how to update overlapping cells
+ * @param cells cells to merge with this tensor
+ * @return a new tensor where this tensor is merged with the other
+ */
+ Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
+
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * removed according to the given set of addresses. Only valid for sparse
+ * or mixed tensors. For mixed tensors, addresses are assumed to only
+ * contain the sparse dimensions, as the entire dense subspace is removed.
+ *
+ * @param addresses list of addresses to remove
+ * @return a new tensor where cells have been removed
+ */
+ Tensor remove(Set<TensorAddress> addresses);
+
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 2c9eefbd130..02d16e6f3e4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -151,12 +151,106 @@ public class TensorTestCase {
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
Tensor.from("tensor(x[1],y[3])", "{}"),
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}"));
+ assertTensorModify((left, right) -> left * right,
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}"));
+ }
+
+ @Test
+ public void testTensorMerge() {
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ }
+
+ @Test
+ public void testTensorRemove() {
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
}
private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) {
assertEquals(expected, init.modify(op, update.cells()));
}
+ private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) {
+ DoubleBinaryOperator op = (left, right) -> right;
+ assertEquals(expected, init.merge(op, update.cells()));
+ }
+
+ private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) {
+ assertEquals(expected, init.remove(update.cells().keySet()));
+ }
+
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),