aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-12 12:29:09 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-15 09:56:56 +0000
commit98b2df819734b334461bc9d1c84243d808d722b7 (patch)
treed1796aa0d4f86847eb5457dc730010ed26940b44 /eval/src/tests/eval/tensor_function/tensor_function_test.cpp
parentada08ec0903230812e85f103d9d05e8a228054a2 (diff)
added test for push_children (tensor IR nodes)
Diffstat (limited to 'eval/src/tests/eval/tensor_function/tensor_function_test.cpp')
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp24
1 files changed, 24 insertions, 0 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 681a4dabc19..641ebddfec2 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -162,4 +162,28 @@ TEST("require that tensor join works") {
TEST_DO(verify_equal(*expect, ctx.eval(prog)));
}
+TEST("require that push_children works") {
+ Stash stash;
+ std::vector<Node::Child::CREF> refs;
+ const Node &a = inject(ValueType::double_type(), 0, stash);
+ const Node &b = inject(ValueType::double_type(), 1, stash);
+ a.push_children(refs);
+ b.push_children(refs);
+ ASSERT_EQUAL(refs.size(), 0u);
+ //-------------------------------------------------------------------------
+ reduce(a, Aggr::SUM, {}, stash).push_children(refs);
+ ASSERT_EQUAL(refs.size(), 1u);
+ EXPECT_EQUAL(&refs[0].get().get(), &a);
+ //-------------------------------------------------------------------------
+ map(b, operation::Neg::f, stash).push_children(refs);
+ ASSERT_EQUAL(refs.size(), 2u);
+ EXPECT_EQUAL(&refs[1].get().get(), &b);
+ //-------------------------------------------------------------------------
+ join(a, b, operation::Add::f, stash).push_children(refs);
+ ASSERT_EQUAL(refs.size(), 4u);
+ EXPECT_EQUAL(&refs[2].get().get(), &a);
+ EXPECT_EQUAL(&refs[3].get().get(), &b);
+ //-------------------------------------------------------------------------
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }