summaryrefslogtreecommitdiffstats
path: root/eval/src/apps/tensor_conformance/generate.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/apps/tensor_conformance/generate.cpp')
-rw-r--r--eval/src/apps/tensor_conformance/generate.cpp11
1 files changed, 11 insertions, 0 deletions
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp
index 0aba5276ace..f70c472cbcd 100644
--- a/eval/src/apps/tensor_conformance/generate.cpp
+++ b/eval/src/apps/tensor_conformance/generate.cpp
@@ -169,6 +169,16 @@ void generate_dot_product(TestBuilder &dst) {
//-----------------------------------------------------------------------------
+void generate_xw_product(TestBuilder &dst) {
+ auto matrix = spec({x(2),y(3)}, Seq({ 3, 5, 7, 11, 13, 17 }));
+ dst.add("reduce(a*b,sum,x)", {{"a", spec(x(2), Seq({ 1, 2 }))}, {"b", matrix}},
+ spec(y(3), Seq({(1*3+2*11),(1*5+2*13),(1*7+2*17)})));
+ dst.add("reduce(a*b,sum,y)", {{"a", spec(y(3), Seq({ 1, 2, 3 }))}, {"b", matrix}},
+ spec(x(2), Seq({(1*3+2*5+3*7),(1*11+2*13+3*17)})));
+}
+
+//-----------------------------------------------------------------------------
+
void generate_tensor_concat(TestBuilder &dst) {
dst.add("concat(a,b,x)", {{"a", spec(10.0)}, {"b", spec(20.0)}}, spec(x(2), Seq({10.0, 20.0})));
dst.add("concat(a,b,x)", {{"a", spec(x(1), Seq({10.0}))}, {"b", spec(20.0)}}, spec(x(2), Seq({10.0, 20.0})));
@@ -218,6 +228,7 @@ Generator::generate(TestBuilder &dst)
generate_tensor_map(dst);
generate_tensor_join(dst);
generate_dot_product(dst);
+ generate_xw_product(dst);
generate_tensor_concat(dst);
generate_tensor_rename(dst);
generate_tensor_lambda(dst);