summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/eval/tensor_function/tensor_function_test.cpp')
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp14
1 files changed, 14 insertions, 0 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index c457f68a614..3d4e2d41cb5 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -510,4 +510,18 @@ TEST("require that tensor function can be dumped for debugging") {
fprintf(stderr, "function dump -->[[%s]]<-- function dump\n", root.as_string().c_str());
}
+TEST("require that full tensor reduce expands dimension list") {
+ Stash stash;
+ const auto &num = inject(ValueType::from_spec("double"), 0, stash);
+ const auto &mat = inject(ValueType::from_spec("tensor(x[5],y[5])"), 1, stash);
+ const auto *reduce_num = as<Reduce>(reduce(num, Aggr::SUM, {}, stash));
+ const auto *reduce_mat = as<Reduce>(reduce(mat, Aggr::SUM, {}, stash));
+ ASSERT_TRUE(reduce_num);
+ ASSERT_TRUE(reduce_mat);
+ EXPECT_EQUAL(reduce_num->dimensions().size(), 0u);
+ ASSERT_EQUAL(reduce_mat->dimensions().size(), 2u);
+ EXPECT_EQUAL(reduce_mat->dimensions()[0], "x");
+ EXPECT_EQUAL(reduce_mat->dimensions()[1], "y");
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }