diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-10-06 07:10:09 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-10-06 07:10:09 +0000 |
commit | 392be2d2538cafb73ca1269003973db91306e3e2 (patch) | |
tree | c9dc600bb12f09af924760135c9d6a408a48a804 | |
parent | 89069de6916e19bf2757dc5621ec4b57ce0f5e7f (diff) |
use dimension_index instead of explicit loop
-rw-r--r-- | eval/src/tests/instruction/generic_concat/generic_concat_test.cpp | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp index a051a345290..a2510fdd2fa 100644 --- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp +++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp @@ -86,21 +86,21 @@ bool concat_addresses(const TensorSpec::Address &a, const TensorSpec::Address &b TensorSpec reference_concat(const TensorSpec &a, const TensorSpec &b, const std::string &concat_dim) { ValueType a_type = ValueType::from_spec(a.type()); ValueType b_type = ValueType::from_spec(b.type()); - size_t cc_dim_a_size = 1; - for (const auto & dim : a_type.dimensions()) { - if (dim.name == concat_dim) { - EXPECT_TRUE(dim.is_indexed()); - cc_dim_a_size = dim.size; - } - } ValueType res_type = ValueType::concat(a_type, b_type, concat_dim); EXPECT_FALSE(res_type.is_error()); + size_t b_offset = 1; + size_t concat_dim_index = a_type.dimension_index(concat_dim); + if (concat_dim_index != ValueType::Dimension::npos) { + const auto &dim = a_type.dimensions()[concat_dim_index]; + EXPECT_TRUE(dim.is_indexed()); + b_offset = dim.size; + } TensorSpec result(res_type.to_spec()); for (const auto &cell_a: a.cells()) { for (const auto &cell_b: b.cells()) { TensorSpec::Address addr_a; TensorSpec::Address addr_b; - if (concat_addresses(cell_a.first, cell_b.first, concat_dim, cc_dim_a_size, addr_a, addr_b)) { + if (concat_addresses(cell_a.first, cell_b.first, concat_dim, b_offset, addr_a, addr_b)) { result.set(addr_a, cell_a.second); result.set(addr_b, cell_b.second); } |