diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-12 14:56:36 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-12 14:56:36 +0100 |
commit | 08bf643ca8de76265205e45878b060f30aa5d187 (patch) | |
tree | 159a292a3f62ae3d7ca330437a7d84bc17979746 /vespajlib | |
parent | 6cd73b95dcdcf95a07a726aab88147c2aa19a029 (diff) |
Implement tensor modify applyTo in Java
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 1 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 20 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | 21 |
3 files changed, 42 insertions, 0 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 70383e8aabf..932513f8a57 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1038,6 +1038,7 @@ "public abstract java.util.Map cells()", "public double asDouble()", "public abstract com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)", + "public com.yahoo.tensor.Tensor modify(java.util.function.DoubleBinaryOperator, java.util.Map)", "public com.yahoo.tensor.Tensor map(java.util.function.DoubleUnaryOperator)", "public varargs com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.lang.String[])", "public com.yahoo.tensor.Tensor reduce(com.yahoo.tensor.functions.Reduce$Aggregator, java.util.List)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 58ae508ea7c..8002990e5c6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -93,6 +93,26 @@ public interface Tensor { */ Tensor withType(TensorType type); + /** + * Returns a new tensor where existing cells in this tensor have been + * modified according to the given operation and cells in the given map. + * Cells in the map outside of existing cells are thus ignored. + * + * @param op the modifying function + * @param cells the cells to modify + * @return a new tensor with modified cells + */ + default Tensor modify(DoubleBinaryOperator op, Map<TensorAddress, Double> cells) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (Iterator<Cell> i = cellIterator(); i.hasNext(); ) { + Cell cell = i.next(); + TensorAddress address = cell.getKey(); + double value = cell.getValue(); + builder.cell(address, cells.containsKey(address) ? op.applyAsDouble(value, cells.get(address)) : value); + } + return builder.build(); + } + // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 05fbb0dbdd9..2c9eefbd130 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -14,6 +14,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.function.DoubleBinaryOperator; import static com.yahoo.tensor.TensorType.Dimension.Type; import static org.junit.Assert.assertEquals; @@ -136,6 +137,26 @@ public class TensorTestCase { assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace))); } + @Test + public void testTensorModify() { + assertTensorModify((left, right) -> right, + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:0}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:0}")); + assertTensorModify((left, right) -> left + right, + Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1,{x:0,y:1}:5}")); + assertTensorModify((left, right) -> left * right, + Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x[1],y[3])", "{}"), + Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}")); + } + + private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) { + assertEquals(expected, init.modify(op, update.cells())); + } + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), |