summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-09-07 13:01:03 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-09-08 14:00:39 +0000
commit2f0877044912b3e39ae355189c780341cf117892 (patch)
tree4a93b5c2d359f94a6357f2c67bfceeff5b57a2d8 /eval
parentfa6c99ac39cba9fd07ae9229995ec5cdc614ddbb (diff)
handle 'distinct' and 'single' flags using templates
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp8
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp152
-rw-r--r--eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp2
-rw-r--r--eval/src/vespa/eval/instruction/dense_join_reduce_plan.h5
-rw-r--r--eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h2
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.cpp189
6 files changed, 205 insertions, 153 deletions
diff --git a/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp b/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp
index 9851e209ba5..7faf7d57738 100644
--- a/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp
+++ b/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp
@@ -13,7 +13,7 @@ ValueType type(const vespalib::string &type_spec) {
TEST(DenseJoinReducePlanTest, make_trivial_plan) {
auto plan = DenseJoinReducePlan(type("double"), type("double"), type("double"));
- EXPECT_TRUE(plan.distinct_result());
+ EXPECT_TRUE(plan.is_distinct());
EXPECT_EQ(plan.lhs_size, 1);
EXPECT_EQ(plan.rhs_size, 1);
EXPECT_EQ(plan.res_size, 1);
@@ -39,7 +39,7 @@ TEST(DenseJoinReducePlanTest, make_simple_plan) {
SmallVector<size_t> expect_lhs_stride = {1,0};
SmallVector<size_t> expect_rhs_stride = {0,1};
SmallVector<size_t> expect_res_stride = {1,0};
- EXPECT_FALSE(plan.distinct_result());
+ EXPECT_FALSE(plan.is_distinct());
EXPECT_EQ(plan.lhs_size, 2);
EXPECT_EQ(plan.rhs_size, 3);
EXPECT_EQ(plan.res_size, 2);
@@ -69,7 +69,7 @@ TEST(DenseJoinReducePlanTest, make_distinct_plan) {
SmallVector<size_t> expect_lhs_stride = {1,0};
SmallVector<size_t> expect_rhs_stride = {0,1};
SmallVector<size_t> expect_res_stride = {3,1};
- EXPECT_TRUE(plan.distinct_result());
+ EXPECT_TRUE(plan.is_distinct());
EXPECT_EQ(plan.lhs_size, 2);
EXPECT_EQ(plan.rhs_size, 3);
EXPECT_EQ(plan.res_size, 6);
@@ -88,7 +88,7 @@ TEST(DenseJoinReducePlanTest, make_complex_plan) {
SmallVector<size_t> expect_lhs_stride = {6,0,2,1};
SmallVector<size_t> expect_rhs_stride = {4,1,0,0};
SmallVector<size_t> expect_res_stride = {12,3,1,0};
- EXPECT_FALSE(plan.distinct_result());
+ EXPECT_FALSE(plan.is_distinct());
EXPECT_EQ(plan.lhs_size, 180);
EXPECT_EQ(plan.rhs_size, 120);
EXPECT_EQ(plan.res_size, 360);
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index 4ca2a5ef79a..e1967f012cb 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -27,31 +27,7 @@ using vespalib::make_string_short::fmt;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
bool bench = false;
double budget = 1.0;
-
-GenSpec::seq_t N_16ths = [] (size_t i) noexcept { return (i + 33.0) / 16.0; };
-
-GenSpec G() { return GenSpec().seq(N_16ths); }
-
-const std::vector<GenSpec> layouts = {
- G(), G(),
- G().idx("x", 5), G().idx("x", 5),
- G().idx("x", 5), G().idx("y", 5),
- G().idx("x", 5), G().idx("x", 5).idx("y", 5),
- G().idx("y", 3), G().idx("x", 2).idx("z", 3),
- G().idx("x", 3).idx("y", 5), G().idx("y", 5).idx("z", 7),
- G().map("x", {"a","b","c"}), G().map("x", {"a","b","c"}),
- G().map("x", {"a","b","c"}), G().map("x", {"a","b"}),
- G().map("x", {"a","b","c"}), G().map("y", {"foo","bar","baz"}),
- G().map("x", {"a","b","c"}), G().map("x", {"a","b","c"}).map("y", {"foo","bar","baz"}),
- G().map("x", {"a","b"}).map("y", {"foo","bar","baz"}), G().map("x", {"a","b","c"}).map("y", {"foo","bar"}),
- G().map("x", {"a","b"}).map("y", {"foo","bar","baz"}), G().map("y", {"foo","bar"}).map("z", {"i","j","k","l"}),
- G().idx("x", 3).map("y", {"foo", "bar"}), G().map("y", {"foo", "bar"}).idx("z", 7),
- G().map("x", {"a","b","c"}).idx("y", 5), G().idx("y", 5).map("z", {"i","j","k","l"})
-};
-
-const std::vector<std::vector<vespalib::string>> reductions = {
- {}, {"x"}, {"y"}, {"z"}, {"x", "y"}, {"x", "z"}, {"y", "z"}
-};
+size_t verify_cnt = 0;
std::vector<std::string> ns_list = {
{"vespalib::eval::instruction::(anonymous namespace)::"},
@@ -76,14 +52,19 @@ std::string strip_ns(const vespalib::string &str) {
return tmp;
}
-TensorSpec make_spec(const vespalib::string &param_name, size_t idx) {
- return GenSpec::from_desc(param_name).cells_double().seq(N(1 + idx));
+using select_cell_type_t = std::function<CellType(size_t idx)>;
+CellType always_double(size_t) { return CellType::DOUBLE; }
+select_cell_type_t select(CellType lct) { return [lct](size_t)noexcept{ return lct; }; }
+select_cell_type_t select(CellType lct, CellType rct) { return [lct,rct](size_t idx)noexcept{ return idx ? rct : lct; }; }
+
+TensorSpec make_spec(const vespalib::string &param_name, size_t idx, select_cell_type_t select_cell_type) {
+ return GenSpec::from_desc(param_name).cells(select_cell_type(idx)).seq(N(1 + idx));
}
-TensorSpec eval_ref(const Function &fun) {
+TensorSpec eval_ref(const Function &fun, select_cell_type_t select_cell_type) {
std::vector<TensorSpec> params;
for (size_t i = 0; i < fun.num_params(); ++i) {
- params.push_back(make_spec(fun.param_name(i), i));
+ params.push_back(make_spec(fun.param_name(i), i, select_cell_type));
}
return ReferenceEvaluation::eval(fun, params);
}
@@ -134,12 +115,13 @@ Optimize universal_only() {
return Optimize::specific("universal_only", my_optimizer);
}
-void verify(const vespalib::string &expr) {
+void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
+ ++verify_cnt;
auto fun = Function::parse(expr);
ASSERT_FALSE(fun->has_error());
std::vector<Value::UP> values;
for (size_t i = 0; i < fun->num_params(); ++i) {
- auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory);
+ auto value = value_from_spec(make_spec(fun->param_name(i), i, select_cell_type), prod_factory);
values.push_back(std::move(value));
}
SimpleObjectParams params({});
@@ -167,23 +149,24 @@ void verify(const vespalib::string &expr) {
} else {
EXPECT_EQ(actual.cells().size, actual.index().size() * expected_type.dense_subspace_size());
}
- auto expected = eval_ref(*fun);
+ auto expected = eval_ref(*fun, select_cell_type);
EXPECT_EQ(spec_from_value(actual), expected);
}
+void verify(const vespalib::string &expr) { verify(expr, always_double); }
using cost_list_t = std::vector<std::pair<vespalib::string,double>>;
std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results;
void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
+ verify(expr);
auto fun = Function::parse(expr);
ASSERT_FALSE(fun->has_error());
- auto expected = eval_ref(*fun);
cost_list_t cost_list;
fprintf(stderr, "BENCH: %s\n", expr.c_str());
for (Optimize &optimize: list) {
std::vector<Value::UP> values;
for (size_t i = 0; i < fun->num_params(); ++i) {
- auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory);
+ auto value = value_from_spec(make_spec(fun->param_name(i), i, always_double), prod_factory);
values.push_back(std::move(value));
}
SimpleObjectParams params({});
@@ -218,8 +201,6 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta);
InterpretedFunction::ProfiledContext pctx(ifun);
ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size());
- EXPECT_EQ(spec_from_value(ifun.eval(pctx.context, params)), expected);
- EXPECT_EQ(spec_from_value(ifun.eval(pctx, params)), expected);
std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero());
std::vector<duration> min_time(ctf_meta.steps.size(), duration::max());
BenchmarkTimer timer(budget);
@@ -251,47 +232,63 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
benchmark_results.emplace_back(expr, std::move(cost_list));
}
-TensorSpec perform_dot_product(const TensorSpec &a, const TensorSpec &b, const std::vector<vespalib::string> &dims)
-{
- Stash stash;
- auto lhs = value_from_spec(a, prod_factory);
- auto rhs = value_from_spec(b, prod_factory);
- auto res_type = ValueType::join(lhs->type(), rhs->type()).reduce(dims);
- EXPECT_FALSE(res_type.is_error());
- UniversalDotProduct dot_product(res_type,
- tensor_function::inject(lhs->type(), 0, stash),
- tensor_function::inject(rhs->type(), 1, stash));
- auto my_op = dot_product.compile_self(prod_factory, stash);
- InterpretedFunction::EvalSingle single(prod_factory, my_op);
- return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs,*rhs})));
+TEST(UniversalDotProductTest, test_select_cell_types) {
+ auto always = always_double;
+ EXPECT_EQ(always(0), CellType::DOUBLE);
+ EXPECT_EQ(always(1), CellType::DOUBLE);
+ EXPECT_EQ(always(0), CellType::DOUBLE);
+ EXPECT_EQ(always(1), CellType::DOUBLE);
+ for (CellType lct: CellTypeUtils::list_types()) {
+ auto sel1 = select(lct);
+ EXPECT_EQ(sel1(0), lct);
+ EXPECT_EQ(sel1(1), lct);
+ EXPECT_EQ(sel1(0), lct);
+ EXPECT_EQ(sel1(1), lct);
+ for (CellType rct: CellTypeUtils::list_types()) {
+ auto sel2 = select(lct, rct);
+ EXPECT_EQ(sel2(0), lct);
+ EXPECT_EQ(sel2(1), rct);
+ EXPECT_EQ(sel2(0), lct);
+ EXPECT_EQ(sel2(1), rct);
+ }
+ }
}
-TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) {
- size_t test_cases = 0;
- ASSERT_TRUE((layouts.size() % 2) == 0);
- for (size_t i = 0; i < layouts.size(); i += 2) {
- const auto &l = layouts[i];
- const auto &r = layouts[i+1];
- for (CellType lct : CellTypeUtils::list_types()) {
- auto lhs = l.cpy().cells(lct);
- if (lhs.bad_scalar()) continue;
- for (CellType rct : CellTypeUtils::list_types()) {
- auto rhs = r.cpy().cells(rct);
- if (rhs.bad_scalar()) continue;
- for (const std::vector<vespalib::string> &dims: reductions) {
- if (ValueType::join(lhs.type(), rhs.type()).reduce(dims).is_error()) continue;
- ++test_cases;
- SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str()));
- auto expect = ReferenceOperations::reduce(ReferenceOperations::join(lhs, rhs, operation::Mul::f), Aggr::SUM, dims);
- auto actual = perform_dot_product(lhs, rhs, dims);
- // fprintf(stderr, "\n===\nLHS: %s\nRHS: %s\n===\nRESULT: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str(), actual.to_string().c_str());
- EXPECT_EQ(actual, expect);
- }
- }
+TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) {
+ // forward, distinct, single
+ verify("reduce(2.0*3.0, sum)");
+
+ for (CellType lct: CellTypeUtils::list_types()) {
+ for (CellType rct: CellTypeUtils::list_types()) {
+ auto sel2 = select(lct, rct);
+ // !forward, !distinct, !single
+ verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2);
+
+ // !forward, !distinct, single
+ verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2);
+
+ // !forward, distinct, !single
+ verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2);
+
+ // forward, !distinct, !single
+ verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2);
+
+ // forward, !distinct, single
+ verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2);
+
+ // forward, distinct, !single
+ verify("reduce(a4_1x8*x8,sum,x)", sel2);
}
}
- EXPECT_GT(test_cases, 500);
- fprintf(stderr, "total test cases run: %zu\n", test_cases);
+ // !forward, distinct, single
+
+ // This case is not possible since 'distinct' implies '!single' as
+ // long as we reduce anything. The only expression allowed to
+ // reduce nothing is the scalar case.
+}
+
+TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) {
+ verify("reduce(a4_1b4_1c4_1x4y3z2w1*a2_1c1_1x4z2,sum,b,c,x)");
}
TEST(UniversalDotProductTest, forwarding_empty_result) {
@@ -336,8 +333,11 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) {
}
auto optimize_list = std::vector<Optimize>({baseline(), with_universal(), universal_only()});
- benchmark("reduce(1.0*2.0,sum)", optimize_list);
+ benchmark("reduce(2.0*3.0,sum)", optimize_list);
benchmark("reduce(5.0*x128,sum,x)", optimize_list);
+ benchmark("reduce(a1*x128,sum,x)", optimize_list);
+ benchmark("reduce(a8*x128,sum,x)", optimize_list);
+ benchmark("reduce(a1_1b8*x128,sum,x)", optimize_list);
benchmark("reduce(x16*x16,sum,x)", optimize_list);
benchmark("reduce(x768*x768,sum,x)", optimize_list);
benchmark("reduce(y64*x8y64,sum,x,y)", optimize_list);
@@ -417,5 +417,7 @@ int main(int argc, char **argv) {
--argc;
}
::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
+ int result = RUN_ALL_TESTS();
+ fprintf(stderr, "verify called %zu times\n", verify_cnt);
+ return result;
}
diff --git a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp
index 20b7d3364a8..8d09abbfe15 100644
--- a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp
+++ b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.cpp
@@ -82,7 +82,7 @@ DenseJoinReducePlan::DenseJoinReducePlan(const ValueType &lhs, const ValueType &
DenseJoinReducePlan::~DenseJoinReducePlan() = default;
bool
-DenseJoinReducePlan::distinct_result() const
+DenseJoinReducePlan::is_distinct() const
{
for (size_t stride: res_stride) {
if (stride == 0) {
diff --git a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h
index 8f9d5218630..3cf55e9ace4 100644
--- a/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h
+++ b/eval/src/vespa/eval/instruction/dense_join_reduce_plan.h
@@ -21,7 +21,10 @@ struct DenseJoinReducePlan {
template <typename F> void execute(size_t lhs, size_t rhs, size_t res, const F &f) const {
run_nested_loop(lhs, rhs, res, loop_cnt, lhs_stride, rhs_stride, res_stride, f);
}
- bool distinct_result() const;
+ template <typename F> void execute_distinct(size_t lhs, size_t rhs, const F &f) const {
+ run_nested_loop(lhs, rhs, loop_cnt, lhs_stride, rhs_stride, f);
+ }
+ bool is_distinct() const;
};
} // namespace
diff --git a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h
index 75b8d329763..7176e6ea6e9 100644
--- a/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h
+++ b/eval/src/vespa/eval/instruction/sparse_join_reduce_plan.h
@@ -67,7 +67,7 @@ public:
SparseJoinReducePlan(const ValueType &lhs, const ValueType &rhs, const ValueType &res);
~SparseJoinReducePlan();
size_t res_dims() const { return _res_dims; }
- bool distinct_result() const { return _res_dims == _in_res.size(); }
+ bool is_distinct() const { return _res_dims == _in_res.size(); }
bool maybe_forward_lhs_index() const;
bool maybe_forward_rhs_index() const;
size_t estimate_result_size(const Value::Index &lhs, const Value::Index &rhs) const {
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
index 19b839a4fd3..414a54f09a8 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
@@ -42,42 +42,6 @@ struct UniversalDotProductParam {
}
};
-template <typename LCT, typename RCT, typename OCT>
-void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) {
- using dot_product = DotProduct<LCT,RCT>;
- const auto &param = unwrap_param<UniversalDotProductParam>(param_in);
- const auto &lhs = state.peek(1);
- const auto &rhs = state.peek(0);
- const auto &lhs_index = lhs.index();
- const auto &rhs_index = rhs.index();
- const auto lhs_cells = lhs.cells().typify<LCT>();
- const auto rhs_cells = rhs.cells().typify<RCT>();
- auto &stored_result = state.stash.create<std::unique_ptr<FastValue<OCT,true>>>(
- std::make_unique<FastValue<OCT,true>>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size,
- param.sparse_plan.estimate_result_size(lhs_index, rhs_index)));
- auto &result = *(stored_result.get());
- OCT *dst;
- auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) {
- dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size);
- };
- auto sparse_fun = [&](size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) {
- auto [space, first] = result.insert_subspace(res_addr);
- if (first) {
- std::fill(space.begin(), space.end(), OCT{});
- }
- dst = space.data();
- param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size,
- rhs_subspace * param.dense_plan.rhs_size,
- 0, dense_fun);
- };
- param.sparse_plan.execute(lhs_index, rhs_index, sparse_fun);
- if (result.my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) {
- auto [empty, ignore] = result.insert_subspace({});
- std::fill(empty.begin(), empty.end(), OCT{});
- }
- state.pop_pop_push(result);
-}
-
template <typename OCT>
const Value &create_empty_result(const UniversalDotProductParam &param, Stash &stash) {
if (param.sparse_plan.res_dims() == 0) {
@@ -88,55 +52,136 @@ const Value &create_empty_result(const UniversalDotProductParam &param, Stash &s
}
}
-template <typename LCT, typename RCT, typename OCT>
-void my_universal_forwarding_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) {
- using dot_product = DotProduct<LCT,RCT>;
- const auto &param = unwrap_param<UniversalDotProductParam>(param_in);
- const auto &lhs = state.peek(1);
- const auto &rhs = state.peek(0);
- size_t lhs_index_size = lhs.index().size();
- size_t rhs_index_size = rhs.index().size();
- if (rhs_index_size == 0 || lhs_index_size == 0) {
- state.pop_pop_push(create_empty_result<OCT>(param, state.stash));
- return;
+template <typename LCT, typename RCT, bool single> struct MyDotProduct;
+template <typename LCT, typename RCT> struct MyDotProduct<LCT, RCT, false> {
+ size_t vector_size;
+ MyDotProduct(size_t vector_size_in) : vector_size(vector_size_in) {}
+ auto operator()(const LCT *lhs, const RCT *rhs) const {
+ return DotProduct<LCT,RCT>::apply(lhs, rhs, vector_size);
+ }
+};
+template <typename LCT, typename RCT> struct MyDotProduct<LCT, RCT, true> {
+ MyDotProduct(size_t) {}
+ auto operator()(const LCT *lhs, const RCT *rhs) const {
+ return (*lhs) * (*rhs);
+ }
+};
+
+template <typename LCT, typename RCT, typename OCT, bool distinct, bool single>
+struct DenseFun {
+ [[no_unique_address]] MyDotProduct<LCT,RCT,single> dot_product;
+ const LCT *lhs;
+ const RCT *rhs;
+ mutable OCT *dst;
+ DenseFun(size_t vector_size_in, const Value &lhs_in, const Value &rhs_in)
+ : dot_product(vector_size_in),
+ lhs(lhs_in.cells().typify<LCT>().data()),
+ rhs(rhs_in.cells().typify<RCT>().data()) {}
+ void operator()(size_t lhs_idx, size_t rhs_idx) const requires distinct {
+ *dst++ = dot_product(lhs + lhs_idx, rhs + rhs_idx);
}
- const auto lhs_cells = lhs.cells().typify<LCT>();
- const auto rhs_cells = rhs.cells().typify<RCT>();
- auto dst_cells = state.stash.create_array<OCT>(lhs_index_size * param.dense_plan.res_size);
- OCT *dst = dst_cells.data();
- auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) {
- dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size);
- };
- for (size_t lhs_subspace = 0; lhs_subspace < lhs_index_size; ++lhs_subspace) {
- for (size_t rhs_subspace = 0; rhs_subspace < rhs_index_size; ++rhs_subspace) {
- param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size,
- rhs_subspace * param.dense_plan.rhs_size,
- lhs_subspace * param.dense_plan.res_size, dense_fun);
+ void operator()(size_t lhs_idx, size_t rhs_idx, size_t dst_idx) const requires (!distinct) {
+ dst[dst_idx] += dot_product(lhs + lhs_idx, rhs + rhs_idx);
+ }
+};
+
+template <typename OCT, bool forward> struct Result {};
+template <typename OCT> struct Result<OCT, false> {
+ mutable FastValue<OCT,true> *fast;
+};
+
+template <typename LCT, typename RCT, typename OCT, bool forward, bool distinct, bool single>
+struct SparseFun {
+ const UniversalDotProductParam &param;
+ DenseFun<LCT,RCT,OCT,distinct,single> dense_fun;
+ [[no_unique_address]] Result<OCT, forward> result;
+ SparseFun(uint64_t param_in, const Value &lhs_in, const Value &rhs_in)
+ : param(unwrap_param<UniversalDotProductParam>(param_in)),
+ dense_fun(param.vector_size, lhs_in, rhs_in),
+ result() {}
+ void operator()(size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) const requires (!forward && !distinct) {
+ auto [space, first] = result.fast->insert_subspace(res_addr);
+ if (first) {
+ std::fill(space.begin(), space.end(), OCT{});
+ }
+ dense_fun.dst = space.data();
+ param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size,
+ rhs_subspace * param.dense_plan.rhs_size,
+ 0, dense_fun);
+ };
+ void operator()(size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) const requires (!forward && distinct) {
+ dense_fun.dst = result.fast->add_subspace(res_addr).data();
+ param.dense_plan.execute_distinct(lhs_subspace * param.dense_plan.lhs_size,
+ rhs_subspace * param.dense_plan.rhs_size,
+ dense_fun);
+ };
+ void operator()(size_t lhs_subspace, size_t rhs_subspace) const requires (forward && !distinct) {
+ param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size,
+ rhs_subspace * param.dense_plan.rhs_size,
+ lhs_subspace * param.dense_plan.res_size, dense_fun);
+ };
+ void operator()(size_t lhs_subspace, size_t rhs_subspace) const requires (forward && distinct) {
+ param.dense_plan.execute_distinct(lhs_subspace * param.dense_plan.lhs_size,
+ rhs_subspace * param.dense_plan.rhs_size, dense_fun);
+ };
+ const Value &calculate_result(const Value::Index &lhs, const Value::Index &rhs, Stash &stash) const requires (!forward) {
+ auto &stored_result = stash.create<std::unique_ptr<FastValue<OCT,true>>>(
+ std::make_unique<FastValue<OCT,true>>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size,
+ param.sparse_plan.estimate_result_size(lhs, rhs)));
+ result.fast = stored_result.get();
+ param.sparse_plan.execute(lhs, rhs, *this);
+ if (result.fast->my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) {
+ auto empty = result.fast->add_subspace(ConstArrayRef<string_id>());
+ std::fill(empty.begin(), empty.end(), OCT{});
}
+ return *(result.fast);
}
- const Value &result = state.stash.create<ValueView>(param.res_type, lhs.index(), TypedCells(dst_cells));
- state.pop_pop_push(result);
+ const Value &calculate_result(const Value::Index &lhs, const Value::Index &rhs, Stash &stash) const requires forward {
+ size_t lhs_size = lhs.size();
+ size_t rhs_size = rhs.size();
+ if (lhs_size == 0 || rhs_size == 0) {
+ return create_empty_result<OCT>(param, stash);
+ }
+ auto dst_cells = (distinct)
+ ? stash.create_uninitialized_array<OCT>(lhs_size * param.dense_plan.res_size)
+ : stash.create_array<OCT>(lhs_size * param.dense_plan.res_size);
+ dense_fun.dst = dst_cells.data();
+ for (size_t lhs_idx = 0; lhs_idx < lhs_size; ++lhs_idx) {
+ for (size_t rhs_idx = 0; rhs_idx < rhs_size; ++rhs_idx) {
+ (*this)(lhs_idx, rhs_idx);
+ }
+ }
+ return stash.create<ValueView>(param.res_type, lhs, TypedCells(dst_cells));
+ }
+};
+
+template <typename LCT, typename RCT, typename OCT, bool forward, bool distinct, bool single>
+void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) {
+ SparseFun<LCT,RCT,OCT,forward,distinct,single> sparse_fun(param_in, state.peek(1), state.peek(0));
+ state.pop_pop_push(sparse_fun.calculate_result(state.peek(1).index(), state.peek(0).index(), state.stash));
}
struct SelectUniversalDotProduct {
- template <typename LCM, typename RCM, typename SCALAR> static auto invoke(const UniversalDotProductParam &param) {
+ template <typename LCM, typename RCM, typename SCALAR, typename FORWARD, typename DISTINCT, typename SINGLE>
+ static auto invoke() {
constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(SCALAR::value);
using LCT = CellValueType<LCM::value.cell_type>;
using RCT = CellValueType<RCM::value.cell_type>;
using OCT = CellValueType<ocm.cell_type>;
- if (param.sparse_plan.maybe_forward_lhs_index()) {
- return my_universal_forwarding_dot_product_op<LCT,RCT,OCT>;
+ if constexpr ((std::same_as<LCT,float> && std::same_as<RCT,float>) ||
+ (std::same_as<LCT,double> && std::same_as<RCT,double>))
+ {
+ return my_universal_dot_product_op<LCT,RCT,OCT,FORWARD::value,DISTINCT::value,SINGLE::value>;
}
- return my_universal_dot_product_op<LCT,RCT,OCT>;
+ return my_universal_dot_product_op<LCT,RCT,OCT,FORWARD::value,false,false>;
}
};
-bool check_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) {
- (void) res;
+bool check_types(const ValueType &lhs, const ValueType &rhs) {
if (lhs.is_double() || rhs.is_double()) {
return false;
}
- if (lhs.count_mapped_dimensions() > 0 || rhs.count_mapped_dimensions() > 0) {
+ if (lhs.count_mapped_dimensions() > 0 && rhs.count_mapped_dimensions() > 0) {
return true;
}
return false;
@@ -156,10 +201,12 @@ UniversalDotProduct::compile_self(const ValueBuilderFactory &, Stash &stash) con
{
auto &param = stash.create<UniversalDotProductParam>(result_type(), lhs().result_type(), rhs().result_type());
using MyTypify = TypifyValue<TypifyCellMeta,TypifyBool>;
- auto op = typify_invoke<3,MyTypify,SelectUniversalDotProduct>(lhs().result_type().cell_meta(),
+ auto op = typify_invoke<6,MyTypify,SelectUniversalDotProduct>(lhs().result_type().cell_meta(),
rhs().result_type().cell_meta(),
result_type().cell_meta().is_scalar,
- param);
+ param.sparse_plan.maybe_forward_lhs_index(),
+ param.sparse_plan.is_distinct() && param.dense_plan.is_distinct(),
+ param.vector_size == 1);
return InterpretedFunction::Instruction(op, wrap_param<UniversalDotProductParam>(param));
}
@@ -171,7 +218,7 @@ UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash, bool for
const ValueType &res_type = expr.result_type();
const ValueType &lhs_type = join->lhs().result_type();
const ValueType &rhs_type = join->rhs().result_type();
- if (force || check_types(res_type, lhs_type, rhs_type)) {
+ if (force || check_types(lhs_type, rhs_type)) {
SparseJoinReducePlan sparse_plan(lhs_type, rhs_type, res_type);
if (sparse_plan.maybe_forward_rhs_index() && !sparse_plan.maybe_forward_lhs_index()) {
return stash.create<UniversalDotProduct>(res_type, join->rhs(), join->lhs());