diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-08-23 15:32:58 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2023-08-23 15:32:58 +0000 |
commit | b9c6ad6890e2c571878982abc9cd3f1bc9426d83 (patch) | |
tree | 937fc2fa3545807f16ba1c3c3cd815da30499db6 /document/src/main | |
parent | 2dd6924585799a8d1bc5319093871e586b659add (diff) |
Extend modify update operation in Java to match the new C++ behavior.
This creates non-existing sub-spaces with default cell value first.
Diffstat (limited to 'document/src/main')
-rw-r--r-- | document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index d9521ee0e1c..835c056868a 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; +import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import java.util.Arrays; +import java.util.HashSet; import java.util.Objects; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; /* @@ -21,6 +26,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { protected Operation operation; protected TensorFieldValue tensor; + protected Optional<Double> defaultCellValue = Optional.empty(); public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) { super(ValueUpdateClassID.TENSORMODIFY); @@ -48,6 +54,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { public TensorFieldValue getValue() { return tensor; } public void setValue(TensorFieldValue value) { tensor = value; } + public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); } @Override public FieldValue applyTo(FieldValue oldValue) { @@ -63,6 +70,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { default: throw new UnsupportedOperationException("Unknown operation: " + operation); } + if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) { + var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get()); + if (!subspaces.isEmpty()) { + oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get()); + } + } Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells()); return new TensorFieldValue(modified); } @@ -72,6 +85,64 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { return oldValue; } + private static boolean hasMappedSubtype(TensorType type) { + return !type.mappedSubtype().equals(TensorType.empty); + } + + private static boolean hasIndexedSubtype(TensorType type) { + return !type.indexedSubtype().equals(TensorType.empty); + } + + private static HashSet<TensorAddress> findSubspacesNotInInput(Tensor input, Tensor modifier) { + var subspaces = new HashSet<TensorAddress>(); + var inputCells = input.cells(); + var type = input.type(); + for (var itr = modifier.cellIterator(); itr.hasNext(); ) { + Tensor.Cell cell = itr.next(); + TensorAddress address = cell.getKey(); + if (!inputCells.containsKey(address)) { + subspaces.add(createSparsePartAddress(address, type)); + } + } + return subspaces; + } + + private static TensorAddress createSparsePartAddress(TensorAddress address, TensorType type) { + var builder = new TensorAddress.Builder(type.mappedSubtype()); + for (int i = 0; i < type.dimensions().size(); ++i) { + var dim = type.dimensions().get(i); + if (dim.isMapped()) { + builder.add(dim.name(), address.label(i)); + } + } + return builder.build(); + } + + private static Tensor insertSubspaces(Tensor input, HashSet<TensorAddress> subspaces, double defaultCellValue) { + var type = input.type(); + boolean mixed = hasMappedSubtype(type) && hasIndexedSubtype(type); + Tensor.Builder builder; + if (mixed) { + var boundBuilder = MixedTensor.BoundBuilder.of(type); + var values = new double[(int) boundBuilder.denseSubspaceSize()]; + Arrays.fill(values, defaultCellValue); + for (var subspace : subspaces) { + boundBuilder.block(subspace, values); + } + builder = boundBuilder; + } else { + builder = Tensor.Builder.of(type); + for (var subspace : subspaces) { + builder.cell(subspace, defaultCellValue); + } + } + for (var itr = input.cellIterator(); itr.hasNext(); ) { + builder.cell(itr.next()); + } + return builder.build(); + } + + @Override protected void checkCompatibility(DataType fieldType) { if (!(fieldType instanceof TensorDataType)) { |