diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-02-03 13:18:27 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-02-03 13:18:27 +0000 |
commit | 62a51d985bde9a25295e7058a50ce52a3be21bce (patch) | |
tree | 318238f3ffcce3d4bbcbaa20ad907d5f13d89ee4 /eval | |
parent | d71935fcbaf1b0d69660d87002947c17fcf4e9a8 (diff) |
use GenSpec in dense_multi_matmul_function_test
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp | 40 |
1 files changed, 24 insertions, 16 deletions
diff --git a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp index ac3abe4f05e..9138668b8c4 100644 --- a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp +++ b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp @@ -4,7 +4,7 @@ #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/eval/test/eval_fixture.h> -#include <vespa/eval/eval/test/tensor_model.hpp> +#include <vespa/eval/eval/test/gen_spec.h> #include <vespa/eval/instruction/dense_multi_matmul_function.h> #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/util/stash.h> @@ -17,21 +17,29 @@ using namespace vespalib::eval::tensor_function; const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); +GenSpec G(std::vector<std::pair<const char *, size_t>> dims) { + GenSpec result; + for (const auto & dim : dims) { + result.idx(dim.first, dim.second); + } + return result; +} + EvalFixture::ParamRepo make_params() { return EvalFixture::ParamRepo() - .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"d", 3}}) // inner/inner - .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"D", 1}, {"a", 2}, {"c", 1}, {"d", 3}, {"e", 1}}) // inner/inner, extra dims - .add_dense({{"B", 1}, {"C", 3}, {"a", 2}, {"d", 3}}) // inner/inner, missing A - .add_dense({{"A", 1}, {"a", 2}, {"d", 3}}) // inner/inner, single mat - .add_dense({{"A", 2}, {"D", 3}, {"a", 2}, {"b", 1}, {"c", 3}}) // inner/inner, inverted - .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"b", 5}}) // inner/outer - .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"b", 5}, {"c", 2}}) // outer/outer - .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"c", 3}}) // not matching + .add_variants("A2B1C3a2d3", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"d",3}})) // inner/inner + .add_variants("A2B1C3D1a2c1d3e1", G({{"A",2}, {"B",1}, {"C",3}, {"D",1}, {"a",2}, {"c",1}, {"d",3}, {"e",1}}))// inner/inner, extra dims + .add_variants("B1C3a2d3", G({{"B",1}, {"C",3}, {"a",2}, {"d",3}})) // inner/inner, missing A + .add_variants("A1a2d3", G({{"A",1}, {"a",2}, {"d",3}})) // inner/inner, single mat + .add_variants("A2D3a2b1c3", G({{"A",2}, {"D",3}, {"a",2}, {"b",1}, {"c",3}})) // inner/inner, inverted + .add_variants("A2B1C3a2b5", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"b",5}})) // inner/outer + .add_variants("A2B1C3b5c2", G({{"A",2}, {"B",1}, {"C",3}, {"b",5}, {"c",2}})) // outer/outer + .add_variants("A2B1C3a2c3", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"c",3}})) // not matching //---------------------------------------------------------------------------------------- - .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"b", 5}, {"d", 3}}) // fixed param - .add_dense({{"B", 1}, {"C", 3}, {"b", 5}, {"d", 3}}) // fixed param, missing A - .add_dense({{"A", 1}, {"b", 5}, {"d", 3}}) // fixed param, single mat - .add_dense({{"B", 5}, {"D", 3}, {"a", 2}, {"b", 1}, {"c", 3}}); // fixed param, inverted + .add_variants("A2B1C3b5d3", G({{"A",2}, {"B",1}, {"C",3}, {"b",5}, {"d",3}})) // fixed param + .add_variants("B1C3b5d3", G({{"B",1}, {"C",3}, {"b",5}, {"d",3}})) // fixed param, missing A + .add_variants("A1b5d3", G({{"A",1}, {"b",5}, {"d",3}})) // fixed param, single mat + .add_variants("B5D3a2b1c3", G({{"B",5}, {"D",3}, {"a",2}, {"b",1}, {"c",3}})); // fixed param, inverted } EvalFixture::ParamRepo param_repo = make_params(); @@ -90,8 +98,8 @@ TEST("require that expressions similar to multi matmul are not optimized") { } TEST("require that multi matmul must have matching cell type") { - TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3f*A2B1C3b5d3,sum,d)")); - TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3f,sum,d)")); + TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3_f*A2B1C3b5d3,sum,d)")); + TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3_f,sum,d)")); } TEST("require that multi matmul must have matching dimension prefix") { @@ -119,7 +127,7 @@ TEST("require that multi matmul function can be debug dumped") { vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, bool float_cells) { - return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_cells ? "f" : "", b.c_str(), float_cells ? "f" : "", common.c_str()); + return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_cells ? "_f" : "", b.c_str(), float_cells ? "_f" : "", common.c_str()); } void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common, |