summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-16 09:40:10 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-16 09:40:10 +0000
commit27ecb51b37162387ce5a061ca8d4c8c69472befc (patch)
tree00ba3a2e7148f7f71ecff7b96d84e489a2acce96 /vespajlib
parenta929c7ad20c4d4e3087b2b495fea7e1545e72979 (diff)
join dimensions must equal common dimensions for optimization
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java40
2 files changed, 42 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 95987d9b886..11996b6a23d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -106,6 +106,8 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
return false;
if (b.type().dimensions().size() != commonDimensions.dimensions().size())
return false;
+ } else if (dimensions.size() != commonDimensions.dimensions().size()) {
+ return false;
} else {
for (TensorType.Dimension dimension : commonDimensions.dimensions()) {
if (!dimensions.contains(dimension.name()))
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java
new file mode 100644
index 00000000000..d073c60d993
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java
@@ -0,0 +1,40 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.evaluation.Name;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author arnej
+ */
+public class ReduceJoinTestCase {
+
+ @Test
+ public void testReduceJoinTriggers() {
+ var a = Tensor.from("tensor(x[3])", "[1,2,3]");
+ var b = Tensor.from("tensor(x[3])", "[4,5,6]");
+ var fa = new ConstantTensor<Name>(a);
+ var fb = new ConstantTensor<Name>(a);
+ var j = new Join<Name>(fa, fb, ScalarFunctions.add());
+ var r = new Reduce<Name>(j, Reduce.Aggregator.sum, "x");
+ var rj = new ReduceJoin<Name>(r, j);
+ assertTrue(rj.canOptimize(a, b));
+ }
+
+ @Test
+ public void testReduceJoinUnoptimized() {
+ var a = Tensor.from("tensor(x[3])", "[1,2,3]");
+ var b = Tensor.from("tensor(y[3])", "[4,5,6]");
+ var fa = new ConstantTensor<Name>(a);
+ var fb = new ConstantTensor<Name>(a);
+ var j = new Join<Name>(fa, fb, ScalarFunctions.add());
+ var r = new Reduce<Name>(j, Reduce.Aggregator.sum, "x");
+ var rj = new ReduceJoin<Name>(r, j);
+ assertFalse(rj.canOptimize(a, b));
+ }
+}