From b1ba5c2995b76cbee5830a9959cd7b22364e5b65 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Wed, 6 Sep 2023 12:59:04 +0000 Subject: handle expanding reduce more testing of corner cases --- .../universal_dot_product_test.cpp | 72 +++++++++++++++++++++- .../eval/instruction/universal_dot_product.cpp | 36 ++++++++--- 2 files changed, 97 insertions(+), 11 deletions(-) diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index 6c0726dab37..a25dba64671 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -134,6 +134,43 @@ Optimize universal_only() { return Optimize::specific("universal_only", my_optimizer); } +void verify(const vespalib::string &expr) { + auto fun = Function::parse(expr); + ASSERT_FALSE(fun->has_error()); + std::vector values; + for (size_t i = 0; i < fun->num_params(); ++i) { + auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory); + values.push_back(std::move(value)); + } + SimpleObjectParams params({}); + std::vector param_types; + for (auto &&up: values) { + params.params.emplace_back(*up); + param_types.push_back(up->type()); + } + NodeTypes node_types(*fun, param_types); + const ValueType &expected_type = node_types.get_type(fun->root()); + ASSERT_FALSE(expected_type.is_error()); + Stash stash; + size_t count = 0; + const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash); + const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count); + ASSERT_GT(count, 0); + InterpretedFunction ifun(prod_factory, optimized); + InterpretedFunction::Context ctx(ifun); + const Value &actual = ifun.eval(ctx, params); + EXPECT_EQ(actual.type(), expected_type); + EXPECT_EQ(actual.cells().type, expected_type.cell_type()); + if (expected_type.count_mapped_dimensions() == 0) { + EXPECT_EQ(actual.index().size(), TrivialIndex::get().size()); + EXPECT_EQ(actual.cells().size, expected_type.dense_subspace_size()); + } else { + EXPECT_EQ(actual.cells().size, actual.index().size() * expected_type.dense_subspace_size()); + } + auto expected = eval_ref(*fun); + EXPECT_EQ(spec_from_value(actual), expected); +} + using cost_list_t = std::vector>; std::vector> benchmark_results; @@ -257,6 +294,39 @@ TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) { fprintf(stderr, "total test cases run: %zu\n", test_cases); } +TEST(UniversalDotProductTest, forwarding_empty_result) { + verify("reduce(x0_0*y8_1,sum,y)"); + verify("reduce(x8_1*y0_0,sum,y)"); + verify("reduce(x0_0z16*y8_1z16,sum,y)"); + verify("reduce(x8_1z16*y0_0z16,sum,y)"); +} + +TEST(UniversalDotProductTest, nonforwarding_empty_result) { + verify("reduce(x0_0y8*x1_1y8,sum,y)"); + verify("reduce(x1_1y8*x0_0y8,sum,y)"); + verify("reduce(x1_7y8z2*x1_1y8z2,sum,y)"); +} + +TEST(UniversalDotProductTest, forwarding_expanding_reduce) { + verify("reduce(5.0*y0_0,sum,y)"); + verify("reduce(z16*y0_0,sum,y)"); + verify("reduce(x1_1*y0_0,sum,y)"); + verify("reduce(x0_0*y1_1,sum,y)"); + verify("reduce(x1_1z16*y0_0,sum,y)"); + verify("reduce(x0_0z16*y1_1,sum,y)"); +} + +TEST(UniversalDotProductTest, nonforwarding_expanding_reduce) { + verify("reduce(x0_0*y1_1,sum,x,y)"); + verify("reduce(x1_1*y0_0,sum,x,y)"); + verify("reduce(x0_0y16*x1_1y16,sum,x)"); + verify("reduce(x1_1y16*x0_0y16,sum,x)"); + verify("reduce(x1_7*y1_1,sum,x,y)"); + verify("reduce(x1_1*y1_7,sum,x,y)"); + verify("reduce(x1_7y16*x1_1y16,sum,x)"); + verify("reduce(x1_1y16*x1_7y16,sum,x)"); +} + TEST(UniversalDotProductTest, bench_vector_dot_product) { if (!bench) { fprintf(stderr, "benchmarking disabled, run with 'bench' parameter to enable\n"); @@ -284,8 +354,6 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) { benchmark("reduce(b64_1x8y128*x8y128,sum,y)", optimize_list); benchmark("reduce(b64_1x128*x128,sum,b,x)", optimize_list); benchmark("reduce(a1_1x128*a2_1b64_1x128,sum,a,x)", optimize_list); - benchmark("reduce(x0_0*y8_1,sum,y)", optimize_list); - benchmark("reduce(x8_1*y0_0,sum,y)", optimize_list); size_t max_expr_size = 0; for (const auto &[expr, cost_list]: benchmark_results) { diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp index 3811508a543..209645743fa 100644 --- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp +++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp @@ -56,26 +56,44 @@ void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t par std::make_unique>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size, param.sparse_plan.estimate_result_size(lhs_index, rhs_index))); auto &result = *(stored_result.get()); - ArrayRef dst; + OCT *dst; auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) { dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); }; auto sparse_fun = [&](size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef res_addr) { - bool first; - std::tie(dst, first) = result.insert_subspace(res_addr); + auto [space, first] = result.insert_subspace(res_addr); if (first) { - std::fill(dst.begin(), dst.end(), OCT{}); + std::fill(space.begin(), space.end(), OCT{}); } + dst = space.data(); param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size, rhs_subspace * param.dense_plan.rhs_size, 0, dense_fun); }; param.sparse_plan.execute(lhs_index, rhs_index, sparse_fun); + if (result.my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) { + auto [empty, ignore] = result.insert_subspace({}); + std::fill(empty.begin(), empty.end(), OCT{}); + } state.pop_pop_push(result); } +template +const Value &create_empty_result(const UniversalDotProductParam ¶m, Stash &stash) { + if (param.sparse_plan.res_dims() == 0) { + if (param.dense_plan.res_stride.empty()) { + return stash.create(0.0); + } else { + auto zero_cells = stash.create_array(param.dense_plan.res_size); + return stash.create(param.res_type, TrivialIndex::get(), TypedCells(zero_cells)); + } + } else { + return stash.create(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type(), 0)); + } +} + template -void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { +void my_universal_forwarding_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) { using dot_product = DotProduct; const auto ¶m = unwrap_param(param_in); const auto &lhs = state.peek(1); @@ -83,15 +101,15 @@ void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64 size_t lhs_index_size = lhs.index().size(); size_t rhs_index_size = rhs.index().size(); if (rhs_index_size == 0 || lhs_index_size == 0) { - const Value &empty = state.stash.create(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type(), 0)); - state.pop_pop_push(empty); + state.pop_pop_push(create_empty_result(param, state.stash)); return; } const auto lhs_cells = lhs.cells().typify(); const auto rhs_cells = rhs.cells().typify(); auto dst_cells = state.stash.create_array(lhs_index_size * param.dense_plan.res_size); + OCT *dst = dst_cells.data(); auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) { - dst_cells[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); + dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size); }; for (size_t lhs_subspace = 0; lhs_subspace < lhs_index_size; ++lhs_subspace) { for (size_t rhs_subspace = 0; rhs_subspace < rhs_index_size; ++rhs_subspace) { @@ -111,7 +129,7 @@ struct SelectUniversalDotProduct { using RCT = CellValueType; using OCT = CellValueType; if (param.sparse_plan.maybe_forward_lhs_index()) { - return my_universal_dense_dot_product_op; + return my_universal_forwarding_dot_product_op; } return my_universal_dot_product_op; } -- cgit v1.2.3