diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-02-14 09:32:54 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-02-14 09:32:54 +0100 |
commit | 3a83cb08d54ec3fe715f78333e7c26bff8564677 (patch) | |
tree | e0980eb57335b0a515e484c5116bad44890d7c09 /document | |
parent | ee58dcd78b0cd178732f53751e24074ca923edb3 (diff) | |
parent | d031de69513e6ac561480c7b2e4991264d8c947c (diff) |
Merge pull request #8483 from vespa-engine/lesters/tensor-add-update-java
Lesters/tensor add update java
Diffstat (limited to 'document')
-rw-r--r-- | document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java | 31 | ||||
-rw-r--r-- | document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java | 45 |
2 files changed, 74 insertions, 2 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index 7a8137ce0a3..cfc3ee0c742 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -6,7 +6,10 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import java.util.Map; import java.util.Objects; /** @@ -37,8 +40,32 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @Override public FieldValue applyTo(FieldValue oldValue) { - // TODO: implement - return null; + if ( ! (oldValue instanceof TensorFieldValue)) { + throw new IllegalStateException("Cannot use tensor add update on non-tensor datatype " + oldValue.getClass().getName()); + } + if ( ! ((TensorFieldValue) oldValue).getTensor().isPresent()) { + throw new IllegalArgumentException("No existing tensor to apply update to"); + } + if ( ! tensor.getTensor().isPresent()) { + return oldValue; + } + + Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); + Map<TensorAddress, Double> oldCells = oldTensor.cells(); + Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); + + // currently, underlying implementation disallows multiple entries with the same key + + Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); + for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { + builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + if ( ! oldCells.containsKey(addCell.getKey())) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + } + return new TensorFieldValue(builder.build()); } @Override diff --git a/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java new file mode 100644 index 00000000000..834056af59f --- /dev/null +++ b/document/src/test/java/com/yahoo/document/update/TensorAddUpdateTest.java @@ -0,0 +1,45 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.document.update; + +import com.yahoo.document.Document; +import com.yahoo.document.DocumentId; +import com.yahoo.document.DocumentType; +import com.yahoo.document.DocumentTypeManager; +import com.yahoo.document.Field; +import com.yahoo.document.TensorDataType; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class TensorAddUpdateTest { + + @Test + public void apply_add_update_operations() { + assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:2}:3}", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"); + assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3}", "{{x:0,y:0}:1,{x:0,y:1}:3}"); + assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{{x:0,y:1}:3,{x:0,y:2}:4}", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"); + assertApplyTo("{}", "{{x:0,y:0}:5}", "{{x:0,y:0}:5}"); + assertApplyTo("{{x:0,y:0}:1, {x:0,y:1}:2}", "{}", "{{x:0,y:0}:1, {x:0,y:1}:2}"); + } + + private void assertApplyTo(String init, String update, String expected) { + String spec = "tensor(x{},y{})"; + DocumentTypeManager types = new DocumentTypeManager(); + DocumentType x = new DocumentType("x"); + x.addField(new Field("f", new TensorDataType(TensorType.fromSpec(spec)))); + types.registerDocumentType(x); + + Document document = new Document(types.getDocumentType("x"), new DocumentId("doc:test:x")); + document.setFieldValue("f", new TensorFieldValue(Tensor.from(spec, init))); + + FieldUpdate.create(document.getField("f")) + .addValueUpdate(new TensorAddUpdate(new TensorFieldValue(Tensor.from(spec, update)))) + .applyTo(document); + Tensor result = ((TensorFieldValue) document.getFieldValue("f")).getTensor().get(); + assertEquals(Tensor.from(spec, expected), result); + } + +} |