diff options
Diffstat (limited to 'document')
14 files changed, 214 insertions, 44 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json index 22c38337e90..899c107a242 100644 --- a/document/abi-spec.json +++ b/document/abi-spec.json @@ -3442,11 +3442,13 @@ ], "methods" : [ "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue)", + "public void <init>(com.yahoo.document.update.TensorModifyUpdate$Operation, com.yahoo.document.datatypes.TensorFieldValue, boolean)", "public static com.yahoo.tensor.TensorType convertDimensionsToMapped(com.yahoo.tensor.TensorType)", "public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()", "public com.yahoo.document.datatypes.TensorFieldValue getValue()", + "public boolean getCreateNonExistingCells()", + "public double getDefaultCellValue()", "public void setValue(com.yahoo.document.datatypes.TensorFieldValue)", - "public void setDefaultCellValue(double)", "public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)", "protected void checkCompatibility(com.yahoo.document.DataType)", "public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)", @@ -3459,7 +3461,7 @@ "fields" : [ "protected com.yahoo.document.update.TensorModifyUpdate$Operation operation", "protected com.yahoo.document.datatypes.TensorFieldValue tensor", - "protected java.util.Optional defaultCellValue" + "protected boolean createNonExistingCells" ] }, "com.yahoo.document.update.TensorRemoveUpdate" : { diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java index 21fa51d5b88..92ede0fbe99 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java @@ -2,6 +2,7 @@ package com.yahoo.document.json.readers; +import com.fasterxml.jackson.core.JsonToken; import com.yahoo.document.Field; import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; @@ -29,6 +30,7 @@ public class TensorModifyUpdateReader { private static final String MODIFY_REPLACE = "replace"; private static final String MODIFY_ADD = "add"; private static final String MODIFY_MULTIPLY = "multiply"; + private static final String MODIFY_CREATE = "create"; public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) { expectFieldIsOfTypeTensor(field); @@ -39,7 +41,7 @@ public class TensorModifyUpdateReader { expectOperationSpecified(result.operation, field.getName()); expectTensorSpecified(result.tensor, field.getName()); - return new TensorModifyUpdate(result.operation, result.tensor); + return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells); } private static void expectFieldIsOfTypeTensor(Field field) { @@ -73,6 +75,7 @@ public class TensorModifyUpdateReader { private static class ModifyUpdateResult { TensorModifyUpdate.Operation operation = null; + boolean createNonExistingCells = false; TensorFieldValue tensor = null; } @@ -85,6 +88,9 @@ public class TensorModifyUpdateReader { case MODIFY_OPERATION: result.operation = createOperation(buffer, field.getName()); break; + case MODIFY_CREATE: + result.createNonExistingCells = createNonExistingCells(buffer, field.getName()); + break; case TENSOR_CELLS: result.tensor = createTensorFromCells(buffer, field); break; @@ -112,6 +118,16 @@ public class TensorModifyUpdateReader { } } + private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) { + if (buffer.current() == JsonToken.VALUE_TRUE) { + return true; + } else if (buffer.current() == JsonToken.VALUE_FALSE) { + return false; + } else { + throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'"); + } + } + private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) { TensorDataType tensorDataType = (TensorDataType)field.getDataType(); TensorType originalType = tensorDataType.getTensorType(); diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java index c6fdc915401..765d999dbc9 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializerHead.java @@ -26,20 +26,35 @@ public class VespaDocumentDeserializerHead extends VespaDocumentDeserializer6 { @Override protected ValueUpdate readTensorModifyUpdate(DataType type) { byte operationId = getByte(null); - TensorModifyUpdate.Operation operation = TensorModifyUpdate.Operation.getOperation(operationId); + var operation = decodeOperation(operationId); if (operation == null) { throw new DeserializationException("Unknown operation id " + operationId + " for tensor modify update"); } if (!(type instanceof TensorDataType)) { throw new DeserializationException("Expected tensor data type, got " + type); } + var createNonExistingCells = decodeCreateNonExistingCells(operationId); + if (createNonExistingCells) { + // Read the default cell value (but it is not used by TensorModifyUpdate). + getDouble(null); + } TensorDataType tensorDataType = (TensorDataType)type; TensorType tensorType = tensorDataType.getTensorType(); TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(tensorType); TensorFieldValue tensor = new TensorFieldValue(convertedType); tensor.deserialize(this); - return new TensorModifyUpdate(operation, tensor); + return new TensorModifyUpdate(operation, tensor, createNonExistingCells); + } + + private TensorModifyUpdate.Operation decodeOperation(byte operationId) { + byte OP_MASK = 0b01111111; + return TensorModifyUpdate.Operation.getOperation(operationId & OP_MASK); + } + + private boolean decodeCreateNonExistingCells(byte operationId) { + byte CREATE_FLAG = -0b10000000; + return (operationId & CREATE_FLAG) != 0; } @Override diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java index 66bc8cbb4d5..b2c3cdc09de 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentSerializerHead.java @@ -19,10 +19,22 @@ public class VespaDocumentSerializerHead extends VespaDocumentSerializer6 { @Override public void write(TensorModifyUpdate update) { - putByte(null, (byte) update.getOperation().id); + putByte(null, (byte) encodeOperationId(update)); + if (update.getCreateNonExistingCells()) { + putDouble(null, update.getDefaultCellValue()); + } update.getValue().serialize(this); } + private int encodeOperationId(TensorModifyUpdate update) { + int operationId = update.getOperation().id; + byte CREATE_FLAG = -0b10000000; + if (update.getCreateNonExistingCells()) { + operationId |= CREATE_FLAG; + } + return operationId; + } + @Override public void write(TensorAddUpdate update) { update.getValue().serialize(this); diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 835c056868a..8a14bd21443 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -26,12 +26,17 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { protected Operation operation; protected TensorFieldValue tensor; - protected Optional<Double> defaultCellValue = Optional.empty(); + protected boolean createNonExistingCells; public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) { + this(operation, tensor, false); + } + + public TensorModifyUpdate(Operation operation, TensorFieldValue tensor, boolean createNonExistingCells) { super(ValueUpdateClassID.TENSORMODIFY); this.operation = operation; this.tensor = tensor; + this.createNonExistingCells = createNonExistingCells; verifyCompatibleType(tensor.getDataType().getTensorType()); } @@ -51,10 +56,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { } public Operation getOperation() { return operation; } - public TensorFieldValue getValue() { return tensor; } + public boolean getCreateNonExistingCells() { return createNonExistingCells; } + public double getDefaultCellValue() { + return (operation == Operation.MULTIPLY) ? 1.0 : 0.0; + } public void setValue(TensorFieldValue value) { tensor = value; } - public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); } @Override public FieldValue applyTo(FieldValue oldValue) { @@ -70,10 +77,10 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { default: throw new UnsupportedOperationException("Unknown operation: " + operation); } - if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) { + if (createNonExistingCells && hasMappedSubtype(oldTensor.type())) { var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get()); if (!subspaces.isEmpty()) { - oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get()); + oldTensor = insertSubspaces(oldTensor, subspaces, getDefaultCellValue()); } } Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells()); @@ -142,7 +149,6 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { return builder.build(); } - @Override protected void checkCompatibility(DataType fieldType) { if (!(fieldType instanceof TensorDataType)) { @@ -162,17 +168,18 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { if (!super.equals(o)) return false; TensorModifyUpdate that = (TensorModifyUpdate) o; return operation == that.operation && - tensor.equals(that.tensor); + tensor.equals(that.tensor) && + createNonExistingCells == that.createNonExistingCells; } @Override public int hashCode() { - return Objects.hash(super.hashCode(), operation, tensor); + return Objects.hash(super.hashCode(), operation, tensor, createNonExistingCells); } @Override public String toString() { - return super.toString() + " " + operation.name + " " + tensor; + return super.toString() + " " + operation.name + " " + tensor + " " + createNonExistingCells; } /** diff --git a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java index 9733cd41a88..9d4d1e8f3aa 100644 --- a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java +++ b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java @@ -822,7 +822,10 @@ public class DocumentUpdateTestCase { result.addFieldUpdate(FieldUpdate.create(getField("dense_tensor")) .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, createTensor())) .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.ADD, createTensor())) - .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor()))); + .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor())) + .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.REPLACE, createTensor(), true)) + .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.ADD, createTensor(), true)) + .addValueUpdate(new TensorModifyUpdate(TensorModifyUpdate.Operation.MULTIPLY, createTensor(), true))); return result; } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 96b5d2c1fb5..4140a9eee02 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -1681,6 +1681,26 @@ public class JsonReaderTestCase { } @Test + public void tensor_modify_update_with_create_non_existing_cells_true() { + assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, true, "sparse_tensor", + inputJson("{", + " 'operation': 'add',", + " 'create': true,", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + } + + @Test + public void tensor_modify_update_with_create_non_existing_cells_false() { + assertTensorModifyUpdate("{{x:a,y:b}:2.0}", TensorModifyUpdate.Operation.ADD, false, "sparse_tensor", + inputJson("{", + " 'operation': 'add',", + " 'create': false,", + " 'cells': [", + " { 'address': { 'x': 'a', 'y': 'b' }, 'value': 2.0 } ]}")); + } + + @Test public void tensor_modify_update_treats_the_input_tensor_as_sparse() { // Note that the type of the tensor in the modify update is sparse (it only has mapped dimensions). assertTensorModifyUpdate("tensor(x{},y{}):{{x:0,y:0}:2.0, {x:1,y:2}:3.0}", @@ -2155,16 +2175,25 @@ public class JsonReaderTestCase { private void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation, String tensorFieldName, String modifyJson) { - assertTensorModifyUpdate(expectedTensor, expectedOperation, tensorFieldName, + assertTensorModifyUpdate(expectedTensor, expectedOperation, false, tensorFieldName, + createTensorModifyUpdate(modifyJson, tensorFieldName)); + } + + private void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation, + boolean expectedCreateNonExistingCells, + String tensorFieldName, String modifyJson) { + assertTensorModifyUpdate(expectedTensor, expectedOperation, expectedCreateNonExistingCells, tensorFieldName, createTensorModifyUpdate(modifyJson, tensorFieldName)); } private static void assertTensorModifyUpdate(String expectedTensor, TensorModifyUpdate.Operation expectedOperation, + boolean expectedCreateNonExistingCells, String tensorFieldName, DocumentUpdate update) { assertTensorFieldUpdate(update, tensorFieldName); TensorModifyUpdate modifyUpdate = (TensorModifyUpdate) update.getFieldUpdate(tensorFieldName).getValueUpdate(0); assertEquals(expectedOperation, modifyUpdate.getOperation()); assertEquals(Tensor.from(expectedTensor), modifyUpdate.getValue().getTensor().get()); + assertEquals(expectedCreateNonExistingCells, modifyUpdate.getCreateNonExistingCells()); } private DocumentUpdate createTensorModifyUpdate(String modifyJson, String tensorFieldName) { diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 55b9090cce8..d0b04cf5449 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -51,35 +51,32 @@ public class TensorModifyUpdateTest { @Test public void apply_modify_update_operations_with_default_cell_value() { - assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0), - "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:5}"); + assertApplyTo("tensor(x{})", "tensor(x{})", Operation.MULTIPLY, true, + "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:6}"); - assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0), - "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:5,{x:c}:4}"); + assertApplyTo("tensor(x{})", "tensor(x{})", Operation.MULTIPLY, true, + "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:6,{x:c}:4}"); - assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, Optional.of(1.0), + assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, true, "{{x:a,y:0}:3,{x:a,y:1}:4,{x:a,y:2}:5}", "{{x:a,y:0}:6,{x:b,y:1}:7,{x:b,y:2}:8,{x:c,y:0}:9}", "{{x:a,y:0}:9,{x:a,y:1}:4,{x:a,y:2}:5," + - "{x:b,y:0}:1,{x:b,y:1}:8,{x:b,y:2}:9," + - "{x:c,y:0}:10,{x:c,y:1}:1,{x:c,y:2}:1}"); + "{x:b,y:0}:0,{x:b,y:1}:7,{x:b,y:2}:8," + + "{x:c,y:0}:9,{x:c,y:1}:0,{x:c,y:2}:0}"); - // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions, - // as the dense subspace is always represented (with default cell value 0.0). - assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.ADD, Optional.of(2.0), - "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:3,{x:2}:0}"); + // NOTE: The default cell value (1.0) used for MULTIPLY operation doesn't have any effect for tensors + // with only indexed dimensions, as the dense subspace is always represented (with default cell value 0.0). + assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.MULTIPLY, true, + "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:0,{x:2}:0}"); } private void assertApplyTo(String spec, Operation op, String input, String update, String expected) { - assertApplyTo(spec, "tensor(x{},y{})", op, Optional.empty(), input, update, expected); + assertApplyTo(spec, "tensor(x{},y{})", op, false, input, update, expected); } - private void assertApplyTo(String inputSpec, String updateSpec, Operation op, Optional<Double> defaultCellValue, String input, String update, String expected) { + private void assertApplyTo(String inputSpec, String updateSpec, Operation op, boolean createNonExistingCells, String input, String update, String expected) { TensorFieldValue inputFieldValue = new TensorFieldValue(Tensor.from(inputSpec, input)); - TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update))); - if (defaultCellValue.isPresent()) { - modifyUpdate.setDefaultCellValue(defaultCellValue.get()); - } + TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update)), createNonExistingCells); TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(inputFieldValue); assertEquals(Tensor.from(inputSpec, expected), updatedFieldValue.getTensor().get()); } diff --git a/document/src/tests/data/serialize-tensor-update-cpp.dat b/document/src/tests/data/serialize-tensor-update-cpp.dat Binary files differindex ad0e9d706b0..d6b6b5e2506 100644 --- a/document/src/tests/data/serialize-tensor-update-cpp.dat +++ b/document/src/tests/data/serialize-tensor-update-cpp.dat diff --git a/document/src/tests/data/serialize-tensor-update-java.dat b/document/src/tests/data/serialize-tensor-update-java.dat Binary files differindex ad0e9d706b0..d6b6b5e2506 100644 --- a/document/src/tests/data/serialize-tensor-update-java.dat +++ b/document/src/tests/data/serialize-tensor-update-java.dat diff --git a/document/src/tests/documentupdatetestcase.cpp b/document/src/tests/documentupdatetestcase.cpp index 3fbccaa155f..b225ca6677b 100644 --- a/document/src/tests/documentupdatetestcase.cpp +++ b/document/src/tests/documentupdatetestcase.cpp @@ -1024,6 +1024,20 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied) .add({{"x", "b"}}, 15)); } +TEST(DocumentUpdateTest, tensor_modify_update_with_create_non_existing_cells_can_be_applied) +{ + TensorUpdateFixture f; + auto baseLine = f.spec().add({{"x", "a"}}, 2) + .add({{"x", "b"}}, 3); + + f.assertApplyUpdate(baseLine, + std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, + f.makeTensor(f.spec().add({{"x", "b"}}, 5).add({{"x", "c"}}, 6)), 0.0), + f.spec().add({{"x", "a"}}, 2) + .add({{"x", "b"}}, 8) + .add({{"x", "c"}}, 6)); +} + TEST(DocumentUpdateTest, tensor_modify_update_can_be_applied_to_nonexisting_tensor) { TensorUpdateFixture f; @@ -1069,6 +1083,9 @@ TEST(DocumentUpdateTest, tensor_modify_update_can_be_roundtrip_serialized) f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor())); f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor())); f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor())); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor(), 0.0)); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor(), 0.0)); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor(), 1.0)); } TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_serialized) @@ -1077,6 +1094,9 @@ TEST(DocumentUpdateTest, tensor_modify_update_on_float_tensor_can_be_roundtrip_s f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor())); f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor())); f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor())); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::REPLACE, f.makeBaselineTensor(), 0.0)); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::ADD, f.makeBaselineTensor(), 0.0)); + f.assertRoundtripSerialize(TensorModifyUpdate(TensorModifyUpdate::Operation::MULTIPLY, f.makeBaselineTensor(), 1.0)); } TEST(DocumentUpdateTest, tensor_modify_update_on_dense_tensor_can_be_roundtrip_serialized) @@ -1170,7 +1190,10 @@ struct TensorUpdateSerializeFixture { result->addUpdate(FieldUpdate(getField("dense_tensor")) .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, makeTensor())) .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, makeTensor())) - .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor()))); + .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor())) + .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::REPLACE, makeTensor(), 0.0)) + .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::ADD, makeTensor(), 0.0)) + .addUpdate(std::make_unique<TensorModifyUpdate>(TensorModifyUpdate::Operation::MULTIPLY, makeTensor(), 1.0))); return result; } diff --git a/document/src/vespa/document/serialization/vespadocumentserializer.cpp b/document/src/vespa/document/serialization/vespadocumentserializer.cpp index 2a2b642dd48..c0b56150c04 100644 --- a/document/src/vespa/document/serialization/vespadocumentserializer.cpp +++ b/document/src/vespa/document/serialization/vespadocumentserializer.cpp @@ -476,11 +476,29 @@ VespaDocumentSerializer::write(const RemoveFieldPathUpdate &value) writeFieldPath(_stream, value); } +namespace { + +uint8_t +encode_operation_id(const TensorModifyUpdate& update) +{ + uint8_t op = static_cast<uint8_t>(update.getOperation()); + uint8_t CREATE_FLAG = 0b10000000; + if (update.get_default_cell_value().has_value()) { + op |= CREATE_FLAG; + } + return op; +} + +} + void VespaDocumentSerializer::write(const TensorModifyUpdate &value) { _stream << uint32_t(ValueUpdate::TensorModify); - _stream << static_cast<uint8_t>(value.getOperation()); + _stream << encode_operation_id(value); + if (value.get_default_cell_value().has_value()) { + _stream << value.get_default_cell_value().value(); + } write(value.getTensor()); } diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index 92b2a8672c3..ad1e3095269 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -87,7 +87,8 @@ TensorModifyUpdate::TensorModifyUpdate() TensorUpdate(), _operation(Operation::MAX_NUM_OPERATIONS), _tensorType(), - _tensor() + _tensor(), + _default_cell_value() { } @@ -96,7 +97,19 @@ TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<Tens TensorUpdate(), _operation(operation), _tensorType(std::make_unique<TensorDataType>(dynamic_cast<const TensorDataType &>(*tensor->getDataType()))), - _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())) + _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())), + _default_cell_value() +{ + *_tensor = *tensor; +} + +TensorModifyUpdate::TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor, double default_cell_value) + : ValueUpdate(TensorModify), + TensorUpdate(), + _operation(operation), + _tensorType(std::make_unique<TensorDataType>(dynamic_cast<const TensorDataType &>(*tensor->getDataType()))), + _tensor(static_cast<TensorFieldValue *>(_tensorType->createFieldValue().release())), + _default_cell_value(default_cell_value) { *_tensor = *tensor; } @@ -116,6 +129,9 @@ TensorModifyUpdate::operator==(const ValueUpdate &other) const if (*_tensor != *o._tensor) { return false; } + if (_default_cell_value != o._default_cell_value) { + return false; + } return true; } @@ -141,7 +157,11 @@ TensorModifyUpdate::apply_to(const Value &old_tensor, { if (auto cellsTensor = _tensor->getAsTensorPtr()) { auto op = getJoinFunction(_operation); - return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory); + if (_default_cell_value.has_value()) { + return TensorPartialUpdate::modify_with_defaults(old_tensor, op, *cellsTensor, _default_cell_value.value(), factory); + } else { + return TensorPartialUpdate::modify(old_tensor, op, *cellsTensor, factory); + } } return {}; } @@ -179,6 +199,9 @@ TensorModifyUpdate::print(std::ostream& out, bool verbose, const std::string& in if (_tensor) { _tensor->print(out, verbose, indent); } + if (_default_cell_value.has_value()) { + out << "," << _default_cell_value.value(); + } out << ")"; } @@ -198,6 +221,26 @@ verifyCellsTensorIsSparse(const vespalib::eval::Value *cellsTensor) throw IllegalStateException(err, VESPA_STRLOC); } +TensorModifyUpdate::Operation +decode_operation(uint8_t encoded_op) +{ + uint8_t OP_MASK = 0b01111111; + uint8_t op = encoded_op & OP_MASK; + if (op >= static_cast<uint8_t>(TensorModifyUpdate::Operation::MAX_NUM_OPERATIONS)) { + vespalib::asciistream msg; + msg << "Unrecognized tensor modify update operation " << static_cast<uint32_t>(op); + throw DeserializeException(msg.str(), VESPA_STRLOC); + } + return static_cast<TensorModifyUpdate::Operation>(op); +} + +bool +decode_create_non_existing_cells(uint8_t encoded_op) +{ + uint8_t CREATE_FLAG = 0b10000000; + return (encoded_op & CREATE_FLAG) != 0; +} + } void @@ -205,12 +248,12 @@ TensorModifyUpdate::deserialize(const DocumentTypeRepo &repo, const DataType &ty { uint8_t op; stream >> op; - if (op >= static_cast<uint8_t>(Operation::MAX_NUM_OPERATIONS)) { - vespalib::asciistream msg; - msg << "Unrecognized tensor modify update operation " << static_cast<uint32_t>(op); - throw DeserializeException(msg.str(), VESPA_STRLOC); + _operation = decode_operation(op); + if (decode_create_non_existing_cells(op)) { + double value; + stream >> value; + _default_cell_value = value; } - _operation = static_cast<Operation>(op); _tensorType = convertToCompatibleType(dynamic_cast<const TensorDataType &>(type)); auto tensor = _tensorType->createFieldValue(); if (tensor->isA(FieldValue::Type::TENSOR)) { diff --git a/document/src/vespa/document/update/tensor_modify_update.h b/document/src/vespa/document/update/tensor_modify_update.h index 9386b4f8a1b..931d5102c4f 100644 --- a/document/src/vespa/document/update/tensor_modify_update.h +++ b/document/src/vespa/document/update/tensor_modify_update.h @@ -2,6 +2,7 @@ #include "tensor_update.h" #include "valueupdate.h" +#include <optional> namespace vespalib::eval { struct Value; } @@ -29,12 +30,15 @@ private: Operation _operation; std::unique_ptr<const TensorDataType> _tensorType; std::unique_ptr<TensorFieldValue> _tensor; + // When this is set, non-existing cells are created in the input tensor before applying the update. + std::optional<double> _default_cell_value; friend ValueUpdate; TensorModifyUpdate(); ACCEPT_UPDATE_VISITOR; public: TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor); + TensorModifyUpdate(Operation operation, std::unique_ptr<TensorFieldValue> tensor, double default_cell_value); TensorModifyUpdate(const TensorModifyUpdate &rhs) = delete; TensorModifyUpdate &operator=(const TensorModifyUpdate &rhs) = delete; ~TensorModifyUpdate() override; @@ -42,6 +46,7 @@ public: bool operator==(const ValueUpdate &other) const override; Operation getOperation() const { return _operation; } const TensorFieldValue &getTensor() const { return *_tensor; } + const std::optional<double>& get_default_cell_value() const { return _default_cell_value; } void checkCompatibility(const Field &field) const override; std::unique_ptr<vespalib::eval::Value> applyTo(const vespalib::eval::Value &tensor) const; std::unique_ptr<Value> apply_to(const Value &tensor, |