diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-08-23 15:32:58 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2023-08-23 15:32:58 +0000 |
commit | b9c6ad6890e2c571878982abc9cd3f1bc9426d83 (patch) | |
tree | 937fc2fa3545807f16ba1c3c3cd815da30499db6 /document/src | |
parent | 2dd6924585799a8d1bc5319093871e586b659add (diff) |
Extend modify update operation in Java to match the new C++ behavior.
This creates non-existing sub-spaces with default cell value first.
Diffstat (limited to 'document/src')
4 files changed, 121 insertions, 12 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index d9521ee0e1c..835c056868a 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.Arrays; +import java.util.HashSet; import java.util.Objects; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; /* @@ -21,6 +26,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { protected Operation operation; protected TensorFieldValue tensor; + protected Optional<Double> defaultCellValue = Optional.empty(); public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) { super(ValueUpdateClassID.TENSORMODIFY); @@ -48,6 +54,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { public TensorFieldValue getValue() { return tensor; } public void setValue(TensorFieldValue value) { tensor = value; } + public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); } @Override public FieldValue applyTo(FieldValue oldValue) { @@ -63,6 +70,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { default: throw new UnsupportedOperationException("Unknown operation: " + operation); } + if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) { + var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get()); + if (!subspaces.isEmpty()) { + oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get()); + } + } Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells()); return new TensorFieldValue(modified); } @@ -72,6 +85,64 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } + private static boolean hasMappedSubtype(TensorType type) { + return !type.mappedSubtype().equals(TensorType.empty); + } + + private static boolean hasIndexedSubtype(TensorType type) { + return !type.indexedSubtype().equals(TensorType.empty); + } + + private static HashSet<TensorAddress> findSubspacesNotInInput(Tensor input, Tensor modifier) { + var subspaces = new HashSet<TensorAddress>(); + var inputCells = input.cells(); + var type = input.type(); + for (var itr = modifier.cellIterator(); itr.hasNext(); ) { + Tensor.Cell cell = itr.next(); + TensorAddress address = cell.getKey(); + if (!inputCells.containsKey(address)) { + subspaces.add(createSparsePartAddress(address, type)); + } + } + return subspaces; + } + + private static TensorAddress createSparsePartAddress(TensorAddress address, TensorType type) { + var builder = new TensorAddress.Builder(type.mappedSubtype()); + for (int i = 0; i < type.dimensions().size(); ++i) { + var dim = type.dimensions().get(i); + if (dim.isMapped()) { + builder.add(dim.name(), address.label(i)); + } + } + return builder.build(); + } + + private static Tensor insertSubspaces(Tensor input, HashSet<TensorAddress> subspaces, double defaultCellValue) { + var type = input.type(); + boolean mixed = hasMappedSubtype(type) && hasIndexedSubtype(type); + Tensor.Builder builder; + if (mixed) { + var boundBuilder = MixedTensor.BoundBuilder.of(type); + var values = new double[(int) boundBuilder.denseSubspaceSize()]; + Arrays.fill(values, defaultCellValue); + for (var subspace : subspaces) { + boundBuilder.block(subspace, values); + } + builder = boundBuilder; + } else { + builder = Tensor.Builder.of(type); + for (var subspace : subspaces) { + builder.cell(subspace, defaultCellValue); + } + } + for (var itr = input.cellIterator(); itr.hasNext(); ) { + builder.cell(itr.next()); + } + return builder.build(); + } + + @Override protected void checkCompatibility(DataType fieldType) { if (!(fieldType instanceof TensorDataType)) { diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java index 60dd5ad1d0d..55b9090cce8 100644 --- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java +++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java @@ -7,6 +7,8 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.util.Optional; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; @@ -47,11 +49,38 @@ public class TensorModifyUpdateTest { "{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}"); } - private void assertApplyTo(String spec, Operation op, String init, String update, String expected) { - TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init)); - TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from("tensor(x{},y{})", update))); - TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(initialFieldValue); - assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get()); + @Test + public void apply_modify_update_operations_with_default_cell_value() { + assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0), + "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:5}"); + + assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0), + "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:5,{x:c}:4}"); + + assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, Optional.of(1.0), + "{{x:a,y:0}:3,{x:a,y:1}:4,{x:a,y:2}:5}", + "{{x:a,y:0}:6,{x:b,y:1}:7,{x:b,y:2}:8,{x:c,y:0}:9}", + "{{x:a,y:0}:9,{x:a,y:1}:4,{x:a,y:2}:5," + + "{x:b,y:0}:1,{x:b,y:1}:8,{x:b,y:2}:9," + + "{x:c,y:0}:10,{x:c,y:1}:1,{x:c,y:2}:1}"); + + // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions, + // as the dense subspace is always represented (with default cell value 0.0). + assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.ADD, Optional.of(2.0), + "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:3,{x:2}:0}"); } + private void assertApplyTo(String spec, Operation op, String input, String update, String expected) { + assertApplyTo(spec, "tensor(x{},y{})", op, Optional.empty(), input, update, expected); + } + + private void assertApplyTo(String inputSpec, String updateSpec, Operation op, Optional<Double> defaultCellValue, String input, String update, String expected) { + TensorFieldValue inputFieldValue = new TensorFieldValue(Tensor.from(inputSpec, input)); + TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update))); + if (defaultCellValue.isPresent()) { + modifyUpdate.setDefaultCellValue(defaultCellValue.get()); + } + TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(inputFieldValue); + assertEquals(Tensor.from(inputSpec, expected), updatedFieldValue.getTensor().get()); + } } diff --git a/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp b/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp index bf0f893b901..bf780dba5d3 100644 --- a/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp +++ b/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp @@ -129,6 +129,12 @@ TEST(PartialModifyTest, partial_modify_with_defauls) { "tensor(x{},y[3]):{{x:\"a\",y:0}:9,{x:\"a\",y:1}:4,{x:\"a\",y:2}:5," "{x:\"b\",y:0}:1,{x:\"b\",y:1}:8,{x:\"b\",y:2}:9," "{x:\"c\",y:0}:10,{x:\"c\",y:1}:1,{x:\"c\",y:2}:1}"); + + // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions, + // as the dense subspace is always represented (with default cell value 0.0). + expect_modify_with_defaults("tensor(x[3]):{{x:0}:2}", "tensor(x{}):{{x:\"1\"}:3}", + operation::Add::f, 2.0, + "tensor(x[3]):{{x:0}:2,{x:1}:3,{x:2}:0}"); } std::vector<std::pair<vespalib::string,vespalib::string>> bad_layouts = { diff --git a/document/src/vespa/document/update/tensor_partial_update.cpp b/document/src/vespa/document/update/tensor_partial_update.cpp index 72bcc044977..e37e5750384 100644 --- a/document/src/vespa/document/update/tensor_partial_update.cpp +++ b/document/src/vespa/document/update/tensor_partial_update.cpp @@ -475,17 +475,20 @@ Value::UP TensorPartialUpdate::modify_with_defaults(const Value& input, join_fun_t function, const Value& modifier, double default_cell_value, const ValueBuilderFactory& factory) { - AddressHandler handler(input.type(), modifier.type()); + const auto& input_type = input.type(); + AddressHandler handler(input_type, modifier.type()); if (!handler.valid) { return {}; } - const size_t dsss = input.type().dense_subspace_size(); - ArrayArrayMap<string_id, double> sub_spaces(handler.for_output.addr.size(), dsss, modifier.index().size()); - find_sub_spaces_not_in_input(input, modifier, default_cell_value, handler, sub_spaces); Value::UP output; - if (sub_spaces.size() > 0) { - output = typify_invoke<1, TypifyCellType, PerformInsertSubspaces>( - input.cells().type, input, handler.for_output, sub_spaces, factory); + if (!input_type.is_dense()) { + const size_t dsss = input_type.dense_subspace_size(); + ArrayArrayMap<string_id, double> sub_spaces(handler.for_output.addr.size(), dsss, modifier.index().size()); + find_sub_spaces_not_in_input(input, modifier, default_cell_value, handler, sub_spaces); + if (sub_spaces.size() > 0) { + output = typify_invoke<1, TypifyCellType, PerformInsertSubspaces>( + input.cells().type, input, handler.for_output, sub_spaces, factory); + } } return typify_invoke<2, TypifyCellType, PerformModify>( input.cells().type, modifier.cells().type, |