diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-06-11 12:11:58 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-06-11 12:11:58 +0000 |
commit | a875850b3ad75b8077bd45d3fac64651dfad7ca3 (patch) | |
tree | 1fca1a2923f4c300cb1ecaaa2c6c03c9d77b2989 /eval/src | |
parent | 4851139ed8e0fabb264ff5e69deb5f364f237160 (diff) |
use common code for simple map and number join
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp | 73 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp | 35 |
2 files changed, 22 insertions, 86 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp index 3f48607cef4..a28c8150d59 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp @@ -2,8 +2,10 @@ #include "dense_number_join_function.h" #include "dense_tensor_view.h" +#include <vespa/vespalib/util/typify.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> namespace vespalib::tensor { @@ -13,6 +15,7 @@ using eval::Value; using eval::ValueType; using eval::TensorFunction; using eval::TensorEngine; +using eval::TypifyCellType; using eval::as; using namespace eval::operation; @@ -26,37 +29,10 @@ using State = eval::InterpretedFunction::State; namespace { -struct CallFun { - join_fun_t function; - CallFun(join_fun_t function_in) : function(function_in) {} - double eval(double a, double b) const { return function(a, b); } -}; - -struct AddFun { - AddFun(join_fun_t) {} - template <typename A, typename B> - auto eval(A a, B b) const { return (a + b); } -}; - -struct MulFun { - MulFun(join_fun_t) {} - template <typename A, typename B> - auto eval(A a, B b) const { return (a * b); } -}; - -// needed for asymmetric operations like Sub and Div -template <typename Fun> -struct SwapFun { - Fun fun; - SwapFun(join_fun_t function_in) : fun(function_in) {} - template <typename A, typename B> - auto eval(A a, B b) const { return fun.eval(b, a); } -}; - template <typename CT, typename Fun> void apply_fun_1_to_n(CT *dst, const CT *pri, CT sec, size_t n, const Fun &fun) { for (size_t i = 0; i < n; ++i) { - dst[i] = fun.eval(pri[i], sec); + dst[i] = fun(pri[i], sec); } } @@ -71,7 +47,7 @@ ArrayRef<CT> make_dst_cells(ConstArrayRef<CT> src_cells, Stash &stash) { template <typename CT, typename Fun, bool inplace, bool swap> void my_number_join_op(State &state, uint64_t param) { - using OP = typename std::conditional<swap,SwapFun<Fun>,Fun>::type; + using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type; OP my_op((join_fun_t)param); const Value &tensor = state.peek(swap ? 0 : 1); CT number = state.peek(swap ? 1 : 0).as_double(); @@ -87,39 +63,13 @@ void my_number_join_op(State &state, uint64_t param) { //----------------------------------------------------------------------------- -template <typename Fun, bool inplace, bool swap> -struct MyNumberJoinOp { - template <typename CT> - static auto get_fun() { return my_number_join_op<CT,Fun,inplace,swap>; } -}; - -template <typename Fun, bool inplace> -op_function my_select_3(ValueType::CellType ct, Primary primary) { - switch (primary) { - case Primary::LHS: return select_1<MyNumberJoinOp<Fun,inplace,false>>(ct); - case Primary::RHS: return select_1<MyNumberJoinOp<Fun,inplace,true>>(ct); +struct MyGetFun { + template <typename R1, typename R2, typename R3, typename R4> static auto invoke() { + return my_number_join_op<R1, R2, R3::value, R4::value>; } - abort(); -} - -template <typename Fun> -op_function my_select_2(ValueType::CellType ct, Primary primary, bool inplace) { - if (inplace) { - return my_select_3<Fun, true>(ct, primary); - } else { - return my_select_3<Fun, false>(ct, primary); - } -} +}; -op_function my_select(ValueType::CellType ct, Primary primary, bool inplace, join_fun_t fun_hint) { - if (fun_hint == Add::f) { - return my_select_2<AddFun>(ct, primary, inplace); - } else if (fun_hint == Mul::f) { - return my_select_2<MulFun>(ct, primary, inplace); - } else { - return my_select_2<CallFun>(ct, primary, inplace); - } -} +using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool>; bool is_dense(const TensorFunction &tf) { return tf.result_type().is_dense(); } bool is_double(const TensorFunction &tf) { return tf.result_type().is_double(); } @@ -154,7 +104,8 @@ DenseNumberJoinFunction::inplace() const Instruction DenseNumberJoinFunction::compile_self(const TensorEngine &, Stash &) const { - auto op = my_select(result_type().cell_type(), _primary, inplace(), function()); + auto op = typify_invoke<4,MyTypify,MyGetFun>(result_type().cell_type(), function(), + inplace(), (_primary == Primary::RHS)); static_assert(sizeof(uint64_t) == sizeof(function())); return Instruction(op, (uint64_t)(function())); } diff --git a/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp index 910e8296afe..784d356da39 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp @@ -2,8 +2,10 @@ #include "dense_simple_map_function.h" #include "dense_tensor_view.h" +#include <vespa/vespalib/util/typify.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> namespace vespalib::tensor { @@ -13,6 +15,7 @@ using eval::Value; using eval::ValueType; using eval::TensorFunction; using eval::TensorEngine; +using eval::TypifyCellType; using eval::as; using namespace eval::operation; @@ -24,16 +27,10 @@ using State = eval::InterpretedFunction::State; namespace { -struct CallFun { - map_fun_t function; - CallFun(map_fun_t function_in) : function(function_in) {} - double eval(double a) const { return function(a); } -}; - template <typename CT, typename Fun> void apply_fun_to_n(CT *dst, const CT *src, size_t n, const Fun &fun) { for (size_t i = 0; i < n; ++i) { - dst[i] = fun.eval(src[i]); + dst[i] = fun(src[i]); } } @@ -60,25 +57,13 @@ void my_simple_map_op(State &state, uint64_t param) { //----------------------------------------------------------------------------- -template <typename Fun, bool inplace> -struct MySimpleMapOp { - template <typename CT> - static auto get_fun() { return my_simple_map_op<CT,Fun,inplace>; } -}; - -template <typename Fun> -op_function my_select_2(ValueType::CellType ct, bool inplace) { - if (inplace) { - return select_1<MySimpleMapOp<Fun,true>>(ct); - } else { - return select_1<MySimpleMapOp<Fun,false>>(ct); +struct MyGetFun { + template <typename R1, typename R2, typename R3> static auto invoke() { + return my_simple_map_op<R1, R2, R3::value>; } -} +}; -op_function my_select(ValueType::CellType ct, bool inplace, map_fun_t fun_hint) { - (void) fun_hint; // ready for function inlining - return my_select_2<CallFun>(ct, inplace); -} +using MyTypify = TypifyValue<TypifyCellType,TypifyOp1,TypifyBool>; } // namespace vespalib::tensor::<unnamed> @@ -96,7 +81,7 @@ DenseSimpleMapFunction::~DenseSimpleMapFunction() = default; Instruction DenseSimpleMapFunction::compile_self(const TensorEngine &, Stash &) const { - auto op = my_select(result_type().cell_type(), inplace(), function()); + auto op = typify_invoke<3,MyTypify,MyGetFun>(result_type().cell_type(), function(), inplace()); static_assert(sizeof(uint64_t) == sizeof(function())); return Instruction(op, (uint64_t)(function())); } |