diff options
Diffstat (limited to 'eval/src/tests/instruction/generic_merge/generic_merge_test.cpp')
-rw-r--r-- | eval/src/tests/instruction/generic_merge/generic_merge_test.cpp | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp index bb14d869440..025ab7f857e 100644 --- a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp +++ b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp @@ -50,12 +50,10 @@ void test_generic_merge_with(const ValueBuilderFactory &factory) { for (size_t i = 0; i < merge_layouts.size(); i += 2) { const auto l = merge_layouts[i]; const auto r = merge_layouts[i+1].cpy().seq(N_16ths); - for (TensorSpec lhs : { l.cpy().cells_float(), - l.cpy().cells_double() }) - { - for (TensorSpec rhs : { r.cpy().cells_float(), - r.cpy().cells_double() }) - { + for (CellType lct : CellTypeUtils::list_types()) { + TensorSpec lhs = l.cpy().cells(lct); + for (CellType rct : CellTypeUtils::list_types()) { + TensorSpec rhs = r.cpy().cells(rct); SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); for (auto fun: {operation::Add::f, operation::Mul::f, operation::Sub::f, operation::Max::f}) { auto expect = ReferenceOperations::merge(lhs, rhs, fun); |