diff options
Diffstat (limited to 'eval/src/apps/tensor_conformance/generate.cpp')
-rw-r--r-- | eval/src/apps/tensor_conformance/generate.cpp | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp index 0aba5276ace..f70c472cbcd 100644 --- a/eval/src/apps/tensor_conformance/generate.cpp +++ b/eval/src/apps/tensor_conformance/generate.cpp @@ -169,6 +169,16 @@ void generate_dot_product(TestBuilder &dst) { //----------------------------------------------------------------------------- +void generate_xw_product(TestBuilder &dst) { + auto matrix = spec({x(2),y(3)}, Seq({ 3, 5, 7, 11, 13, 17 })); + dst.add("reduce(a*b,sum,x)", {{"a", spec(x(2), Seq({ 1, 2 }))}, {"b", matrix}}, + spec(y(3), Seq({(1*3+2*11),(1*5+2*13),(1*7+2*17)}))); + dst.add("reduce(a*b,sum,y)", {{"a", spec(y(3), Seq({ 1, 2, 3 }))}, {"b", matrix}}, + spec(x(2), Seq({(1*3+2*5+3*7),(1*11+2*13+3*17)}))); +} + +//----------------------------------------------------------------------------- + void generate_tensor_concat(TestBuilder &dst) { dst.add("concat(a,b,x)", {{"a", spec(10.0)}, {"b", spec(20.0)}}, spec(x(2), Seq({10.0, 20.0}))); dst.add("concat(a,b,x)", {{"a", spec(x(1), Seq({10.0}))}, {"b", spec(20.0)}}, spec(x(2), Seq({10.0, 20.0}))); @@ -218,6 +228,7 @@ Generator::generate(TestBuilder &dst) generate_tensor_map(dst); generate_tensor_join(dst); generate_dot_product(dst); + generate_xw_product(dst); generate_tensor_concat(dst); generate_tensor_rename(dst); generate_tensor_lambda(dst); |