summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-20 12:46:24 +0100
committerLester Solbakken <lesters@oath.com>2019-02-20 12:46:24 +0100
commit085b6922c07f4626c61e2ed2e6dde6beec0855de (patch)
tree597fc14c08199339c9ab9286c365af6e8d4cdcdb /vespajlib
parent85e394563c8b711a1a0307c8ac5953c1817f5629 (diff)
TensorAddUpdate support for mixed tensors
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java14
5 files changed, 64 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 932513f8a57..480523982fa 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -808,6 +808,7 @@
"public com.yahoo.tensor.IndexedTensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.DimensionSizes dimensionSizes()",
"public java.util.Map cells()",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -852,6 +853,7 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)"
@@ -937,6 +939,7 @@
"public java.util.Iterator valueIterator()",
"public java.util.Map cells()",
"public com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public int hashCode()",
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
@@ -1039,6 +1042,7 @@
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
+ "public abstract com.yahoo.tensor.Tensor merge(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index fb55b2d5014..704cead7c01 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -13,6 +13,7 @@ import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
/**
* An indexed (dense) tensor backed by a double array.
@@ -190,6 +191,11 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) {
+ throw new IllegalArgumentException("Merge is not supported for indexed tensors");
+ }
+
+ @Override
public int hashCode() { return Arrays.hashCode(values); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index ec3020a1a4e..f44b3ce13b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
+import java.util.function.DoubleBinaryOperator;
/**
* A sparse implementation of a tensor backed by a Map of cells to values.
@@ -51,6 +52,25 @@ public class MappedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+
+ // currently, underlying implementation disallows multiple entries with the same key
+
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ if ( ! cells.containsKey(addCell.getKey())) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 17e33c58a13..3630a016691 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -9,6 +9,7 @@ import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
/**
@@ -70,13 +71,17 @@ public class MixedTensor implements Tensor {
return cells.iterator();
}
+ private Iterable<Cell> cellIterable() {
+ return this::cellIterator;
+ }
+
/**
* Returns an iterator over the values of this tensor.
* The iteration order is the same as for cellIterator.
*/
@Override
public Iterator<Double> valueIterator() {
- return new Iterator<Double>() {
+ return new Iterator<>() {
Iterator<Cell> cellIterator = cellIterator();
@Override
public boolean hasNext() {
@@ -108,6 +113,20 @@ public class MixedTensor implements Tensor {
}
@Override
+ public Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> addCells) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Cell cell : cellIterable()) {
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, addCells.containsKey(address) ? op.applyAsDouble(value, addCells.get(address)) : value);
+ }
+ for (Map.Entry<TensorAddress, Double> addCell : addCells.entrySet()) {
+ builder.cell(addCell.getKey(), addCell.getValue());
+ }
+ return builder.build();
+ }
+
+ @Override
public int hashCode() { return cells.hashCode(); }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8002990e5c6..175e6b41daa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -113,6 +113,20 @@ public interface Tensor {
return builder.build();
}
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * modified according to the given operation and cells in the given map.
+ * In contrast to {@link #modify}, previously non-existing cells are added
+ * to this tensor. Only valid for sparse or mixed tensors.
+ *
+ * @param op how to update overlapping cells
+ * @param cells cells to merge with this tensor
+ * @return a new tensor where this tensor is merged with the other
+ */
+ Tensor merge(DoubleBinaryOperator op, Map<TensorAddress, Double> cells);
+
+// Tensor remove(Tensor other);
+
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {