diff options
Diffstat (limited to 'vespajlib')
5 files changed, 51 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 480523982fa..c3fe8c5c7ad 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -809,6 +809,7 @@ "public com.yahoo.tensor.DimensionSizes dimensionSizes()", "public java.util.Map cells()", "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -854,6 +855,7 @@ "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)" @@ -940,6 +942,7 @@ "public java.util.Map cells()", "public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", "public boolean equals(java.lang.Object)", @@ -1043,6 +1046,7 @@ "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)", + "public abstract com.yahoo.tensor.Tensor remove(java.util.Set)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 704cead7c01..38d832d01c2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -196,6 +196,11 @@ public class IndexedTensor implements Tensor { } @Override + public Tensor remove(Set<TensorAddress> addresses) { + throw new IllegalArgumentException("Remove is not supported for indexed tensors"); + } + + @Override public int hashCode() { return Arrays.hashCode(values); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index f44b3ce13b7..22ceed22d3e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; +import java.util.Set; import java.util.function.DoubleBinaryOperator; /** @@ -71,6 +72,19 @@ public class MappedTensor implements Tensor { } @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Iterator<Tensor.Cell> i = cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + TensorAddress address = cell.getKey(); + if ( ! addresses.contains(address)) { + builder.cell(address, cell.getValue()); + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 3630a016691..00229c56171 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -6,9 +6,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; @@ -127,6 +129,22 @@ public class MixedTensor implements Tensor { } @Override + public Tensor remove(Set<TensorAddress> addresses) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) { + TensorAddress sparsePartialAddress = entry.getKey(); + if ( ! addresses.contains(sparsePartialAddress)) { + long offset = entry.getValue(); + for (int i = 0; i < index.denseSubspaceSize; ++i) { + Cell cell = cells.get((int)offset + i); + builder.cell(cell.getKey(), cell.getValue()); + } + } + } + return builder.build(); + } + + @Override public int hashCode() { return cells.hashCode(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 175e6b41daa..a2333f41135 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; @@ -125,7 +126,15 @@ public interface Tensor { */ Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells); -// Tensor remove(Tensor other); + /** + * Returns a new tensor where existing cells in this tensor have been + * removed according to the given set of addresses. Only valid for sparse + * or mixed tensors. + * + * @param addresses list of addresses to remove + * @return a new tensor where cells have been removed + */ + Tensor remove(Set<TensorAddress> addresses); // ----------------- Primitive tensor functions |