summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-09-06 12:59:04 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-09-07 07:57:52 +0000
commitb1ba5c2995b76cbee5830a9959cd7b22364e5b65 (patch)
treeadda7458010cb86d8537126cfb07c3705f8f233e /eval
parentdaba552c567f1fcb9e300ae65825c1d97cedbb5e (diff)
handle expanding reduce
more testing of corner cases
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp72
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.cpp36
2 files changed, 97 insertions, 11 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index 6c0726dab37..a25dba64671 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -134,6 +134,43 @@ Optimize universal_only() {
return Optimize::specific("universal_only", my_optimizer);
}
+void verify(const vespalib::string &expr) {
+ auto fun = Function::parse(expr);
+ ASSERT_FALSE(fun->has_error());
+ std::vector<Value::UP> values;
+ for (size_t i = 0; i < fun->num_params(); ++i) {
+ auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory);
+ values.push_back(std::move(value));
+ }
+ SimpleObjectParams params({});
+ std::vector<ValueType> param_types;
+ for (auto &&up: values) {
+ params.params.emplace_back(*up);
+ param_types.push_back(up->type());
+ }
+ NodeTypes node_types(*fun, param_types);
+ const ValueType &expected_type = node_types.get_type(fun->root());
+ ASSERT_FALSE(expected_type.is_error());
+ Stash stash;
+ size_t count = 0;
+ const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash);
+ const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count);
+ ASSERT_GT(count, 0);
+ InterpretedFunction ifun(prod_factory, optimized);
+ InterpretedFunction::Context ctx(ifun);
+ const Value &actual = ifun.eval(ctx, params);
+ EXPECT_EQ(actual.type(), expected_type);
+ EXPECT_EQ(actual.cells().type, expected_type.cell_type());
+ if (expected_type.count_mapped_dimensions() == 0) {
+ EXPECT_EQ(actual.index().size(), TrivialIndex::get().size());
+ EXPECT_EQ(actual.cells().size, expected_type.dense_subspace_size());
+ } else {
+ EXPECT_EQ(actual.cells().size, actual.index().size() * expected_type.dense_subspace_size());
+ }
+ auto expected = eval_ref(*fun);
+ EXPECT_EQ(spec_from_value(actual), expected);
+}
+
using cost_list_t = std::vector<std::pair<vespalib::string,double>>;
std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results;
@@ -257,6 +294,39 @@ TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) {
fprintf(stderr, "total test cases run: %zu\n", test_cases);
}
+TEST(UniversalDotProductTest, forwarding_empty_result) {
+ verify("reduce(x0_0*y8_1,sum,y)");
+ verify("reduce(x8_1*y0_0,sum,y)");
+ verify("reduce(x0_0z16*y8_1z16,sum,y)");
+ verify("reduce(x8_1z16*y0_0z16,sum,y)");
+}
+
+TEST(UniversalDotProductTest, nonforwarding_empty_result) {
+ verify("reduce(x0_0y8*x1_1y8,sum,y)");
+ verify("reduce(x1_1y8*x0_0y8,sum,y)");
+ verify("reduce(x1_7y8z2*x1_1y8z2,sum,y)");
+}
+
+TEST(UniversalDotProductTest, forwarding_expanding_reduce) {
+ verify("reduce(5.0*y0_0,sum,y)");
+ verify("reduce(z16*y0_0,sum,y)");
+ verify("reduce(x1_1*y0_0,sum,y)");
+ verify("reduce(x0_0*y1_1,sum,y)");
+ verify("reduce(x1_1z16*y0_0,sum,y)");
+ verify("reduce(x0_0z16*y1_1,sum,y)");
+}
+
+TEST(UniversalDotProductTest, nonforwarding_expanding_reduce) {
+ verify("reduce(x0_0*y1_1,sum,x,y)");
+ verify("reduce(x1_1*y0_0,sum,x,y)");
+ verify("reduce(x0_0y16*x1_1y16,sum,x)");
+ verify("reduce(x1_1y16*x0_0y16,sum,x)");
+ verify("reduce(x1_7*y1_1,sum,x,y)");
+ verify("reduce(x1_1*y1_7,sum,x,y)");
+ verify("reduce(x1_7y16*x1_1y16,sum,x)");
+ verify("reduce(x1_1y16*x1_7y16,sum,x)");
+}
+
TEST(UniversalDotProductTest, bench_vector_dot_product) {
if (!bench) {
fprintf(stderr, "benchmarking disabled, run with 'bench' parameter to enable\n");
@@ -284,8 +354,6 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) {
benchmark("reduce(b64_1x8y128*x8y128,sum,y)", optimize_list);
benchmark("reduce(b64_1x128*x128,sum,b,x)", optimize_list);
benchmark("reduce(a1_1x128*a2_1b64_1x128,sum,a,x)", optimize_list);
- benchmark("reduce(x0_0*y8_1,sum,y)", optimize_list);
- benchmark("reduce(x8_1*y0_0,sum,y)", optimize_list);
size_t max_expr_size = 0;
for (const auto &[expr, cost_list]: benchmark_results) {
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
index 3811508a543..209645743fa 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
@@ -56,26 +56,44 @@ void my_universal_dot_product_op(InterpretedFunction::State &state, uint64_t par
std::make_unique<FastValue<OCT,true>>(param.res_type, param.sparse_plan.res_dims(), param.dense_plan.res_size,
param.sparse_plan.estimate_result_size(lhs_index, rhs_index)));
auto &result = *(stored_result.get());
- ArrayRef<OCT> dst;
+ OCT *dst;
auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) {
dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size);
};
auto sparse_fun = [&](size_t lhs_subspace, size_t rhs_subspace, ConstArrayRef<string_id> res_addr) {
- bool first;
- std::tie(dst, first) = result.insert_subspace(res_addr);
+ auto [space, first] = result.insert_subspace(res_addr);
if (first) {
- std::fill(dst.begin(), dst.end(), OCT{});
+ std::fill(space.begin(), space.end(), OCT{});
}
+ dst = space.data();
param.dense_plan.execute(lhs_subspace * param.dense_plan.lhs_size,
rhs_subspace * param.dense_plan.rhs_size,
0, dense_fun);
};
param.sparse_plan.execute(lhs_index, rhs_index, sparse_fun);
+ if (result.my_index.map.size() == 0 && param.sparse_plan.res_dims() == 0) {
+ auto [empty, ignore] = result.insert_subspace({});
+ std::fill(empty.begin(), empty.end(), OCT{});
+ }
state.pop_pop_push(result);
}
+template <typename OCT>
+const Value &create_empty_result(const UniversalDotProductParam &param, Stash &stash) {
+ if (param.sparse_plan.res_dims() == 0) {
+ if (param.dense_plan.res_stride.empty()) {
+ return stash.create<DoubleValue>(0.0);
+ } else {
+ auto zero_cells = stash.create_array<OCT>(param.dense_plan.res_size);
+ return stash.create<ValueView>(param.res_type, TrivialIndex::get(), TypedCells(zero_cells));
+ }
+ } else {
+ return stash.create<ValueView>(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type<OCT>(), 0));
+ }
+}
+
template <typename LCT, typename RCT, typename OCT>
-void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) {
+void my_universal_forwarding_dot_product_op(InterpretedFunction::State &state, uint64_t param_in) {
using dot_product = DotProduct<LCT,RCT>;
const auto &param = unwrap_param<UniversalDotProductParam>(param_in);
const auto &lhs = state.peek(1);
@@ -83,15 +101,15 @@ void my_universal_dense_dot_product_op(InterpretedFunction::State &state, uint64
size_t lhs_index_size = lhs.index().size();
size_t rhs_index_size = rhs.index().size();
if (rhs_index_size == 0 || lhs_index_size == 0) {
- const Value &empty = state.stash.create<ValueView>(param.res_type, EmptyIndex::get(), TypedCells(nullptr, get_cell_type<OCT>(), 0));
- state.pop_pop_push(empty);
+ state.pop_pop_push(create_empty_result<OCT>(param, state.stash));
return;
}
const auto lhs_cells = lhs.cells().typify<LCT>();
const auto rhs_cells = rhs.cells().typify<RCT>();
auto dst_cells = state.stash.create_array<OCT>(lhs_index_size * param.dense_plan.res_size);
+ OCT *dst = dst_cells.data();
auto dense_fun = [&](size_t lhs_idx, size_t rhs_idx, size_t dst_idx) {
- dst_cells[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size);
+ dst[dst_idx] += dot_product::apply(&lhs_cells[lhs_idx], &rhs_cells[rhs_idx], param.vector_size);
};
for (size_t lhs_subspace = 0; lhs_subspace < lhs_index_size; ++lhs_subspace) {
for (size_t rhs_subspace = 0; rhs_subspace < rhs_index_size; ++rhs_subspace) {
@@ -111,7 +129,7 @@ struct SelectUniversalDotProduct {
using RCT = CellValueType<RCM::value.cell_type>;
using OCT = CellValueType<ocm.cell_type>;
if (param.sparse_plan.maybe_forward_lhs_index()) {
- return my_universal_dense_dot_product_op<LCT,RCT,OCT>;
+ return my_universal_forwarding_dot_product_op<LCT,RCT,OCT>;
}
return my_universal_dot_product_op<LCT,RCT,OCT>;
}