summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--document/abi-spec.json6
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java18
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java19
-rw-r--r--document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java14
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java25
-rw-r--r--document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java5
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java31
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java31
-rw-r--r--document/src/tests/data/serialize-tensor-update-cpp.datbin231 -> 348 bytes
-rw-r--r--document/src/tests/data/serialize-tensor-update-java.datbin231 -> 348 bytes
-rw-r--r--document/src/tests/documentupdatetestcase.cpp25
-rw-r--r--document/src/vespa/document/serialization/vespadocumentserializer.cpp20
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.cpp59
-rw-r--r--document/src/vespa/document/update/tensor_modify_update.h5
14 files changed, 214 insertions, 44 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index 22c38337e90..899c107a242 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -3442,11 +3442,13 @@
],
"methods" : [
"public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)",
+ "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue, boolean)",
"public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)",
"public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()",
"public com.yahoo.document.datatypes.TensorFieldValue getValue()",
+ "public boolean getCreateNonExistingCells()",
+ "public double getDefaultCellValue()",
"public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
- "public void setDefaultCellValue(double)",
"public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)",
"protected void checkCompatibility(com.yahoo.document.DataType)",
"public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)",
@@ -3459,7 +3461,7 @@
"fields" : [
"protected com.yahoo.document.update.TensorModifyUpdate$Operation operation",
"protected com.yahoo.document.datatypes.TensorFieldValue tensor",
- "protected java.util.Optional defaultCellValue"
+ "protected boolean createNonExistingCells"
]
},
"com.yahoo.document.update.TensorRemoveUpdate" : {
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
index 21fa51d5b88..92ede0fbe99 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
@@ -2,6 +2,7 @@
package com.yahoo.document.json.readers;
+import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
@@ -29,6 +30,7 @@ public class TensorModifyUpdateReader {
private static final String MODIFY_REPLACE = "replace";
private static final String MODIFY_ADD = "add";
private static final String MODIFY_MULTIPLY = "multiply";
+ private static final String MODIFY_CREATE = "create";
public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
expectFieldIsOfTypeTensor(field);
@@ -39,7 +41,7 @@ public class TensorModifyUpdateReader {
expectOperationSpecified(result.operation, field.getName());
expectTensorSpecified(result.tensor, field.getName());
- return new TensorModifyUpdate(result.operation, result.tensor);
+ return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells);
}
private static void expectFieldIsOfTypeTensor(Field field) {
@@ -73,6 +75,7 @@ public class TensorModifyUpdateReader {
private static class ModifyUpdateResult {
TensorModifyUpdate.Operation operation = null;
+ boolean createNonExistingCells = false;
TensorFieldValue tensor = null;
}
@@ -85,6 +88,9 @@ public class TensorModifyUpdateReader {
case MODIFY_OPERATION:
result.operation = createOperation(buffer, field.getName());
break;
+ case MODIFY_CREATE:
+ result.createNonExistingCells = createNonExistingCells(buffer, field.getName());
+ break;
case TENSOR_CELLS:
result.tensor = createTensorFromCells(buffer, field);
break;
@@ -112,6 +118,16 @@ public class TensorModifyUpdateReader {
}
}
+ private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) {
+ if (buffer.current() == JsonToken.VALUE_TRUE) {
+ return true;
+ } else if (buffer.current() == JsonToken.VALUE_FALSE) {
+ return false;
+ } else {
+ throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'");
+ }
+ }
+
private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
index c6fdc915401..765d999dbc9 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java
@@ -26,20 +26,35 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 {
@Override
protected ValueUpdate readTensorModifyUpdate(DataType type) {
byte operationId = getByte(null);
- TensorModifyUpdate.Operation operation = TensorModifyUpdate.Operation.getOperation(operationId);
+ var operation = decodeOperation(operationId);
if (operation == null) {
throw new DeserializationException("Unknown operation id " + operationId + " for tensor modify update");
}
if (!(type instanceof TensorDataType)) {
throw new DeserializationException("Expected tensor data type, got " + type);
}
+ var createNonExistingCells = decodeCreateNonExistingCells(operationId);
+ if (createNonExistingCells) {
+ // Read the default cell value (but it is not used by TensorModifyUpdate).
+ getDouble(null);
+ }
TensorDataType tensorDataType = (TensorDataType)type;
TensorType tensorType = tensorDataType.getTensorType();
TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType);
TensorFieldValue tensor = new TensorFieldValue(convertedType);
tensor.deserialize(this);
- return new TensorModifyUpdate(operation, tensor);
+ return new TensorModifyUpdate(operation, tensor, createNonExistingCells);
+ }
+
+ private TensorModifyUpdate.Operation decodeOperation(byte operationId) {
+ byte OP_MASK = 0b01111111;
+ return TensorModifyUpdate.Operation.getOperation(operationId & OP_MASK);
+ }
+
+ private boolean decodeCreateNonExistingCells(byte operationId) {
+ byte CREATE_FLAG = -0b10000000;
+ return (operationId & CREATE_FLAG) != 0;
}
@Override
diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
index 66bc8cbb4d5..b2c3cdc09de 100644
--- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
+++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java
@@ -19,10 +19,22 @@ public class VespaDocumentSerializerHead extends VespaDocumentSerializer6 {
@Override
public void write(TensorModifyUpdate update) {
- putByte(null, (byte) update.getOperation().id);
+ putByte(null, (byte) encodeOperationId(update));
+ if (update.getCreateNonExistingCells()) {
+ putDouble(null, update.getDefaultCellValue());
+ }
update.getValue().serialize(this);
}
+ private int encodeOperationId(TensorModifyUpdate update) {
+ int operationId = update.getOperation().id;
+ byte CREATE_FLAG = -0b10000000;
+ if (update.getCreateNonExistingCells()) {
+ operationId |= CREATE_FLAG;
+ }
+ return operationId;
+ }
+
@Override
public void write(TensorAddUpdate update) {
update.getValue().serialize(this);
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 835c056868a..8a14bd21443 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -26,12 +26,17 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
protected Operation operation;
protected TensorFieldValue tensor;
- protected Optional<Double> defaultCellValue = Optional.empty();
+ protected boolean createNonExistingCells;
public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) {
+ this(operation, tensor, false);
+ }
+
+ public TensorModifyUpdate(Operation operation, TensorFieldValue tensor, boolean createNonExistingCells) {
super(ValueUpdateClassID.TENSORMODIFY);
this.operation = operation;
this.tensor = tensor;
+ this.createNonExistingCells = createNonExistingCells;
verifyCompatibleType(tensor.getDataType().getTensorType());
}
@@ -51,10 +56,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
}
public Operation getOperation() { return operation; }
-
public TensorFieldValue getValue() { return tensor; }
+ public boolean getCreateNonExistingCells() { return createNonExistingCells; }
+ public double getDefaultCellValue() {
+ return (operation == Operation.MULTIPLY) ? 1.0 : 0.0;
+ }
public void setValue(TensorFieldValue value) { tensor = value; }
- public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); }
@Override
public FieldValue applyTo(FieldValue oldValue) {
@@ -70,10 +77,10 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
- if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) {
+ if (createNonExistingCells && hasMappedSubtype(oldTensor.type())) {
var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get());
if (!subspaces.isEmpty()) {
- oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get());
+ oldTensor = insertSubspaces(oldTensor, subspaces, getDefaultCellValue());
}
}
Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells());
@@ -142,7 +149,6 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
return builder.build();
}
-
@Override
protected void checkCompatibility(DataType fieldType) {
if (!(fieldType instanceof TensorDataType)) {
@@ -162,17 +168,18 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
if (!super.equals(o)) return false;
TensorModifyUpdate that = (TensorModifyUpdate) o;
return operation == that.operation &&
- tensor.equals(that.tensor);
+ tensor.equals(that.tensor) &&
+ createNonExistingCells == that.createNonExistingCells;
}
@Override
public int hashCode() {
- return Objects.hash(super.hashCode(), operation, tensor);
+ return Objects.hash(super.hashCode(), operation, tensor, createNonExistingCells);
}
@Override
public String toString() {
- return super.toString() + " " + operation.name + " " + tensor;
+ return super.toString() + " " + operation.name + " " + tensor + " " + createNonExistingCells;
}
/**
diff --git a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java
index 9733cd41a88..9d4d1e8f3aa 100644
--- a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java
+++ b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java
@@ -822,7 +822,10 @@ public class DocumentUpdateTestCase {
result.addFieldUpdate(FieldUpdate.create(getField("dense_tensor"))
.addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, createTensor()))
.addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.ADD, createTensor()))
- .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor())));
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor()))
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, createTensor(), true))
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.ADD, createTensor(), true))
+ .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor(), true)));
return result;
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 96b5d2c1fb5..4140a9eee02 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1681,6 +1681,26 @@ public class JsonReaderTestCase {
}
@Test
+ public void tensor_modify_update_with_create_non_existing_cells_true() {
+ assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, true, "sparse_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'create': true,",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
+ public void tensor_modify_update_with_create_non_existing_cells_false() {
+ assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, false, "sparse_tensor",
+ inputJson("{",
+ " 'operation': 'add',",
+ " 'create': false,",
+ " 'cells': [",
+ " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}"));
+ }
+
+ @Test
public void tensor_modify_update_treats_the_input_tensor_as_sparse() {
// Note that the type of the tensor in the modify update is sparse (it only has mapped dimensions).
assertTensorModifyUpdate("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:1,y:2}:3.0}",
@@ -2155,16 +2175,25 @@ public class JsonReaderTestCase {
private void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation,
String tensorFieldName, String modifyJson) {
- assertTensorModifyUpdate(expectedTensor, expectedOperation, tensorFieldName,
+ assertTensorModifyUpdate(expectedTensor, expectedOperation, false, tensorFieldName,
+ createTensorModifyUpdate(modifyJson, tensorFieldName));
+ }
+
+ private void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation,
+ boolean expectedCreateNonExistingCells,
+ String tensorFieldName, String modifyJson) {
+ assertTensorModifyUpdate(expectedTensor, expectedOperation, expectedCreateNonExistingCells, tensorFieldName,
createTensorModifyUpdate(modifyJson, tensorFieldName));
}
private static void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation,
+ boolean expectedCreateNonExistingCells,
String tensorFieldName, DocumentUpdate update) {
assertTensorFieldUpdate(update, tensorFieldName);
TensorModifyUpdate modifyUpdate = (TensorModifyUpdate) update.getFieldUpdate(tensorFieldName).getValueUpdate(0);
assertEquals(expectedOperation, modifyUpdate.getOperation());
assertEquals(Tensor.from(expectedTensor), modifyUpdate.getValue().getTensor().get());
+ assertEquals(expectedCreateNonExistingCells, modifyUpdate.getCreateNonExistingCells());
}
private DocumentUpdate createTensorModifyUpdate(String modifyJson, String tensorFieldName) {
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 55b9090cce8..d0b04cf5449 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -51,35 +51,32 @@ public class TensorModifyUpdateTest {
@Test
public void apply_modify_update_operations_with_default_cell_value() {
- assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0),
- "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:5}");
+ assertApplyTo("tensor(x{})", "tensor(x{})", Operation.MULTIPLY, true,
+ "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:6}");
- assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0),
- "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:5,{x:c}:4}");
+ assertApplyTo("tensor(x{})", "tensor(x{})", Operation.MULTIPLY, true,
+ "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:6,{x:c}:4}");
- assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, Optional.of(1.0),
+ assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, true,
"{{x:a,y:0}:3,{x:a,y:1}:4,{x:a,y:2}:5}",
"{{x:a,y:0}:6,{x:b,y:1}:7,{x:b,y:2}:8,{x:c,y:0}:9}",
"{{x:a,y:0}:9,{x:a,y:1}:4,{x:a,y:2}:5," +
- "{x:b,y:0}:1,{x:b,y:1}:8,{x:b,y:2}:9," +
- "{x:c,y:0}:10,{x:c,y:1}:1,{x:c,y:2}:1}");
+ "{x:b,y:0}:0,{x:b,y:1}:7,{x:b,y:2}:8," +
+ "{x:c,y:0}:9,{x:c,y:1}:0,{x:c,y:2}:0}");
- // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions,
- // as the dense subspace is always represented (with default cell value 0.0).
- assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.ADD, Optional.of(2.0),
- "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:3,{x:2}:0}");
+ // NOTE: The default cell value (1.0) used for MULTIPLY operation doesn't have any effect for tensors
+ // with only indexed dimensions, as the dense subspace is always represented (with default cell value 0.0).
+ assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.MULTIPLY, true,
+ "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:0,{x:2}:0}");
}
private void assertApplyTo(String spec, Operation op, String input, String update, String expected) {
- assertApplyTo(spec, "tensor(x{},y{})", op, Optional.empty(), input, update, expected);
+ assertApplyTo(spec, "tensor(x{},y{})", op, false, input, update, expected);
}
- private void assertApplyTo(String inputSpec, String updateSpec, Operation op, Optional<Double> defaultCellValue, String input, String update, String expected) {
+ private void assertApplyTo(String inputSpec, String updateSpec, Operation op, boolean createNonExistingCells, String input, String update, String expected) {
TensorFieldValue inputFieldValue = new TensorFieldValue(Tensor.from(inputSpec, input));
- TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update)));
- if (defaultCellValue.isPresent()) {
- modifyUpdate.setDefaultCellValue(defaultCellValue.get());
- }
+ TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update)), createNonExistingCells);
TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(inputFieldValue);
assertEquals(Tensor.from(inputSpec, expected), updatedFieldValue.getTensor().get());
}
diff --git a/document/src/tests/data/serialize-tensor-update-cpp.dat b/document/src/tests/data/serialize-tensor-update-cpp.dat
index ad0e9d706b0..d6b6b5e2506 100644
--- a/document/src/tests/data/serialize-tensor-update-cpp.dat
+++ b/document/src/tests/data/serialize-tensor-update-cpp.dat
Binary files differ
diff --git a/document/src/tests/data/serialize-tensor-update-java.dat b/document/src/tests/data/serialize-tensor-update-java.dat
index ad0e9d706b0..d6b6b5e2506 100644
--- a/document/src/tests/data/serialize-tensor-update-java.dat
+++ b/document/src/tests/data/serialize-tensor-update-java.dat
Binary files differ
diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp
index 3fbccaa155f..b225ca6677b 100644
--- a/document/src/tests/documentupdatetestcase.cpp
+++ b/document/src/tests/documentupdatetestcase.cpp
@@ -1024,6 +1024,20 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied)
.add({{"x", "b"}}, 15));
}
+TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_can_be_applied)
+{
+ TensorUpdateFixture f;
+ auto baseLine = f.spec().add({{"x", "a"}}, 2)
+ .add({{"x", "b"}}, 3);
+
+ f.assertApplyUpdate(baseLine,
+ std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD,
+ f.makeTensor(f.spec().add({{"x", "b"}}, 5).add({{"x", "c"}}, 6)), 0.0),
+ f.spec().add({{"x", "a"}}, 2)
+ .add({{"x", "b"}}, 8)
+ .add({{"x", "c"}}, 6));
+}
+
TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor)
{
TensorUpdateFixture f;
@@ -1069,6 +1083,9 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_roundtrip_serialized)
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor()));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor(), 1.0));
}
TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_serialized)
@@ -1077,6 +1094,9 @@ TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_s
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor()));
f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor()));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor(), 0.0));
+ f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor(), 1.0));
}
TEST(DocumentUpdateTest, tensor_modify_update_on_dense_tensor_can_be_roundtrip_serialized)
@@ -1170,7 +1190,10 @@ struct TensorUpdateSerializeFixture {
result->addUpdate(FieldUpdate(getField("dense_tensor"))
.addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, makeTensor()))
.addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, makeTensor()))
- .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor())));
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor()))
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, makeTensor(), 0.0))
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, makeTensor(), 0.0))
+ .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor(), 1.0)));
return result;
}
diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
index 2a2b642dd48..c0b56150c04 100644
--- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp
+++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp
@@ -476,11 +476,29 @@ VespaDocumentSerializer::write(const RemoveFieldPathUpdate &value)
writeFieldPath(_stream, value);
}
+namespace {
+
+uint8_t
+encode_operation_id(const TensorModifyUpdate& update)
+{
+ uint8_t op = static_cast<uint8_t>(update.getOperation());
+ uint8_t CREATE_FLAG = 0b10000000;
+ if (update.get_default_cell_value().has_value()) {
+ op |= CREATE_FLAG;
+ }
+ return op;
+}
+
+}
+
void
VespaDocumentSerializer::write(const TensorModifyUpdate &value)
{
_stream << uint32_t(ValueUpdate::TensorModify);
- _stream << static_cast<uint8_t>(value.getOperation());
+ _stream << encode_operation_id(value);
+ if (value.get_default_cell_value().has_value()) {
+ _stream << value.get_default_cell_value().value();
+ }
write(value.getTensor());
}
diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp
index 92b2a8672c3..ad1e3095269 100644
--- a/document/src/vespa/document/update/tensor_modify_update.cpp
+++ b/document/src/vespa/document/update/tensor_modify_update.cpp
@@ -87,7 +87,8 @@ TensorModifyUpdate::TensorModifyUpdate()
TensorUpdate(),
_operation(Operation::MAX_NUM_OPERATIONS),
_tensorType(),
- _tensor()
+ _tensor(),
+ _default_cell_value()
{
}
@@ -96,7 +97,19 @@ TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<Tens
TensorUpdate(),
_operation(operation),
_tensorType(std::make_unique<TensorDataType>(dynamic_cast<const TensorDataType &>(*tensor->getDataType()))),
- _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release()))
+ _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())),
+ _default_cell_value()
+{
+ *_tensor = *tensor;
+}
+
+TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor, double default_cell_value)
+ : ValueUpdate(TensorModify),
+ TensorUpdate(),
+ _operation(operation),
+ _tensorType(std::make_unique<TensorDataType>(dynamic_cast<const TensorDataType &>(*tensor->getDataType()))),
+ _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())),
+ _default_cell_value(default_cell_value)
{
*_tensor = *tensor;
}
@@ -116,6 +129,9 @@ TensorModifyUpdate::operator==(const ValueUpdate &other) const
if (*_tensor != *o._tensor) {
return false;
}
+ if (_default_cell_value != o._default_cell_value) {
+ return false;
+ }
return true;
}
@@ -141,7 +157,11 @@ TensorModifyUpdate::apply_to(const Value &old_tensor,
{
if (auto cellsTensor = _tensor->getAsTensorPtr()) {
auto op = getJoinFunction(_operation);
- return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory);
+ if (_default_cell_value.has_value()) {
+ return TensorPartialUpdate::modify_with_defaults(old_tensor, op, *cellsTensor, _default_cell_value.value(), factory);
+ } else {
+ return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory);
+ }
}
return {};
}
@@ -179,6 +199,9 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in
if (_tensor) {
_tensor->print(out, verbose, indent);
}
+ if (_default_cell_value.has_value()) {
+ out << "," << _default_cell_value.value();
+ }
out << ")";
}
@@ -198,6 +221,26 @@ verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor)
throw IllegalStateException(err, VESPA_STRLOC);
}
+TensorModifyUpdate::Operation
+decode_operation(uint8_t encoded_op)
+{
+ uint8_t OP_MASK = 0b01111111;
+ uint8_t op = encoded_op & OP_MASK;
+ if (op >= static_cast<uint8_t>(TensorModifyUpdate::Operation::MAX_NUM_OPERATIONS)) {
+ vespalib::asciistream msg;
+ msg << "Unrecognized tensor modify update operation " << static_cast<uint32_t>(op);
+ throw DeserializeException(msg.str(), VESPA_STRLOC);
+ }
+ return static_cast<TensorModifyUpdate::Operation>(op);
+}
+
+bool
+decode_create_non_existing_cells(uint8_t encoded_op)
+{
+ uint8_t CREATE_FLAG = 0b10000000;
+ return (encoded_op & CREATE_FLAG) != 0;
+}
+
}
void
@@ -205,12 +248,12 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty
{
uint8_t op;
stream >> op;
- if (op >= static_cast<uint8_t>(Operation::MAX_NUM_OPERATIONS)) {
- vespalib::asciistream msg;
- msg << "Unrecognized tensor modify update operation " << static_cast<uint32_t>(op);
- throw DeserializeException(msg.str(), VESPA_STRLOC);
+ _operation = decode_operation(op);
+ if (decode_create_non_existing_cells(op)) {
+ double value;
+ stream >> value;
+ _default_cell_value = value;
}
- _operation = static_cast<Operation>(op);
_tensorType = convertToCompatibleType(dynamic_cast<const TensorDataType &>(type));
auto tensor = _tensorType->createFieldValue();
if (tensor->isA(FieldValue::Type::TENSOR)) {
diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h
index 9386b4f8a1b..931d5102c4f 100644
--- a/document/src/vespa/document/update/tensor_modify_update.h
+++ b/document/src/vespa/document/update/tensor_modify_update.h
@@ -2,6 +2,7 @@
#include "tensor_update.h"
#include "valueupdate.h"
+#include <optional>
namespace vespalib::eval { struct Value; }
@@ -29,12 +30,15 @@ private:
Operation _operation;
std::unique_ptr<const TensorDataType> _tensorType;
std::unique_ptr<TensorFieldValue> _tensor;
+ // When this is set, non-existing cells are created in the input tensor before applying the update.
+ std::optional<double> _default_cell_value;
friend ValueUpdate;
TensorModifyUpdate();
ACCEPT_UPDATE_VISITOR;
public:
TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor);
+ TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor, double default_cell_value);
TensorModifyUpdate(const TensorModifyUpdate &rhs) = delete;
TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs) = delete;
~TensorModifyUpdate() override;
@@ -42,6 +46,7 @@ public:
bool operator==(const ValueUpdate &other) const override;
Operation getOperation() const { return _operation; }
const TensorFieldValue &getTensor() const { return *_tensor; }
+ const std::optional<double>& get_default_cell_value() const { return _default_cell_value; }
void checkCompatibility(const Field &field) const override;
std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const;
std::unique_ptr<Value> apply_to(const Value &tensor,