summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-02-21 10:55:55 +0100
committerLester Solbakken <lesters@oath.com>2019-02-21 10:55:55 +0100
commit21651a8420530f069d42f37ca4dd0381f043501a (patch)
tree41a1d392c21fd43ffcb617d337490890c254d1b2 /vespajlib
parent5d0bff5230d3d8a304f786cbcc3c486ee9f941bb (diff)
Cleanup of tensor updates - Java
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java3
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java94
3 files changed, 99 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index c51d3c32df9..08878edeb83 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -130,9 +130,11 @@ public class MixedTensor implements Tensor {
@Override
public Tensor remove(Set<TensorAddress> addresses) {
Tensor.Builder builder = Tensor.Builder.of(type());
+
+ // iterate through all sparse addresses referencing a dense subspace
for (Map.Entry<TensorAddress, Long> entry : index.sparseMap.entrySet()) {
TensorAddress sparsePartialAddress = entry.getKey();
- if ( ! addresses.contains(sparsePartialAddress)) {
+ if ( ! addresses.contains(sparsePartialAddress)) { // assumption: addresses only contain the sparse part
long offset = entry.getValue();
for (int i = 0; i < index.denseSubspaceSize; ++i) {
Cell cell = cells.get((int)offset + i);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index a2333f41135..eb16801c306 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -129,7 +129,8 @@ public interface Tensor {
/**
* Returns a new tensor where existing cells in this tensor have been
* removed according to the given set of addresses. Only valid for sparse
- * or mixed tensors.
+ * or mixed tensors. For mixed tensors, addresses are assumed to only
+ * contain the sparse dimensions, as the entire dense subspace is removed.
*
* @param addresses list of addresses to remove
* @return a new tensor where cells have been removed
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 2c9eefbd130..02d16e6f3e4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -151,12 +151,106 @@ public class TensorTestCase {
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
Tensor.from("tensor(x[1],y[3])", "{}"),
Tensor.from("tensor(x[1],y[2])", "{{x:0,y:0}:0,{x:0,y:1}:0}"));
+ assertTensorModify((left, right) -> left * right,
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1, {x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:6}"));
+ }
+
+ @Test
+ public void testTensorMerge() {
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2,{x:0,y:2}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:2}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:0,{x:0,y:2}:3}")); // notice difference with sparse case - y is dense dimension here with default value 0.0
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:0}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:1}:3,{x:0,y:2}:4}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:0,{x:0,y:1}:3,{x:0,y:2}:4}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:5}"));
+ assertTensorMerge(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:0,y:1}:2}"));
+ }
+
+ @Test
+ public void testTensorRemove() {
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:2}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1,{x:0,y:1}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:1}"),
+ Tensor.from("tensor(x{},y{})", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{},y{})", "{}"),
+ Tensor.from("tensor(x{},y{})", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2, {x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"), // notice update is without dense dimension
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:1,{x:1,y:0}:2}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:1,y:0}:2,{x:1,y:1}:0,{x:1,y:2}:0}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{}"),
+ Tensor.from("tensor(x{})", "{{x:0}:1}"),
+ Tensor.from("tensor(x{},y[3])", "{}"));
+ assertTensorRemove(
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"),
+ Tensor.from("tensor(x{})", "{}"),
+ Tensor.from("tensor(x{},y[3])", "{{x:0,y:0}:2,{x:0,y:1}:3}"));
}
private void assertTensorModify(DoubleBinaryOperator op, Tensor init, Tensor update, Tensor expected) {
assertEquals(expected, init.modify(op, update.cells()));
}
+ private void assertTensorMerge(Tensor init, Tensor update, Tensor expected) {
+ DoubleBinaryOperator op = (left, right) -> right;
+ assertEquals(expected, init.merge(op, update.cells()));
+ }
+
+ private void assertTensorRemove(Tensor init, Tensor update, Tensor expected) {
+ assertEquals(expected, init.remove(update.cells().keySet()));
+ }
+
+
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),