summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp')
-rw-r--r--eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp16
1 files changed, 11 insertions, 5 deletions
diff --git a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
index eaf4623afea..274117ea693 100644
--- a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
@@ -25,6 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref();
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("x5", spec({x(5)}, N()))
+ .add("x5f", spec(float_cells({x(5)}), N()))
.add("x5y1", spec({x(5),y(1)}, N()))
.add("y1z1", spec({y(1),z(1)}, N()))
.add("x_m", spec({x({"a"})}, N()));
@@ -78,9 +79,9 @@ TEST("require that non-canonical dimension addition is not optimized") {
TEST_DO(verify_not_optimized("tensor(y[1])(1)/x5"));
}
-TEST("require that dimension addition with overlapping dimensions is not optimized") {
- TEST_DO(verify_not_optimized("x5y1*tensor(y[1],z[1])(1)"));
- TEST_DO(verify_not_optimized("tensor(y[1],z[1])(1)*x5y1"));
+TEST("require that dimension addition with overlapping dimensions is optimized") {
+ TEST_DO(verify_optimized("x5y1*tensor(y[1],z[1])(1)"));
+ TEST_DO(verify_optimized("tensor(y[1],z[1])(1)*x5y1"));
}
TEST("require that dimension addition with inappropriate dimensions is not optimized") {
@@ -99,8 +100,13 @@ TEST("require that dimension addition optimization requires unit constant tensor
TEST_DO(verify_not_optimized("tensor(x[2])(1)*tensor(y[2])(1)"));
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)"));
+TEST("require that optimization also works for float cells") {
+ TEST_DO(verify_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)"));
+ TEST_DO(verify_optimized("x5f*tensor<float>(a[1],b[1],c[1])(1)"));
+}
+
+TEST("require that optimization is disabled if unit vector would promote tensor cell types") {
+ TEST_DO(verify_not_optimized("x5f*tensor(a[1],b[1],c[1])(1)"));
}
TEST_MAIN() { TEST_RUN_ALL(); }