summaryrefslogtreecommitdiffstats
path: root/document/src/main
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-08-23 15:32:58 +0000
committerGeir Storli <geirst@yahooinc.com>2023-08-23 15:32:58 +0000
commitb9c6ad6890e2c571878982abc9cd3f1bc9426d83 (patch)
tree937fc2fa3545807f16ba1c3c3cd815da30499db6 /document/src/main
parent2dd6924585799a8d1bc5319093871e586b659add (diff)
Extend modify update operation in Java to match the new C++ behavior.
This creates non-existing sub-spaces with default cell value first.
Diffstat (limited to 'document/src/main')
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java71
1 files changed, 71 insertions, 0 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index d9521ee0e1c..835c056868a 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
+import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.Arrays;
+import java.util.HashSet;
import java.util.Objects;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
/*
@@ -21,6 +26,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
protected Operation operation;
protected TensorFieldValue tensor;
+ protected Optional<Double> defaultCellValue = Optional.empty();
public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) {
super(ValueUpdateClassID.TENSORMODIFY);
@@ -48,6 +54,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
public TensorFieldValue getValue() { return tensor; }
public void setValue(TensorFieldValue value) { tensor = value; }
+ public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); }
@Override
public FieldValue applyTo(FieldValue oldValue) {
@@ -63,6 +70,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
+ if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) {
+ var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get());
+ if (!subspaces.isEmpty()) {
+ oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get());
+ }
+ }
Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells());
return new TensorFieldValue(modified);
}
@@ -72,6 +85,64 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
+ private static boolean hasMappedSubtype(TensorType type) {
+ return !type.mappedSubtype().equals(TensorType.empty);
+ }
+
+ private static boolean hasIndexedSubtype(TensorType type) {
+ return !type.indexedSubtype().equals(TensorType.empty);
+ }
+
+ private static HashSet<TensorAddress> findSubspacesNotInInput(Tensor input, Tensor modifier) {
+ var subspaces = new HashSet<TensorAddress>();
+ var inputCells = input.cells();
+ var type = input.type();
+ for (var itr = modifier.cellIterator(); itr.hasNext(); ) {
+ Tensor.Cell cell = itr.next();
+ TensorAddress address = cell.getKey();
+ if (!inputCells.containsKey(address)) {
+ subspaces.add(createSparsePartAddress(address, type));
+ }
+ }
+ return subspaces;
+ }
+
+ private static TensorAddress createSparsePartAddress(TensorAddress address, TensorType type) {
+ var builder = new TensorAddress.Builder(type.mappedSubtype());
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ var dim = type.dimensions().get(i);
+ if (dim.isMapped()) {
+ builder.add(dim.name(), address.label(i));
+ }
+ }
+ return builder.build();
+ }
+
+ private static Tensor insertSubspaces(Tensor input, HashSet<TensorAddress> subspaces, double defaultCellValue) {
+ var type = input.type();
+ boolean mixed = hasMappedSubtype(type) && hasIndexedSubtype(type);
+ Tensor.Builder builder;
+ if (mixed) {
+ var boundBuilder = MixedTensor.BoundBuilder.of(type);
+ var values = new double[(int) boundBuilder.denseSubspaceSize()];
+ Arrays.fill(values, defaultCellValue);
+ for (var subspace : subspaces) {
+ boundBuilder.block(subspace, values);
+ }
+ builder = boundBuilder;
+ } else {
+ builder = Tensor.Builder.of(type);
+ for (var subspace : subspaces) {
+ builder.cell(subspace, defaultCellValue);
+ }
+ }
+ for (var itr = input.cellIterator(); itr.hasNext(); ) {
+ builder.cell(itr.next());
+ }
+ return builder.build();
+ }
+
+
@Override
protected void checkCompatibility(DataType fieldType) {
if (!(fieldType instanceof TensorDataType)) {