diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-03-16 09:40:10 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-03-16 09:40:10 +0000 |
commit | 27ecb51b37162387ce5a061ca8d4c8c69472befc (patch) | |
tree | 00ba3a2e7148f7f71ecff7b96d84e489a2acce96 /vespajlib | |
parent | a929c7ad20c4d4e3087b2b495fea7e1545e72979 (diff) |
join dimensions must equal common dimensions for optimization
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java | 2 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java | 40 |
2 files changed, 42 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 95987d9b886..11996b6a23d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -106,6 +106,8 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N return false; if (b.type().dimensions().size() != commonDimensions.dimensions().size()) return false; + } else if (dimensions.size() != commonDimensions.dimensions().size()) { + return false; } else { for (TensorType.Dimension dimension : commonDimensions.dimensions()) { if (!dimensions.contains(dimension.name())) diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java new file mode 100644 index 00000000000..d073c60d993 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ReduceJoinTestCase.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.evaluation.Name; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author arnej + */ +public class ReduceJoinTestCase { + + @Test + public void testReduceJoinTriggers() { + var a = Tensor.from("tensor(x[3])", "[1,2,3]"); + var b = Tensor.from("tensor(x[3])", "[4,5,6]"); + var fa = new ConstantTensor<Name>(a); + var fb = new ConstantTensor<Name>(a); + var j = new Join<Name>(fa, fb, ScalarFunctions.add()); + var r = new Reduce<Name>(j, Reduce.Aggregator.sum, "x"); + var rj = new ReduceJoin<Name>(r, j); + assertTrue(rj.canOptimize(a, b)); + } + + @Test + public void testReduceJoinUnoptimized() { + var a = Tensor.from("tensor(x[3])", "[1,2,3]"); + var b = Tensor.from("tensor(y[3])", "[4,5,6]"); + var fa = new ConstantTensor<Name>(a); + var fb = new ConstantTensor<Name>(a); + var j = new Join<Name>(fa, fb, ScalarFunctions.add()); + var r = new Reduce<Name>(j, Reduce.Aggregator.sum, "x"); + var rj = new ReduceJoin<Name>(r, j); + assertFalse(rj.canOptimize(a, b)); + } +} |