summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/generic_join
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-03 13:36:38 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-03 13:36:38 +0000
commit93f936c9e3531d2f2dc63f2262cd8069097238a8 (patch)
treed97eb060b69358a4fec2cda2bde32b6371647a7b /eval/src/tests/instruction/generic_join
parent2d5116489de2695acfe7bd7928e65f369ce068f5 (diff)
use CellTypeUtils::list_types to loop over possible cell types in tests
Diffstat (limited to 'eval/src/tests/instruction/generic_join')
-rw-r--r--eval/src/tests/instruction/generic_join/generic_join_test.cpp10
1 files changed, 4 insertions, 6 deletions
diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
index 181f44d0f2e..2f619bcaa54 100644
--- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp
+++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp
@@ -107,12 +107,10 @@ TEST(GenericJoinTest, generic_join_works_for_simple_and_fast_values) {
for (size_t i = 0; i < join_layouts.size(); i += 2) {
const auto &l = join_layouts[i];
const auto &r = join_layouts[i+1];
- for (TensorSpec lhs : { l.cpy().cells_float(),
- l.cpy().cells_double() })
- {
- for (TensorSpec rhs : { r.cpy().cells_float(),
- r.cpy().cells_double() })
- {
+ for (CellType lct : CellTypeUtils::list_types()) {
+ TensorSpec lhs = l.cpy().cells(lct);
+ for (CellType rct : CellTypeUtils::list_types()) {
+ TensorSpec rhs = r.cpy().cells(rct);
for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) {
SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str()));
auto expect = ReferenceOperations::join(lhs, rhs, fun);