summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-12 14:56:36 +0100
committerLester Solbakken <lesters@oath.com>2019-02-12 14:56:36 +0100
commit08bf643ca8de76265205e45878b060f30aa5d187 (patch)
tree159a292a3f62ae3d7ca330437a7d84bc17979746 /vespajlib
parent6cd73b95dcdcf95a07a726aab88147c2aa19a029 (diff)
Implement tensor modify applyTo in Java
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java20
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java21
3 files changed, 42 insertions, 0 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 70383e8aabf..932513f8a57 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1038,6 +1038,7 @@
"public abstract java.util.Map cells()",
"public double asDouble()",
"public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)",
+ "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)",
"public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)",
"public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])",
"public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 58ae508ea7c..8002990e5c6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -93,6 +93,26 @@ public interface Tensor {
*/
Tensor withType(TensorType type);
+ /**
+ * Returns a new tensor where existing cells in this tensor have been
+ * modified according to the given operation and cells in the given map.
+ * Cells in the map outside of existing cells are thus ignored.
+ *
+ * @param op the modifying function
+ * @param cells the cells to modify
+ * @return a new tensor with modified cells
+ */
+ default Tensor modify(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) {
+ Cell cell = i.next();
+ TensorAddress address = cell.getKey();
+ double value = cell.getValue();
+ builder.cell(address, cells.containsKey(address) ? op.applyAsDouble(value, cells.get(address)) : value);
+ }
+ return builder.build();
+ }
+
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 05fbb0dbdd9..2c9eefbd130 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -14,6 +14,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
+import java.util.function.DoubleBinaryOperator;
import static com.yahoo.tensor.TensorType.Dimension.Type;
import static org.junit.Assert.assertEquals;
@@ -136,6 +137,26 @@ public class TensorTestCase {
assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace)));
}
+ @Test
+ public void testTensorModify() {
+ assertTensorModify((left, right) -> right,
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:0}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:0}"));
+ assertTensorModify((left, right) -> left + right,
+ Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1,{x:0,y:1}:5}"));
+ assertTensorModify((left, right) -> left * right,
+ Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x[1],y[3])", "{}"),
+ Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}"));
+ }
+
+ private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) {
+ assertEquals(expected, init.modify(op, update.cells()));
+ }
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),