summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2023-09-06 17:35:51 +0200
committerGitHub <noreply@github.com>2023-09-06 17:35:51 +0200
commit22ccfc422bf32f3f7c9d419340ae463bd869fe5e (patch)
tree7eead3dfe2d009973571c8ea4934b3a1908cc298 /eval
parentd891ecd6b7e5e809c08e708fc45addcec7ad68ef (diff)
parent68fdcdbcc671c5c2d68df2cecaac3ea50957ef0c (diff)
Merge pull request #28413 from vespa-engine/havardpe/avoid-making-new-value-index
detect not having to make a new value index
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp84
-rw-r--r--eval/src/vespa/eval/instruction/sparse_join_reduce_plan.cpp157
-rw-r--r--eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h57
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.cpp55
4 files changed, 191 insertions, 162 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index e3393dc2de7..6c0726dab37 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -134,15 +134,15 @@ Optimize universal_only() {
return Optimize::specific("universal_only", my_optimizer);
}
-using cost_map_t = std::map<vespalib::string,double>;
-std::vector<std::pair<vespalib::string,cost_map_t>> benchmark_results;
+using cost_list_t = std::vector<std::pair<vespalib::string,double>>;
+std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results;
-void benchmark(const vespalib::string &desc, const vespalib::string &expr, std::vector<Optimize> list) {
+void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
auto fun = Function::parse(expr);
ASSERT_FALSE(fun->has_error());
auto expected = eval_ref(*fun);
- cost_map_t cost_map;
- fprintf(stderr, "BENCH: %s (%s)\n", desc.c_str(), expr.c_str());
+ cost_list_t cost_list;
+ fprintf(stderr, "BENCH: %s\n", expr.c_str());
for (Optimize &optimize: list) {
std::vector<Value::UP> values;
for (size_t i = 0; i < fun->num_params(); ++i) {
@@ -179,29 +179,27 @@ void benchmark(const vespalib::string &desc, const vespalib::string &expr, std::
ASSERT_NE(optimized, nullptr);
CTFMetaData ctf_meta;
InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta);
+ InterpretedFunction::ProfiledContext pctx(ifun);
ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size());
- BenchmarkTimer timer(budget);
+ EXPECT_EQ(spec_from_value(ifun.eval(pctx.context, params)), expected);
+ EXPECT_EQ(spec_from_value(ifun.eval(pctx, params)), expected);
std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero());
std::vector<duration> min_time(ctf_meta.steps.size(), duration::max());
- InterpretedFunction::ProfiledContext pctx(ifun);
- for (bool first = true; timer.has_budget(); first = false) {
- const Value &profiled_result = ifun.eval(pctx, params);
- if (first) {
- EXPECT_EQ(spec_from_value(profiled_result), expected);
- }
+ BenchmarkTimer timer(budget);
+ while (timer.has_budget()) {
timer.before();
const Value &result = ifun.eval(pctx.context, params);
+ (void) result;
timer.after();
- if (first) {
- EXPECT_EQ(spec_from_value(result), expected);
- }
+ const Value &profiled_result = ifun.eval(pctx, params);
+ (void) profiled_result;
for (size_t i = 0; i < ctf_meta.steps.size(); ++i) {
min_time[i] = std::min(min_time[i], pctx.cost[i].second - prev_time[i]);
prev_time[i] = pctx.cost[i].second;
}
}
double cost_us = timer.min_time() * 1000.0 * 1000.0;
- cost_map.emplace(optimize.name, cost_us);
+ cost_list.emplace_back(optimize.name, cost_us);
fprintf(stderr, " optimized with: %s: %g us {\n", optimize.name.c_str(), cost_us);
for (size_t i = 0; i < ctf_meta.steps.size(); ++i) {
auto name = strip_ns(ctf_meta.steps[i].class_name);
@@ -213,7 +211,7 @@ void benchmark(const vespalib::string &desc, const vespalib::string &expr, std::
fprintf(stderr, " }\n");
}
fprintf(stderr, "\n");
- benchmark_results.emplace_back(desc, std::move(cost_map));
+ benchmark_results.emplace_back(expr, std::move(cost_list));
}
TensorSpec perform_dot_product(const TensorSpec &a, const TensorSpec &b, const std::vector<vespalib::string> &dims)
@@ -266,35 +264,43 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) {
}
auto optimize_list = std::vector<Optimize>({baseline(), with_universal(), universal_only()});
- benchmark("number number", "reduce(1.0*2.0,sum)", optimize_list);
- benchmark("number vector", "reduce(5.0*x128,sum,x)", optimize_list);
- benchmark("vector vector small", "reduce(x16*x16,sum,x)", optimize_list);
- benchmark("vector vector large", "reduce(x768*x768,sum,x)", optimize_list);
- benchmark("vector matrix full", "reduce(y64*x8y64,sum,x,y)", optimize_list);
- benchmark("vector matrix inner", "reduce(y64*x8y64,sum,y)", optimize_list);
- benchmark("vector matrix outer", "reduce(y64*x8y64,sum,x)", optimize_list);
- benchmark("matrix matrix same", "reduce(a8y64*a8y64,sum,y)", optimize_list);
- benchmark("matrix matrix different", "reduce(a8y64*b8y64,sum,y)", optimize_list);
- benchmark("matmul", "reduce(a8b64*b64c8,sum,b)", optimize_list);
- benchmark("sparse overlap", "reduce(x64_1*x64_1,sum,x)", optimize_list);
- benchmark("sparse no overlap", "reduce(a64_1*b64_1,sum,b)", optimize_list);
- benchmark("mixed dense", "reduce(a1_16x768*x768,sum,x)", optimize_list);
- benchmark("mixed mixed complex", "reduce(a1_1x128*a2_1b64_1x128,sum,a,x)", optimize_list);
+ benchmark("reduce(1.0*2.0,sum)", optimize_list);
+ benchmark("reduce(5.0*x128,sum,x)", optimize_list);
+ benchmark("reduce(x16*x16,sum,x)", optimize_list);
+ benchmark("reduce(x768*x768,sum,x)", optimize_list);
+ benchmark("reduce(y64*x8y64,sum,x,y)", optimize_list);
+ benchmark("reduce(y64*x8y64,sum,y)", optimize_list);
+ benchmark("reduce(y64*x8y64,sum,x)", optimize_list);
+ benchmark("reduce(a8y64*a8y64,sum,y)", optimize_list);
+ benchmark("reduce(a8y64*a8y64,sum,a)", optimize_list);
+ benchmark("reduce(a8y64*b8y64,sum,y)", optimize_list);
+ benchmark("reduce(a8b64*b64c8,sum,b)", optimize_list);
+ benchmark("reduce(x64_1*x64_1,sum,x)", optimize_list);
+ benchmark("reduce(a64_1*b64_1,sum,b)", optimize_list);
+ benchmark("reduce(a8_1b8_1*b8_1c8_1,sum,b)", optimize_list);
+ benchmark("reduce(a8_1b8_1*b8_1c8_1,sum,a,c)", optimize_list);
+ benchmark("reduce(a8_1b8_1*b8_1c8_1,sum,a,b,c)", optimize_list);
+ benchmark("reduce(b64_1x128*x128,sum,x)", optimize_list);
+ benchmark("reduce(b64_1x8y128*x8y128,sum,y)", optimize_list);
+ benchmark("reduce(b64_1x128*x128,sum,b,x)", optimize_list);
+ benchmark("reduce(a1_1x128*a2_1b64_1x128,sum,a,x)", optimize_list);
+ benchmark("reduce(x0_0*y8_1,sum,y)", optimize_list);
+ benchmark("reduce(x8_1*y0_0,sum,y)", optimize_list);
- size_t max_desc_size = 0;
- for (const auto &[desc, cost_map]: benchmark_results) {
- max_desc_size = std::max(max_desc_size, desc.size());
+ size_t max_expr_size = 0;
+ for (const auto &[expr, cost_list]: benchmark_results) {
+ max_expr_size = std::max(max_expr_size, expr.size());
}
- for (const auto &[desc, cost_map]: benchmark_results) {
- for (size_t i = 0; i < max_desc_size - desc.size(); ++i) {
+ for (const auto &[expr, cost_list]: benchmark_results) {
+ for (size_t i = 0; i < max_expr_size - expr.size(); ++i) {
fprintf(stderr, " ");
}
- fprintf(stderr, "%s: ", desc.c_str());
+ fprintf(stderr, "%s: ", expr.c_str());
size_t cnt = 0;
double baseline_cost = 0.0;
double with_universal_cost = 0.0;
double universal_only_cost = 0.0;
- for (const auto &[name, cost]: cost_map) {
+ for (const auto &[name, cost]: cost_list) {
if (++cnt > 1) {
fprintf(stderr, ", ");
}
@@ -336,7 +342,7 @@ int main(int argc, char **argv) {
--argc;
}
if ((argc > 1) && (slow_option == argv[1])) {
- budget = 5.0;
+ budget = 10.0;
++argv;
--argc;
}
diff --git a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.cpp b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.cpp
index 00499e7f997..fbef6ee5b7f 100644
--- a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.cpp
+++ b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.cpp
@@ -38,149 +38,80 @@ size_t count_only_in_second(const Dims &first, const Dims &second) {
return result;
}
-struct SparseJoinReduceState {
- SmallVector<string_id,4> addr_space;
- SmallVector<string_id*,4> a_addr;
- SmallVector<const string_id*,4> overlap;
- SmallVector<string_id*,4> b_only;
- SmallVector<size_t,4> b_view;
- size_t a_subspace;
- size_t b_subspace;
- uint32_t res_dims;
- SparseJoinReduceState(const bool *in_a, const bool *in_b, const bool *in_res, size_t dims)
- : addr_space(dims), a_addr(), overlap(), b_only(), b_view(), a_subspace(), b_subspace(), res_dims(0)
- {
- size_t b_idx = 0;
- uint32_t dims_end = addr_space.size();
- for (size_t i = 0; i < dims; ++i) {
- string_id *id = in_res[i] ? &addr_space[res_dims++] : &addr_space[--dims_end];
- if (in_a[i]) {
- a_addr.push_back(id);
- if (in_b[i]) {
- overlap.push_back(id);
- b_view.push_back(b_idx++);
- }
- } else if (in_b[i]) {
- b_only.push_back(id);
- ++b_idx;
- }
- }
- // Kept dimensions are allocated from the start and dropped
- // dimensions are allocated from the end. Make sure they
- // combine to exactly cover the complete address space.
- assert(res_dims == dims_end);
- }
- ~SparseJoinReduceState();
-};
-SparseJoinReduceState::~SparseJoinReduceState() = default;
-
-void execute_plan(const Value::Index &a, const Value::Index &b,
- const bool *in_a, const bool *in_b, const bool *in_res,
- size_t dims, auto &&f)
-{
- SparseJoinReduceState state(in_a, in_b, in_res, dims);
- auto outer = a.create_view({});
- auto inner = b.create_view(state.b_view);
- outer->lookup({});
- while (outer->next_result(state.a_addr, state.a_subspace)) {
- inner->lookup(state.overlap);
- while (inner->next_result(state.b_only, state.b_subspace)) {
- f(state.a_subspace, state.b_subspace, ConstArrayRef<string_id>{state.addr_space.begin(), state.res_dims});
- }
- }
-}
-
-using est_fun = SparseJoinReducePlan::est_fun_t;
-using est_filter = std::function<bool(bool, bool, bool)>;
-
-struct Est {
- est_filter filter;
- est_fun estimate;
- bool can_use;
- Est(est_filter filter_in, est_fun estimate_in)
- : filter(filter_in), estimate(estimate_in), can_use(true) {}
- ~Est();
-};
-Est::~Est() = default;
-
size_t est_1(size_t, size_t) noexcept { return 1; }
size_t est_a_or_0(size_t a, size_t b) noexcept { return (b == 0) ? 0 : a; }
size_t est_b_or_0(size_t a, size_t b) noexcept { return (a == 0) ? 0 : b; }
size_t est_min(size_t a, size_t b) noexcept { return std::min(a, b); }
size_t est_mul(size_t a, size_t b) noexcept { return (a * b); }
-bool no_dims(bool, bool, bool) noexcept { return false; }
bool reduce_all(bool, bool, bool keep) noexcept { return !keep; }
-bool keep_a_reduce_b(bool a, bool b, bool keep) noexcept {
- if (keep) {
- return (a && !b);
- } else {
- return (!a && b);
- }
-}
-bool keep_b_reduce_a(bool a, bool b, bool keep) noexcept { return keep_a_reduce_b(b, a, keep); }
-bool full_overlap(bool a, bool b, bool) noexcept { return (a == b); }
+bool keep_a_reduce_b(bool a, bool b, bool keep) noexcept { return (keep == a) && (keep != b); }
+bool keep_b_reduce_a(bool a, bool b, bool keep) noexcept { return (keep == b) && (keep != a); }
bool no_overlap_keep_all(bool a, bool b, bool keep) noexcept { return keep && (a != b); }
-std::vector<Est> make_est_list() {
- return {
- { no_dims, est_1 },
- { reduce_all, est_1 },
- { keep_a_reduce_b, est_a_or_0 },
- { keep_b_reduce_a, est_b_or_0 },
- { full_overlap, est_min },
- { no_overlap_keep_all, est_mul }
- };
-}
+} // <unnamed>
-void update_est_list(std::vector<Est> &est_list, bool in_lhs, bool in_rhs, bool in_res) {
- for (Est &est: est_list) {
- if (est.can_use && !est.filter(in_lhs, in_rhs, in_res)) {
- est.can_use = false;
- }
- }
+SparseJoinReducePlan::est_fun_t
+SparseJoinReducePlan::select_estimate() const
+{
+ if (check(reduce_all)) return est_1;
+ if (check(no_overlap_keep_all)) return est_mul;
+ if (check(keep_a_reduce_b)) return est_a_or_0;
+ if (check(keep_b_reduce_a)) return est_b_or_0;
+ return est_min;
}
-est_fun select_estimate(const std::vector<Est> &est_list) {
- for (const Est &est: est_list) {
- if (est.can_use) {
- return est.estimate;
+SparseJoinReducePlan::State::State(const bool *in_a, const bool *in_b, const bool *in_res, size_t dims)
+ : addr_space(dims), a_addr(), overlap(), b_only(), b_view(), a_subspace(), b_subspace(), res_dims(0)
+{
+ size_t b_idx = 0;
+ uint32_t dims_end = addr_space.size();
+ for (size_t i = 0; i < dims; ++i) {
+ string_id *id = in_res[i] ? &addr_space[res_dims++] : &addr_space[--dims_end];
+ if (in_a[i]) {
+ a_addr.push_back(id);
+ if (in_b[i]) {
+ overlap.push_back(id);
+ b_view.push_back(b_idx++);
+ }
+ } else if (in_b[i]) {
+ b_only.push_back(id);
+ ++b_idx;
}
}
- return est_min;
+ // Kept dimensions are allocated from the start and dropped
+ // dimensions are allocated from the end. Make sure they
+ // combine to exactly cover the complete address space.
+ assert(res_dims == dims_end);
}
-} // <unnamed>
+SparseJoinReducePlan::State::~State() = default;
SparseJoinReducePlan::SparseJoinReducePlan(const ValueType &lhs, const ValueType &rhs, const ValueType &res)
- : _in_lhs(), _in_rhs(), _in_res(), _res_dims(0), _estimate()
+ : _in_lhs(), _in_rhs(), _in_res(), _res_dims(res.count_mapped_dimensions()), _estimate()
{
auto dims = merge(lhs.mapped_dimensions(), rhs.mapped_dimensions());
assert(count_only_in_second(dims, res.mapped_dimensions()) == 0);
- auto est_list = make_est_list();
for (const auto &dim: dims) {
_in_lhs.push_back(lhs.has_dimension(dim.name));
_in_rhs.push_back(rhs.has_dimension(dim.name));
_in_res.push_back(res.has_dimension(dim.name));
- if (_in_res.back()) {
- ++_res_dims;
- }
- update_est_list(est_list, _in_lhs.back(), _in_rhs.back(), _in_res.back());
}
- _estimate = select_estimate(est_list);
- assert(bool(_estimate));
+ _estimate = select_estimate();
}
SparseJoinReducePlan::~SparseJoinReducePlan() = default;
-void
-SparseJoinReducePlan::execute(const Value::Index &lhs, const Value::Index &rhs, F f) const {
- if (rhs.size() < lhs.size()) {
- auto swap = [&](auto a, auto b, auto addr) { f(b, a, addr); };
- execute_plan(rhs, lhs, _in_rhs.data(), _in_lhs.data(), _in_res.data(), _in_res.size(), swap);
- } else {
- execute_plan(lhs, rhs, _in_lhs.data(), _in_rhs.data(), _in_res.data(), _in_res.size(), f);
- }
+bool
+SparseJoinReducePlan::maybe_forward_lhs_index() const
+{
+ return check(keep_a_reduce_b);
+}
+
+bool
+SparseJoinReducePlan::maybe_forward_rhs_index() const
+{
+ return check(keep_b_reduce_a);
}
} // namespace
diff --git a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h
index c93bf46e2dc..75b8d329763 100644
--- a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h
+++ b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h
@@ -14,8 +14,7 @@ public:
friend class SparseJoinReducePlanTest;
using BitList = SmallVector<bool,8>;
- using est_fun_t = std::function<size_t(size_t lhs_size, size_t rhs_size)>;
- using F = std::function<void(size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr)>;
+ using est_fun_t = size_t (*)(size_t lhs_size, size_t rhs_size) noexcept;
private:
BitList _in_lhs;
@@ -24,15 +23,65 @@ private:
size_t _res_dims;
est_fun_t _estimate;
+ struct State {
+ SmallVector<string_id,4> addr_space;
+ SmallVector<string_id*,4> a_addr;
+ SmallVector<const string_id*,4> overlap;
+ SmallVector<string_id*,4> b_only;
+ SmallVector<size_t,4> b_view;
+ size_t a_subspace;
+ size_t b_subspace;
+ uint32_t res_dims;
+ State(const bool *in_a, const bool *in_b, const bool *in_res, size_t dims);
+ ~State();
+ };
+
+ static void execute_plan(const Value::Index &a, const Value::Index &b,
+ const bool *in_a, const bool *in_b, const bool *in_res,
+ size_t dims, auto &&f)
+ {
+ State state(in_a, in_b, in_res, dims);
+ auto outer = a.create_view({});
+ auto inner = b.create_view(state.b_view);
+ outer->lookup({});
+ while (outer->next_result(state.a_addr, state.a_subspace)) {
+ inner->lookup(state.overlap);
+ while (inner->next_result(state.b_only, state.b_subspace)) {
+ f(state.a_subspace, state.b_subspace, ConstArrayRef<string_id>{state.addr_space.begin(), state.res_dims});
+ }
+ }
+ }
+
+ bool check(auto &&pred) const {
+ for (size_t i = 0; i < _in_lhs.size(); ++i) {
+ if (!pred(_in_lhs[i], _in_rhs[i], _in_res[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ est_fun_t select_estimate() const;
+
public:
SparseJoinReducePlan(const ValueType &lhs, const ValueType &rhs, const ValueType &res);
~SparseJoinReducePlan();
size_t res_dims() const { return _res_dims; }
- bool distinct_result() const { return (_res_dims == _in_res.size()); }
+ bool distinct_result() const { return _res_dims == _in_res.size(); }
+ bool maybe_forward_lhs_index() const;
+ bool maybe_forward_rhs_index() const;
size_t estimate_result_size(const Value::Index &lhs, const Value::Index &rhs) const {
return _estimate(lhs.size(), rhs.size());
}
- void execute(const Value::Index &lhs, const Value::Index &rhs, F f) const;
+ // f ~= std::function<void(size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr)>;
+ void execute(const Value::Index &lhs, const Value::Index &rhs, auto &&f) const {
+ if (rhs.size() < lhs.size()) {
+ auto swap = [&f](auto a, auto b, auto addr) { f(b, a, addr); };
+ execute_plan(rhs, lhs, _in_rhs.data(), _in_lhs.data(), _in_res.data(), _in_res.size(), swap);
+ } else {
+ execute_plan(lhs, rhs, _in_lhs.data(), _in_rhs.data(), _in_res.data(), _in_res.size(), f);
+ }
+ }
};
} // namespace
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
index 86e6be52de4..3811508a543 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
@@ -74,22 +74,58 @@ void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t par
state.pop_pop_push(result);
}
+template <typename LCT, typename RCT, typename OCT>
+void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) {
+ using dot_product = DotProduct<LCT,RCT>;
+ const auto &param = unwrap_param<UniversalDotProductParam>(param_in);
+ const auto &lhs = state.peek(1);
+ const auto &rhs = state.peek(0);
+ size_t lhs_index_size = lhs.index().size();
+ size_t rhs_index_size = rhs.index().size();
+ if (rhs_index_size == 0 || lhs_index_size == 0) {
+ const Value &empty = state.stash.create<ValueView>(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type<OCT>(), 0));
+ state.pop_pop_push(empty);
+ return;
+ }
+ const auto lhs_cells = lhs.cells().typify<LCT>();
+ const auto rhs_cells = rhs.cells().typify<RCT>();
+ auto dst_cells = state.stash.create_array<OCT>(lhs_index_size * param.dense_plan.res_size);
+ auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) {
+ dst_cells[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size);
+ };
+ for (size_t lhs_subspace = 0; lhs_subspace < lhs_index_size; ++lhs_subspace) {
+ for (size_t rhs_subspace = 0; rhs_subspace < rhs_index_size; ++rhs_subspace) {
+ param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size,
+ rhs_subspace * param.dense_plan.rhs_size,
+ lhs_subspace * param.dense_plan.res_size, dense_fun);
+ }
+ }
+ const Value &result = state.stash.create<ValueView>(param.res_type, lhs.index(), TypedCells(dst_cells));
+ state.pop_pop_push(result);
+}
+
struct SelectUniversalDotProduct {
- template <typename LCM, typename RCM, typename SCALAR> static auto invoke(const UniversalDotProductParam &) {
+ template <typename LCM, typename RCM, typename SCALAR> static auto invoke(const UniversalDotProductParam &param) {
constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(SCALAR::value);
using LCT = CellValueType<LCM::value.cell_type>;
using RCT = CellValueType<RCM::value.cell_type>;
using OCT = CellValueType<ocm.cell_type>;
+ if (param.sparse_plan.maybe_forward_lhs_index()) {
+ return my_universal_dense_dot_product_op<LCT,RCT,OCT>;
+ }
return my_universal_dot_product_op<LCT,RCT,OCT>;
}
};
bool check_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) {
- UniversalDotProductParam param(res, lhs, rhs);
- if (param.vector_size < 8) {
+ (void) res;
+ if (lhs.is_double() || rhs.is_double()) {
return false;
}
- return true;
+ if (lhs.count_mapped_dimensions() > 0 || rhs.count_mapped_dimensions() > 0) {
+ return true;
+ }
+ return false;
}
} // namespace <unnamed>
@@ -118,8 +154,15 @@ UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash, bool for
{
if (auto reduce = as<Reduce>(expr); reduce && (reduce->aggr() == Aggr::SUM)) {
if (auto join = as<Join>(reduce->child()); join && (join->function() == Mul::f)) {
- if (force || check_types(expr.result_type(), join->lhs().result_type(), join->rhs().result_type())) {
- return stash.create<UniversalDotProduct>(expr.result_type(), join->lhs(), join->rhs());
+ const ValueType &res_type = expr.result_type();
+ const ValueType &lhs_type = join->lhs().result_type();
+ const ValueType &rhs_type = join->rhs().result_type();
+ if (force || check_types(res_type, lhs_type, rhs_type)) {
+ SparseJoinReducePlan sparse_plan(lhs_type, rhs_type, res_type);
+ if (sparse_plan.maybe_forward_rhs_index() && !sparse_plan.maybe_forward_lhs_index()) {
+ return stash.create<UniversalDotProduct>(res_type, join->rhs(), join->lhs());
+ }
+ return stash.create<UniversalDotProduct>(res_type, join->lhs(), join->rhs());
}
}
}