diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-09-28 10:30:00 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-09-28 10:45:36 +0000 |
commit | 905cbe0bee949379fd8573cbd298e96acfa25906 (patch) | |
tree | 60897ee989168aa848011893e2c2a93747c2862e /eval/src/tests/instruction/generic_join | |
parent | e9dfc402f1e900a37d90cd62bdad704ebe395703 (diff) |
implement reference join without using SimpleTensorEngine
Diffstat (limited to 'eval/src/tests/instruction/generic_join')
-rw-r--r-- | eval/src/tests/instruction/generic_join/generic_join_test.cpp | 35 |
1 files changed, 27 insertions, 8 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp index 53df23d77be..4821bf092da 100644 --- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp +++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp @@ -41,13 +41,32 @@ std::vector<Layout> join_layouts = { float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}) }; -TensorSpec simple_tensor_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { - Stash stash; - const auto &engine = SimpleTensorEngine::ref(); - auto lhs = engine.from_spec(a); - auto rhs = engine.from_spec(b); - const auto &result = engine.join(*lhs, *rhs, function, stash); - return engine.to_spec(result); +bool join_address(const TensorSpec::Address &a, const TensorSpec::Address &b, TensorSpec::Address &addr) { + for (const auto &dim_a: a) { + auto pos_b = b.find(dim_a.first); + if ((pos_b != b.end()) && !(pos_b->second == dim_a.second)) { + return false; + } + addr.insert_or_assign(dim_a.first, dim_a.second); + } + return true; +} + +TensorSpec reference_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { + ValueType res_type = ValueType::join(ValueType::from_spec(a.type()), ValueType::from_spec(b.type())); + EXPECT_FALSE(res_type.is_error()); + TensorSpec result(res_type.to_spec()); + for (const auto &cell_a: a.cells()) { + for (const auto &cell_b: b.cells()) { + TensorSpec::Address addr; + if (join_address(cell_a.first, cell_b.first, addr) && + join_address(cell_b.first, cell_a.first, addr)) + { + result.add(addr, function(cell_a.second, cell_b.second)); + } + } + } + return result; } TensorSpec perform_generic_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { @@ -109,7 +128,7 @@ TEST(GenericJoinTest, generic_join_works_for_simple_values) { TensorSpec rhs = spec(join_layouts[i + 1], Div16(N())); for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) { SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); - auto expect = simple_tensor_join(lhs, rhs, fun); + auto expect = reference_join(lhs, rhs, fun); auto actual = perform_generic_join(lhs, rhs, fun); EXPECT_EQ(actual, expect); } |