diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-13 11:35:48 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-13 11:35:48 +0100 |
commit | 5d82e40fa6951002c8a732da04b6bc4a9ba82646 (patch) | |
tree | 8b987f8920f4aff4bf6183bb6fb53e0f7855024b /document/src/main/java/com/yahoo | |
parent | b9869d95dd4d80e23f15d610756924aaa12ea28b (diff) |
Add implementation of TensorAddUpdate applyTo
Diffstat (limited to 'document/src/main/java/com/yahoo')
-rw-r--r-- | document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java index 7a8137ce0a3..73833933367 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorAddUpdate.java @@ -6,8 +6,14 @@ import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.serialization.DocumentUpdateWriter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; import java.util.Objects; +import java.util.Set; /** * An update used to add cells to a sparse tensor (has only mapped dimensions). @@ -37,8 +43,32 @@ public class TensorAddUpdate extends ValueUpdate<TensorFieldValue> { @Override public FieldValue applyTo(FieldValue oldValue) { - // TODO: implement - return null; + if ( ! (oldValue instanceof TensorFieldValue)) { + throw new IllegalStateException("Cannot use tensor add update on non-tensor datatype " + oldValue.getClass().getName()); + } + if ( ! ((TensorFieldValue) oldValue).getTensor().isPresent()) { + throw new IllegalArgumentException("No existing tensor to apply update to"); + } + if ( ! tensor.getTensor().isPresent()) { + return oldValue; + } + + Tensor oldTensor = ((TensorFieldValue) oldValue).getTensor().get(); + Map<TensorAddress, Double> oldCells = oldTensor.cells(); + Map<TensorAddress, Double> addCells = tensor.getTensor().get().cells(); + + // currently, underlying implementation disallows multiple entries with the same key + + Tensor.Builder builder = Tensor.Builder.of(oldTensor.type()); + for (Map.Entry<TensorAddress, Double> oldCell : oldCells.entrySet()) { + builder.cell(oldCell.getKey(), addCells.getOrDefault(oldCell.getKey(), oldCell.getValue())); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + if ( ! oldCells.containsKey(addCell.getKey())) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + } + return new TensorFieldValue(builder.build()); } @Override |