From b9c6ad6890e2c571878982abc9cd3f1bc9426d83 Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Wed, 23 Aug 2023 15:32:58 +0000 Subject: Extend modify update operation in Java to match the new C++ behavior. This creates non-existing sub-spaces with default cell value first. --- .../yahoo/document/update/TensorModifyUpdate.java | 71 ++++++++++++++++++++++ 1 file changed, 71 insertions(+) (limited to 'document/src/main') diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index d9521ee0e1c..835c056868a 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.Arrays; +import java.util.HashSet; import java.util.Objects; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; /* @@ -21,6 +26,7 @@ public class TensorModifyUpdate extends ValueUpdate { protected Operation operation; protected TensorFieldValue tensor; + protected Optional defaultCellValue = Optional.empty(); public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) { super(ValueUpdateClassID.TENSORMODIFY); @@ -48,6 +54,7 @@ public class TensorModifyUpdate extends ValueUpdate { public TensorFieldValue getValue() { return tensor; } public void setValue(TensorFieldValue value) { tensor = value; } + public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); } @Override public FieldValue applyTo(FieldValue oldValue) { @@ -63,6 +70,12 @@ public class TensorModifyUpdate extends ValueUpdate { default: throw new UnsupportedOperationException("Unknown operation: " + operation); } + if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) { + var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get()); + if (!subspaces.isEmpty()) { + oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get()); + } + } Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells()); return new TensorFieldValue(modified); } @@ -72,6 +85,64 @@ public class TensorModifyUpdate extends ValueUpdate { return oldValue; } + private static boolean hasMappedSubtype(TensorType type) { + return !type.mappedSubtype().equals(TensorType.empty); + } + + private static boolean hasIndexedSubtype(TensorType type) { + return !type.indexedSubtype().equals(TensorType.empty); + } + + private static HashSet findSubspacesNotInInput(Tensor input, Tensor modifier) { + var subspaces = new HashSet(); + var inputCells = input.cells(); + var type = input.type(); + for (var itr = modifier.cellIterator(); itr.hasNext(); ) { + Tensor.Cell cell = itr.next(); + TensorAddress address = cell.getKey(); + if (!inputCells.containsKey(address)) { + subspaces.add(createSparsePartAddress(address, type)); + } + } + return subspaces; + } + + private static TensorAddress createSparsePartAddress(TensorAddress address, TensorType type) { + var builder = new TensorAddress.Builder(type.mappedSubtype()); + for (int i = 0; i < type.dimensions().size(); ++i) { + var dim = type.dimensions().get(i); + if (dim.isMapped()) { + builder.add(dim.name(), address.label(i)); + } + } + return builder.build(); + } + + private static Tensor insertSubspaces(Tensor input, HashSet subspaces, double defaultCellValue) { + var type = input.type(); + boolean mixed = hasMappedSubtype(type) && hasIndexedSubtype(type); + Tensor.Builder builder; + if (mixed) { + var boundBuilder = MixedTensor.BoundBuilder.of(type); + var values = new double[(int) boundBuilder.denseSubspaceSize()]; + Arrays.fill(values, defaultCellValue); + for (var subspace : subspaces) { + boundBuilder.block(subspace, values); + } + builder = boundBuilder; + } else { + builder = Tensor.Builder.of(type); + for (var subspace : subspaces) { + builder.cell(subspace, defaultCellValue); + } + } + for (var itr = input.cellIterator(); itr.hasNext(); ) { + builder.cell(itr.next()); + } + return builder.build(); + } + + @Override protected void checkCompatibility(DataType fieldType) { if (!(fieldType instanceof TensorDataType)) { -- cgit v1.2.3