diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-04-08 11:14:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-08 11:14:55 +0200 |
commit | 907638cc59eba8f6e80efcc980ca3f35f30c8f20 (patch) | |
tree | 8a2482f5fd71072a4a017fd11137cc5ff686d96e /eval | |
parent | 1edf0b9cba9d0f36dc7f4a615c486f3b5ca735d6 (diff) | |
parent | 71dc1922dd23599b56cab8fafc10c75f00138e0c (diff) |
Merge pull request #17282 from vespa-engine/arnej/optimize-unstable-join-with-number
add optimization for join between bfloat16 tensor and number
Diffstat (limited to 'eval')
3 files changed, 52 insertions, 36 deletions
diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp index a6486de6858..4a0eb00e7e7 100644 --- a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp +++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp @@ -5,8 +5,8 @@ #include <vespa/eval/eval/test/eval_fixture.h> #include <vespa/eval/eval/test/gen_spec.h> #include <vespa/eval/instruction/join_with_number_function.h> - #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/unwind_message.h> using namespace vespalib; using namespace vespalib::eval; @@ -33,11 +33,12 @@ std::ostream &operator<<(std::ostream &os, Primary primary) struct FunInfo { using LookFor = JoinWithNumberFunction; Primary primary; + bool pri_mut; bool inplace; void verify(const EvalFixture &fixture, const LookFor &fun) const { EXPECT_TRUE(fun.result_is_mutable()); EXPECT_EQUAL(fun.primary(), primary); - EXPECT_EQUAL(fun.inplace(), inplace); + EXPECT_EQUAL(fun.primary_is_mutable(), pri_mut); if (inplace) { size_t idx = (fun.primary() == Primary::LHS) ? 0 : 1; EXPECT_EQUAL(fixture.result_value().cells().data, @@ -46,17 +47,18 @@ struct FunInfo { } }; -void verify_optimized(const vespalib::string &expr, Primary primary, bool inplace) { - // fprintf(stderr, "%s\n", expr.c_str()); +void verify_optimized(const vespalib::string &expr, Primary primary, bool pri_mut) { + UNWIND_MSG("optimize %s", expr.c_str()); const CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 2); - FunInfo stable_details{primary, inplace}; + FunInfo stable_details{primary, pri_mut, pri_mut}; TEST_DO(EvalFixture::verify<FunInfo>(expr, {stable_details}, stable_types)); const CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 2); - TEST_DO(EvalFixture::verify<FunInfo>(expr, {}, unstable_types)); + FunInfo unstable_details{primary, pri_mut, false}; + TEST_DO(EvalFixture::verify<FunInfo>(expr, {unstable_details}, unstable_types)); } void verify_not_optimized(const vespalib::string &expr) { - // fprintf(stderr, "%s\n", expr.c_str()); + UNWIND_MSG("not: %s", expr.c_str()); CellTypeSpace all_types(CellTypeUtils::list_types(), 2); TEST_DO(EvalFixture::verify<FunInfo>(expr, {}, all_types)); } diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.cpp b/eval/src/vespa/eval/instruction/join_with_number_function.cpp index c574e3f8ad9..592076f23ce 100644 --- a/eval/src/vespa/eval/instruction/join_with_number_function.cpp +++ b/eval/src/vespa/eval/instruction/join_with_number_function.cpp @@ -14,39 +14,55 @@ namespace vespalib::eval { using Instruction = InterpretedFunction::Instruction; using State = InterpretedFunction::State; +using vespalib::eval::tensor_function::unwrap_param; +using vespalib::eval::tensor_function::wrap_param; namespace { -template <typename CT, bool inplace> -ArrayRef<CT> make_dst_cells(ConstArrayRef<CT> src_cells, Stash &stash) { - if (inplace) { +struct JoinWithNumberParam { + const ValueType res_type; + const join_fun_t function; + JoinWithNumberParam(const ValueType &r, join_fun_t f) : res_type(r), function(f) {} +}; + +template <typename ICT, typename OCT, bool inplace> +ArrayRef<OCT> make_dst_cells(ConstArrayRef<ICT> src_cells, Stash &stash) { + if constexpr (inplace) { + static_assert(std::is_same_v<ICT,OCT>); return unconstify(src_cells); } else { - return stash.create_uninitialized_array<CT>(src_cells.size()); + return stash.create_uninitialized_array<OCT>(src_cells.size()); } } -template <typename CT, typename Fun, bool inplace, bool swap> -void my_number_join_op(State &state, uint64_t param) { +template <typename ICT, typename OCT, typename Fun, bool inplace, bool swap> +void my_number_join_op(State &state, uint64_t param_in) { + const auto ¶m = unwrap_param<JoinWithNumberParam>(param_in); using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type; - OP my_op((join_fun_t)param); + OP my_op(param.function); const Value &tensor = state.peek(swap ? 0 : 1); - CT number = state.peek(swap ? 1 : 0).as_double(); - auto src_cells = tensor.cells().typify<CT>(); - auto dst_cells = make_dst_cells<CT, inplace>(src_cells, state.stash); + OCT number = state.peek(swap ? 1 : 0).as_double(); + auto src_cells = tensor.cells().typify<ICT>(); + auto dst_cells = make_dst_cells<ICT, OCT, inplace>(src_cells, state.stash); apply_op2_vec_num(dst_cells.begin(), src_cells.begin(), number, dst_cells.size(), my_op); if (inplace) { state.pop_pop_push(tensor); } else { - state.pop_pop_push(state.stash.create<ValueView>(tensor.type(), tensor.index(), TypedCells(dst_cells))); + state.pop_pop_push(state.stash.create<ValueView>(param.res_type, tensor.index(), TypedCells(dst_cells))); } } struct SelectJoinWithNumberOp { - template<typename CT, typename Fun, - typename InputIsMutable, typename NumberWasLeft> + template<typename CM, typename Fun, + typename PrimaryMutable, typename NumberWasLeft> static auto invoke() { - return my_number_join_op<CT, Fun, InputIsMutable::value, NumberWasLeft::value>; + constexpr CellMeta icm = CM::value; + constexpr CellMeta num(CellType::DOUBLE, true); + constexpr CellMeta ocm = CellMeta::join(icm, num); + using ICT = CellValueType<icm.cell_type>; + using OCT = CellValueType<ocm.cell_type>; + constexpr bool inplace = (PrimaryMutable::value && std::is_same_v<ICT,OCT>); + return my_number_join_op<ICT, OCT, Fun, inplace, NumberWasLeft::value>; } }; @@ -62,7 +78,7 @@ JoinWithNumberFunction::JoinWithNumberFunction(const Join &original, bool tensor JoinWithNumberFunction::~JoinWithNumberFunction() = default; bool -JoinWithNumberFunction::inplace() const { +JoinWithNumberFunction::primary_is_mutable() const { if (_primary == Primary::LHS) { return lhs().result_is_mutable(); } else { @@ -70,16 +86,19 @@ JoinWithNumberFunction::inplace() const { } } -using MyTypify = TypifyValue<TypifyCellType,vespalib::TypifyBool,operation::TypifyOp2>; +using MyTypify = TypifyValue<TypifyCellMeta,vespalib::TypifyBool,operation::TypifyOp2>; InterpretedFunction::Instruction -JoinWithNumberFunction::compile_self(const ValueBuilderFactory &, Stash &) const +JoinWithNumberFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const { - auto op = typify_invoke<4,MyTypify,SelectJoinWithNumberOp>(result_type().cell_type(), + const auto ¶m = stash.create<JoinWithNumberParam>(result_type(), _function); + auto input_type = (_primary == Primary::LHS) ? lhs().result_type() : rhs().result_type(); + assert(result_type() == ValueType::join(input_type, ValueType::double_type())); + auto op = typify_invoke<4,MyTypify,SelectJoinWithNumberOp>(input_type.cell_meta(), _function, - inplace(), + primary_is_mutable(), (_primary == Primary::RHS)); - return Instruction(op, (uint64_t)(_function)); + return Instruction(op, wrap_param<JoinWithNumberParam>(param)); } void @@ -87,7 +106,7 @@ JoinWithNumberFunction::visit_self(vespalib::ObjectVisitor &visitor) const { Super::visit_self(visitor); visitor.visitBool("tensor_was_right", (_primary == Primary::RHS)); - visitor.visitBool("is_inplace", inplace()); + visitor.visitBool("primary_is_mutable", primary_is_mutable()); } const TensorFunction & @@ -95,17 +114,12 @@ JoinWithNumberFunction::optimize(const TensorFunction &expr, Stash &stash) { if (! expr.result_type().is_double()) { if (const auto *join = as<Join>(expr)) { - const ValueType &result_type = join->result_type(); const TensorFunction &lhs = join->lhs(); const TensorFunction &rhs = join->rhs(); - if (lhs.result_type().is_double() && - (result_type == rhs.result_type())) - { + if (lhs.result_type().is_double()) { return stash.create<JoinWithNumberFunction>(*join, true); } - if (rhs.result_type().is_double() && - (result_type == lhs.result_type())) - { + if (rhs.result_type().is_double()) { return stash.create<JoinWithNumberFunction>(*join, false); } } diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.h b/eval/src/vespa/eval/instruction/join_with_number_function.h index 546ff75b175..9d29ad5eb5d 100644 --- a/eval/src/vespa/eval/instruction/join_with_number_function.h +++ b/eval/src/vespa/eval/instruction/join_with_number_function.h @@ -23,7 +23,7 @@ public: JoinWithNumberFunction(const tensor_function::Join &original_join, bool number_on_left); ~JoinWithNumberFunction(); Primary primary() const { return _primary; } - bool inplace() const; + bool primary_is_mutable() const; bool result_is_mutable() const override { return true; } InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; |