aboutsummaryrefslogtreecommitdiffstats
path: root/document
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-02-14 09:32:54 +0100
committerGitHub <noreply@github.com>2019-02-14 09:32:54 +0100
commit3a83cb08d54ec3fe715f78333e7c26bff8564677 (patch)
treee0980eb57335b0a515e484c5116bad44890d7c09 /document
parentee58dcd78b0cd178732f53751e24074ca923edb3 (diff)
parentd031de69513e6ac561480c7b2e4991264d8c947c (diff)
Merge pull request #8483 from vespa-engine/lesters/tensor-add-update-java
Lesters/tensor add update java
Diffstat (limited to 'document')
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java31
-rw-r--r--document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java45
2 files changed, 74 insertions, 2 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
index 7a8137ce0a3..cfc3ee0c742 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java
@@ -6,7 +6,10 @@ import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.serialization.DocumentUpdateWriter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import java.util.Map;
import java.util.Objects;
/**
@@ -37,8 +40,32 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> {
@Override
public FieldValue applyTo(FieldValue oldValue) {
- // TODO: implement
- return null;
+ if ( ! (oldValue instanceof TensorFieldValue)) {
+ throw new IllegalStateException("Cannot use tensor add update on non-tensor datatype " + oldValue.getClass().getName());
+ }
+ if ( ! ((TensorFieldValue) oldValue).getTensor().isPresent()) {
+ throw new IllegalArgumentException("No existing tensor to apply update to");
+ }
+ if ( ! tensor.getTensor().isPresent()) {
+ return oldValue;
+ }
+
+ Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get();
+ Map<TensorAddress, Double> oldCells = oldTensor.cells();
+ Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells();
+
+ // currently, underlying implementation disallows multiple entries with the same key
+
+ Tensor.Builder builder = Tensor.Builder.of(oldTensor.type());
+ for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) {
+ builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue()));
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ if ( ! oldCells.containsKey(addCell.getKey())) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ }
+ return new TensorFieldValue(builder.build());
}
@Override
diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
new file mode 100644
index 00000000000..834056af59f
--- /dev/null
+++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java
@@ -0,0 +1,45 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.document.update;
+
+import com.yahoo.document.Document;
+import com.yahoo.document.DocumentId;
+import com.yahoo.document.DocumentType;
+import com.yahoo.document.DocumentTypeManager;
+import com.yahoo.document.Field;
+import com.yahoo.document.TensorDataType;
+import com.yahoo.document.datatypes.TensorFieldValue;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class TensorAddUpdateTest {
+
+ @Test
+ public void apply_add_update_operations() {
+ assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}");
+ assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}");
+ assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}");
+ assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}");
+ assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}");
+ }
+
+ private void assertApplyTo(String init, String update, String expected) {
+ String spec = "tensor(x{},y{})";
+ DocumentTypeManager types = new DocumentTypeManager();
+ DocumentType x = new DocumentType("x");
+ x.addField(new Field("f", new TensorDataType(TensorType.fromSpec(spec))));
+ types.registerDocumentType(x);
+
+ Document document = new Document(types.getDocumentType("x"), new DocumentId("doc:test:x"));
+ document.setFieldValue("f", new TensorFieldValue(Tensor.from(spec, init)));
+
+ FieldUpdate.create(document.getField("f"))
+ .addValueUpdate(new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update))))
+ .applyTo(document);
+ Tensor result = ((TensorFieldValue) document.getFieldValue("f")).getTensor().get();
+ assertEquals(Tensor.from(spec, expected), result);
+ }
+
+}