diff options
author | Håvard Pettersen <havardpe@oath.com> | 2018-03-13 14:29:06 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2018-03-13 14:37:42 +0000 |
commit | b594edf34af42499f55a83d63ab87d0f58fa1f0a (patch) | |
tree | fd2f7da548785fb7b18d9ab0292366a21ded942d /eval/src/tests | |
parent | 838e41a332f7d227cc22ccaa4b23f9276dea274e (diff) |
use ObjectVisitor to debug dump TensorFunction trees
Diffstat (limited to 'eval/src/tests')
3 files changed, 38 insertions, 0 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index 41037fa06ef..23ea1e8c13a 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -336,4 +336,26 @@ TEST("require that push_children works") { //------------------------------------------------------------------------- } +TEST("require that tensor function can be dumped for debugging") { + Stash stash; + auto my_value_1 = stash.create<DoubleValue>(5.0); + auto my_value_2 = stash.create<DoubleValue>(1.0); + //------------------------------------------------------------------------- + const auto &x5 = inject(ValueType::from_spec("tensor(x[5])"), 0, stash); + const auto &mapped_x5 = map(x5, operation::Relu::f, stash); + const auto &const_1 = const_value(my_value_1, stash); + const auto &joined_x5 = join(mapped_x5, const_1, operation::Mul::f, stash); + //------------------------------------------------------------------------- + const auto &x2 = inject(ValueType::from_spec("tensor(x[2])"), 1, stash); + const auto &a3y10 = inject(ValueType::from_spec("tensor(a[3],y[10])"), 2, stash); + const auto &a3 = reduce(a3y10, Aggr::SUM, {"y"}, stash); + const auto &x3 = rename(a3, {"a"}, {"x"}, stash); + const auto &concat_x5 = concat(x3, x2, "x", stash); + //------------------------------------------------------------------------- + const auto &const_2 = const_value(my_value_2, stash); + const auto &root = if_node(const_2, joined_x5, concat_x5, stash); + EXPECT_EQUAL(root.result_type(), ValueType::from_spec("tensor(x[5])")); + fprintf(stderr, "function dump -->[[%s]]<-- function dump\n", root.as_string().c_str()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp index c794b81f573..3a5b27965d0 100644 --- a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp +++ b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp @@ -144,4 +144,12 @@ TEST("require that mapped tensors are not optimized") { TEST_DO(verify_not_optimized("mut_x_sparse+mut_x_sparse")); } +TEST("require that inplace join can be debug dumped") { + EvalFixture fixture(prod_engine, "con_x5_A-mut_x5_B", param_repo, true, true); + auto info = fixture.find_all<DenseInplaceJoinFunction>(); + ASSERT_EQUAL(info.size(), 1u); + EXPECT_TRUE(info[0]->result_is_mutable()); + fprintf(stderr, "%s\n", info[0]->as_string().c_str()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp index 536dd95de9c..f18e72b0d07 100644 --- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp @@ -124,4 +124,12 @@ TEST("require that xw products with incompatible dimensions are not optimized") TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)")); } +TEST("require that xw product can be debug dumped") { + EvalFixture fixture(prod_engine, "reduce(y5*x8y5,sum,y)", param_repo, true); + auto info = fixture.find_all<DenseXWProductFunction>(); + ASSERT_EQUAL(info.size(), 1u); + EXPECT_TRUE(info[0]->result_is_mutable()); + fprintf(stderr, "%s\n", info[0]->as_string().c_str()); +} + TEST_MAIN() { TEST_RUN_ALL(); } |