diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-20 12:46:24 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-20 12:46:24 +0100 |
commit | 085b6922c07f4626c61e2ed2e6dde6beec0855de (patch) | |
tree | 597fc14c08199339c9ab9286c365af6e8d4cdcdb /vespajlib/src/main/java/com/yahoo/tensor | |
parent | 85e394563c8b711a1a0307c8ac5953c1817f5629 (diff) |
TensorAddUpdate support for mixed tensors
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
4 files changed, 60 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index fb55b2d5014..704cead7c01 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -13,6 +13,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; +import java.util.function.DoubleBinaryOperator; /** * An indexed (dense) tensor backed by a double array. @@ -190,6 +191,11 @@ public class IndexedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { + throw new IllegalArgumentException("Merge is not supported for indexed tensors"); + } + + @Override public int hashCode() { return Arrays.hashCode(values); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index ec3020a1a4e..f44b3ce13b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; +import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -51,6 +52,25 @@ public class MappedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + + // currently, underlying implementation disallows multiple entries with the same key + + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + if ( ! cells.containsKey(addCell.getKey())) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 17e33c58a13..3630a016691 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -9,6 +9,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; /** @@ -70,13 +71,17 @@ public class MixedTensor implements Tensor { return cells.iterator(); } + private Iterable<Cell> cellIterable() { + return this::cellIterator; + } + /** * Returns an iterator over the values of this tensor. * The iteration order is the same as for cellIterator. */ @Override public Iterator<Double> valueIterator() { - return new Iterator<Double>() { + return new Iterator<>() { Iterator<Cell> cellIterator = cellIterator(); @Override public boolean hasNext() { @@ -108,6 +113,20 @@ public class MixedTensor implements Tensor { } @Override + public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Cell cell : cellIterable()) { + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value); + } + for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) { + builder.cell(addCell.getKey(), addCell.getValue()); + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8002990e5c6..175e6b41daa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -113,6 +113,20 @@ public interface Tensor { return builder.build(); } + /** + * Returns a new tensor where existing cells in this tensor have been + * modified according to the given operation and cells in the given map. + * In contrast to {@link #modify}, previously non-existing cells are added + * to this tensor. Only valid for sparse or mixed tensors. + * + * @param op how to update overlapping cells + * @param cells cells to merge with this tensor + * @return a new tensor where this tensor is merged with the other + */ + Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); + +// Tensor remove(Tensor other); + // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { |