diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-09-07 13:01:03 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-09-08 14:00:39 +0000 |
commit | 2f0877044912b3e39ae355189c780341cf117892 (patch) | |
tree | 4a93b5c2d359f94a6357f2c67bfceeff5b57a2d8 /eval/src | |
parent | fa6c99ac39cba9fd07ae9229995ec5cdc614ddbb (diff) |
handle 'distinct' and 'single' flags using templates
Diffstat (limited to 'eval/src')
6 files changed, 205 insertions, 153 deletions
diff --git a/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp b/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp index 9851e209ba5..7faf7d57738 100644 --- a/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp +++ b/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp @@ -13,7 +13,7 @@ ValueType type(const vespalib::string &type_spec) { TEST(DenseJoinReducePlanTest, make_trivial_plan) { auto plan = DenseJoinReducePlan(type("double"), type("double"), type("double")); - EXPECT_TRUE(plan.distinct_result()); + EXPECT_TRUE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 1); EXPECT_EQ(plan.rhs_size, 1); EXPECT_EQ(plan.res_size, 1); @@ -39,7 +39,7 @@ TEST(DenseJoinReducePlanTest, make_simple_plan) { SmallVector<size_t> expect_lhs_stride = {1,0}; SmallVector<size_t> expect_rhs_stride = {0,1}; SmallVector<size_t> expect_res_stride = {1,0}; - EXPECT_FALSE(plan.distinct_result()); + EXPECT_FALSE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 2); EXPECT_EQ(plan.rhs_size, 3); EXPECT_EQ(plan.res_size, 2); @@ -69,7 +69,7 @@ TEST(DenseJoinReducePlanTest, make_distinct_plan) { SmallVector<size_t> expect_lhs_stride = {1,0}; SmallVector<size_t> expect_rhs_stride = {0,1}; SmallVector<size_t> expect_res_stride = {3,1}; - EXPECT_TRUE(plan.distinct_result()); + EXPECT_TRUE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 2); EXPECT_EQ(plan.rhs_size, 3); EXPECT_EQ(plan.res_size, 6); @@ -88,7 +88,7 @@ TEST(DenseJoinReducePlanTest, make_complex_plan) { SmallVector<size_t> expect_lhs_stride = {6,0,2,1}; SmallVector<size_t> expect_rhs_stride = {4,1,0,0}; SmallVector<size_t> expect_res_stride = {12,3,1,0}; - EXPECT_FALSE(plan.distinct_result()); + EXPECT_FALSE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 180); EXPECT_EQ(plan.rhs_size, 120); EXPECT_EQ(plan.res_size, 360); diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index 4ca2a5ef79a..e1967f012cb 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -27,31 +27,7 @@ using vespalib::make_string_short::fmt; const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); bool bench = false; double budget = 1.0; - -GenSpec::seq_t N_16ths = [] (size_t i) noexcept { return (i + 33.0) / 16.0; }; - -GenSpec G() { return GenSpec().seq(N_16ths); } - -const std::vector<GenSpec> layouts = { - G(), G(), - G().idx("x", 5), G().idx("x", 5), - G().idx("x", 5), G().idx("y", 5), - G().idx("x", 5), G().idx("x", 5).idx("y", 5), - G().idx("y", 3), G().idx("x", 2).idx("z", 3), - G().idx("x", 3).idx("y", 5), G().idx("y", 5).idx("z", 7), - G().map("x", {"a","b","c"}), G().map("x", {"a","b","c"}), - G().map("x", {"a","b","c"}), G().map("x", {"a","b"}), - G().map("x", {"a","b","c"}), G().map("y", {"foo","bar","baz"}), - G().map("x", {"a","b","c"}), G().map("x", {"a","b","c"}).map("y", {"foo","bar","baz"}), - G().map("x", {"a","b"}).map("y", {"foo","bar","baz"}), G().map("x", {"a","b","c"}).map("y", {"foo","bar"}), - G().map("x", {"a","b"}).map("y", {"foo","bar","baz"}), G().map("y", {"foo","bar"}).map("z", {"i","j","k","l"}), - G().idx("x", 3).map("y", {"foo", "bar"}), G().map("y", {"foo", "bar"}).idx("z", 7), - G().map("x", {"a","b","c"}).idx("y", 5), G().idx("y", 5).map("z", {"i","j","k","l"}) -}; - -const std::vector<std::vector<vespalib::string>> reductions = { - {}, {"x"}, {"y"}, {"z"}, {"x", "y"}, {"x", "z"}, {"y", "z"} -}; +size_t verify_cnt = 0; std::vector<std::string> ns_list = { {"vespalib::eval::instruction::(anonymous namespace)::"}, @@ -76,14 +52,19 @@ std::string strip_ns(const vespalib::string &str) { return tmp; } -TensorSpec make_spec(const vespalib::string ¶m_name, size_t idx) { - return GenSpec::from_desc(param_name).cells_double().seq(N(1 + idx)); +using select_cell_type_t = std::function<CellType(size_t idx)>; +CellType always_double(size_t) { return CellType::DOUBLE; } +select_cell_type_t select(CellType lct) { return [lct](size_t)noexcept{ return lct; }; } +select_cell_type_t select(CellType lct, CellType rct) { return [lct,rct](size_t idx)noexcept{ return idx ? rct : lct; }; } + +TensorSpec make_spec(const vespalib::string ¶m_name, size_t idx, select_cell_type_t select_cell_type) { + return GenSpec::from_desc(param_name).cells(select_cell_type(idx)).seq(N(1 + idx)); } -TensorSpec eval_ref(const Function &fun) { +TensorSpec eval_ref(const Function &fun, select_cell_type_t select_cell_type) { std::vector<TensorSpec> params; for (size_t i = 0; i < fun.num_params(); ++i) { - params.push_back(make_spec(fun.param_name(i), i)); + params.push_back(make_spec(fun.param_name(i), i, select_cell_type)); } return ReferenceEvaluation::eval(fun, params); } @@ -134,12 +115,13 @@ Optimize universal_only() { return Optimize::specific("universal_only", my_optimizer); } -void verify(const vespalib::string &expr) { +void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { + ++verify_cnt; auto fun = Function::parse(expr); ASSERT_FALSE(fun->has_error()); std::vector<Value::UP> values; for (size_t i = 0; i < fun->num_params(); ++i) { - auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory); + auto value = value_from_spec(make_spec(fun->param_name(i), i, select_cell_type), prod_factory); values.push_back(std::move(value)); } SimpleObjectParams params({}); @@ -167,23 +149,24 @@ void verify(const vespalib::string &expr) { } else { EXPECT_EQ(actual.cells().size, actual.index().size() * expected_type.dense_subspace_size()); } - auto expected = eval_ref(*fun); + auto expected = eval_ref(*fun, select_cell_type); EXPECT_EQ(spec_from_value(actual), expected); } +void verify(const vespalib::string &expr) { verify(expr, always_double); } using cost_list_t = std::vector<std::pair<vespalib::string,double>>; std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results; void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { + verify(expr); auto fun = Function::parse(expr); ASSERT_FALSE(fun->has_error()); - auto expected = eval_ref(*fun); cost_list_t cost_list; fprintf(stderr, "BENCH: %s\n", expr.c_str()); for (Optimize &optimize: list) { std::vector<Value::UP> values; for (size_t i = 0; i < fun->num_params(); ++i) { - auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory); + auto value = value_from_spec(make_spec(fun->param_name(i), i, always_double), prod_factory); values.push_back(std::move(value)); } SimpleObjectParams params({}); @@ -218,8 +201,6 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta); InterpretedFunction::ProfiledContext pctx(ifun); ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size()); - EXPECT_EQ(spec_from_value(ifun.eval(pctx.context, params)), expected); - EXPECT_EQ(spec_from_value(ifun.eval(pctx, params)), expected); std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero()); std::vector<duration> min_time(ctf_meta.steps.size(), duration::max()); BenchmarkTimer timer(budget); @@ -251,47 +232,63 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { benchmark_results.emplace_back(expr, std::move(cost_list)); } -TensorSpec perform_dot_product(const TensorSpec &a, const TensorSpec &b, const std::vector<vespalib::string> &dims) -{ - Stash stash; - auto lhs = value_from_spec(a, prod_factory); - auto rhs = value_from_spec(b, prod_factory); - auto res_type = ValueType::join(lhs->type(), rhs->type()).reduce(dims); - EXPECT_FALSE(res_type.is_error()); - UniversalDotProduct dot_product(res_type, - tensor_function::inject(lhs->type(), 0, stash), - tensor_function::inject(rhs->type(), 1, stash)); - auto my_op = dot_product.compile_self(prod_factory, stash); - InterpretedFunction::EvalSingle single(prod_factory, my_op); - return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs,*rhs}))); +TEST(UniversalDotProductTest, test_select_cell_types) { + auto always = always_double; + EXPECT_EQ(always(0), CellType::DOUBLE); + EXPECT_EQ(always(1), CellType::DOUBLE); + EXPECT_EQ(always(0), CellType::DOUBLE); + EXPECT_EQ(always(1), CellType::DOUBLE); + for (CellType lct: CellTypeUtils::list_types()) { + auto sel1 = select(lct); + EXPECT_EQ(sel1(0), lct); + EXPECT_EQ(sel1(1), lct); + EXPECT_EQ(sel1(0), lct); + EXPECT_EQ(sel1(1), lct); + for (CellType rct: CellTypeUtils::list_types()) { + auto sel2 = select(lct, rct); + EXPECT_EQ(sel2(0), lct); + EXPECT_EQ(sel2(1), rct); + EXPECT_EQ(sel2(0), lct); + EXPECT_EQ(sel2(1), rct); + } + } } -TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) { - size_t test_cases = 0; - ASSERT_TRUE((layouts.size() % 2) == 0); - for (size_t i = 0; i < layouts.size(); i += 2) { - const auto &l = layouts[i]; - const auto &r = layouts[i+1]; - for (CellType lct : CellTypeUtils::list_types()) { - auto lhs = l.cpy().cells(lct); - if (lhs.bad_scalar()) continue; - for (CellType rct : CellTypeUtils::list_types()) { - auto rhs = r.cpy().cells(rct); - if (rhs.bad_scalar()) continue; - for (const std::vector<vespalib::string> &dims: reductions) { - if (ValueType::join(lhs.type(), rhs.type()).reduce(dims).is_error()) continue; - ++test_cases; - SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); - auto expect = ReferenceOperations::reduce(ReferenceOperations::join(lhs, rhs, operation::Mul::f), Aggr::SUM, dims); - auto actual = perform_dot_product(lhs, rhs, dims); - // fprintf(stderr, "\n===\nLHS: %s\nRHS: %s\n===\nRESULT: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str(), actual.to_string().c_str()); - EXPECT_EQ(actual, expect); - } - } +TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) { + // forward, distinct, single + verify("reduce(2.0*3.0, sum)"); + + for (CellType lct: CellTypeUtils::list_types()) { + for (CellType rct: CellTypeUtils::list_types()) { + auto sel2 = select(lct, rct); + // !forward, !distinct, !single + verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2); + + // !forward, !distinct, single + verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2); + + // !forward, distinct, !single + verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2); + + // forward, !distinct, !single + verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2); + + // forward, !distinct, single + verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2); + + // forward, distinct, !single + verify("reduce(a4_1x8*x8,sum,x)", sel2); } } - EXPECT_GT(test_cases, 500); - fprintf(stderr, "total test cases run: %zu\n", test_cases); + // !forward, distinct, single + + // This case is not possible since 'distinct' implies '!single' as + // long as we reduce anything. The only expression allowed to + // reduce nothing is the scalar case. +} + +TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) { + verify("reduce(a4_1b4_1c4_1x4y3z2w1*a2_1c1_1x4z2,sum,b,c,x)"); } TEST(UniversalDotProductTest, forwarding_empty_result) { @@ -336,8 +333,11 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) { } auto optimize_list = std::vector<Optimize>({baseline(), with_universal(), universal_only()}); - benchmark("reduce(1.0*2.0,sum)", optimize_list); + benchmark("reduce(2.0*3.0,sum)", optimize_list); benchmark("reduce(5.0*x128,sum,x)", optimize_list); + benchmark("reduce(a1*x128,sum,x)", optimize_list); + benchmark("reduce(a8*x128,sum,x)", optimize_list); + benchmark("reduce(a1_1b8*x128,sum,x)", optimize_list); benchmark("reduce(x16*x16,sum,x)", optimize_list); benchmark("reduce(x768*x768,sum,x)", optimize_list); benchmark("reduce(y64*x8y64,sum,x,y)", optimize_list); @@ -417,5 +417,7 @@ int main(int argc, char **argv) { --argc; } ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + int result = RUN_ALL_TESTS(); + fprintf(stderr, "verify called %zu times\n", verify_cnt); + return result; } diff --git a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp index 20b7d3364a8..8d09abbfe15 100644 --- a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp +++ b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp @@ -82,7 +82,7 @@ DenseJoinReducePlan::DenseJoinReducePlan(const ValueType &lhs, const ValueType & DenseJoinReducePlan::~DenseJoinReducePlan() = default; bool -DenseJoinReducePlan::distinct_result() const +DenseJoinReducePlan::is_distinct() const { for (size_t stride: res_stride) { if (stride == 0) { diff --git a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h index 8f9d5218630..3cf55e9ace4 100644 --- a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h +++ b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h @@ -21,7 +21,10 @@ struct DenseJoinReducePlan { template <typename F> void execute(size_t lhs, size_t rhs, size_t res, const F &f) const { run_nested_loop(lhs, rhs, res, loop_cnt, lhs_stride, rhs_stride, res_stride, f); } - bool distinct_result() const; + template <typename F> void execute_distinct(size_t lhs, size_t rhs, const F &f) const { + run_nested_loop(lhs, rhs, loop_cnt, lhs_stride, rhs_stride, f); + } + bool is_distinct() const; }; } // namespace diff --git a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h index 75b8d329763..7176e6ea6e9 100644 --- a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h +++ b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h @@ -67,7 +67,7 @@ public: SparseJoinReducePlan(const ValueType &lhs, const ValueType &rhs, const ValueType &res); ~SparseJoinReducePlan(); size_t res_dims() const { return _res_dims; } - bool distinct_result() const { return _res_dims == _in_res.size(); } + bool is_distinct() const { return _res_dims == _in_res.size(); } bool maybe_forward_lhs_index() const; bool maybe_forward_rhs_index() const; size_t estimate_result_size(const Value::Index &lhs, const Value::Index &rhs) const { diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp index 19b839a4fd3..414a54f09a8 100644 --- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp +++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp @@ -42,42 +42,6 @@ struct UniversalDotProductParam { } }; -template <typename LCT, typename RCT, typename OCT> -void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { - using dot_product = DotProduct<LCT,RCT>; - const auto ¶m = unwrap_param<UniversalDotProductParam>(param_in); - const auto &lhs = state.peek(1); - const auto &rhs = state.peek(0); - const auto &lhs_index = lhs.index(); - const auto &rhs_index = rhs.index(); - const auto lhs_cells = lhs.cells().typify<LCT>(); - const auto rhs_cells = rhs.cells().typify<RCT>(); - auto &stored_result = state.stash.create<std::unique_ptr<FastValue<OCT,true>>>( - std::make_unique<FastValue<OCT,true>>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size, - param.sparse_plan.estimate_result_size(lhs_index, rhs_index))); - auto &result = *(stored_result.get()); - OCT *dst; - auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) { - dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); - }; - auto sparse_fun = [&](size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) { - auto [space, first] = result.insert_subspace(res_addr); - if (first) { - std::fill(space.begin(), space.end(), OCT{}); - } - dst = space.data(); - param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size, - rhs_subspace * param.dense_plan.rhs_size, - 0, dense_fun); - }; - param.sparse_plan.execute(lhs_index, rhs_index, sparse_fun); - if (result.my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) { - auto [empty, ignore] = result.insert_subspace({}); - std::fill(empty.begin(), empty.end(), OCT{}); - } - state.pop_pop_push(result); -} - template <typename OCT> const Value &create_empty_result(const UniversalDotProductParam ¶m, Stash &stash) { if (param.sparse_plan.res_dims() == 0) { @@ -88,55 +52,136 @@ const Value &create_empty_result(const UniversalDotProductParam ¶m, Stash &s } } -template <typename LCT, typename RCT, typename OCT> -void my_universal_forwarding_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { - using dot_product = DotProduct<LCT,RCT>; - const auto ¶m = unwrap_param<UniversalDotProductParam>(param_in); - const auto &lhs = state.peek(1); - const auto &rhs = state.peek(0); - size_t lhs_index_size = lhs.index().size(); - size_t rhs_index_size = rhs.index().size(); - if (rhs_index_size == 0 || lhs_index_size == 0) { - state.pop_pop_push(create_empty_result<OCT>(param, state.stash)); - return; +template <typename LCT, typename RCT, bool single> struct MyDotProduct; +template <typename LCT, typename RCT> struct MyDotProduct<LCT, RCT, false> { + size_t vector_size; + MyDotProduct(size_t vector_size_in) : vector_size(vector_size_in) {} + auto operator()(const LCT *lhs, const RCT *rhs) const { + return DotProduct<LCT,RCT>::apply(lhs, rhs, vector_size); + } +}; +template <typename LCT, typename RCT> struct MyDotProduct<LCT, RCT, true> { + MyDotProduct(size_t) {} + auto operator()(const LCT *lhs, const RCT *rhs) const { + return (*lhs) * (*rhs); + } +}; + +template <typename LCT, typename RCT, typename OCT, bool distinct, bool single> +struct DenseFun { + [[no_unique_address]] MyDotProduct<LCT,RCT,single> dot_product; + const LCT *lhs; + const RCT *rhs; + mutable OCT *dst; + DenseFun(size_t vector_size_in, const Value &lhs_in, const Value &rhs_in) + : dot_product(vector_size_in), + lhs(lhs_in.cells().typify<LCT>().data()), + rhs(rhs_in.cells().typify<RCT>().data()) {} + void operator()(size_t lhs_idx, size_t rhs_idx) const requires distinct { + *dst++ = dot_product(lhs + lhs_idx, rhs + rhs_idx); } - const auto lhs_cells = lhs.cells().typify<LCT>(); - const auto rhs_cells = rhs.cells().typify<RCT>(); - auto dst_cells = state.stash.create_array<OCT>(lhs_index_size * param.dense_plan.res_size); - OCT *dst = dst_cells.data(); - auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) { - dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); - }; - for (size_t lhs_subspace = 0; lhs_subspace < lhs_index_size; ++lhs_subspace) { - for (size_t rhs_subspace = 0; rhs_subspace < rhs_index_size; ++rhs_subspace) { - param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size, - rhs_subspace * param.dense_plan.rhs_size, - lhs_subspace * param.dense_plan.res_size, dense_fun); + void operator()(size_t lhs_idx, size_t rhs_idx, size_t dst_idx) const requires (!distinct) { + dst[dst_idx] += dot_product(lhs + lhs_idx, rhs + rhs_idx); + } +}; + +template <typename OCT, bool forward> struct Result {}; +template <typename OCT> struct Result<OCT, false> { + mutable FastValue<OCT,true> *fast; +}; + +template <typename LCT, typename RCT, typename OCT, bool forward, bool distinct, bool single> +struct SparseFun { + const UniversalDotProductParam ¶m; + DenseFun<LCT,RCT,OCT,distinct,single> dense_fun; + [[no_unique_address]] Result<OCT, forward> result; + SparseFun(uint64_t param_in, const Value &lhs_in, const Value &rhs_in) + : param(unwrap_param<UniversalDotProductParam>(param_in)), + dense_fun(param.vector_size, lhs_in, rhs_in), + result() {} + void operator()(size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) const requires (!forward && !distinct) { + auto [space, first] = result.fast->insert_subspace(res_addr); + if (first) { + std::fill(space.begin(), space.end(), OCT{}); + } + dense_fun.dst = space.data(); + param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size, + rhs_subspace * param.dense_plan.rhs_size, + 0, dense_fun); + }; + void operator()(size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) const requires (!forward && distinct) { + dense_fun.dst = result.fast->add_subspace(res_addr).data(); + param.dense_plan.execute_distinct(lhs_subspace * param.dense_plan.lhs_size, + rhs_subspace * param.dense_plan.rhs_size, + dense_fun); + }; + void operator()(size_t lhs_subspace, size_t rhs_subspace) const requires (forward && !distinct) { + param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size, + rhs_subspace * param.dense_plan.rhs_size, + lhs_subspace * param.dense_plan.res_size, dense_fun); + }; + void operator()(size_t lhs_subspace, size_t rhs_subspace) const requires (forward && distinct) { + param.dense_plan.execute_distinct(lhs_subspace * param.dense_plan.lhs_size, + rhs_subspace * param.dense_plan.rhs_size, dense_fun); + }; + const Value &calculate_result(const Value::Index &lhs, const Value::Index &rhs, Stash &stash) const requires (!forward) { + auto &stored_result = stash.create<std::unique_ptr<FastValue<OCT,true>>>( + std::make_unique<FastValue<OCT,true>>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size, + param.sparse_plan.estimate_result_size(lhs, rhs))); + result.fast = stored_result.get(); + param.sparse_plan.execute(lhs, rhs, *this); + if (result.fast->my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) { + auto empty = result.fast->add_subspace(ConstArrayRef<string_id>()); + std::fill(empty.begin(), empty.end(), OCT{}); } + return *(result.fast); } - const Value &result = state.stash.create<ValueView>(param.res_type, lhs.index(), TypedCells(dst_cells)); - state.pop_pop_push(result); + const Value &calculate_result(const Value::Index &lhs, const Value::Index &rhs, Stash &stash) const requires forward { + size_t lhs_size = lhs.size(); + size_t rhs_size = rhs.size(); + if (lhs_size == 0 || rhs_size == 0) { + return create_empty_result<OCT>(param, stash); + } + auto dst_cells = (distinct) + ? stash.create_uninitialized_array<OCT>(lhs_size * param.dense_plan.res_size) + : stash.create_array<OCT>(lhs_size * param.dense_plan.res_size); + dense_fun.dst = dst_cells.data(); + for (size_t lhs_idx = 0; lhs_idx < lhs_size; ++lhs_idx) { + for (size_t rhs_idx = 0; rhs_idx < rhs_size; ++rhs_idx) { + (*this)(lhs_idx, rhs_idx); + } + } + return stash.create<ValueView>(param.res_type, lhs, TypedCells(dst_cells)); + } +}; + +template <typename LCT, typename RCT, typename OCT, bool forward, bool distinct, bool single> +void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { + SparseFun<LCT,RCT,OCT,forward,distinct,single> sparse_fun(param_in, state.peek(1), state.peek(0)); + state.pop_pop_push(sparse_fun.calculate_result(state.peek(1).index(), state.peek(0).index(), state.stash)); } struct SelectUniversalDotProduct { - template <typename LCM, typename RCM, typename SCALAR> static auto invoke(const UniversalDotProductParam ¶m) { + template <typename LCM, typename RCM, typename SCALAR, typename FORWARD, typename DISTINCT, typename SINGLE> + static auto invoke() { constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(SCALAR::value); using LCT = CellValueType<LCM::value.cell_type>; using RCT = CellValueType<RCM::value.cell_type>; using OCT = CellValueType<ocm.cell_type>; - if (param.sparse_plan.maybe_forward_lhs_index()) { - return my_universal_forwarding_dot_product_op<LCT,RCT,OCT>; + if constexpr ((std::same_as<LCT,float> && std::same_as<RCT,float>) || + (std::same_as<LCT,double> && std::same_as<RCT,double>)) + { + return my_universal_dot_product_op<LCT,RCT,OCT,FORWARD::value,DISTINCT::value,SINGLE::value>; } - return my_universal_dot_product_op<LCT,RCT,OCT>; + return my_universal_dot_product_op<LCT,RCT,OCT,FORWARD::value,false,false>; } }; -bool check_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { - (void) res; +bool check_types(const ValueType &lhs, const ValueType &rhs) { if (lhs.is_double() || rhs.is_double()) { return false; } - if (lhs.count_mapped_dimensions() > 0 || rhs.count_mapped_dimensions() > 0) { + if (lhs.count_mapped_dimensions() > 0 && rhs.count_mapped_dimensions() > 0) { return true; } return false; @@ -156,10 +201,12 @@ UniversalDotProduct::compile_self(const ValueBuilderFactory &, Stash &stash) con { auto ¶m = stash.create<UniversalDotProductParam>(result_type(), lhs().result_type(), rhs().result_type()); using MyTypify = TypifyValue<TypifyCellMeta,TypifyBool>; - auto op = typify_invoke<3,MyTypify,SelectUniversalDotProduct>(lhs().result_type().cell_meta(), + auto op = typify_invoke<6,MyTypify,SelectUniversalDotProduct>(lhs().result_type().cell_meta(), rhs().result_type().cell_meta(), result_type().cell_meta().is_scalar, - param); + param.sparse_plan.maybe_forward_lhs_index(), + param.sparse_plan.is_distinct() && param.dense_plan.is_distinct(), + param.vector_size == 1); return InterpretedFunction::Instruction(op, wrap_param<UniversalDotProductParam>(param)); } @@ -171,7 +218,7 @@ UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash, bool for const ValueType &res_type = expr.result_type(); const ValueType &lhs_type = join->lhs().result_type(); const ValueType &rhs_type = join->rhs().result_type(); - if (force || check_types(res_type, lhs_type, rhs_type)) { + if (force || check_types(lhs_type, rhs_type)) { SparseJoinReducePlan sparse_plan(lhs_type, rhs_type, res_type); if (sparse_plan.maybe_forward_rhs_index() && !sparse_plan.maybe_forward_lhs_index()) { return stash.create<UniversalDotProduct>(res_type, join->rhs(), join->lhs()); |