diff options
Diffstat (limited to 'eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp')
-rw-r--r-- | eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp | 16 |
1 files changed, 11 insertions, 5 deletions
diff --git a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp index eaf4623afea..274117ea693 100644 --- a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp +++ b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp @@ -25,6 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref(); EvalFixture::ParamRepo make_params() { return EvalFixture::ParamRepo() .add("x5", spec({x(5)}, N())) + .add("x5f", spec(float_cells({x(5)}), N())) .add("x5y1", spec({x(5),y(1)}, N())) .add("y1z1", spec({y(1),z(1)}, N())) .add("x_m", spec({x({"a"})}, N())); @@ -78,9 +79,9 @@ TEST("require that non-canonical dimension addition is not optimized") { TEST_DO(verify_not_optimized("tensor(y[1])(1)/x5")); } -TEST("require that dimension addition with overlapping dimensions is not optimized") { - TEST_DO(verify_not_optimized("x5y1*tensor(y[1],z[1])(1)")); - TEST_DO(verify_not_optimized("tensor(y[1],z[1])(1)*x5y1")); +TEST("require that dimension addition with overlapping dimensions is optimized") { + TEST_DO(verify_optimized("x5y1*tensor(y[1],z[1])(1)")); + TEST_DO(verify_optimized("tensor(y[1],z[1])(1)*x5y1")); } TEST("require that dimension addition with inappropriate dimensions is not optimized") { @@ -99,8 +100,13 @@ TEST("require that dimension addition optimization requires unit constant tensor TEST_DO(verify_not_optimized("tensor(x[2])(1)*tensor(y[2])(1)")); } -TEST("require that optimization is disabled for tensors with non-double cells") { - TEST_DO(verify_not_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)")); +TEST("require that optimization also works for float cells") { + TEST_DO(verify_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)")); + TEST_DO(verify_optimized("x5f*tensor<float>(a[1],b[1],c[1])(1)")); +} + +TEST("require that optimization is disabled if unit vector would promote tensor cell types") { + TEST_DO(verify_not_optimized("x5f*tensor(a[1],b[1],c[1])(1)")); } TEST_MAIN() { TEST_RUN_ALL(); } |