diff options
Diffstat (limited to 'eval/src/tests/instruction/mixed_inner_product_function')
-rw-r--r-- | eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp | 121 |
1 files changed, 38 insertions, 83 deletions
diff --git a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp index 11e7c41fad2..95d3d882f7b 100644 --- a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp +++ b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp @@ -1,6 +1,5 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/eval/eval/fast_value.h> #include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/eval/test/eval_fixture.h> #include <vespa/eval/eval/test/gen_spec.h> @@ -20,69 +19,43 @@ using namespace vespalib; using namespace vespalib::eval; using namespace vespalib::eval::test; -const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); - - //----------------------------------------------------------------------------- -EvalFixture::ParamRepo make_params() { - return EvalFixture::ParamRepo() - .add_variants("x3", GenSpec(2.0).idx("x", 3)) - .add_variants("x3$2", GenSpec(13.25).idx("x", 3)) - .add_variants("y3", GenSpec(4.0).idx("y", 3)) - .add_variants("z3", GenSpec(0.25).idx("z", 3)) - .add_variants("x3y3", GenSpec(5.0).idx("x", 3).idx("y", 3)) - .add_variants("x1y3", GenSpec(6.0).idx("x", 1).idx("y", 3)) - .add_variants("x3y1", GenSpec(1.5).idx("x", 3).idx("y", 1)) - .add_variants("x3z3", GenSpec(2.5).idx("x", 3).idx("z", 3)) - .add_variants("x3y3z3", GenSpec(-4.0).idx("x", 3).idx("y", 3).idx("z", 3)) - .add("mix_x3zm", GenSpec(0.5).idx("x", 3).map("z", {"c","d"})) - .add("mix_y3zm", GenSpec(3.5).idx("y", 3).map("z", {"c","d"})) - .add("mix_x3zm_f", GenSpec(0.5).idx("x", 3).map("z", {"c","d"}).cells_float()) - .add("mix_y3zm_f", GenSpec(3.5).idx("y", 3).map("z", {"c","d"}).cells_float()) - .add("mix_x3y3zm", GenSpec(0.0).idx("x", 3).idx("y", 3).map("z", {"c","d"})) - ; - -} -EvalFixture::ParamRepo param_repo = make_params(); +template <typename T> +struct FunInfo { + using LookFor = T; + void verify(const LookFor &fun) const { + EXPECT_TRUE(fun.result_is_mutable()); + } +}; void assert_mixed_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fast_fixture(prod_factory, expr, param_repo, true); - EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); - auto info = fast_fixture.find_all<MixedInnerProductFunction>(); - ASSERT_EQ(info.size(), 1u); - EXPECT_TRUE(info[0]->result_is_mutable()); + TEST_STATE(expr.c_str()); + CellTypeSpace all_types(CellTypeUtils::list_types(), 2); + using MIP = FunInfo<MixedInnerProductFunction>; + EvalFixture::verify<MIP>(expr, {MIP{}}, all_types); } void assert_not_mixed_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fast_fixture(prod_factory, expr, param_repo, true); - EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); - auto info = fast_fixture.find_all<MixedInnerProductFunction>(); - ASSERT_EQ(info.size(), 0u); + TEST_STATE(expr.c_str()); + CellTypeSpace all_types(CellTypeUtils::list_types(), 2); + using MIP = FunInfo<MixedInnerProductFunction>; + EvalFixture::verify<MIP>(expr, {}, all_types); } void assert_dense_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_factory, expr, param_repo, false); - EvalFixture fast_fixture(prod_factory, expr, param_repo, true); - EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); - auto info = fast_fixture.find_all<MixedInnerProductFunction>(); - ASSERT_EQ(info.size(), 0u); - auto info2 = fast_fixture.find_all<DenseDotProductFunction>(); - auto info3 = fast_fixture.find_all<DenseMatMulFunction>(); - auto info4 = fast_fixture.find_all<DenseMultiMatMulFunction>(); - auto info5 = fast_fixture.find_all<DenseXWProductFunction>(); - ASSERT_EQ(info2.size() + info3.size() + info4.size() + info5.size(), 1u); + TEST_STATE(expr.c_str()); + CellTypeSpace all_types(CellTypeUtils::list_types(), 2); + using MIP = FunInfo<MixedInnerProductFunction>; + EvalFixture::verify<MIP>(expr, {}, all_types); + using XWP = FunInfo<DenseXWProductFunction>; + EvalFixture::verify<XWP>(expr, {XWP{}}, all_types); } //----------------------------------------------------------------------------- TEST(MixedInnerProduct, use_dense_optimizers_when_possible) { - // actually, all these trigger DenseXWProduct + // actually, all these trigger DenseXWProduct (prioritized before MixedInner) assert_dense_optimized("reduce(x3 * x3y1,sum,x)"); assert_dense_optimized("reduce(y3 * x1y3,sum,y)"); assert_dense_optimized("reduce(y3 * x3y3,sum,y)"); @@ -91,46 +64,28 @@ TEST(MixedInnerProduct, use_dense_optimizers_when_possible) { } TEST(MixedInnerProduct, trigger_optimizer_when_possible) { - assert_mixed_optimized("reduce(x3 * mix_x3zm,sum,x)"); - assert_mixed_optimized("reduce(x3_f * mix_x3zm,sum,x)"); - assert_mixed_optimized("reduce(x3 * mix_x3zm_f,sum,x)"); - assert_mixed_optimized("reduce(x3_f * mix_x3zm_f,sum,x)"); - assert_mixed_optimized("reduce(x3$2 * mix_x3zm,sum,x)"); - assert_mixed_optimized("reduce(x3$2_f * mix_x3zm,sum,x)"); - assert_mixed_optimized("reduce(y3 * mix_y3zm,sum,y)"); - assert_mixed_optimized("reduce(y3_f * mix_y3zm,sum,y)"); - assert_mixed_optimized("reduce(y3 * mix_y3zm_f,sum,y)"); - assert_mixed_optimized("reduce(y3_f * mix_y3zm_f,sum,y)"); - assert_mixed_optimized("reduce(x3y1 * mix_x3zm,sum,x)"); - assert_mixed_optimized("reduce(x3y1_f * mix_x3zm,sum,x)"); - assert_mixed_optimized("reduce(x3y1 * mix_x3zm,sum,x,y)"); - assert_mixed_optimized("reduce(x3y1_f * mix_x3zm,sum,x,y)"); - assert_mixed_optimized("reduce(x1y3 * mix_y3zm,sum,y)"); - assert_mixed_optimized("reduce(x1y3_f * mix_y3zm,sum,y)"); - assert_mixed_optimized("reduce(x1y3 * x1y3,sum,y)"); - assert_mixed_optimized("reduce(x1y3 * x1y3_f,sum,y)"); - assert_mixed_optimized("reduce(x1y3_f * x1y3,sum,y)"); - assert_mixed_optimized("reduce(x1y3_f * x1y3_f,sum,y)"); - assert_mixed_optimized("reduce(x1y3 * mix_y3zm,sum,y)"); - assert_mixed_optimized("reduce(x1y3_f * mix_y3zm,sum,y)"); - assert_mixed_optimized("reduce(mix_x3zm * x3,sum,x)"); - assert_mixed_optimized("reduce(mix_x3zm * x3_f,sum,x)"); - assert_mixed_optimized("reduce(mix_x3zm * x3y1,sum,x)"); - assert_mixed_optimized("reduce(mix_x3zm * x3y1_f,sum,x)"); - assert_mixed_optimized("reduce(mix_y3zm * y3,sum,y)"); - assert_mixed_optimized("reduce(mix_y3zm * y3_f,sum,y)"); - assert_mixed_optimized("reduce(mix_y3zm * x1y3,sum,y)"); - assert_mixed_optimized("reduce(mix_y3zm * x1y3_f,sum,y)"); + assert_mixed_optimized("reduce(x3 * x3z2_1,sum,x)"); + assert_mixed_optimized("reduce(x3$2 * x3z2_1,sum,x)"); + assert_mixed_optimized("reduce(y3 * y3z2_1,sum,y)"); + assert_mixed_optimized("reduce(x3y1 * x3z2_1,sum,x)"); + assert_mixed_optimized("reduce(x3y1 * x3z2_1,sum,x,y)"); + assert_mixed_optimized("reduce(x1y3 * y3z2_1,sum,y)"); + assert_mixed_optimized("reduce(x1y3 * x1y3$2,sum,y)"); + assert_mixed_optimized("reduce(x1y3 * y3z2_1,sum,y)"); + assert_mixed_optimized("reduce(x3z2_1 * x3,sum,x)"); + assert_mixed_optimized("reduce(x3z2_1 * x3y1,sum,x)"); + assert_mixed_optimized("reduce(y3z2_1 * y3,sum,y)"); + assert_mixed_optimized("reduce(y3z2_1 * x1y3,sum,y)"); } TEST(MixedInnerProduct, should_not_trigger_optimizer_for_other_cases) { assert_not_mixed_optimized("reduce(x3y3z3 * x3,sum,x)"); assert_not_mixed_optimized("reduce(x3y3z3 * y3,sum,y)"); assert_not_mixed_optimized("reduce(x3y3z3 * x3y3,sum,x,y)"); - assert_not_mixed_optimized("reduce(x3y3 * mix_y3zm,sum,y)"); - assert_not_mixed_optimized("reduce(mix_y3zm * x3,sum,x,y)"); - assert_not_mixed_optimized("reduce(mix_x3y3zm * y3,sum,y,z)"); - assert_not_mixed_optimized("reduce(mix_x3y3zm * y3,sum,x,y)"); + assert_not_mixed_optimized("reduce(x3y3 * y3z2_1,sum,y)"); + assert_not_mixed_optimized("reduce(y3z2_1 * x3,sum,x,y)"); + assert_not_mixed_optimized("reduce(x3y3z2_1 * y3,sum,y,z)"); + assert_not_mixed_optimized("reduce(x3y3z2_1 * y3,sum,x,y)"); } TEST(MixedInnerProduct, check_compatibility_with_complex_types) { |