summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-20 14:30:31 +0100
committerLester Solbakken <lesters@oath.com>2019-02-20 14:30:31 +0100
commitc85a3fee56c13f82d14d480e7569432e1f352316 (patch)
tree1ba19b8b498a7c4e0004939a8139fcfbd8d75875 /vespajlib/src/main/java/com/yahoo/tensor
parent085b6922c07f4626c61e2ed2e6dde6beec0855de (diff)
TensorRemoveUpdate support for mixed tensors
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java11
4 files changed, 47 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 704cead7c01..38d832d01c2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -196,6 +196,11 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ throw new IllegalArgumentException("Remove is not supported for indexed tensors");
+ }
+
+ @Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index f44b3ce13b7..22ceed22d3e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
/**
@@ -71,6 +72,19 @@ public class MappedTensor implements Tensor {
}
@Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ TensorAddress address = cell.getKey();
+ if ( ! addresses.contains(address)) {
+ builder.cell(address, cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 3630a016691..00229c56171 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -6,9 +6,11 @@ import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
@@ -127,6 +129,22 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor remove(Set<TensorAddress> addresses) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) {
+ TensorAddress sparsePartialAddress = entry.getKey();
+ if ( ! addresses.contains(sparsePartialAddress)) {
+ long offset = entry.getValue();
+ for (int i = 0; i < index.denseSubspaceSize; ++i) {
+ Cell cell = cells.get((int)offset + i);
+ builder.cell(cell.getKey(), cell.getValue());
+ }
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 175e6b41daa..a2333f41135 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -25,6 +25,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
@@ -125,7 +126,15 @@ public interface Tensor {
*/
Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
-// Tensor remove(Tensor other);
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * removed according to the given set of addresses. Only valid for sparse
+ * or mixed tensors.
+ *
+ * @param addresses list of addresses to remove
+ * @return a new tensor where cells have been removed
+ */
+ Tensor remove(Set<TensorAddress> addresses);
// ----------------- Primitive tensor functions