diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-09-06 12:59:04 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-09-07 07:57:52 +0000 |
commit | b1ba5c2995b76cbee5830a9959cd7b22364e5b65 (patch) | |
tree | adda7458010cb86d8537126cfb07c3705f8f233e /eval | |
parent | daba552c567f1fcb9e300ae65825c1d97cedbb5e (diff) |
handle expanding reduce
more testing of corner cases
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp | 72 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/universal_dot_product.cpp | 36 |
2 files changed, 97 insertions, 11 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index 6c0726dab37..a25dba64671 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -134,6 +134,43 @@ Optimize universal_only() { return Optimize::specific("universal_only", my_optimizer); } +void verify(const vespalib::string &expr) { + auto fun = Function::parse(expr); + ASSERT_FALSE(fun->has_error()); + std::vector<Value::UP> values; + for (size_t i = 0; i < fun->num_params(); ++i) { + auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory); + values.push_back(std::move(value)); + } + SimpleObjectParams params({}); + std::vector<ValueType> param_types; + for (auto &&up: values) { + params.params.emplace_back(*up); + param_types.push_back(up->type()); + } + NodeTypes node_types(*fun, param_types); + const ValueType &expected_type = node_types.get_type(fun->root()); + ASSERT_FALSE(expected_type.is_error()); + Stash stash; + size_t count = 0; + const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash); + const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count); + ASSERT_GT(count, 0); + InterpretedFunction ifun(prod_factory, optimized); + InterpretedFunction::Context ctx(ifun); + const Value &actual = ifun.eval(ctx, params); + EXPECT_EQ(actual.type(), expected_type); + EXPECT_EQ(actual.cells().type, expected_type.cell_type()); + if (expected_type.count_mapped_dimensions() == 0) { + EXPECT_EQ(actual.index().size(), TrivialIndex::get().size()); + EXPECT_EQ(actual.cells().size, expected_type.dense_subspace_size()); + } else { + EXPECT_EQ(actual.cells().size, actual.index().size() * expected_type.dense_subspace_size()); + } + auto expected = eval_ref(*fun); + EXPECT_EQ(spec_from_value(actual), expected); +} + using cost_list_t = std::vector<std::pair<vespalib::string,double>>; std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results; @@ -257,6 +294,39 @@ TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) { fprintf(stderr, "total test cases run: %zu\n", test_cases); } +TEST(UniversalDotProductTest, forwarding_empty_result) { + verify("reduce(x0_0*y8_1,sum,y)"); + verify("reduce(x8_1*y0_0,sum,y)"); + verify("reduce(x0_0z16*y8_1z16,sum,y)"); + verify("reduce(x8_1z16*y0_0z16,sum,y)"); +} + +TEST(UniversalDotProductTest, nonforwarding_empty_result) { + verify("reduce(x0_0y8*x1_1y8,sum,y)"); + verify("reduce(x1_1y8*x0_0y8,sum,y)"); + verify("reduce(x1_7y8z2*x1_1y8z2,sum,y)"); +} + +TEST(UniversalDotProductTest, forwarding_expanding_reduce) { + verify("reduce(5.0*y0_0,sum,y)"); + verify("reduce(z16*y0_0,sum,y)"); + verify("reduce(x1_1*y0_0,sum,y)"); + verify("reduce(x0_0*y1_1,sum,y)"); + verify("reduce(x1_1z16*y0_0,sum,y)"); + verify("reduce(x0_0z16*y1_1,sum,y)"); +} + +TEST(UniversalDotProductTest, nonforwarding_expanding_reduce) { + verify("reduce(x0_0*y1_1,sum,x,y)"); + verify("reduce(x1_1*y0_0,sum,x,y)"); + verify("reduce(x0_0y16*x1_1y16,sum,x)"); + verify("reduce(x1_1y16*x0_0y16,sum,x)"); + verify("reduce(x1_7*y1_1,sum,x,y)"); + verify("reduce(x1_1*y1_7,sum,x,y)"); + verify("reduce(x1_7y16*x1_1y16,sum,x)"); + verify("reduce(x1_1y16*x1_7y16,sum,x)"); +} + TEST(UniversalDotProductTest, bench_vector_dot_product) { if (!bench) { fprintf(stderr, "benchmarking disabled, run with 'bench' parameter to enable\n"); @@ -284,8 +354,6 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) { benchmark("reduce(b64_1x8y128*x8y128,sum,y)", optimize_list); benchmark("reduce(b64_1x128*x128,sum,b,x)", optimize_list); benchmark("reduce(a1_1x128*a2_1b64_1x128,sum,a,x)", optimize_list); - benchmark("reduce(x0_0*y8_1,sum,y)", optimize_list); - benchmark("reduce(x8_1*y0_0,sum,y)", optimize_list); size_t max_expr_size = 0; for (const auto &[expr, cost_list]: benchmark_results) { diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp index 3811508a543..209645743fa 100644 --- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp +++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp @@ -56,26 +56,44 @@ void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t par std::make_unique<FastValue<OCT,true>>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size, param.sparse_plan.estimate_result_size(lhs_index, rhs_index))); auto &result = *(stored_result.get()); - ArrayRef<OCT> dst; + OCT *dst; auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) { dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); }; auto sparse_fun = [&](size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) { - bool first; - std::tie(dst, first) = result.insert_subspace(res_addr); + auto [space, first] = result.insert_subspace(res_addr); if (first) { - std::fill(dst.begin(), dst.end(), OCT{}); + std::fill(space.begin(), space.end(), OCT{}); } + dst = space.data(); param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size, rhs_subspace * param.dense_plan.rhs_size, 0, dense_fun); }; param.sparse_plan.execute(lhs_index, rhs_index, sparse_fun); + if (result.my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) { + auto [empty, ignore] = result.insert_subspace({}); + std::fill(empty.begin(), empty.end(), OCT{}); + } state.pop_pop_push(result); } +template <typename OCT> +const Value &create_empty_result(const UniversalDotProductParam ¶m, Stash &stash) { + if (param.sparse_plan.res_dims() == 0) { + if (param.dense_plan.res_stride.empty()) { + return stash.create<DoubleValue>(0.0); + } else { + auto zero_cells = stash.create_array<OCT>(param.dense_plan.res_size); + return stash.create<ValueView>(param.res_type, TrivialIndex::get(), TypedCells(zero_cells)); + } + } else { + return stash.create<ValueView>(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type<OCT>(), 0)); + } +} + template <typename LCT, typename RCT, typename OCT> -void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { +void my_universal_forwarding_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { using dot_product = DotProduct<LCT,RCT>; const auto ¶m = unwrap_param<UniversalDotProductParam>(param_in); const auto &lhs = state.peek(1); @@ -83,15 +101,15 @@ void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64 size_t lhs_index_size = lhs.index().size(); size_t rhs_index_size = rhs.index().size(); if (rhs_index_size == 0 || lhs_index_size == 0) { - const Value &empty = state.stash.create<ValueView>(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type<OCT>(), 0)); - state.pop_pop_push(empty); + state.pop_pop_push(create_empty_result<OCT>(param, state.stash)); return; } const auto lhs_cells = lhs.cells().typify<LCT>(); const auto rhs_cells = rhs.cells().typify<RCT>(); auto dst_cells = state.stash.create_array<OCT>(lhs_index_size * param.dense_plan.res_size); + OCT *dst = dst_cells.data(); auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) { - dst_cells[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); + dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); }; for (size_t lhs_subspace = 0; lhs_subspace < lhs_index_size; ++lhs_subspace) { for (size_t rhs_subspace = 0; rhs_subspace < rhs_index_size; ++rhs_subspace) { @@ -111,7 +129,7 @@ struct SelectUniversalDotProduct { using RCT = CellValueType<RCM::value.cell_type>; using OCT = CellValueType<ocm.cell_type>; if (param.sparse_plan.maybe_forward_lhs_index()) { - return my_universal_dense_dot_product_op<LCT,RCT,OCT>; + return my_universal_forwarding_dot_product_op<LCT,RCT,OCT>; } return my_universal_dot_product_op<LCT,RCT,OCT>; } |