diff options
author | Lester Solbakken <lesters@oath.com> | 2019-02-21 10:55:55 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-02-21 10:55:55 +0100 |
commit | 21651a8420530f069d42f37ca4dd0381f043501a (patch) | |
tree | 41a1d392c21fd43ffcb617d337490890c254d1b2 /vespajlib | |
parent | 5d0bff5230d3d8a304f786cbcc3c486ee9f941bb (diff) |
Cleanup of tensor updates - Java
Diffstat (limited to 'vespajlib')
3 files changed, 99 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index c51d3c32df9..08878edeb83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -130,9 +130,11 @@ public class MixedTensor implements Tensor { @Override public Tensor remove(Set<TensorAddress> addresses) { Tensor.Builder builder = Tensor.Builder.of(type()); + + // iterate through all sparse addresses referencing a dense subspace for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) { TensorAddress sparsePartialAddress = entry.getKey(); - if ( ! addresses.contains(sparsePartialAddress)) { + if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part long offset = entry.getValue(); for (int i = 0; i < index.denseSubspaceSize; ++i) { Cell cell = cells.get((int)offset + i); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index a2333f41135..eb16801c306 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -129,7 +129,8 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been * removed according to the given set of addresses. Only valid for sparse - * or mixed tensors. + * or mixed tensors. For mixed tensors, addresses are assumed to only + * contain the sparse dimensions, as the entire dense subspace is removed. * * @param addresses list of addresses to remove * @return a new tensor where cells have been removed diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 2c9eefbd130..02d16e6f3e4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -151,12 +151,106 @@ public class TensorTestCase { Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), Tensor.from("tensor(x[1],y[3])", "{}"), Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}")); + assertTensorModify((left, right) -> left * right, + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}")); + } + + @Test + public void testTensorMerge() { + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0 + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}")); + assertTensorMerge( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}")); + } + + @Test + public void testTensorRemove() { + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"), + Tensor.from("tensor(x{},y{})", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{},y{})", "{}"), + Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{}"), + Tensor.from("tensor(x{})", "{{x:0}:1}"), + Tensor.from("tensor(x{},y[3])", "{}")); + assertTensorRemove( + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"), + Tensor.from("tensor(x{})", "{}"), + Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}")); } private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) { assertEquals(expected, init.modify(op, update.cells())); } + private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) { + DoubleBinaryOperator op = (left, right) -> right; + assertEquals(expected, init.merge(op, update.cells())); + } + + private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) { + assertEquals(expected, init.remove(update.cells().keySet())); + } + + private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), |