summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-06-11 12:11:58 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-06-11 12:11:58 +0000
commita875850b3ad75b8077bd45d3fac64651dfad7ca3 (patch)
tree1fca1a2923f4c300cb1ecaaa2c6c03c9d77b2989 /eval
parent4851139ed8e0fabb264ff5e69deb5f364f237160 (diff)
use common code for simple map and number join
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp73
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp35
2 files changed, 22 insertions, 86 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp
index 3f48607cef4..a28c8150d59 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_number_join_function.cpp
@@ -2,8 +2,10 @@
#include "dense_number_join_function.h"
#include "dense_tensor_view.h"
+#include <vespa/vespalib/util/typify.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
namespace vespalib::tensor {
@@ -13,6 +15,7 @@ using eval::Value;
using eval::ValueType;
using eval::TensorFunction;
using eval::TensorEngine;
+using eval::TypifyCellType;
using eval::as;
using namespace eval::operation;
@@ -26,37 +29,10 @@ using State = eval::InterpretedFunction::State;
namespace {
-struct CallFun {
- join_fun_t function;
- CallFun(join_fun_t function_in) : function(function_in) {}
- double eval(double a, double b) const { return function(a, b); }
-};
-
-struct AddFun {
- AddFun(join_fun_t) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return (a + b); }
-};
-
-struct MulFun {
- MulFun(join_fun_t) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return (a * b); }
-};
-
-// needed for asymmetric operations like Sub and Div
-template <typename Fun>
-struct SwapFun {
- Fun fun;
- SwapFun(join_fun_t function_in) : fun(function_in) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return fun.eval(b, a); }
-};
-
template <typename CT, typename Fun>
void apply_fun_1_to_n(CT *dst, const CT *pri, CT sec, size_t n, const Fun &fun) {
for (size_t i = 0; i < n; ++i) {
- dst[i] = fun.eval(pri[i], sec);
+ dst[i] = fun(pri[i], sec);
}
}
@@ -71,7 +47,7 @@ ArrayRef<CT> make_dst_cells(ConstArrayRef<CT> src_cells, Stash &stash) {
template <typename CT, typename Fun, bool inplace, bool swap>
void my_number_join_op(State &state, uint64_t param) {
- using OP = typename std::conditional<swap,SwapFun<Fun>,Fun>::type;
+ using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
OP my_op((join_fun_t)param);
const Value &tensor = state.peek(swap ? 0 : 1);
CT number = state.peek(swap ? 1 : 0).as_double();
@@ -87,39 +63,13 @@ void my_number_join_op(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-template <typename Fun, bool inplace, bool swap>
-struct MyNumberJoinOp {
- template <typename CT>
- static auto get_fun() { return my_number_join_op<CT,Fun,inplace,swap>; }
-};
-
-template <typename Fun, bool inplace>
-op_function my_select_3(ValueType::CellType ct, Primary primary) {
- switch (primary) {
- case Primary::LHS: return select_1<MyNumberJoinOp<Fun,inplace,false>>(ct);
- case Primary::RHS: return select_1<MyNumberJoinOp<Fun,inplace,true>>(ct);
+struct MyGetFun {
+ template <typename R1, typename R2, typename R3, typename R4> static auto invoke() {
+ return my_number_join_op<R1, R2, R3::value, R4::value>;
}
- abort();
-}
-
-template <typename Fun>
-op_function my_select_2(ValueType::CellType ct, Primary primary, bool inplace) {
- if (inplace) {
- return my_select_3<Fun, true>(ct, primary);
- } else {
- return my_select_3<Fun, false>(ct, primary);
- }
-}
+};
-op_function my_select(ValueType::CellType ct, Primary primary, bool inplace, join_fun_t fun_hint) {
- if (fun_hint == Add::f) {
- return my_select_2<AddFun>(ct, primary, inplace);
- } else if (fun_hint == Mul::f) {
- return my_select_2<MulFun>(ct, primary, inplace);
- } else {
- return my_select_2<CallFun>(ct, primary, inplace);
- }
-}
+using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool>;
bool is_dense(const TensorFunction &tf) { return tf.result_type().is_dense(); }
bool is_double(const TensorFunction &tf) { return tf.result_type().is_double(); }
@@ -154,7 +104,8 @@ DenseNumberJoinFunction::inplace() const
Instruction
DenseNumberJoinFunction::compile_self(const TensorEngine &, Stash &) const
{
- auto op = my_select(result_type().cell_type(), _primary, inplace(), function());
+ auto op = typify_invoke<4,MyTypify,MyGetFun>(result_type().cell_type(), function(),
+ inplace(), (_primary == Primary::RHS));
static_assert(sizeof(uint64_t) == sizeof(function()));
return Instruction(op, (uint64_t)(function()));
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp
index 910e8296afe..784d356da39 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_simple_map_function.cpp
@@ -2,8 +2,10 @@
#include "dense_simple_map_function.h"
#include "dense_tensor_view.h"
+#include <vespa/vespalib/util/typify.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
namespace vespalib::tensor {
@@ -13,6 +15,7 @@ using eval::Value;
using eval::ValueType;
using eval::TensorFunction;
using eval::TensorEngine;
+using eval::TypifyCellType;
using eval::as;
using namespace eval::operation;
@@ -24,16 +27,10 @@ using State = eval::InterpretedFunction::State;
namespace {
-struct CallFun {
- map_fun_t function;
- CallFun(map_fun_t function_in) : function(function_in) {}
- double eval(double a) const { return function(a); }
-};
-
template <typename CT, typename Fun>
void apply_fun_to_n(CT *dst, const CT *src, size_t n, const Fun &fun) {
for (size_t i = 0; i < n; ++i) {
- dst[i] = fun.eval(src[i]);
+ dst[i] = fun(src[i]);
}
}
@@ -60,25 +57,13 @@ void my_simple_map_op(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-template <typename Fun, bool inplace>
-struct MySimpleMapOp {
- template <typename CT>
- static auto get_fun() { return my_simple_map_op<CT,Fun,inplace>; }
-};
-
-template <typename Fun>
-op_function my_select_2(ValueType::CellType ct, bool inplace) {
- if (inplace) {
- return select_1<MySimpleMapOp<Fun,true>>(ct);
- } else {
- return select_1<MySimpleMapOp<Fun,false>>(ct);
+struct MyGetFun {
+ template <typename R1, typename R2, typename R3> static auto invoke() {
+ return my_simple_map_op<R1, R2, R3::value>;
}
-}
+};
-op_function my_select(ValueType::CellType ct, bool inplace, map_fun_t fun_hint) {
- (void) fun_hint; // ready for function inlining
- return my_select_2<CallFun>(ct, inplace);
-}
+using MyTypify = TypifyValue<TypifyCellType,TypifyOp1,TypifyBool>;
} // namespace vespalib::tensor::<unnamed>
@@ -96,7 +81,7 @@ DenseSimpleMapFunction::~DenseSimpleMapFunction() = default;
Instruction
DenseSimpleMapFunction::compile_self(const TensorEngine &, Stash &) const
{
- auto op = my_select(result_type().cell_type(), inplace(), function());
+ auto op = typify_invoke<3,MyTypify,MyGetFun>(result_type().cell_type(), function(), inplace());
static_assert(sizeof(uint64_t) == sizeof(function()));
return Instruction(op, (uint64_t)(function()));
}