summaryrefslogtreecommitdiffstats
path: root/eval/src/tests
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-03-13 14:29:06 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-03-13 14:37:42 +0000
commitb594edf34af42499f55a83d63ab87d0f58fa1f0a (patch)
treefd2f7da548785fb7b18d9ab0292366a21ded942d /eval/src/tests
parent838e41a332f7d227cc22ccaa4b23f9276dea274e (diff)
use ObjectVisitor to debug dump TensorFunction trees
Diffstat (limited to 'eval/src/tests')
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp22
-rw-r--r--eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp8
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp8
3 files changed, 38 insertions, 0 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 41037fa06ef..23ea1e8c13a 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -336,4 +336,26 @@ TEST("require that push_children works") {
//-------------------------------------------------------------------------
}
+TEST("require that tensor function can be dumped for debugging") {
+ Stash stash;
+ auto my_value_1 = stash.create<DoubleValue>(5.0);
+ auto my_value_2 = stash.create<DoubleValue>(1.0);
+ //-------------------------------------------------------------------------
+ const auto &x5 = inject(ValueType::from_spec("tensor(x[5])"), 0, stash);
+ const auto &mapped_x5 = map(x5, operation::Relu::f, stash);
+ const auto &const_1 = const_value(my_value_1, stash);
+ const auto &joined_x5 = join(mapped_x5, const_1, operation::Mul::f, stash);
+ //-------------------------------------------------------------------------
+ const auto &x2 = inject(ValueType::from_spec("tensor(x[2])"), 1, stash);
+ const auto &a3y10 = inject(ValueType::from_spec("tensor(a[3],y[10])"), 2, stash);
+ const auto &a3 = reduce(a3y10, Aggr::SUM, {"y"}, stash);
+ const auto &x3 = rename(a3, {"a"}, {"x"}, stash);
+ const auto &concat_x5 = concat(x3, x2, "x", stash);
+ //-------------------------------------------------------------------------
+ const auto &const_2 = const_value(my_value_2, stash);
+ const auto &root = if_node(const_2, joined_x5, concat_x5, stash);
+ EXPECT_EQUAL(root.result_type(), ValueType::from_spec("tensor(x[5])"));
+ fprintf(stderr, "function dump -->[[%s]]<-- function dump\n", root.as_string().c_str());
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
index c794b81f573..3a5b27965d0 100644
--- a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
@@ -144,4 +144,12 @@ TEST("require that mapped tensors are not optimized") {
TEST_DO(verify_not_optimized("mut_x_sparse+mut_x_sparse"));
}
+TEST("require that inplace join can be debug dumped") {
+ EvalFixture fixture(prod_engine, "con_x5_A-mut_x5_B", param_repo, true, true);
+ auto info = fixture.find_all<DenseInplaceJoinFunction>();
+ ASSERT_EQUAL(info.size(), 1u);
+ EXPECT_TRUE(info[0]->result_is_mutable());
+ fprintf(stderr, "%s\n", info[0]->as_string().c_str());
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 536dd95de9c..f18e72b0d07 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -124,4 +124,12 @@ TEST("require that xw products with incompatible dimensions are not optimized")
TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)"));
}
+TEST("require that xw product can be debug dumped") {
+ EvalFixture fixture(prod_engine, "reduce(y5*x8y5,sum,y)", param_repo, true);
+ auto info = fixture.find_all<DenseXWProductFunction>();
+ ASSERT_EQUAL(info.size(), 1u);
+ EXPECT_TRUE(info[0]->result_is_mutable());
+ fprintf(stderr, "%s\n", info[0]->as_string().c_str());
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }