diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-11-12 08:55:54 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-11-12 08:57:49 +0000 |
commit | 473cd0a65a9f0e845af7dfa800879b445bdfb4ec (patch) | |
tree | e74ae61939343185f7d7ae65ab951ef409868163 /eval/src/tests/instruction/dense_multi_matmul_function | |
parent | 43dacb6e1baa4a6805d92b1f537c541586702c17 (diff) |
test with FastValueBuilderFactory also
Diffstat (limited to 'eval/src/tests/instruction/dense_multi_matmul_function')
-rw-r--r-- | eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp index ce9f599d27c..8f4a06b2335 100644 --- a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp +++ b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp @@ -1,5 +1,6 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/eval/eval/fast_value.h> #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/eval/test/eval_fixture.h> @@ -15,7 +16,8 @@ using namespace vespalib::eval; using namespace vespalib::eval::test; using namespace vespalib::eval::tensor_function; -const TensorEngine &prod_engine = tensor::DefaultTensorEngine::ref(); +const TensorEngine &old_engine = tensor::DefaultTensorEngine::ref(); +const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); EvalFixture::ParamRepo make_params() { return EvalFixture::ParamRepo() @@ -39,8 +41,8 @@ void verify_optimized(const vespalib::string &expr, size_t lhs_size, size_t common_size, size_t rhs_size, size_t matmul_cnt, bool lhs_inner, bool rhs_inner) { - EvalFixture slow_fixture(prod_engine, expr, param_repo, false); - EvalFixture fixture(prod_engine, expr, param_repo, true); + EvalFixture slow_fixture(prod_factory, expr, param_repo, false); + EvalFixture fixture(prod_factory, expr, param_repo, true); EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); EXPECT_EQUAL(fixture.result(), slow_fixture.result()); auto info = fixture.find_all<DenseMultiMatMulFunction>(); @@ -52,15 +54,36 @@ void verify_optimized(const vespalib::string &expr, EXPECT_EQUAL(info[0]->matmul_cnt(), matmul_cnt); EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner); EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner); + + EvalFixture old_slow_fixture(old_engine, expr, param_repo, false); + EvalFixture old_fixture(old_engine, expr, param_repo, true); + EXPECT_EQUAL(old_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(old_fixture.result(), old_slow_fixture.result()); + info = old_fixture.find_all<DenseMultiMatMulFunction>(); + ASSERT_EQUAL(info.size(), 1u); + EXPECT_TRUE(info[0]->result_is_mutable()); + EXPECT_EQUAL(info[0]->lhs_size(), lhs_size); + EXPECT_EQUAL(info[0]->common_size(), common_size); + EXPECT_EQUAL(info[0]->rhs_size(), rhs_size); + EXPECT_EQUAL(info[0]->matmul_cnt(), matmul_cnt); + EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner); + EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner); } void verify_not_optimized(const vespalib::string &expr) { - EvalFixture slow_fixture(prod_engine, expr, param_repo, false); - EvalFixture fixture(prod_engine, expr, param_repo, true); + EvalFixture slow_fixture(prod_factory, expr, param_repo, false); + EvalFixture fixture(prod_factory, expr, param_repo, true); EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); EXPECT_EQUAL(fixture.result(), slow_fixture.result()); auto info = fixture.find_all<DenseMultiMatMulFunction>(); EXPECT_TRUE(info.empty()); + + EvalFixture old_slow_fixture(old_engine, expr, param_repo, false); + EvalFixture old_fixture(old_engine, expr, param_repo, true); + EXPECT_EQUAL(old_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQUAL(old_fixture.result(), old_slow_fixture.result()); + info = old_fixture.find_all<DenseMultiMatMulFunction>(); + EXPECT_TRUE(info.empty()); } TEST("require that multi matmul can be optimized") { @@ -110,7 +133,7 @@ TEST("require that multi matmul ignores trivial dimensions") { } TEST("require that multi matmul function can be debug dumped") { - EvalFixture fixture(prod_engine, "reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", param_repo, true); + EvalFixture fixture(prod_factory, "reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", param_repo, true); auto info = fixture.find_all<DenseMultiMatMulFunction>(); ASSERT_EQUAL(info.size(), 1u); fprintf(stderr, "%s\n", info[0]->as_string().c_str()); |