From 4b85b4e3922407bbe370e04deacdbfa9aacaea80 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 20 Jan 2021 11:25:38 +0000 Subject: MixedInnerProductFunction must not trigger if extra dimensions are reduced --- .../mixed_inner_product_function_test.cpp | 3 +++ eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) (limited to 'eval') diff --git a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp index 278bda888f4..0c91a7f9b70 100644 --- a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp +++ b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp @@ -49,6 +49,7 @@ EvalFixture::ParamRepo make_params() { .add("mix_y3zm", spec({y(3),z({"c","d"})}, MyVecSeq(3.5))) .add("mix_x3zm_f", spec(float_cells({x(3),z({"c","d"})}), MyVecSeq(0.5))) .add("mix_y3zm_f", spec(float_cells({y(3),z({"c","d"})}), MyVecSeq(3.5))) + .add("mix_x3y3zm", spec({x(3),y(3),z({"c","d"})}, MyVecSeq(0.0))) ; } @@ -131,6 +132,8 @@ TEST(MixedInnerProduct, should_not_trigger_optimizer_for_other_cases) { assert_not_mixed_optimized("reduce(x3y3z3 * x3y3,sum,x,y)"); assert_not_mixed_optimized("reduce(x3y3 * mix_y3zm,sum,y)"); assert_not_mixed_optimized("reduce(mix_y3zm * x3,sum,x,y)"); + assert_not_mixed_optimized("reduce(mix_x3y3zm * y3,sum,y,z)"); + assert_not_mixed_optimized("reduce(mix_x3y3zm * y3,sum,x,y)"); } TEST(MixedInnerProduct, check_compatibility_with_complex_types) { diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp index bb68166341c..1d5d446ec59 100644 --- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp @@ -118,7 +118,14 @@ MixedInnerProductFunction::compatible_types(const ValueType &res, const ValueTyp dense_dims.pop_back(); mixed_dims.pop_back(); } - return true; + while (! mixed_dims.empty()) { + const auto &name = mixed_dims.back().name; + if (res.dimension_index(name) == ValueType::Dimension::npos) { + return false; + } + mixed_dims.pop_back(); + } + return (res.mapped_dimensions() == mixed.mapped_dimensions()); } return false; } -- cgit v1.2.3