summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@yahoo-inc.com>2018-03-02 13:24:59 +0000
committerArne Juul <arnej@yahoo-inc.com>2018-03-02 13:24:59 +0000
commitdeeceb6caa231e0045bed187f9267316d695f1bd (patch)
tree0241767cf243a6609095a08277d7426297492125 /eval
parentf729e9465f18924b05eb9c652bdf4bbc6052f08a (diff)
mark output from XW product as mutable
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp5
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h2
2 files changed, 7 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
index c794b81f573..daa84bd226f 100644
--- a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
@@ -38,6 +38,7 @@ EvalFixture::ParamRepo make_params() {
.add("con_x5_A", spec({x(5)}, seq))
.add("con_x5_B", spec({x(5)}, seq))
.add("con_x5_C", spec({x(5)}, seq))
+ .add("con_y3_A", spec({y(3)}, seq))
.add("con_x5y3_A", spec({x(5),y(3)}, seq))
.add("con_x5y3_B", spec({x(5),y(3)}, seq))
.add_mutable("mut_dbl_A", spec(1.5))
@@ -144,4 +145,8 @@ TEST("require that mapped tensors are not optimized") {
TEST_DO(verify_not_optimized("mut_x_sparse+mut_x_sparse"));
}
+TEST("require that output from xw product can be optimized") {
+ TEST_DO(verify_optimized("reduce(con_x5_A*con_x5y3_B,sum,x)+con_y3_A", 1, -1));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
index 221c3891775..100d5c4e247 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
@@ -43,6 +43,8 @@ public:
~DenseXWProductFunction() {}
+ bool result_is_mutable() const override { return true; }
+
size_t vectorSize() const { return _vectorSize; }
size_t resultSize() const { return _resultSize; }