aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/generic_join
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-09-28 10:30:00 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-09-28 10:45:36 +0000
commit905cbe0bee949379fd8573cbd298e96acfa25906 (patch)
tree60897ee989168aa848011893e2c2a93747c2862e /eval/src/tests/instruction/generic_join
parente9dfc402f1e900a37d90cd62bdad704ebe395703 (diff)
implement reference join without using SimpleTensorEngine
Diffstat (limited to 'eval/src/tests/instruction/generic_join')
-rw-r--r--eval/src/tests/instruction/generic_join/generic_join_test.cpp35
1 files changed, 27 insertions, 8 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
index 53df23d77be..4821bf092da 100644
--- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp
+++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
@@ -41,13 +41,32 @@ std::vector<Layout> join_layouts = {
float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})})
};
-TensorSpec simple_tensor_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
- Stash stash;
- const auto &engine = SimpleTensorEngine::ref();
- auto lhs = engine.from_spec(a);
- auto rhs = engine.from_spec(b);
- const auto &result = engine.join(*lhs, *rhs, function, stash);
- return engine.to_spec(result);
+bool join_address(const TensorSpec::Address &a, const TensorSpec::Address &b, TensorSpec::Address &addr) {
+ for (const auto &dim_a: a) {
+ auto pos_b = b.find(dim_a.first);
+ if ((pos_b != b.end()) && !(pos_b->second == dim_a.second)) {
+ return false;
+ }
+ addr.insert_or_assign(dim_a.first, dim_a.second);
+ }
+ return true;
+}
+
+TensorSpec reference_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
+ ValueType res_type = ValueType::join(ValueType::from_spec(a.type()), ValueType::from_spec(b.type()));
+ EXPECT_FALSE(res_type.is_error());
+ TensorSpec result(res_type.to_spec());
+ for (const auto &cell_a: a.cells()) {
+ for (const auto &cell_b: b.cells()) {
+ TensorSpec::Address addr;
+ if (join_address(cell_a.first, cell_b.first, addr) &&
+ join_address(cell_b.first, cell_a.first, addr))
+ {
+ result.add(addr, function(cell_a.second, cell_b.second));
+ }
+ }
+ }
+ return result;
}
TensorSpec perform_generic_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) {
@@ -109,7 +128,7 @@ TEST(GenericJoinTest, generic_join_works_for_simple_values) {
TensorSpec rhs = spec(join_layouts[i + 1], Div16(N()));
for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) {
SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
- auto expect = simple_tensor_join(lhs, rhs, fun);
+ auto expect = reference_join(lhs, rhs, fun);
auto actual = perform_generic_join(lhs, rhs, fun);
EXPECT_EQ(actual, expect);
}