summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-06 07:10:09 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-06 07:10:09 +0000
commit392be2d2538cafb73ca1269003973db91306e3e2 (patch)
treec9dc600bb12f09af924760135c9d6a408a48a804
parent89069de6916e19bf2757dc5621ec4b57ce0f5e7f (diff)
use dimension_index instead of explicit loop
-rw-r--r--eval/src/tests/instruction/generic_concat/generic_concat_test.cpp16
1 files changed, 8 insertions, 8 deletions
diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
index a051a345290..a2510fdd2fa 100644
--- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
+++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp
@@ -86,21 +86,21 @@ bool concat_addresses(const TensorSpec::Address &a, const TensorSpec::Address &b
TensorSpec reference_concat(const TensorSpec &a, const TensorSpec &b, const std::string &concat_dim) {
ValueType a_type = ValueType::from_spec(a.type());
ValueType b_type = ValueType::from_spec(b.type());
- size_t cc_dim_a_size = 1;
- for (const auto & dim : a_type.dimensions()) {
- if (dim.name == concat_dim) {
- EXPECT_TRUE(dim.is_indexed());
- cc_dim_a_size = dim.size;
- }
- }
ValueType res_type = ValueType::concat(a_type, b_type, concat_dim);
EXPECT_FALSE(res_type.is_error());
+ size_t b_offset = 1;
+ size_t concat_dim_index = a_type.dimension_index(concat_dim);
+ if (concat_dim_index != ValueType::Dimension::npos) {
+ const auto &dim = a_type.dimensions()[concat_dim_index];
+ EXPECT_TRUE(dim.is_indexed());
+ b_offset = dim.size;
+ }
TensorSpec result(res_type.to_spec());
for (const auto &cell_a: a.cells()) {
for (const auto &cell_b: b.cells()) {
TensorSpec::Address addr_a;
TensorSpec::Address addr_b;
- if (concat_addresses(cell_a.first, cell_b.first, concat_dim, cc_dim_a_size, addr_a, addr_b)) {
+ if (concat_addresses(cell_a.first, cell_b.first, concat_dim, b_offset, addr_a, addr_b)) {
result.set(addr_a, cell_a.second);
result.set(addr_b, cell_b.second);
}