diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-20 14:30:31 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-20 14:30:31 +0100 |
commit | c85a3fee56c13f82d14d480e7569432e1f352316 (patch) | |
tree | 1ba19b8b498a7c4e0004939a8139fcfbd8d75875 /vespajlib/src/main/java/com/yahoo/tensor | |
parent | 085b6922c07f4626c61e2ed2e6dde6beec0855de (diff) |
TensorRemoveUpdate support for mixed tensors
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
4 files changed, 47 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 704cead7c01..38d832d01c2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -196,6 +196,11 @@ public class IndexedTensor implements Tensor { } @Override + public Tensor remove(Set<TensorAddress> addresses) { + throw new IllegalArgumentException("Remove is not supported for indexed tensors"); + } + + @Override public int hashCode() { return Arrays.hashCode(values); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index f44b3ce13b7..22ceed22d3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; import java.util.function.DoubleBinaryOperator; /** @@ -71,6 +72,19 @@ public class MappedTensor implements Tensor { } @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + TensorAddress address = cell.getKey(); + if ( ! addresses.contains(address)) { + builder.cell(address, cell.getValue()); + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 3630a016691..00229c56171 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -6,9 +6,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; @@ -127,6 +129,22 @@ public class MixedTensor implements Tensor { } @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) { + TensorAddress sparsePartialAddress = entry.getKey(); + if ( ! addresses.contains(sparsePartialAddress)) { + long offset = entry.getValue(); + for (int i = 0; i < index.denseSubspaceSize; ++i) { + Cell cell = cells.get((int)offset + i); + builder.cell(cell.getKey(), cell.getValue()); + } + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 175e6b41daa..a2333f41135 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; @@ -125,7 +126,15 @@ public interface Tensor { */ Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); -// Tensor remove(Tensor other); + /** + * Returns a new tensor where existing cells in this tensor have been + * removed according to the given set of addresses. Only valid for sparse + * or mixed tensors. + * + * @param addresses list of addresses to remove + * @return a new tensor where cells have been removed + */ + Tensor remove(Set<TensorAddress> addresses); // ----------------- Primitive tensor functions |