summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-06-13 05:43:59 +0000
committerArne Juul <arnej@verizonmedia.com>2020-06-13 12:56:14 +0000
commit5cb5db5d44fcd4fb271e3b32a7f3f0384e68e497 (patch)
tree2b59b89d53b39437ef35f9b137d5b46b337e8e92 /eval
parentb88fe841de3fd25b28392d1bb1b0d17f45e209f7 (diff)
typify_invoke instead of multiple select levels
* follow pattern suggested in matmul PR
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp47
1 files changed, 16 insertions, 31 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
index a8a896a893a..968308d69c9 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
@@ -76,35 +76,6 @@ void my_cblas_float_xw_product_op(eval::InterpretedFunction::State &state, uint6
state.pop_pop_push(state.stash.create<DenseTensorView>(self.result_type, TypedCells(dst_cells)));
}
-template <bool common_inner>
-struct MyXWProductOp {
- template <typename LCT, typename RCT>
- static auto invoke() { return my_xw_product_op<LCT,RCT,common_inner>; }
-};
-
-template <bool common_inner>
-eval::InterpretedFunction::op_function my_select2(CellType lct, CellType rct) {
- if (lct == rct) {
- if (lct == ValueType::CellType::DOUBLE) {
- return my_cblas_double_xw_product_op<common_inner>;
- }
- if (lct == ValueType::CellType::FLOAT) {
- return my_cblas_float_xw_product_op<common_inner>;
- }
- }
- using Target = MyXWProductOp<common_inner>;
- using MyTypify = eval::TypifyCellType;
- return typify_invoke<2,MyTypify,Target>(lct, rct);
-}
-
-eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct, bool common_inner) {
- if (common_inner) {
- return my_select2<true>(lct, rct);
- } else {
- return my_select2<false>(lct, rct);
- }
-}
-
bool isDenseTensor(const ValueType &type, size_t d) {
return (type.is_dense() && (type.dimensions().size() == d));
}
@@ -134,6 +105,18 @@ const TensorFunction &createDenseXWProduct(const ValueType &res, const TensorFun
common_inner);
}
+struct MyXWProductOp {
+ template<typename R1, typename R2, typename R3> static auto invoke() {
+ if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) {
+ return my_cblas_double_xw_product_op<R3::value>;
+ } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) {
+ return my_cblas_float_xw_product_op<R3::value>;
+ } else {
+ return my_xw_product_op<R1, R2, R3::value>;
+ }
+ }
+};
+
} // namespace vespalib::tensor::<unnamed>
DenseXWProductFunction::Self::Self(const eval::ValueType &result_type_in,
@@ -162,8 +145,10 @@ eval::InterpretedFunction::Instruction
DenseXWProductFunction::compile_self(const TensorEngine &, Stash &stash) const
{
Self &self = stash.create<Self>(result_type(), _vector_size, _result_size);
- auto op = my_select(lhs().result_type().cell_type(),
- rhs().result_type().cell_type(), _common_inner);
+ using MyTypify = TypifyValue<eval::TypifyCellType,vespalib::TypifyBool>;
+ auto op = typify_invoke<3,MyTypify,MyXWProductOp>(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type(),
+ _common_inner);
return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self));
}