diff options
Diffstat (limited to 'eval/src/tests/instruction')
-rw-r--r-- | eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp | 8 | ||||
-rw-r--r-- | eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp | 152 |
2 files changed, 81 insertions, 79 deletions
diff --git a/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp b/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp index 9851e209ba5..7faf7d57738 100644 --- a/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp +++ b/eval/src/tests/instruction/dense_join_reduce_plan/dense_join_reduce_plan_test.cpp @@ -13,7 +13,7 @@ ValueType type(const vespalib::string &type_spec) { TEST(DenseJoinReducePlanTest, make_trivial_plan) { auto plan = DenseJoinReducePlan(type("double"), type("double"), type("double")); - EXPECT_TRUE(plan.distinct_result()); + EXPECT_TRUE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 1); EXPECT_EQ(plan.rhs_size, 1); EXPECT_EQ(plan.res_size, 1); @@ -39,7 +39,7 @@ TEST(DenseJoinReducePlanTest, make_simple_plan) { SmallVector<size_t> expect_lhs_stride = {1,0}; SmallVector<size_t> expect_rhs_stride = {0,1}; SmallVector<size_t> expect_res_stride = {1,0}; - EXPECT_FALSE(plan.distinct_result()); + EXPECT_FALSE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 2); EXPECT_EQ(plan.rhs_size, 3); EXPECT_EQ(plan.res_size, 2); @@ -69,7 +69,7 @@ TEST(DenseJoinReducePlanTest, make_distinct_plan) { SmallVector<size_t> expect_lhs_stride = {1,0}; SmallVector<size_t> expect_rhs_stride = {0,1}; SmallVector<size_t> expect_res_stride = {3,1}; - EXPECT_TRUE(plan.distinct_result()); + EXPECT_TRUE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 2); EXPECT_EQ(plan.rhs_size, 3); EXPECT_EQ(plan.res_size, 6); @@ -88,7 +88,7 @@ TEST(DenseJoinReducePlanTest, make_complex_plan) { SmallVector<size_t> expect_lhs_stride = {6,0,2,1}; SmallVector<size_t> expect_rhs_stride = {4,1,0,0}; SmallVector<size_t> expect_res_stride = {12,3,1,0}; - EXPECT_FALSE(plan.distinct_result()); + EXPECT_FALSE(plan.is_distinct()); EXPECT_EQ(plan.lhs_size, 180); EXPECT_EQ(plan.rhs_size, 120); EXPECT_EQ(plan.res_size, 360); diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index 4ca2a5ef79a..e1967f012cb 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -27,31 +27,7 @@ using vespalib::make_string_short::fmt; const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); bool bench = false; double budget = 1.0; - -GenSpec::seq_t N_16ths = [] (size_t i) noexcept { return (i + 33.0) / 16.0; }; - -GenSpec G() { return GenSpec().seq(N_16ths); } - -const std::vector<GenSpec> layouts = { - G(), G(), - G().idx("x", 5), G().idx("x", 5), - G().idx("x", 5), G().idx("y", 5), - G().idx("x", 5), G().idx("x", 5).idx("y", 5), - G().idx("y", 3), G().idx("x", 2).idx("z", 3), - G().idx("x", 3).idx("y", 5), G().idx("y", 5).idx("z", 7), - G().map("x", {"a","b","c"}), G().map("x", {"a","b","c"}), - G().map("x", {"a","b","c"}), G().map("x", {"a","b"}), - G().map("x", {"a","b","c"}), G().map("y", {"foo","bar","baz"}), - G().map("x", {"a","b","c"}), G().map("x", {"a","b","c"}).map("y", {"foo","bar","baz"}), - G().map("x", {"a","b"}).map("y", {"foo","bar","baz"}), G().map("x", {"a","b","c"}).map("y", {"foo","bar"}), - G().map("x", {"a","b"}).map("y", {"foo","bar","baz"}), G().map("y", {"foo","bar"}).map("z", {"i","j","k","l"}), - G().idx("x", 3).map("y", {"foo", "bar"}), G().map("y", {"foo", "bar"}).idx("z", 7), - G().map("x", {"a","b","c"}).idx("y", 5), G().idx("y", 5).map("z", {"i","j","k","l"}) -}; - -const std::vector<std::vector<vespalib::string>> reductions = { - {}, {"x"}, {"y"}, {"z"}, {"x", "y"}, {"x", "z"}, {"y", "z"} -}; +size_t verify_cnt = 0; std::vector<std::string> ns_list = { {"vespalib::eval::instruction::(anonymous namespace)::"}, @@ -76,14 +52,19 @@ std::string strip_ns(const vespalib::string &str) { return tmp; } -TensorSpec make_spec(const vespalib::string ¶m_name, size_t idx) { - return GenSpec::from_desc(param_name).cells_double().seq(N(1 + idx)); +using select_cell_type_t = std::function<CellType(size_t idx)>; +CellType always_double(size_t) { return CellType::DOUBLE; } +select_cell_type_t select(CellType lct) { return [lct](size_t)noexcept{ return lct; }; } +select_cell_type_t select(CellType lct, CellType rct) { return [lct,rct](size_t idx)noexcept{ return idx ? rct : lct; }; } + +TensorSpec make_spec(const vespalib::string ¶m_name, size_t idx, select_cell_type_t select_cell_type) { + return GenSpec::from_desc(param_name).cells(select_cell_type(idx)).seq(N(1 + idx)); } -TensorSpec eval_ref(const Function &fun) { +TensorSpec eval_ref(const Function &fun, select_cell_type_t select_cell_type) { std::vector<TensorSpec> params; for (size_t i = 0; i < fun.num_params(); ++i) { - params.push_back(make_spec(fun.param_name(i), i)); + params.push_back(make_spec(fun.param_name(i), i, select_cell_type)); } return ReferenceEvaluation::eval(fun, params); } @@ -134,12 +115,13 @@ Optimize universal_only() { return Optimize::specific("universal_only", my_optimizer); } -void verify(const vespalib::string &expr) { +void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) { + ++verify_cnt; auto fun = Function::parse(expr); ASSERT_FALSE(fun->has_error()); std::vector<Value::UP> values; for (size_t i = 0; i < fun->num_params(); ++i) { - auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory); + auto value = value_from_spec(make_spec(fun->param_name(i), i, select_cell_type), prod_factory); values.push_back(std::move(value)); } SimpleObjectParams params({}); @@ -167,23 +149,24 @@ void verify(const vespalib::string &expr) { } else { EXPECT_EQ(actual.cells().size, actual.index().size() * expected_type.dense_subspace_size()); } - auto expected = eval_ref(*fun); + auto expected = eval_ref(*fun, select_cell_type); EXPECT_EQ(spec_from_value(actual), expected); } +void verify(const vespalib::string &expr) { verify(expr, always_double); } using cost_list_t = std::vector<std::pair<vespalib::string,double>>; std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results; void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { + verify(expr); auto fun = Function::parse(expr); ASSERT_FALSE(fun->has_error()); - auto expected = eval_ref(*fun); cost_list_t cost_list; fprintf(stderr, "BENCH: %s\n", expr.c_str()); for (Optimize &optimize: list) { std::vector<Value::UP> values; for (size_t i = 0; i < fun->num_params(); ++i) { - auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory); + auto value = value_from_spec(make_spec(fun->param_name(i), i, always_double), prod_factory); values.push_back(std::move(value)); } SimpleObjectParams params({}); @@ -218,8 +201,6 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta); InterpretedFunction::ProfiledContext pctx(ifun); ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size()); - EXPECT_EQ(spec_from_value(ifun.eval(pctx.context, params)), expected); - EXPECT_EQ(spec_from_value(ifun.eval(pctx, params)), expected); std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero()); std::vector<duration> min_time(ctf_meta.steps.size(), duration::max()); BenchmarkTimer timer(budget); @@ -251,47 +232,63 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { benchmark_results.emplace_back(expr, std::move(cost_list)); } -TensorSpec perform_dot_product(const TensorSpec &a, const TensorSpec &b, const std::vector<vespalib::string> &dims) -{ - Stash stash; - auto lhs = value_from_spec(a, prod_factory); - auto rhs = value_from_spec(b, prod_factory); - auto res_type = ValueType::join(lhs->type(), rhs->type()).reduce(dims); - EXPECT_FALSE(res_type.is_error()); - UniversalDotProduct dot_product(res_type, - tensor_function::inject(lhs->type(), 0, stash), - tensor_function::inject(rhs->type(), 1, stash)); - auto my_op = dot_product.compile_self(prod_factory, stash); - InterpretedFunction::EvalSingle single(prod_factory, my_op); - return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs,*rhs}))); +TEST(UniversalDotProductTest, test_select_cell_types) { + auto always = always_double; + EXPECT_EQ(always(0), CellType::DOUBLE); + EXPECT_EQ(always(1), CellType::DOUBLE); + EXPECT_EQ(always(0), CellType::DOUBLE); + EXPECT_EQ(always(1), CellType::DOUBLE); + for (CellType lct: CellTypeUtils::list_types()) { + auto sel1 = select(lct); + EXPECT_EQ(sel1(0), lct); + EXPECT_EQ(sel1(1), lct); + EXPECT_EQ(sel1(0), lct); + EXPECT_EQ(sel1(1), lct); + for (CellType rct: CellTypeUtils::list_types()) { + auto sel2 = select(lct, rct); + EXPECT_EQ(sel2(0), lct); + EXPECT_EQ(sel2(1), rct); + EXPECT_EQ(sel2(0), lct); + EXPECT_EQ(sel2(1), rct); + } + } } -TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) { - size_t test_cases = 0; - ASSERT_TRUE((layouts.size() % 2) == 0); - for (size_t i = 0; i < layouts.size(); i += 2) { - const auto &l = layouts[i]; - const auto &r = layouts[i+1]; - for (CellType lct : CellTypeUtils::list_types()) { - auto lhs = l.cpy().cells(lct); - if (lhs.bad_scalar()) continue; - for (CellType rct : CellTypeUtils::list_types()) { - auto rhs = r.cpy().cells(rct); - if (rhs.bad_scalar()) continue; - for (const std::vector<vespalib::string> &dims: reductions) { - if (ValueType::join(lhs.type(), rhs.type()).reduce(dims).is_error()) continue; - ++test_cases; - SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); - auto expect = ReferenceOperations::reduce(ReferenceOperations::join(lhs, rhs, operation::Mul::f), Aggr::SUM, dims); - auto actual = perform_dot_product(lhs, rhs, dims); - // fprintf(stderr, "\n===\nLHS: %s\nRHS: %s\n===\nRESULT: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str(), actual.to_string().c_str()); - EXPECT_EQ(actual, expect); - } - } +TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) { + // forward, distinct, single + verify("reduce(2.0*3.0, sum)"); + + for (CellType lct: CellTypeUtils::list_types()) { + for (CellType rct: CellTypeUtils::list_types()) { + auto sel2 = select(lct, rct); + // !forward, !distinct, !single + verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2); + + // !forward, !distinct, single + verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2); + + // !forward, distinct, !single + verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2); + + // forward, !distinct, !single + verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2); + + // forward, !distinct, single + verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2); + + // forward, distinct, !single + verify("reduce(a4_1x8*x8,sum,x)", sel2); } } - EXPECT_GT(test_cases, 500); - fprintf(stderr, "total test cases run: %zu\n", test_cases); + // !forward, distinct, single + + // This case is not possible since 'distinct' implies '!single' as + // long as we reduce anything. The only expression allowed to + // reduce nothing is the scalar case. +} + +TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) { + verify("reduce(a4_1b4_1c4_1x4y3z2w1*a2_1c1_1x4z2,sum,b,c,x)"); } TEST(UniversalDotProductTest, forwarding_empty_result) { @@ -336,8 +333,11 @@ TEST(UniversalDotProductTest, bench_vector_dot_product) { } auto optimize_list = std::vector<Optimize>({baseline(), with_universal(), universal_only()}); - benchmark("reduce(1.0*2.0,sum)", optimize_list); + benchmark("reduce(2.0*3.0,sum)", optimize_list); benchmark("reduce(5.0*x128,sum,x)", optimize_list); + benchmark("reduce(a1*x128,sum,x)", optimize_list); + benchmark("reduce(a8*x128,sum,x)", optimize_list); + benchmark("reduce(a1_1b8*x128,sum,x)", optimize_list); benchmark("reduce(x16*x16,sum,x)", optimize_list); benchmark("reduce(x768*x768,sum,x)", optimize_list); benchmark("reduce(y64*x8y64,sum,x,y)", optimize_list); @@ -417,5 +417,7 @@ int main(int argc, char **argv) { --argc; } ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); + int result = RUN_ALL_TESTS(); + fprintf(stderr, "verify called %zu times\n", verify_cnt); + return result; } |