diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-01-12 12:41:32 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-01-12 12:41:32 +0000 |
commit | 1549d04d8df60e40266ca1d78b0be888df39618e (patch) | |
tree | 1d72e19b45353e6778eb65a138537cdf66b8f29e /vespalib | |
parent | aa047631c6fe95f11570e14555775ed3f426116f (diff) |
first version of tensor concat
Diffstat (limited to 'vespalib')
-rw-r--r-- | vespalib/src/tests/eval/value_type/value_type_test.cpp | 12 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/eval/simple_tensor.cpp | 136 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/eval/simple_tensor.h | 1 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/eval/value_type.cpp | 28 | ||||
-rw-r--r-- | vespalib/src/vespa/vespalib/eval/value_type.h | 1 |
5 files changed, 127 insertions, 51 deletions
diff --git a/vespalib/src/tests/eval/value_type/value_type_test.cpp b/vespalib/src/tests/eval/value_type/value_type_test.cpp index f0f9871f45a..1a1f1ae6cca 100644 --- a/vespalib/src/tests/eval/value_type/value_type_test.cpp +++ b/vespalib/src/tests/eval/value_type/value_type_test.cpp @@ -59,6 +59,18 @@ TEST("require that dimension names can be obtained") { std::vector<vespalib::string>({"x", "y", "z"})); } +TEST("require that dimension index can be obtained") { + EXPECT_EQUAL(ValueType::error_type().dimension_index("x"), ValueType::Dimension::npos); + EXPECT_EQUAL(ValueType::any_type().dimension_index("x"), ValueType::Dimension::npos); + EXPECT_EQUAL(ValueType::double_type().dimension_index("x"), ValueType::Dimension::npos); + EXPECT_EQUAL(ValueType::tensor_type({}).dimension_index("x"), ValueType::Dimension::npos); + auto my_type = ValueType::tensor_type({{"y", 10}, {"x"}, {"z", 0}}); + EXPECT_EQUAL(my_type.dimension_index("x"), 0); + EXPECT_EQUAL(my_type.dimension_index("y"), 1); + EXPECT_EQUAL(my_type.dimension_index("z"), 2); + EXPECT_EQUAL(my_type.dimension_index("w"), ValueType::Dimension::npos); +} + void verify_equal(const ValueType &a, const ValueType &b) { EXPECT_TRUE(a == b); EXPECT_TRUE(b == a); diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp b/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp index d0d3ab49f42..120c7207100 100644 --- a/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp +++ b/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp @@ -56,6 +56,20 @@ Address select(const Address &a, const Address &b, const IndexList &selector) { return result; } +size_t get_dimension_size(const ValueType &type, size_t dim_idx) { + if (dim_idx == ValueType::Dimension::npos) { + return 1; + } + return type.dimensions()[dim_idx].size; +} + +size_t get_dimension_index(const Address &addr, size_t dim_idx) { + if (dim_idx == ValueType::Dimension::npos) { + return 0; + } + return addr[dim_idx].index; +} + /** * Helper class used when building SimpleTensors. While a tensor * in its final form simply contains a collection of cells, the @@ -163,61 +177,64 @@ public: }; /** - * Helper class used to analyze the combination of types for binary - * operations performed on SimpleTensors. The type of each tensor is - * used as input. The constructor will calculate the result type of - * the operation as well as which dimensions from each tensor is - * overlapping with the other tensor and also how to build the final - * address by indicating which labels to select from the concatenation - * of the input addresses. + * Helper class used to calculate which dimensions are shared between + * types and which are not. Also calculates how address elements from + * cells with the different types should be combined into a single + * address. **/ struct TypeAnalyzer { - using DimensionList = std::vector<ValueType::Dimension>; - ValueType result_type; + static constexpr size_t npos = -1; + IndexList only_a; IndexList overlap_a; IndexList overlap_b; - IndexList selector; - TypeAnalyzer(const ValueType &lhs, const ValueType &rhs) - : result_type(ValueType::any_type()), overlap_a(), overlap_b(), selector() + IndexList only_b; + IndexList combine; + size_t ignored_a; + size_t ignored_b; + TypeAnalyzer(const ValueType &lhs, const ValueType &rhs, const vespalib::string &ignore = "") + : only_a(), overlap_a(), overlap_b(), only_b(), combine(), ignored_a(npos), ignored_b(npos) { - DimensionList union_dims; const auto &a = lhs.dimensions(); const auto &b = rhs.dimensions(); size_t b_idx = 0; for (size_t a_idx = 0; a_idx < a.size(); ++a_idx) { while ((b_idx < b.size()) && (b[b_idx].name < a[a_idx].name)) { - selector.push_back(a.size() + b_idx); - union_dims.push_back(b[b_idx++]); + if (b[b_idx].name != ignore) { + only_b.push_back(b_idx); + combine.push_back(a.size() + b_idx); + } else { + ignored_b = b_idx; + } + ++b_idx; } if ((b_idx < b.size()) && (b[b_idx].name == a[a_idx].name)) { - assert(a[a_idx].is_mapped() == b[b_idx].is_mapped()); - overlap_a.push_back(a_idx); - overlap_b.push_back(b_idx); - if (b[b_idx].size < a[a_idx].size) { - selector.push_back(a.size() + b_idx); - union_dims.push_back(b[b_idx]); + if (a[a_idx].name != ignore) { + overlap_a.push_back(a_idx); + overlap_b.push_back(b_idx); + combine.push_back(a_idx); } else { - selector.push_back(a_idx); - union_dims.push_back(a[a_idx]); + ignored_a = a_idx; + ignored_b = b_idx; } ++b_idx; } else { - selector.push_back(a_idx); - union_dims.push_back(a[a_idx]); + if (a[a_idx].name != ignore) { + only_a.push_back(a_idx); + combine.push_back(a_idx); + } else { + ignored_a = a_idx; + } } } while (b_idx < b.size()) { - selector.push_back(a.size() + b_idx); - union_dims.push_back(b[b_idx++]); - } - if (union_dims.empty()) { - result_type = ValueType::double_type(); - } else { - result_type = ValueType::tensor_type(union_dims); + if (b[b_idx].name != ignore) { + only_b.push_back(b_idx); + combine.push_back(a.size() + b_idx); + } else { + ignored_b = b_idx; + } + ++b_idx; } - assert(selector.size() == result_type.dimensions().size()); - assert(overlap_a.size() == overlap_b.size()); - assert_type(result_type); } }; @@ -270,15 +287,22 @@ private: } public: - View(const SimpleTensor &tensor, const IndexList &selector_in) - : _less(selector_in), _refs() + View(const SimpleTensor &tensor, const IndexList &selector) + : _less(selector), _refs() { - _refs.reserve(tensor.cells().size()); for (const auto &cell: tensor.cells()) { _refs.emplace_back(cell); } std::sort(_refs.begin(), _refs.end(), _less); } + View(const EqualRange &range, const IndexList &selector) + : _less(selector), _refs() + { + for (const auto &cell: range) { + _refs.emplace_back(cell); + } + std::sort(_refs.begin(), _refs.end(), _less); + } const IndexList &selector() const { return _less.selector; } const CellRef *refs_begin() const { return &_refs[0]; } const CellRef *refs_end() const { return (refs_begin() + _refs.size()); } @@ -451,14 +475,15 @@ SimpleTensor::map(const UnaryOperation &op, const SimpleTensor &a) std::unique_ptr<SimpleTensor> SimpleTensor::join(const BinaryOperation &op, const SimpleTensor &a, const SimpleTensor &b) { + ValueType result_type = ValueType::join(a.type(), b.type()); + Builder builder(result_type); TypeAnalyzer type_info(a.type(), b.type()); - Builder builder(type_info.result_type); View view_a(a, type_info.overlap_a); View view_b(b, type_info.overlap_b); for (ViewMatcher matcher(view_a, view_b); matcher.valid(); matcher.next()) { for (const auto &ref_a: matcher.get_a()) { for (const auto &ref_b: matcher.get_b()) { - builder.set(select(ref_a.get().address, ref_b.get().address, type_info.selector), + builder.set(select(ref_a.get().address, ref_b.get().address, type_info.combine), op.eval(ref_a.get().value, ref_b.get().value)); } } @@ -466,5 +491,36 @@ SimpleTensor::join(const BinaryOperation &op, const SimpleTensor &a, const Simpl return builder.build(); } +std::unique_ptr<SimpleTensor> +SimpleTensor::concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension) +{ + ValueType result_type = ValueType::concat(a.type(), b.type(), dimension); + Builder builder(result_type); + TypeAnalyzer type_info(a.type(), b.type(), dimension); + View view_a(a, type_info.overlap_a); + View view_b(b, type_info.overlap_b); + size_t cat_dim_idx = result_type.dimension_index(dimension); + size_t cat_offset = get_dimension_size(a.type(), type_info.ignored_a); + for (ViewMatcher matcher(view_a, view_b); matcher.valid(); matcher.next()) { + View subview_a(matcher.get_a(), type_info.only_a); + View subview_b(matcher.get_b(), type_info.only_b); + for (auto range_a = subview_a.first_range(); !range_a.empty(); range_a = subview_a.next_range(range_a)) { + for (auto range_b = subview_b.first_range(); !range_b.empty(); range_b = subview_b.next_range(range_b)) { + Address addr = select(range_a.begin()->get().address, range_b.begin()->get().address, type_info.combine); + addr.insert(addr.begin() + cat_dim_idx, Label(size_t(0))); + for (const auto &ref: range_a) { + addr[cat_dim_idx].index = get_dimension_index(ref.get().address, type_info.ignored_a); + builder.set(addr, ref.get().value); + } + for (const auto &ref: range_b) { + addr[cat_dim_idx].index = cat_offset + get_dimension_index(ref.get().address, type_info.ignored_b); + builder.set(addr, ref.get().value); + } + } + } + } + return builder.build(); +} + } // namespace vespalib::eval } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor.h b/vespalib/src/vespa/vespalib/eval/simple_tensor.h index 51af84a53ff..1485a79fd7a 100644 --- a/vespalib/src/vespa/vespalib/eval/simple_tensor.h +++ b/vespalib/src/vespa/vespalib/eval/simple_tensor.h @@ -78,6 +78,7 @@ public: static bool equal(const SimpleTensor &a, const SimpleTensor &b); static std::unique_ptr<SimpleTensor> map(const UnaryOperation &op, const SimpleTensor &a); static std::unique_ptr<SimpleTensor> join(const BinaryOperation &op, const SimpleTensor &a, const SimpleTensor &b); + static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension); }; } // namespace vespalib::eval diff --git a/vespalib/src/vespa/vespalib/eval/value_type.cpp b/vespalib/src/vespa/vespalib/eval/value_type.cpp index ec6bb969289..a038ee46583 100644 --- a/vespalib/src/vespa/vespalib/eval/value_type.cpp +++ b/vespalib/src/vespa/vespalib/eval/value_type.cpp @@ -12,22 +12,23 @@ namespace { using Dimension = ValueType::Dimension; using DimensionList = std::vector<Dimension>; -const Dimension *find_dimension(const std::vector<Dimension> &list, const vespalib::string &name) { - for (const auto &item: list) { - if (item.name == name) { - return &item; +size_t my_dimension_index(const std::vector<Dimension> &list, const vespalib::string &name) { + for (size_t idx = 0; idx < list.size(); ++idx) { + if (list[idx].name == name) { + return idx; } } - return nullptr; + return ValueType::Dimension::npos; } Dimension *find_dimension(std::vector<Dimension> &list, const vespalib::string &name) { - for (auto &item: list) { - if (item.name == name) { - return &item; - } - } - return nullptr; + size_t idx = my_dimension_index(list, name); + return (idx != ValueType::Dimension::npos) ? &list[idx] : nullptr; +} + +const Dimension *find_dimension(const std::vector<Dimension> &list, const vespalib::string &name) { + size_t idx = my_dimension_index(list, name); + return (idx != ValueType::Dimension::npos) ? &list[idx] : nullptr; } void sort_dimensions(DimensionList &dimensions) { @@ -132,6 +133,11 @@ ValueType::is_dense() const return true; } +size_t +ValueType::dimension_index(const vespalib::string &name) const { + return my_dimension_index(_dimensions, name); +} + std::vector<vespalib::string> ValueType::dimension_names() const { diff --git a/vespalib/src/vespa/vespalib/eval/value_type.h b/vespalib/src/vespa/vespalib/eval/value_type.h index 820e8bcf5dc..f6d02336daa 100644 --- a/vespalib/src/vespa/vespalib/eval/value_type.h +++ b/vespalib/src/vespa/vespalib/eval/value_type.h @@ -54,6 +54,7 @@ public: bool is_sparse() const; bool is_dense() const; const std::vector<Dimension> &dimensions() const { return _dimensions; } + size_t dimension_index(const vespalib::string &name) const; std::vector<vespalib::string> dimension_names() const; bool maybe_tensor() const { return (is_any() || is_tensor()); } bool unknown_dimensions() const { return (maybe_tensor() && _dimensions.empty()); } |