summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp')
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp15
1 files changed, 11 insertions, 4 deletions
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 335aa4791a4..426281686d7 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -45,6 +45,7 @@ EvalFixture::ParamRepo make_params() {
.add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
.add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
.add("x2y3f", spec(float_cells({x(2),y(3)}), MyMatSeq()))
+ .add("y3z2f", spec(float_cells({y(3),z(2)}), MyMatSeq()))
.add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
.add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
.add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
@@ -118,10 +119,16 @@ TEST("require that xw product can be debug dumped") {
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("reduce(y3f*x2y3,sum,y)"));
- TEST_DO(verify_not_optimized("reduce(y3*x2y3f,sum,y)"));
- TEST_DO(verify_not_optimized("reduce(y3f*x2y3f,sum,y)"));
+TEST("require that optimization works for float cells") {
+ TEST_DO(verify_optimized("reduce(y3f*x2y3,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(y3*x2y3f,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(y3f*x2y3f,sum,y)", 3, 2, true));
+}
+
+TEST("require that optimization works for float cells with inconvenient dimension nesting") {
+ TEST_DO(verify_optimized("reduce(y3f*y3z2,sum,y)", 3, 2, false));
+ TEST_DO(verify_optimized("reduce(y3*y3z2f,sum,y)", 3, 2, false));
+ TEST_DO(verify_optimized("reduce(y3f*y3z2f,sum,y)", 3, 2, false));
}
TEST_MAIN() { TEST_RUN_ALL(); }