summaryrefslogtreecommitdiffstats
path: root/vespalib
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-01-12 12:41:32 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-01-12 12:41:32 +0000
commit1549d04d8df60e40266ca1d78b0be888df39618e (patch)
tree1d72e19b45353e6778eb65a138537cdf66b8f29e /vespalib
parentaa047631c6fe95f11570e14555775ed3f426116f (diff)
first version of tensor concat
Diffstat (limited to 'vespalib')
-rw-r--r--vespalib/src/tests/eval/value_type/value_type_test.cpp12
-rw-r--r--vespalib/src/vespa/vespalib/eval/simple_tensor.cpp136
-rw-r--r--vespalib/src/vespa/vespalib/eval/simple_tensor.h1
-rw-r--r--vespalib/src/vespa/vespalib/eval/value_type.cpp28
-rw-r--r--vespalib/src/vespa/vespalib/eval/value_type.h1
5 files changed, 127 insertions, 51 deletions
diff --git a/vespalib/src/tests/eval/value_type/value_type_test.cpp b/vespalib/src/tests/eval/value_type/value_type_test.cpp
index f0f9871f45a..1a1f1ae6cca 100644
--- a/vespalib/src/tests/eval/value_type/value_type_test.cpp
+++ b/vespalib/src/tests/eval/value_type/value_type_test.cpp
@@ -59,6 +59,18 @@ TEST("require that dimension names can be obtained") {
std::vector<vespalib::string>({"x", "y", "z"}));
}
+TEST("require that dimension index can be obtained") {
+ EXPECT_EQUAL(ValueType::error_type().dimension_index("x"), ValueType::Dimension::npos);
+ EXPECT_EQUAL(ValueType::any_type().dimension_index("x"), ValueType::Dimension::npos);
+ EXPECT_EQUAL(ValueType::double_type().dimension_index("x"), ValueType::Dimension::npos);
+ EXPECT_EQUAL(ValueType::tensor_type({}).dimension_index("x"), ValueType::Dimension::npos);
+ auto my_type = ValueType::tensor_type({{"y", 10}, {"x"}, {"z", 0}});
+ EXPECT_EQUAL(my_type.dimension_index("x"), 0);
+ EXPECT_EQUAL(my_type.dimension_index("y"), 1);
+ EXPECT_EQUAL(my_type.dimension_index("z"), 2);
+ EXPECT_EQUAL(my_type.dimension_index("w"), ValueType::Dimension::npos);
+}
+
void verify_equal(const ValueType &a, const ValueType &b) {
EXPECT_TRUE(a == b);
EXPECT_TRUE(b == a);
diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp b/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp
index d0d3ab49f42..120c7207100 100644
--- a/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp
+++ b/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp
@@ -56,6 +56,20 @@ Address select(const Address &a, const Address &b, const IndexList &selector) {
return result;
}
+size_t get_dimension_size(const ValueType &type, size_t dim_idx) {
+ if (dim_idx == ValueType::Dimension::npos) {
+ return 1;
+ }
+ return type.dimensions()[dim_idx].size;
+}
+
+size_t get_dimension_index(const Address &addr, size_t dim_idx) {
+ if (dim_idx == ValueType::Dimension::npos) {
+ return 0;
+ }
+ return addr[dim_idx].index;
+}
+
/**
* Helper class used when building SimpleTensors. While a tensor
* in its final form simply contains a collection of cells, the
@@ -163,61 +177,64 @@ public:
};
/**
- * Helper class used to analyze the combination of types for binary
- * operations performed on SimpleTensors. The type of each tensor is
- * used as input. The constructor will calculate the result type of
- * the operation as well as which dimensions from each tensor is
- * overlapping with the other tensor and also how to build the final
- * address by indicating which labels to select from the concatenation
- * of the input addresses.
+ * Helper class used to calculate which dimensions are shared between
+ * types and which are not. Also calculates how address elements from
+ * cells with the different types should be combined into a single
+ * address.
**/
struct TypeAnalyzer {
- using DimensionList = std::vector<ValueType::Dimension>;
- ValueType result_type;
+ static constexpr size_t npos = -1;
+ IndexList only_a;
IndexList overlap_a;
IndexList overlap_b;
- IndexList selector;
- TypeAnalyzer(const ValueType &lhs, const ValueType &rhs)
- : result_type(ValueType::any_type()), overlap_a(), overlap_b(), selector()
+ IndexList only_b;
+ IndexList combine;
+ size_t ignored_a;
+ size_t ignored_b;
+ TypeAnalyzer(const ValueType &lhs, const ValueType &rhs, const vespalib::string &ignore = "")
+ : only_a(), overlap_a(), overlap_b(), only_b(), combine(), ignored_a(npos), ignored_b(npos)
{
- DimensionList union_dims;
const auto &a = lhs.dimensions();
const auto &b = rhs.dimensions();
size_t b_idx = 0;
for (size_t a_idx = 0; a_idx < a.size(); ++a_idx) {
while ((b_idx < b.size()) && (b[b_idx].name < a[a_idx].name)) {
- selector.push_back(a.size() + b_idx);
- union_dims.push_back(b[b_idx++]);
+ if (b[b_idx].name != ignore) {
+ only_b.push_back(b_idx);
+ combine.push_back(a.size() + b_idx);
+ } else {
+ ignored_b = b_idx;
+ }
+ ++b_idx;
}
if ((b_idx < b.size()) && (b[b_idx].name == a[a_idx].name)) {
- assert(a[a_idx].is_mapped() == b[b_idx].is_mapped());
- overlap_a.push_back(a_idx);
- overlap_b.push_back(b_idx);
- if (b[b_idx].size < a[a_idx].size) {
- selector.push_back(a.size() + b_idx);
- union_dims.push_back(b[b_idx]);
+ if (a[a_idx].name != ignore) {
+ overlap_a.push_back(a_idx);
+ overlap_b.push_back(b_idx);
+ combine.push_back(a_idx);
} else {
- selector.push_back(a_idx);
- union_dims.push_back(a[a_idx]);
+ ignored_a = a_idx;
+ ignored_b = b_idx;
}
++b_idx;
} else {
- selector.push_back(a_idx);
- union_dims.push_back(a[a_idx]);
+ if (a[a_idx].name != ignore) {
+ only_a.push_back(a_idx);
+ combine.push_back(a_idx);
+ } else {
+ ignored_a = a_idx;
+ }
}
}
while (b_idx < b.size()) {
- selector.push_back(a.size() + b_idx);
- union_dims.push_back(b[b_idx++]);
- }
- if (union_dims.empty()) {
- result_type = ValueType::double_type();
- } else {
- result_type = ValueType::tensor_type(union_dims);
+ if (b[b_idx].name != ignore) {
+ only_b.push_back(b_idx);
+ combine.push_back(a.size() + b_idx);
+ } else {
+ ignored_b = b_idx;
+ }
+ ++b_idx;
}
- assert(selector.size() == result_type.dimensions().size());
- assert(overlap_a.size() == overlap_b.size());
- assert_type(result_type);
}
};
@@ -270,15 +287,22 @@ private:
}
public:
- View(const SimpleTensor &tensor, const IndexList &selector_in)
- : _less(selector_in), _refs()
+ View(const SimpleTensor &tensor, const IndexList &selector)
+ : _less(selector), _refs()
{
- _refs.reserve(tensor.cells().size());
for (const auto &cell: tensor.cells()) {
_refs.emplace_back(cell);
}
std::sort(_refs.begin(), _refs.end(), _less);
}
+ View(const EqualRange &range, const IndexList &selector)
+ : _less(selector), _refs()
+ {
+ for (const auto &cell: range) {
+ _refs.emplace_back(cell);
+ }
+ std::sort(_refs.begin(), _refs.end(), _less);
+ }
const IndexList &selector() const { return _less.selector; }
const CellRef *refs_begin() const { return &_refs[0]; }
const CellRef *refs_end() const { return (refs_begin() + _refs.size()); }
@@ -451,14 +475,15 @@ SimpleTensor::map(const UnaryOperation &op, const SimpleTensor &a)
std::unique_ptr<SimpleTensor>
SimpleTensor::join(const BinaryOperation &op, const SimpleTensor &a, const SimpleTensor &b)
{
+ ValueType result_type = ValueType::join(a.type(), b.type());
+ Builder builder(result_type);
TypeAnalyzer type_info(a.type(), b.type());
- Builder builder(type_info.result_type);
View view_a(a, type_info.overlap_a);
View view_b(b, type_info.overlap_b);
for (ViewMatcher matcher(view_a, view_b); matcher.valid(); matcher.next()) {
for (const auto &ref_a: matcher.get_a()) {
for (const auto &ref_b: matcher.get_b()) {
- builder.set(select(ref_a.get().address, ref_b.get().address, type_info.selector),
+ builder.set(select(ref_a.get().address, ref_b.get().address, type_info.combine),
op.eval(ref_a.get().value, ref_b.get().value));
}
}
@@ -466,5 +491,36 @@ SimpleTensor::join(const BinaryOperation &op, const SimpleTensor &a, const Simpl
return builder.build();
}
+std::unique_ptr<SimpleTensor>
+SimpleTensor::concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension)
+{
+ ValueType result_type = ValueType::concat(a.type(), b.type(), dimension);
+ Builder builder(result_type);
+ TypeAnalyzer type_info(a.type(), b.type(), dimension);
+ View view_a(a, type_info.overlap_a);
+ View view_b(b, type_info.overlap_b);
+ size_t cat_dim_idx = result_type.dimension_index(dimension);
+ size_t cat_offset = get_dimension_size(a.type(), type_info.ignored_a);
+ for (ViewMatcher matcher(view_a, view_b); matcher.valid(); matcher.next()) {
+ View subview_a(matcher.get_a(), type_info.only_a);
+ View subview_b(matcher.get_b(), type_info.only_b);
+ for (auto range_a = subview_a.first_range(); !range_a.empty(); range_a = subview_a.next_range(range_a)) {
+ for (auto range_b = subview_b.first_range(); !range_b.empty(); range_b = subview_b.next_range(range_b)) {
+ Address addr = select(range_a.begin()->get().address, range_b.begin()->get().address, type_info.combine);
+ addr.insert(addr.begin() + cat_dim_idx, Label(size_t(0)));
+ for (const auto &ref: range_a) {
+ addr[cat_dim_idx].index = get_dimension_index(ref.get().address, type_info.ignored_a);
+ builder.set(addr, ref.get().value);
+ }
+ for (const auto &ref: range_b) {
+ addr[cat_dim_idx].index = cat_offset + get_dimension_index(ref.get().address, type_info.ignored_b);
+ builder.set(addr, ref.get().value);
+ }
+ }
+ }
+ }
+ return builder.build();
+}
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor.h b/vespalib/src/vespa/vespalib/eval/simple_tensor.h
index 51af84a53ff..1485a79fd7a 100644
--- a/vespalib/src/vespa/vespalib/eval/simple_tensor.h
+++ b/vespalib/src/vespa/vespalib/eval/simple_tensor.h
@@ -78,6 +78,7 @@ public:
static bool equal(const SimpleTensor &a, const SimpleTensor &b);
static std::unique_ptr<SimpleTensor> map(const UnaryOperation &op, const SimpleTensor &a);
static std::unique_ptr<SimpleTensor> join(const BinaryOperation &op, const SimpleTensor &a, const SimpleTensor &b);
+ static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension);
};
} // namespace vespalib::eval
diff --git a/vespalib/src/vespa/vespalib/eval/value_type.cpp b/vespalib/src/vespa/vespalib/eval/value_type.cpp
index ec6bb969289..a038ee46583 100644
--- a/vespalib/src/vespa/vespalib/eval/value_type.cpp
+++ b/vespalib/src/vespa/vespalib/eval/value_type.cpp
@@ -12,22 +12,23 @@ namespace {
using Dimension = ValueType::Dimension;
using DimensionList = std::vector<Dimension>;
-const Dimension *find_dimension(const std::vector<Dimension> &list, const vespalib::string &name) {
- for (const auto &item: list) {
- if (item.name == name) {
- return &item;
+size_t my_dimension_index(const std::vector<Dimension> &list, const vespalib::string &name) {
+ for (size_t idx = 0; idx < list.size(); ++idx) {
+ if (list[idx].name == name) {
+ return idx;
}
}
- return nullptr;
+ return ValueType::Dimension::npos;
}
Dimension *find_dimension(std::vector<Dimension> &list, const vespalib::string &name) {
- for (auto &item: list) {
- if (item.name == name) {
- return &item;
- }
- }
- return nullptr;
+ size_t idx = my_dimension_index(list, name);
+ return (idx != ValueType::Dimension::npos) ? &list[idx] : nullptr;
+}
+
+const Dimension *find_dimension(const std::vector<Dimension> &list, const vespalib::string &name) {
+ size_t idx = my_dimension_index(list, name);
+ return (idx != ValueType::Dimension::npos) ? &list[idx] : nullptr;
}
void sort_dimensions(DimensionList &dimensions) {
@@ -132,6 +133,11 @@ ValueType::is_dense() const
return true;
}
+size_t
+ValueType::dimension_index(const vespalib::string &name) const {
+ return my_dimension_index(_dimensions, name);
+}
+
std::vector<vespalib::string>
ValueType::dimension_names() const
{
diff --git a/vespalib/src/vespa/vespalib/eval/value_type.h b/vespalib/src/vespa/vespalib/eval/value_type.h
index 820e8bcf5dc..f6d02336daa 100644
--- a/vespalib/src/vespa/vespalib/eval/value_type.h
+++ b/vespalib/src/vespa/vespalib/eval/value_type.h
@@ -54,6 +54,7 @@ public:
bool is_sparse() const;
bool is_dense() const;
const std::vector<Dimension> &dimensions() const { return _dimensions; }
+ size_t dimension_index(const vespalib::string &name) const;
std::vector<vespalib::string> dimension_names() const;
bool maybe_tensor() const { return (is_any() || is_tensor()); }
bool unknown_dimensions() const { return (maybe_tensor() && _dimensions.empty()); }