summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-08-25 10:07:35 +0000
committerGeir Storli <geirst@yahooinc.com>2023-08-25 10:07:35 +0000
commit6fd175bcb3e1b61323214380d6f23324a3056043 (patch)
tree6644e278e75daff81491a265ced5ccdc3cd4413a /document
parent39420e6f2331825568605cfeb2975844de99de3a (diff)
Add "create non-existing cells" flag to TensorModifyUpdate.
When this is true, non-existing cells in the input tensor is created before applying the modify update. The default cell value is 0.0 for REPLACE and ADD operations, and 1.0 for MULTIPLY operation.
Diffstat (limited to 'document')
-rw-r--r--document/abi-spec.json6
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java18
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java19
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java25
-rw-r--r--document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java5
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java31
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java31
-rw-r--r--document/src/tests/data/serialize-tensor-update-cpp.datbin231 -> 348 bytes
-rw-r--r--document/src/tests/data/serialize-tensor-update-java.datbin231 -> 348 bytes
-rw-r--r--document/src/tests/documentupdatetestcase.cpp25
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.cpp20
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp59
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.h5
14 files changed, 214 insertions, 44 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index 22c38337e90..899c107a242 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -3442,11 +3442,13 @@
],
"methods" : [
"public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)",
+ "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue, boolean)",
"public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)",
"public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()",
"public com.yahoo.document.datatypes.TensorFieldValue getValue()",
+ "public boolean getCreateNonExistingCells()",
+ "public double getDefaultCellValue()",
"public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
- "public void setDefaultCellValue(double)",
"public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)",
"protected void checkCompatibility(com.yahoo.document.DataType)",
"public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)",
@@ -3459,7 +3461,7 @@
"fields" : [
"protected com.yahoo.document.update.TensorModifyUpdate$Operation operation",
"protected com.yahoo.document.datatypes.TensorFieldValue tensor",
- "protected java.util.Optional defaultCellValue"
+ "protected boolean createNonExistingCells"
]
},
"com.yahoo.document.update.TensorRemoveUpdate" : {
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 21fa51d5b88..92ede0fbe99 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -2,6 +2,7 @@
package com.yahoo.document.json.readers;
+import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -29,6 +30,7 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_REPLACE = "replace";
private static final String MODIFY_ADD = "add";
private static final String MODIFY_MULTIPLY = "multiply";
+ private static final String MODIFY_CREATE = "create";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
expectFieldIsOfTypeTensor(field);
@@ -39,7 +41,7 @@ public class TensorModifyUpdateReader {
expectOperationSpecified(result.operation, field.getName());
expectTensorSpecified(result.tensor, field.getName());
- return new TensorModifyUpdate(result.operation, result.tensor);
+ return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells);
}
private static void expectFieldIsOfTypeTensor(Field field) {
@@ -73,6 +75,7 @@ public class TensorModifyUpdateReader {
private static class ModifyUpdateResult {
TensorModifyUpdate.Operation operation = null;
+ boolean createNonExistingCells = false;
TensorFieldValue tensor = null;
}
@@ -85,6 +88,9 @@ public class TensorModifyUpdateReader {
case MODIFY_OPERATION:
result.operation = createOperation(buffer, field.getName());
break;
+ case MODIFY_CREATE:
+ result.createNonExistingCells = createNonExistingCells(buffer, field.getName());
+ break;
case TENSOR_CELLS:
result.tensor = createTensorFromCells(buffer, field);
break;
@@ -112,6 +118,16 @@ public class TensorModifyUpdateReader {
}
}
+ private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) {
+ if (buffer.current() == JsonToken.VALUE_TRUE) {
+ return true;
+ } else if (buffer.current() == JsonToken.VALUE_FALSE) {
+ return false;
+ } else {
+ throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'");
+ }
+ }
+
private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index c6fdc915401..765d999dbc9 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -26,20 +26,35 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
@Override
protected ValueUpdate readTensorModifyUpdate(DataType type) {
byte operationId = getByte(null);
- TensorModifyUpdate.Operation operation = TensorModifyUpdate.Operation.getOperation(operationId);
+ var operation = decodeOperation(operationId);
if (operation == null) {
throw new DeserializationException("Unknown operation id " + operationId + " for tensor modify update");
}
if (!(type instanceof TensorDataType)) {
throw new DeserializationException("Expected tensor data type, got " + type);
}
+ var createNonExistingCells = decodeCreateNonExistingCells(operationId);
+ if (createNonExistingCells) {
+ // Read the default cell value (but it is not used by TensorModifyUpdate).
+ getDouble(null);
+ }
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
- return new TensorModifyUpdate(operation, tensor);
+ return new TensorModifyUpdate(operation, tensor, createNonExistingCells);
+ }
+
+ private TensorModifyUpdate.Operation decodeOperation(byte operationId) {
+ byte OP_MASK = 0b01111111;
+ return TensorModifyUpdate.Operation.getOperation(operationId & OP_MASK);
+ }
+
+ private boolean decodeCreateNonExistingCells(byte operationId) {
+ byte CREATE_FLAG = -0b10000000;
+ return (operationId & CREATE_FLAG) != 0;
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
index 66bc8cbb4d5..b2c3cdc09de 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
@@ -19,10 +19,22 @@ public class VespaDocumentSerializerHead extends VespaDocumentSerializer6 {
@Override
public void write(TensorModifyUpdate update) {
- putByte(null, (byte) update.getOperation().id);
+ putByte(null, (byte) encodeOperationId(update));
+ if (update.getCreateNonExistingCells()) {
+ putDouble(null, update.getDefaultCellValue());
+ }
update.getValue().serialize(this);
}
+ private int encodeOperationId(TensorModifyUpdate update) {
+ int operationId = update.getOperation().id;
+ byte CREATE_FLAG = -0b10000000;
+ if (update.getCreateNonExistingCells()) {
+ operationId |= CREATE_FLAG;
+ }
+ return operationId;
+ }
+
@Override
public void write(TensorAddUpdate update) {
update.getValue().serialize(this);
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 835c056868a..8a14bd21443 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -26,12 +26,17 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
protected Operation operation;
protected TensorFieldValue tensor;
- protected Optional<Double> defaultCellValue = Optional.empty();
+ protected boolean createNonExistingCells;
public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) {
+ this(operation, tensor, false);
+ }
+
+ public TensorModifyUpdate(Operation operation, TensorFieldValue tensor, boolean createNonExistingCells) {
super(ValueUpdateClassID.TENSORMODIFY);
this.operation = operation;
this.tensor = tensor;
+ this.createNonExistingCells = createNonExistingCells;
verifyCompatibleType(tensor.getDataType().getTensorType());
}
@@ -51,10 +56,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
}
public Operation getOperation() { return operation; }
-
public TensorFieldValue getValue() { return tensor; }
+ public boolean getCreateNonExistingCells() { return createNonExistingCells; }
+ public double getDefaultCellValue() {
+ return (operation == Operation.MULTIPLY) ? 1.0 : 0.0;
+ }
public void setValue(TensorFieldValue value) { tensor = value; }
- public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); }
@Override
public FieldValue applyTo(FieldValue oldValue) {
@@ -70,10 +77,10 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
- if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) {
+ if (createNonExistingCells && hasMappedSubtype(oldTensor.type())) {
var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get());
if (!subspaces.isEmpty()) {
- oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get());
+ oldTensor = insertSubspaces(oldTensor, subspaces, getDefaultCellValue());
}
}
Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells());
@@ -142,7 +149,6 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
@Override
protected void checkCompatibility(DataType fieldType) {
if (!(fieldType instanceof TensorDataType)) {
@@ -162,17 +168,18 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
if (!super.equals(o)) return false;
TensorModifyUpdate that = (TensorModifyUpdate) o;
return operation == that.operation &&
- tensor.equals(that.tensor);
+ tensor.equals(that.tensor) &&
+ createNonExistingCells == that.createNonExistingCells;
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), operation, tensor);
+ return Objects.hash(super.hashCode(), operation, tensor, createNonExistingCells);
}
@Override
public String toString() {
- return super.toString() + " " + operation.name + " " + tensor;
+ return super.toString() + " " + operation.name + " " + tensor + " " + createNonExistingCells;
}
/**
diff --git a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java
index 9733cd41a88..9d4d1e8f3aa 100644
--- a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java
+++ b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java
@@ -822,7 +822,10 @@ public class DocumentUpdateTestCase {
result.addFieldUpdate(FieldUpdate.create(getField("dense_tensor"))
.addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, createTensor()))
.addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.ADD, createTensor()))
- .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor())));
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor()))
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, createTensor(), true))
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.ADD, createTensor(), true))
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor(), true)));
return result;
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 96b5d2c1fb5..4140a9eee02 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1681,6 +1681,26 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_create_non_existing_cells_true() {
+ assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, true, "sparse_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'create': true,",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_create_non_existing_cells_false() {
+ assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, false, "sparse_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'create': false,",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
public void tensor_modify_update_treats_the_input_tensor_as_sparse() {
// Note that the type of the tensor in the modify update is sparse (it only has mapped dimensions).
assertTensorModifyUpdate("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:1,y:2}:3.0}",
@@ -2155,16 +2175,25 @@ public class JsonReaderTestCase {
private void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation,
String tensorFieldName, String modifyJson) {
- assertTensorModifyUpdate(expectedTensor, expectedOperation, tensorFieldName,
+ assertTensorModifyUpdate(expectedTensor, expectedOperation, false, tensorFieldName,
+ createTensorModifyUpdate(modifyJson, tensorFieldName));
+ }
+
+ private void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation,
+ boolean expectedCreateNonExistingCells,
+ String tensorFieldName, String modifyJson) {
+ assertTensorModifyUpdate(expectedTensor, expectedOperation, expectedCreateNonExistingCells, tensorFieldName,
createTensorModifyUpdate(modifyJson, tensorFieldName));
}
private static void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation,
+ boolean expectedCreateNonExistingCells,
String tensorFieldName, DocumentUpdate update) {
assertTensorFieldUpdate(update, tensorFieldName);
TensorModifyUpdate modifyUpdate = (TensorModifyUpdate) update.getFieldUpdate(tensorFieldName).getValueUpdate(0);
assertEquals(expectedOperation, modifyUpdate.getOperation());
assertEquals(Tensor.from(expectedTensor), modifyUpdate.getValue().getTensor().get());
+ assertEquals(expectedCreateNonExistingCells, modifyUpdate.getCreateNonExistingCells());
}
private DocumentUpdate createTensorModifyUpdate(String modifyJson, String tensorFieldName) {
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 55b9090cce8..d0b04cf5449 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -51,35 +51,32 @@ public class TensorModifyUpdateTest {
@Test
public void apply_modify_update_operations_with_default_cell_value() {
- assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0),
- "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:5}");
+ assertApplyTo("tensor(x{})", "tensor(x{})", Operation.MULTIPLY, true,
+ "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:6}");
- assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0),
- "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:5,{x:c}:4}");
+ assertApplyTo("tensor(x{})", "tensor(x{})", Operation.MULTIPLY, true,
+ "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:6,{x:c}:4}");
- assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, Optional.of(1.0),
+ assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, true,
"{{x:a,y:0}:3,{x:a,y:1}:4,{x:a,y:2}:5}",
"{{x:a,y:0}:6,{x:b,y:1}:7,{x:b,y:2}:8,{x:c,y:0}:9}",
"{{x:a,y:0}:9,{x:a,y:1}:4,{x:a,y:2}:5," +
- "{x:b,y:0}:1,{x:b,y:1}:8,{x:b,y:2}:9," +
- "{x:c,y:0}:10,{x:c,y:1}:1,{x:c,y:2}:1}");
+ "{x:b,y:0}:0,{x:b,y:1}:7,{x:b,y:2}:8," +
+ "{x:c,y:0}:9,{x:c,y:1}:0,{x:c,y:2}:0}");
- // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions,
- // as the dense subspace is always represented (with default cell value 0.0).
- assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.ADD, Optional.of(2.0),
- "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:3,{x:2}:0}");
+ // NOTE: The default cell value (1.0) used for MULTIPLY operation doesn't have any effect for tensors
+ // with only indexed dimensions, as the dense subspace is always represented (with default cell value 0.0).
+ assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.MULTIPLY, true,
+ "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:0,{x:2}:0}");
}
private void assertApplyTo(String spec, Operation op, String input, String update, String expected) {
- assertApplyTo(spec, "tensor(x{},y{})", op, Optional.empty(), input, update, expected);
+ assertApplyTo(spec, "tensor(x{},y{})", op, false, input, update, expected);
}
- private void assertApplyTo(String inputSpec, String updateSpec, Operation op, Optional<Double> defaultCellValue, String input, String update, String expected) {
+ private void assertApplyTo(String inputSpec, String updateSpec, Operation op, boolean createNonExistingCells, String input, String update, String expected) {
TensorFieldValue inputFieldValue = new TensorFieldValue(Tensor.from(inputSpec, input));
- TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update)));
- if (defaultCellValue.isPresent()) {
- modifyUpdate.setDefaultCellValue(defaultCellValue.get());
- }
+ TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update)), createNonExistingCells);
TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(inputFieldValue);
assertEquals(Tensor.from(inputSpec, expected), updatedFieldValue.getTensor().get());
}
diff --git a/document/src/tests/data/serialize-tensor-update-cpp.dat b/document/src/tests/data/serialize-tensor-update-cpp.dat
index ad0e9d706b0..d6b6b5e2506 100644
--- a/document/src/tests/data/serialize-tensor-update-cpp.dat
+++ b/document/src/tests/data/serialize-tensor-update-cpp.dat
Binary files differ
diff --git a/document/src/tests/data/serialize-tensor-update-java.dat b/document/src/tests/data/serialize-tensor-update-java.dat
index ad0e9d706b0..d6b6b5e2506 100644
--- a/document/src/tests/data/serialize-tensor-update-java.dat
+++ b/document/src/tests/data/serialize-tensor-update-java.dat
Binary files differ
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 3fbccaa155f..b225ca6677b 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1024,6 +1024,20 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied)
.add({{"x", "b"}}, 15));
}
+TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_can_be_applied)
+{
+ TensorUpdateFixture f;
+ auto baseLine = f.spec().add({{"x", "a"}}, 2)
+ .add({{"x", "b"}}, 3);
+
+ f.assertApplyUpdate(baseLine,
+ std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD,
+ f.makeTensor(f.spec().add({{"x", "b"}}, 5).add({{"x", "c"}}, 6)), 0.0),
+ f.spec().add({{"x", "a"}}, 2)
+ .add({{"x", "b"}}, 8)
+ .add({{"x", "c"}}, 6));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor)
{
TensorUpdateFixture f;
@@ -1069,6 +1083,9 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_roundtrip_serialized)
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor()));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor(), 1.0));
}
TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_serialized)
@@ -1077,6 +1094,9 @@ TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_s
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor()));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor(), 1.0));
}
TEST(DocumentUpdateTest, tensor_modify_update_on_dense_tensor_can_be_roundtrip_serialized)
@@ -1170,7 +1190,10 @@ struct TensorUpdateSerializeFixture {
result->addUpdate(FieldUpdate(getField("dense_tensor"))
.addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, makeTensor()))
.addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, makeTensor()))
- .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor())));
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor()))
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, makeTensor(), 0.0))
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, makeTensor(), 0.0))
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor(), 1.0)));
return result;
}
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
index 2a2b642dd48..c0b56150c04 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
@@ -476,11 +476,29 @@ VespaDocumentSerializer::write(const RemoveFieldPathUpdate &value)
writeFieldPath(_stream, value);
}
+namespace {
+
+uint8_t
+encode_operation_id(const TensorModifyUpdate& update)
+{
+ uint8_t op = static_cast<uint8_t>(update.getOperation());
+ uint8_t CREATE_FLAG = 0b10000000;
+ if (update.get_default_cell_value().has_value()) {
+ op |= CREATE_FLAG;
+ }
+ return op;
+}
+
+}
+
void
VespaDocumentSerializer::write(const TensorModifyUpdate &value)
{
_stream << uint32_t(ValueUpdate::TensorModify);
- _stream << static_cast<uint8_t>(value.getOperation());
+ _stream << encode_operation_id(value);
+ if (value.get_default_cell_value().has_value()) {
+ _stream << value.get_default_cell_value().value();
+ }
write(value.getTensor());
}
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 92b2a8672c3..ad1e3095269 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -87,7 +87,8 @@ TensorModifyUpdate::TensorModifyUpdate()
TensorUpdate(),
_operation(Operation::MAX_NUM_OPERATIONS),
_tensorType(),
- _tensor()
+ _tensor(),
+ _default_cell_value()
{
}
@@ -96,7 +97,19 @@ TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<Tens
TensorUpdate(),
_operation(operation),
_tensorType(std::make_unique<TensorDataType>(dynamic_cast<const TensorDataType &>(*tensor->getDataType()))),
- _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
+ _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())),
+ _default_cell_value()
+{
+ *_tensor = *tensor;
+}
+
+TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor, double default_cell_value)
+ : ValueUpdate(TensorModify),
+ TensorUpdate(),
+ _operation(operation),
+ _tensorType(std::make_unique<TensorDataType>(dynamic_cast<const TensorDataType &>(*tensor->getDataType()))),
+ _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())),
+ _default_cell_value(default_cell_value)
{
*_tensor = *tensor;
}
@@ -116,6 +129,9 @@ TensorModifyUpdate::operator==(const ValueUpdate &other) const
if (*_tensor != *o._tensor) {
return false;
}
+ if (_default_cell_value != o._default_cell_value) {
+ return false;
+ }
return true;
}
@@ -141,7 +157,11 @@ TensorModifyUpdate::apply_to(const Value &old_tensor,
{
if (auto cellsTensor = _tensor->getAsTensorPtr()) {
auto op = getJoinFunction(_operation);
- return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory);
+ if (_default_cell_value.has_value()) {
+ return TensorPartialUpdate::modify_with_defaults(old_tensor, op, *cellsTensor, _default_cell_value.value(), factory);
+ } else {
+ return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory);
+ }
}
return {};
}
@@ -179,6 +199,9 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in
if (_tensor) {
_tensor->print(out, verbose, indent);
}
+ if (_default_cell_value.has_value()) {
+ out << "," << _default_cell_value.value();
+ }
out << ")";
}
@@ -198,6 +221,26 @@ verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor)
throw IllegalStateException(err, VESPA_STRLOC);
}
+TensorModifyUpdate::Operation
+decode_operation(uint8_t encoded_op)
+{
+ uint8_t OP_MASK = 0b01111111;
+ uint8_t op = encoded_op & OP_MASK;
+ if (op >= static_cast<uint8_t>(TensorModifyUpdate::Operation::MAX_NUM_OPERATIONS)) {
+ vespalib::asciistream msg;
+ msg << "Unrecognized tensor modify update operation " << static_cast<uint32_t>(op);
+ throw DeserializeException(msg.str(), VESPA_STRLOC);
+ }
+ return static_cast<TensorModifyUpdate::Operation>(op);
+}
+
+bool
+decode_create_non_existing_cells(uint8_t encoded_op)
+{
+ uint8_t CREATE_FLAG = 0b10000000;
+ return (encoded_op & CREATE_FLAG) != 0;
+}
+
}
void
@@ -205,12 +248,12 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty
{
uint8_t op;
stream >> op;
- if (op >= static_cast<uint8_t>(Operation::MAX_NUM_OPERATIONS)) {
- vespalib::asciistream msg;
- msg << "Unrecognized tensor modify update operation " << static_cast<uint32_t>(op);
- throw DeserializeException(msg.str(), VESPA_STRLOC);
+ _operation = decode_operation(op);
+ if (decode_create_non_existing_cells(op)) {
+ double value;
+ stream >> value;
+ _default_cell_value = value;
}
- _operation = static_cast<Operation>(op);
_tensorType = convertToCompatibleType(dynamic_cast<const TensorDataType &>(type));
auto tensor = _tensorType->createFieldValue();
if (tensor->isA(FieldValue::Type::TENSOR)) {
diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h
index 9386b4f8a1b..931d5102c4f 100644
--- a/document/src/vespa/document/update/tensor_modify_update.h
+++ b/document/src/vespa/document/update/tensor_modify_update.h
@@ -2,6 +2,7 @@
#include "tensor_update.h"
#include "valueupdate.h"
+#include <optional>
namespace vespalib::eval { struct Value; }
@@ -29,12 +30,15 @@ private:
Operation _operation;
std::unique_ptr<const TensorDataType> _tensorType;
std::unique_ptr<TensorFieldValue> _tensor;
+ // When this is set, non-existing cells are created in the input tensor before applying the update.
+ std::optional<double> _default_cell_value;
friend ValueUpdate;
TensorModifyUpdate();
ACCEPT_UPDATE_VISITOR;
public:
TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor);
+ TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor, double default_cell_value);
TensorModifyUpdate(const TensorModifyUpdate &rhs) = delete;
TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs) = delete;
~TensorModifyUpdate() override;
@@ -42,6 +46,7 @@ public:
bool operator==(const ValueUpdate &other) const override;
Operation getOperation() const { return _operation; }
const TensorFieldValue &getTensor() const { return *_tensor; }
+ const std::optional<double>& get_default_cell_value() const { return _default_cell_value; }
void checkCompatibility(const Field &field) const override;
std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const;
std::unique_ptr<Value> apply_to(const Value &tensor,