diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2022-06-02 08:29:28 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2022-06-07 08:48:55 +0000 |
commit | fba0288cbe69a0ea3644ce085ecb84eb86cd1e9f (patch) | |
tree | 7c82c65cfd7bfa71c0a24d784bb03518f588e0cf /eval/src/tests/instruction | |
parent | 38e71d4979792c42b0d163268ad1335cf3176b37 (diff) |
112 mixed dot product optimization
Diffstat (limited to 'eval/src/tests/instruction')
3 files changed, 101 insertions, 2 deletions
diff --git a/eval/src/tests/instruction/mixed_112_dot_product/CMakeLists.txt b/eval/src/tests/instruction/mixed_112_dot_product/CMakeLists.txt new file mode 100644 index 00000000000..fae2f185afb --- /dev/null +++ b/eval/src/tests/instruction/mixed_112_dot_product/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_mixed_112_dot_product_test_app TEST + SOURCES + mixed_112_dot_product_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_mixed_112_dot_product_test_app COMMAND eval_mixed_112_dot_product_test_app) diff --git a/eval/src/tests/instruction/mixed_112_dot_product/mixed_112_dot_product_test.cpp b/eval/src/tests/instruction/mixed_112_dot_product/mixed_112_dot_product_test.cpp new file mode 100644 index 00000000000..d3c4d89cf47 --- /dev/null +++ b/eval/src/tests/instruction/mixed_112_dot_product/mixed_112_dot_product_test.cpp @@ -0,0 +1,92 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/instruction/mixed_112_dot_product.h> +#include <vespa/eval/eval/test/eval_fixture.h> +#include <vespa/eval/eval/test/gen_spec.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::eval::test; + +using vespalib::make_string_short::fmt; + +//----------------------------------------------------------------------------- + +struct FunInfo { + using LookFor = Mixed112DotProduct; + void verify(const LookFor &fun) const { + EXPECT_TRUE(fun.result_is_mutable()); + } +}; + +void verify_optimized_cell_types(const vespalib::string &expr) +{ + CellTypeSpace stable(CellTypeUtils::list_stable_types(), 3); + CellTypeSpace unstable(CellTypeUtils::list_unstable_types(), 3); + EvalFixture::verify<FunInfo>(expr, {FunInfo()}, CellTypeSpace(stable).same()); + EvalFixture::verify<FunInfo>(expr, {}, CellTypeSpace(stable).different()); + EvalFixture::verify<FunInfo>(expr, {}, unstable); +} + +void verify_optimized(const vespalib::string &expr, size_t num_params = 3) +{ + CellTypeSpace just_float({CellType::FLOAT}, num_params); + EvalFixture::verify<FunInfo>(expr, {FunInfo()}, just_float); +} + +void verify_not_optimized(const vespalib::string &expr) { + CellTypeSpace just_double({CellType::DOUBLE}, 3); + EvalFixture::verify<FunInfo>(expr, {}, just_double); +} + +//----------------------------------------------------------------------------- + +TEST(Mixed112DotProduct, expression_can_be_optimized) +{ + verify_optimized_cell_types("reduce(x5_2*y8*x7_1y8,sum)"); +} + +TEST(Mixed112DotProduct, inverse_dimension_matching_is_handled) { + verify_optimized("reduce(y5_2*x8*x8y7_1,sum)"); +} + +TEST(Mixed112DotProduct, different_input_placement_is_handled) +{ + std::array<vespalib::string,3> params = {"x3_1", "y3", "x3_1y3"}; + for (size_t p1 = 0; p1 < params.size(); ++p1) { + for (size_t p2 = 0; p2 < params.size(); ++p2) { + for (size_t p3 = 0; p3 < params.size(); ++p3) { + if ((p1 != p2) && (p1 != p3) && (p2 != p3)) { + verify_optimized(fmt("reduce((%s*%s)*%s,sum)", params[p1].c_str(), params[p2].c_str(), params[p3].c_str())); + verify_optimized(fmt("reduce(%s*(%s*%s),sum)", params[p1].c_str(), params[p2].c_str(), params[p3].c_str())); + } + } + } + } +} + +TEST(Mixed112DotProduct, expression_can_be_optimized_with_extra_tensors) +{ + verify_optimized("reduce((x5_2*y4)*(x5_1y4*x3_1),sum)", 4); + verify_optimized("reduce((x5_2*x3_1)*(y4*x5_1y4),sum)", 4); +} + +TEST(Mixed112DotProduct, similar_expressions_are_not_optimized) +{ + verify_not_optimized("reduce(x5_2*y4*x5_1y4,prod)"); + verify_not_optimized("reduce(x5_2+y4*x5_1y4,sum)"); + verify_not_optimized("reduce(x5_2*y4+x5_1y4,sum)"); + verify_not_optimized("reduce(x5_2*z4*x5_1y4,sum)"); + verify_not_optimized("reduce(x5_2*y4*x5_1z4,sum)"); + verify_not_optimized("reduce(x5_2*x1_1y4*x5_1y4,sum)"); + verify_not_optimized("reduce(x5_2*y4*x5_1,sum)"); + verify_not_optimized("reduce(x5*y4*x5y4,sum)"); + verify_not_optimized("reduce(x5_1*y4_1*x5_1y4_1,sum)"); +} + +//----------------------------------------------------------------------------- + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/instruction/sparse_112_dot_product/sparse_112_dot_product_test.cpp b/eval/src/tests/instruction/sparse_112_dot_product/sparse_112_dot_product_test.cpp index 9325a203ff3..bab45afe114 100644 --- a/eval/src/tests/instruction/sparse_112_dot_product/sparse_112_dot_product_test.cpp +++ b/eval/src/tests/instruction/sparse_112_dot_product/sparse_112_dot_product_test.cpp @@ -13,8 +13,6 @@ using namespace vespalib::eval::test; using vespalib::make_string_short::fmt; -const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); - //----------------------------------------------------------------------------- struct FunInfo { |