diff options
Diffstat (limited to 'eval/src/tests/instruction/generic_join/generic_join_test.cpp')
-rw-r--r-- | eval/src/tests/instruction/generic_join/generic_join_test.cpp | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp index 82696bbbd5e..cf04b2ca990 100644 --- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp +++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp @@ -148,4 +148,26 @@ TEST(GenericJoinTest, generic_join_works_for_simple_and_fast_values) { } } +TensorSpec immediate_generic_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { + const auto &factory = SimpleValueBuilderFactory::get(); + auto lhs = value_from_spec(a, factory); + auto rhs = value_from_spec(b, factory); + auto up = GenericJoin::perform_join(*lhs, *rhs, function, factory); + return spec_from_value(*up); +} + +TEST(GenericJoinTest, immediate_generic_join_works) { + ASSERT_TRUE((join_layouts.size() % 2) == 0); + for (size_t i = 0; i < join_layouts.size(); i += 2) { + TensorSpec lhs = spec(join_layouts[i], Div16(N())); + TensorSpec rhs = spec(join_layouts[i + 1], Div16(N())); + for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) { + SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + auto expect = reference_join(lhs, rhs, fun); + auto actual = immediate_generic_join(lhs, rhs, fun); + EXPECT_EQ(actual, expect); + } + } +} + GTEST_MAIN_RUN_ALL_TESTS() |