diff options
Diffstat (limited to 'eval/src/tests/streamed/value/streamed_value_test.cpp')
-rw-r--r-- | eval/src/tests/streamed/value/streamed_value_test.cpp | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/eval/src/tests/streamed/value/streamed_value_test.cpp b/eval/src/tests/streamed/value/streamed_value_test.cpp index 44c30d226bd..e5f4088d908 100644 --- a/eval/src/tests/streamed/value/streamed_value_test.cpp +++ b/eval/src/tests/streamed/value/streamed_value_test.cpp @@ -67,9 +67,8 @@ TensorSpec streamed_value_join(const TensorSpec &a, const TensorSpec &b, join_fu TEST(StreamedValueTest, streamed_values_can_be_converted_from_and_to_tensor_spec) { for (const auto &layout: layouts) { - for (TensorSpec expect : { layout.cpy().cells_float(), - layout.cpy().cells_double() }) - { + for (CellType ct : CellTypeUtils::list_types()) { + TensorSpec expect = layout.cpy().cells(ct); std::unique_ptr<Value> value = value_from_spec(expect, StreamedValueBuilderFactory::get()); TensorSpec actual = spec_from_value(*value); EXPECT_EQ(actual, expect); @@ -79,9 +78,8 @@ TEST(StreamedValueTest, streamed_values_can_be_converted_from_and_to_tensor_spec TEST(StreamedValueTest, streamed_values_can_be_copied) { for (const auto &layout: layouts) { - for (TensorSpec expect : { layout.cpy().cells_float(), - layout.cpy().cells_double() }) - { + for (CellType ct : CellTypeUtils::list_types()) { + TensorSpec expect = layout.cpy().cells(ct); std::unique_ptr<Value> value = value_from_spec(expect, StreamedValueBuilderFactory::get()); std::unique_ptr<Value> copy = StreamedValueBuilderFactory::get().copy(*value); TensorSpec actual = spec_from_value(*copy); @@ -131,12 +129,10 @@ TEST(StreamedValueTest, new_generic_join_works_for_streamed_values) { for (size_t i = 0; i < join_layouts.size(); i += 2) { const auto l = join_layouts[i].cpy().seq(N_16ths); const auto r = join_layouts[i + 1].cpy().seq(N_16ths); - for (TensorSpec lhs : { l.cpy().cells_float(), - l.cpy().cells_double() }) - { - for (TensorSpec rhs : { r.cpy().cells_float(), - r.cpy().cells_double() }) - { + for (CellType lct : CellTypeUtils::list_types()) { + TensorSpec lhs = l.cpy().cells(lct); + for (CellType rct : CellTypeUtils::list_types()) { + TensorSpec rhs = r.cpy().cells(rct); for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Max::f}) { SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); auto expect = ReferenceOperations::join(lhs, rhs, fun); |