summaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-08-23 15:32:58 +0000
committerGeir Storli <geirst@yahooinc.com>2023-08-23 15:32:58 +0000
commitb9c6ad6890e2c571878982abc9cd3f1bc9426d83 (patch)
tree937fc2fa3545807f16ba1c3c3cd815da30499db6 /document
parent2dd6924585799a8d1bc5319093871e586b659add (diff)
Extend modify update operation in Java to match the new C++ behavior.
This creates non-existing sub-spaces with default cell value first.
Diffstat (limited to 'document')
-rw-r--r--document/abi-spec.json4
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java71
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java39
-rw-r--r--document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp6
-rw-r--r--document/src/vespa/document/update/tensor_partial_update.cpp17
5 files changed, 124 insertions, 13 deletions
diff --git a/document/abi-spec.json b/document/abi-spec.json
index 6129ea991d5..22c38337e90 100644
--- a/document/abi-spec.json
+++ b/document/abi-spec.json
@@ -3446,6 +3446,7 @@
"public com.yahoo.document.update.TensorModifyUpdate$Operation getOperation()",
"public com.yahoo.document.datatypes.TensorFieldValue getValue()",
"public void setValue(com.yahoo.document.datatypes.TensorFieldValue)",
+ "public void setDefaultCellValue(double)",
"public com.yahoo.document.datatypes.FieldValue applyTo(com.yahoo.document.datatypes.FieldValue)",
"protected void checkCompatibility(com.yahoo.document.DataType)",
"public void serialize(com.yahoo.document.serialization.DocumentUpdateWriter, com.yahoo.document.DataType)",
@@ -3457,7 +3458,8 @@
],
"fields" : [
"protected com.yahoo.document.update.TensorModifyUpdate$Operation operation",
- "protected com.yahoo.document.datatypes.TensorFieldValue tensor"
+ "protected com.yahoo.document.datatypes.TensorFieldValue tensor",
+ "protected java.util.Optional defaultCellValue"
]
},
"com.yahoo.document.update.TensorRemoveUpdate" : {
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index d9521ee0e1c..835c056868a 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -6,10 +6,15 @@ import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
+import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.Arrays;
+import java.util.HashSet;
import java.util.Objects;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
/*
@@ -21,6 +26,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
protected Operation operation;
protected TensorFieldValue tensor;
+ protected Optional<Double> defaultCellValue = Optional.empty();
public TensorModifyUpdate(Operation operation, TensorFieldValue tensor) {
super(ValueUpdateClassID.TENSORMODIFY);
@@ -48,6 +54,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
public TensorFieldValue getValue() { return tensor; }
public void setValue(TensorFieldValue value) { tensor = value; }
+ public void setDefaultCellValue(double value) { defaultCellValue = Optional.of(value); }
@Override
public FieldValue applyTo(FieldValue oldValue) {
@@ -63,6 +70,12 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
default:
throw new UnsupportedOperationException("Unknown operation: " + operation);
}
+ if (defaultCellValue.isPresent() && hasMappedSubtype(oldTensor.type())) {
+ var subspaces = findSubspacesNotInInput(oldTensor, tensor.getTensor().get());
+ if (!subspaces.isEmpty()) {
+ oldTensor = insertSubspaces(oldTensor, subspaces, defaultCellValue.get());
+ }
+ }
Tensor modified = oldTensor.modify(modifier, tensor.getTensor().get().cells());
return new TensorFieldValue(modified);
}
@@ -72,6 +85,64 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
return oldValue;
}
+ private static boolean hasMappedSubtype(TensorType type) {
+ return !type.mappedSubtype().equals(TensorType.empty);
+ }
+
+ private static boolean hasIndexedSubtype(TensorType type) {
+ return !type.indexedSubtype().equals(TensorType.empty);
+ }
+
+ private static HashSet<TensorAddress> findSubspacesNotInInput(Tensor input, Tensor modifier) {
+ var subspaces = new HashSet<TensorAddress>();
+ var inputCells = input.cells();
+ var type = input.type();
+ for (var itr = modifier.cellIterator(); itr.hasNext(); ) {
+ Tensor.Cell cell = itr.next();
+ TensorAddress address = cell.getKey();
+ if (!inputCells.containsKey(address)) {
+ subspaces.add(createSparsePartAddress(address, type));
+ }
+ }
+ return subspaces;
+ }
+
+ private static TensorAddress createSparsePartAddress(TensorAddress address, TensorType type) {
+ var builder = new TensorAddress.Builder(type.mappedSubtype());
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ var dim = type.dimensions().get(i);
+ if (dim.isMapped()) {
+ builder.add(dim.name(), address.label(i));
+ }
+ }
+ return builder.build();
+ }
+
+ private static Tensor insertSubspaces(Tensor input, HashSet<TensorAddress> subspaces, double defaultCellValue) {
+ var type = input.type();
+ boolean mixed = hasMappedSubtype(type) && hasIndexedSubtype(type);
+ Tensor.Builder builder;
+ if (mixed) {
+ var boundBuilder = MixedTensor.BoundBuilder.of(type);
+ var values = new double[(int) boundBuilder.denseSubspaceSize()];
+ Arrays.fill(values, defaultCellValue);
+ for (var subspace : subspaces) {
+ boundBuilder.block(subspace, values);
+ }
+ builder = boundBuilder;
+ } else {
+ builder = Tensor.Builder.of(type);
+ for (var subspace : subspaces) {
+ builder.cell(subspace, defaultCellValue);
+ }
+ }
+ for (var itr = input.cellIterator(); itr.hasNext(); ) {
+ builder.cell(itr.next());
+ }
+ return builder.build();
+ }
+
+
@Override
protected void checkCompatibility(DataType fieldType) {
if (!(fieldType instanceof TensorDataType)) {
diff --git a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
index 60dd5ad1d0d..55b9090cce8 100644
--- a/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
+++ b/document/src/test/java/com/yahoo/document/update/TensorModifyUpdateTest.java
@@ -7,6 +7,8 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.util.Optional;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@@ -47,11 +49,38 @@ public class TensorModifyUpdateTest {
"{{x:0,y:0}:3, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:3,{x:0,y:1}:6}");
}
- private void assertApplyTo(String spec, Operation op, String init, String update, String expected) {
- TensorFieldValue initialFieldValue = new TensorFieldValue(Tensor.from(spec, init));
- TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from("tensor(x{},y{})", update)));
- TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(initialFieldValue);
- assertEquals(Tensor.from(spec, expected), updatedFieldValue.getTensor().get());
+ @Test
+ public void apply_modify_update_operations_with_default_cell_value() {
+ assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0),
+ "{{x:a}:1,{x:b}:2}", "{{x:b}:3}", "{{x:a}:1,{x:b}:5}");
+
+ assertApplyTo("tensor(x{})", "tensor(x{})", Operation.ADD, Optional.of(0.0),
+ "{{x:a}:1,{x:b}:2}", "{{x:b}:3,{x:c}:4}", "{{x:a}:1,{x:b}:5,{x:c}:4}");
+
+ assertApplyTo("tensor(x{},y[3])", "tensor(x{},y{})", Operation.ADD, Optional.of(1.0),
+ "{{x:a,y:0}:3,{x:a,y:1}:4,{x:a,y:2}:5}",
+ "{{x:a,y:0}:6,{x:b,y:1}:7,{x:b,y:2}:8,{x:c,y:0}:9}",
+ "{{x:a,y:0}:9,{x:a,y:1}:4,{x:a,y:2}:5," +
+ "{x:b,y:0}:1,{x:b,y:1}:8,{x:b,y:2}:9," +
+ "{x:c,y:0}:10,{x:c,y:1}:1,{x:c,y:2}:1}");
+
+ // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions,
+ // as the dense subspace is always represented (with default cell value 0.0).
+ assertApplyTo("tensor(x[3])", "tensor(x{})", Operation.ADD, Optional.of(2.0),
+ "{{x:0}:2}", "{{x:1}:3}", "{{x:0}:2,{x:1}:3,{x:2}:0}");
}
+ private void assertApplyTo(String spec, Operation op, String input, String update, String expected) {
+ assertApplyTo(spec, "tensor(x{},y{})", op, Optional.empty(), input, update, expected);
+ }
+
+ private void assertApplyTo(String inputSpec, String updateSpec, Operation op, Optional<Double> defaultCellValue, String input, String update, String expected) {
+ TensorFieldValue inputFieldValue = new TensorFieldValue(Tensor.from(inputSpec, input));
+ TensorModifyUpdate modifyUpdate = new TensorModifyUpdate(op, new TensorFieldValue(Tensor.from(updateSpec, update)));
+ if (defaultCellValue.isPresent()) {
+ modifyUpdate.setDefaultCellValue(defaultCellValue.get());
+ }
+ TensorFieldValue updatedFieldValue = (TensorFieldValue) modifyUpdate.applyTo(inputFieldValue);
+ assertEquals(Tensor.from(inputSpec, expected), updatedFieldValue.getTensor().get());
+ }
}
diff --git a/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp b/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp
index bf0f893b901..bf780dba5d3 100644
--- a/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp
+++ b/document/src/tests/tensor_fieldvalue/partial_modify/partial_modify_test.cpp
@@ -129,6 +129,12 @@ TEST(PartialModifyTest, partial_modify_with_defauls) {
"tensor(x{},y[3]):{{x:\"a\",y:0}:9,{x:\"a\",y:1}:4,{x:\"a\",y:2}:5,"
"{x:\"b\",y:0}:1,{x:\"b\",y:1}:8,{x:\"b\",y:2}:9,"
"{x:\"c\",y:0}:10,{x:\"c\",y:1}:1,{x:\"c\",y:2}:1}");
+
+ // NOTE: The specified default cell value doesn't have any effect for tensors with only indexed dimensions,
+ // as the dense subspace is always represented (with default cell value 0.0).
+ expect_modify_with_defaults("tensor(x[3]):{{x:0}:2}", "tensor(x{}):{{x:\"1\"}:3}",
+ operation::Add::f, 2.0,
+ "tensor(x[3]):{{x:0}:2,{x:1}:3,{x:2}:0}");
}
std::vector<std::pair<vespalib::string,vespalib::string>> bad_layouts = {
diff --git a/document/src/vespa/document/update/tensor_partial_update.cpp b/document/src/vespa/document/update/tensor_partial_update.cpp
index 72bcc044977..e37e5750384 100644
--- a/document/src/vespa/document/update/tensor_partial_update.cpp
+++ b/document/src/vespa/document/update/tensor_partial_update.cpp
@@ -475,17 +475,20 @@ Value::UP
TensorPartialUpdate::modify_with_defaults(const Value& input, join_fun_t function,
const Value& modifier, double default_cell_value, const ValueBuilderFactory& factory)
{
- AddressHandler handler(input.type(), modifier.type());
+ const auto& input_type = input.type();
+ AddressHandler handler(input_type, modifier.type());
if (!handler.valid) {
return {};
}
- const size_t dsss = input.type().dense_subspace_size();
- ArrayArrayMap<string_id, double> sub_spaces(handler.for_output.addr.size(), dsss, modifier.index().size());
- find_sub_spaces_not_in_input(input, modifier, default_cell_value, handler, sub_spaces);
Value::UP output;
- if (sub_spaces.size() > 0) {
- output = typify_invoke<1, TypifyCellType, PerformInsertSubspaces>(
- input.cells().type, input, handler.for_output, sub_spaces, factory);
+ if (!input_type.is_dense()) {
+ const size_t dsss = input_type.dense_subspace_size();
+ ArrayArrayMap<string_id, double> sub_spaces(handler.for_output.addr.size(), dsss, modifier.index().size());
+ find_sub_spaces_not_in_input(input, modifier, default_cell_value, handler, sub_spaces);
+ if (sub_spaces.size() > 0) {
+ output = typify_invoke<1, TypifyCellType, PerformInsertSubspaces>(
+ input.cells().type, input, handler.for_output, sub_spaces, factory);
+ }
}
return typify_invoke<2, TypifyCellType, PerformModify>(
input.cells().type, modifier.cells().type,