summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-01-20 11:25:38 +0000
committerArne Juul <arnej@verizonmedia.com>2021-01-20 11:42:13 +0000
commit4b85b4e3922407bbe370e04deacdbfa9aacaea80 (patch)
tree97bec22a97e2b1f0481198a8fccdd6c192bfd38e /eval
parentd3b98e18334be55acf416080cbed723a29b338b7 (diff)
MixedInnerProductFunction must not trigger if extra dimensions are reduced
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp3
-rw-r--r--eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp9
2 files changed, 11 insertions, 1 deletions
diff --git a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp
index 278bda888f4..0c91a7f9b70 100644
--- a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp
+++ b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp
@@ -49,6 +49,7 @@ EvalFixture::ParamRepo make_params() {
.add("mix_y3zm", spec({y(3),z({"c","d"})}, MyVecSeq(3.5)))
.add("mix_x3zm_f", spec(float_cells({x(3),z({"c","d"})}), MyVecSeq(0.5)))
.add("mix_y3zm_f", spec(float_cells({y(3),z({"c","d"})}), MyVecSeq(3.5)))
+ .add("mix_x3y3zm", spec({x(3),y(3),z({"c","d"})}, MyVecSeq(0.0)))
;
}
@@ -131,6 +132,8 @@ TEST(MixedInnerProduct, should_not_trigger_optimizer_for_other_cases) {
assert_not_mixed_optimized("reduce(x3y3z3 * x3y3,sum,x,y)");
assert_not_mixed_optimized("reduce(x3y3 * mix_y3zm,sum,y)");
assert_not_mixed_optimized("reduce(mix_y3zm * x3,sum,x,y)");
+ assert_not_mixed_optimized("reduce(mix_x3y3zm * y3,sum,y,z)");
+ assert_not_mixed_optimized("reduce(mix_x3y3zm * y3,sum,x,y)");
}
TEST(MixedInnerProduct, check_compatibility_with_complex_types) {
diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
index bb68166341c..1d5d446ec59 100644
--- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp
@@ -118,7 +118,14 @@ MixedInnerProductFunction::compatible_types(const ValueType &res, const ValueTyp
dense_dims.pop_back();
mixed_dims.pop_back();
}
- return true;
+ while (! mixed_dims.empty()) {
+ const auto &name = mixed_dims.back().name;
+ if (res.dimension_index(name) == ValueType::Dimension::npos) {
+ return false;
+ }
+ mixed_dims.pop_back();
+ }
+ return (res.mapped_dimensions() == mixed.mapped_dimensions());
}
return false;
}