summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-02-03 13:18:27 +0000
committerArne Juul <arnej@verizonmedia.com>2021-02-03 13:18:27 +0000
commit62a51d985bde9a25295e7058a50ce52a3be21bce (patch)
tree318238f3ffcce3d4bbcbaa20ad907d5f13d89ee4 /eval
parentd71935fcbaf1b0d69660d87002947c17fcf4e9a8 (diff)
use GenSpec in dense_multi_matmul_function_test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp40
1 files changed, 24 insertions, 16 deletions
diff --git a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
index ac3abe4f05e..9138668b8c4 100644
--- a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
+++ b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
@@ -4,7 +4,7 @@
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
-#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/eval/instruction/dense_multi_matmul_function.h>
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/util/stash.h>
@@ -17,21 +17,29 @@ using namespace vespalib::eval::tensor_function;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+GenSpec G(std::vector<std::pair<const char *, size_t>> dims) {
+ GenSpec result;
+ for (const auto & dim : dims) {
+ result.idx(dim.first, dim.second);
+ }
+ return result;
+}
+
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
- .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"d", 3}}) // inner/inner
- .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"D", 1}, {"a", 2}, {"c", 1}, {"d", 3}, {"e", 1}}) // inner/inner, extra dims
- .add_dense({{"B", 1}, {"C", 3}, {"a", 2}, {"d", 3}}) // inner/inner, missing A
- .add_dense({{"A", 1}, {"a", 2}, {"d", 3}}) // inner/inner, single mat
- .add_dense({{"A", 2}, {"D", 3}, {"a", 2}, {"b", 1}, {"c", 3}}) // inner/inner, inverted
- .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"b", 5}}) // inner/outer
- .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"b", 5}, {"c", 2}}) // outer/outer
- .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"c", 3}}) // not matching
+ .add_variants("A2B1C3a2d3", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"d",3}})) // inner/inner
+ .add_variants("A2B1C3D1a2c1d3e1", G({{"A",2}, {"B",1}, {"C",3}, {"D",1}, {"a",2}, {"c",1}, {"d",3}, {"e",1}}))// inner/inner, extra dims
+ .add_variants("B1C3a2d3", G({{"B",1}, {"C",3}, {"a",2}, {"d",3}})) // inner/inner, missing A
+ .add_variants("A1a2d3", G({{"A",1}, {"a",2}, {"d",3}})) // inner/inner, single mat
+ .add_variants("A2D3a2b1c3", G({{"A",2}, {"D",3}, {"a",2}, {"b",1}, {"c",3}})) // inner/inner, inverted
+ .add_variants("A2B1C3a2b5", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"b",5}})) // inner/outer
+ .add_variants("A2B1C3b5c2", G({{"A",2}, {"B",1}, {"C",3}, {"b",5}, {"c",2}})) // outer/outer
+ .add_variants("A2B1C3a2c3", G({{"A",2}, {"B",1}, {"C",3}, {"a",2}, {"c",3}})) // not matching
//----------------------------------------------------------------------------------------
- .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"b", 5}, {"d", 3}}) // fixed param
- .add_dense({{"B", 1}, {"C", 3}, {"b", 5}, {"d", 3}}) // fixed param, missing A
- .add_dense({{"A", 1}, {"b", 5}, {"d", 3}}) // fixed param, single mat
- .add_dense({{"B", 5}, {"D", 3}, {"a", 2}, {"b", 1}, {"c", 3}}); // fixed param, inverted
+ .add_variants("A2B1C3b5d3", G({{"A",2}, {"B",1}, {"C",3}, {"b",5}, {"d",3}})) // fixed param
+ .add_variants("B1C3b5d3", G({{"B",1}, {"C",3}, {"b",5}, {"d",3}})) // fixed param, missing A
+ .add_variants("A1b5d3", G({{"A",1}, {"b",5}, {"d",3}})) // fixed param, single mat
+ .add_variants("B5D3a2b1c3", G({{"B",5}, {"D",3}, {"a",2}, {"b",1}, {"c",3}})); // fixed param, inverted
}
EvalFixture::ParamRepo param_repo = make_params();
@@ -90,8 +98,8 @@ TEST("require that expressions similar to multi matmul are not optimized") {
}
TEST("require that multi matmul must have matching cell type") {
- TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3f*A2B1C3b5d3,sum,d)"));
- TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3f,sum,d)"));
+ TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3_f*A2B1C3b5d3,sum,d)"));
+ TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3_f,sum,d)"));
}
TEST("require that multi matmul must have matching dimension prefix") {
@@ -119,7 +127,7 @@ TEST("require that multi matmul function can be debug dumped") {
vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
bool float_cells)
{
- return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_cells ? "f" : "", b.c_str(), float_cells ? "f" : "", common.c_str());
+ return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_cells ? "_f" : "", b.c_str(), float_cells ? "_f" : "", common.c_str());
}
void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,