summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/generic_join/generic_join_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/instruction/generic_join/generic_join_test.cpp')
-rw-r--r--eval/src/tests/instruction/generic_join/generic_join_test.cpp22
1 files changed, 22 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
index 82696bbbd5e..cf04b2ca990 100644
--- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp
+++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
@@ -148,4 +148,26 @@ TEST(GenericJoinTest, generic_join_works_for_simple_and_fast_values) {
}
}
+TensorSpec immediate_generic_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
+ const auto &factory = SimpleValueBuilderFactory::get();
+ auto lhs = value_from_spec(a, factory);
+ auto rhs = value_from_spec(b, factory);
+ auto up = GenericJoin::perform_join(*lhs, *rhs, function, factory);
+ return spec_from_value(*up);
+}
+
+TEST(GenericJoinTest, immediate_generic_join_works) {
+ ASSERT_TRUE((join_layouts.size() % 2) == 0);
+ for (size_t i = 0; i < join_layouts.size(); i += 2) {
+ TensorSpec lhs = spec(join_layouts[i], Div16(N()));
+ TensorSpec rhs = spec(join_layouts[i + 1], Div16(N()));
+ for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) {
+ SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
+ auto expect = reference_join(lhs, rhs, fun);
+ auto actual = immediate_generic_join(lhs, rhs, fun);
+ EXPECT_EQ(actual, expect);
+ }
+ }
+}
+
GTEST_MAIN_RUN_ALL_TESTS()