summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-03-15 11:35:37 +0100
committerGitHub <noreply@github.com>2021-03-15 11:35:37 +0100
commit61338b390a89aa9d7c5f9d7c3767e847c5a8a1c8 (patch)
tree7adf037a290daf88322e025af23b03afbe7ff3f2 /eval
parentebbb3bc331924854ba1b6f2ee7c871ee0170f00a (diff)
parentee41dcf75bf47d2686f14c7e256e73b9e1cc09fb (diff)
Merge pull request #16946 from vespa-engine/arnej/unify-via-cell-meta
use CellMeta::join and reduce to compute result cell type
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/cell_type.h9
-rw-r--r--eval/src/vespa/eval/instruction/dense_matmul_function.cpp39
-rw-r--r--eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp22
-rw-r--r--eval/src/vespa/eval/instruction/dense_xw_product_function.cpp36
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp24
5 files changed, 73 insertions, 57 deletions
diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h
index 9036188bad3..4eec639478d 100644
--- a/eval/src/vespa/eval/eval/cell_type.h
+++ b/eval/src/vespa/eval/eval/cell_type.h
@@ -124,6 +124,9 @@ struct CellMeta {
static constexpr CellMeta reduce(CellType input_cell_type, bool output_is_scalar) {
return normalize(input_cell_type, output_is_scalar).decay();
}
+ constexpr CellMeta reduce(bool output_is_scalar) const {
+ return CellMeta::reduce(cell_type, output_is_scalar);
+ }
static constexpr CellMeta join(CellMeta a, CellMeta b) { return unify(a, b).decay(); }
static constexpr CellMeta merge(CellMeta a, CellMeta b) { return unify(a, b).decay(); }
static constexpr CellMeta concat(CellMeta a, CellMeta b) { return unify(a, b); }
@@ -133,12 +136,6 @@ struct CellMeta {
constexpr CellMeta rename() const { return self(); }
};
-template <typename A, typename B> constexpr auto unify_cell_types() {
- constexpr CellMeta a(get_cell_type<A>(), false);
- constexpr CellMeta b(get_cell_type<B>(), false);
- return get_cell_value<CellMeta::unify(a, b).cell_type>();
-}
-
struct TypifyCellType {
template <typename T> using Result = TypifyResultType<T>;
template <typename F> static decltype(auto) resolve(CellType value, F &&f) {
diff --git a/eval/src/vespa/eval/instruction/dense_matmul_function.cpp b/eval/src/vespa/eval/instruction/dense_matmul_function.cpp
index 11ad646d0f5..509a25d28a2 100644
--- a/eval/src/vespa/eval/instruction/dense_matmul_function.cpp
+++ b/eval/src/vespa/eval/instruction/dense_matmul_function.cpp
@@ -14,9 +14,9 @@ using namespace operation;
namespace {
-template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner>
-double my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t common_size, size_t rhs_size) {
- double result = 0.0;
+template <typename LCT, typename RCT, typename OCT, bool lhs_common_inner, bool rhs_common_inner>
+OCT my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t common_size, size_t rhs_size) {
+ OCT result = 0.0;
for (size_t i = 0; i < common_size; ++i) {
result += ((*lhs) * (*rhs));
lhs += (lhs_common_inner ? 1 : lhs_size);
@@ -25,10 +25,9 @@ double my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t co
return result;
}
-template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner>
+template <typename LCT, typename RCT, typename OCT, bool lhs_common_inner, bool rhs_common_inner>
void my_matmul_op(InterpretedFunction::State &state, uint64_t param) {
const DenseMatMulFunction::Self &self = unwrap_param<DenseMatMulFunction::Self>(param);
- using OCT = decltype(unify_cell_types<LCT,RCT>());
auto lhs_cells = state.peek(1).cells().typify<LCT>();
auto rhs_cells = state.peek(0).cells().typify<RCT>();
auto dst_cells = state.stash.create_uninitialized_array<OCT>(self.lhs_size * self.rhs_size);
@@ -37,7 +36,8 @@ void my_matmul_op(InterpretedFunction::State &state, uint64_t param) {
for (size_t i = 0; i < self.lhs_size; ++i) {
const RCT *rhs = rhs_cells.cbegin();
for (size_t j = 0; j < self.rhs_size; ++j) {
- *dst++ = my_dot_product<LCT,RCT,lhs_common_inner,rhs_common_inner>(lhs, rhs, self.lhs_size, self.common_size, self.rhs_size);
+ *dst++ = my_dot_product<LCT,RCT,OCT,lhs_common_inner,rhs_common_inner>(lhs, rhs,
+ self.lhs_size, self.common_size, self.rhs_size);
rhs += (rhs_common_inner ? self.common_size : 1);
}
lhs += (lhs_common_inner ? self.common_size : 1);
@@ -112,14 +112,18 @@ const TensorFunction &create_matmul(const TensorFunction &a, const TensorFunctio
}
}
-struct MyGetFun {
- template<typename R1, typename R2, typename R3, typename R4> static auto invoke() {
- if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) {
- return my_cblas_double_matmul_op<R3::value, R4::value>;
- } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) {
- return my_cblas_float_matmul_op<R3::value, R4::value>;
+struct SelectDenseMatmul {
+ template<typename LCM, typename RCM, typename LhsCommonInner, typename RhsCommonInner> static auto invoke() {
+ constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(false);
+ using LCT = CellValueType<LCM::value.cell_type>;
+ using RCT = CellValueType<RCM::value.cell_type>;
+ using OCT = CellValueType<ocm.cell_type>;
+ if (std::is_same_v<LCT,double> && std::is_same_v<RCT,double>) {
+ return my_cblas_double_matmul_op<LhsCommonInner::value, RhsCommonInner::value>;
+ } else if (std::is_same_v<LCT,float> && std::is_same_v<RCT,float>) {
+ return my_cblas_float_matmul_op<LhsCommonInner::value, RhsCommonInner::value>;
} else {
- return my_matmul_op<R1, R2, R3::value, R4::value>;
+ return my_matmul_op<LCT, RCT, OCT, LhsCommonInner::value, RhsCommonInner::value>;
}
}
};
@@ -161,11 +165,12 @@ DenseMatMulFunction::~DenseMatMulFunction() = default;
InterpretedFunction::Instruction
DenseMatMulFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
- using MyTypify = TypifyValue<TypifyCellType,TypifyBool>;
+ using MyTypify = TypifyValue<TypifyCellMeta,TypifyBool>;
Self &self = stash.create<Self>(result_type(), _lhs_size, _common_size, _rhs_size);
- auto op = typify_invoke<4,MyTypify,MyGetFun>(
- lhs().result_type().cell_type(), rhs().result_type().cell_type(),
- _lhs_common_inner, _rhs_common_inner);
+ auto op = typify_invoke<4,MyTypify,SelectDenseMatmul>(
+ lhs().result_type().cell_meta().not_scalar(),
+ rhs().result_type().cell_meta().not_scalar(),
+ _lhs_common_inner, _rhs_common_inner);
return InterpretedFunction::Instruction(op, wrap_param<DenseMatMulFunction::Self>(self));
}
diff --git a/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp b/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp
index 55c6760a391..2815da40d0a 100644
--- a/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp
+++ b/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp
@@ -32,11 +32,10 @@ struct ExpandParams {
: result_type(result_type_in), result_size(result_size_in), function(function_in) {}
};
-template <typename LCT, typename RCT, typename Fun, bool rhs_inner>
+template <typename LCT, typename RCT, typename DCT, typename Fun, bool rhs_inner>
void my_simple_expand_op(State &state, uint64_t param) {
using ICT = typename std::conditional<rhs_inner,RCT,LCT>::type;
using OCT = typename std::conditional<rhs_inner,LCT,RCT>::type;
- using DCT = decltype(unify_cell_types<LCT,RCT>());
using OP = typename std::conditional<rhs_inner,SwapArgs2<Fun>,Fun>::type;
const ExpandParams &params = unwrap_param<ExpandParams>(param);
OP my_op(params.function);
@@ -53,13 +52,18 @@ void my_simple_expand_op(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-struct MyGetFun {
- template <typename R1, typename R2, typename R3, typename R4> static auto invoke() {
- return my_simple_expand_op<R1, R2, R3, R4::value>;
+struct SelectDenseSimpleExpand {
+ template<typename LCM, typename RCM, typename Fun, typename RhsInner>
+ static auto invoke() {
+ constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value);
+ using LCT = CellValueType<LCM::value.cell_type>;
+ using RCT = CellValueType<RCM::value.cell_type>;
+ using OCT = CellValueType<ocm.cell_type>;
+ return my_simple_expand_op<LCT, RCT, OCT, Fun, RhsInner::value>;
}
};
-using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool>;
+using MyTypify = TypifyValue<TypifyCellMeta,TypifyOp2,TypifyBool>;
//-----------------------------------------------------------------------------
@@ -98,9 +102,9 @@ DenseSimpleExpandFunction::compile_self(const ValueBuilderFactory &, Stash &stas
{
size_t result_size = result_type().dense_subspace_size();
const ExpandParams &params = stash.create<ExpandParams>(result_type(), result_size, function());
- auto op = typify_invoke<4,MyTypify,MyGetFun>(lhs().result_type().cell_type(),
- rhs().result_type().cell_type(),
- function(), (_inner == Inner::RHS));
+ auto op = typify_invoke<4,MyTypify,SelectDenseSimpleExpand>(lhs().result_type().cell_meta().not_scalar(),
+ rhs().result_type().cell_meta().not_scalar(),
+ function(), (_inner == Inner::RHS));
return Instruction(op, wrap_param<ExpandParams>(params));
}
diff --git a/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp b/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp
index b68a3a87ef1..371a5767382 100644
--- a/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp
+++ b/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp
@@ -15,9 +15,9 @@ using namespace operation;
namespace {
-template <typename LCT, typename RCT, bool common_inner>
-double my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t result_size) {
- double result = 0.0;
+template <typename LCT, typename RCT, typename OCT, bool common_inner>
+OCT my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t result_size) {
+ OCT result = 0.0;
for (size_t i = 0; i < vector_size; ++i) {
result += ((*lhs) * (*rhs));
++lhs;
@@ -26,17 +26,16 @@ double my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t
return result;
}
-template <typename LCT, typename RCT, bool common_inner>
+template <typename LCT, typename RCT, typename OCT, bool common_inner>
void my_xw_product_op(InterpretedFunction::State &state, uint64_t param) {
const DenseXWProductFunction::Self &self = unwrap_param<DenseXWProductFunction::Self>(param);
- using OCT = decltype(unify_cell_types<LCT,RCT>());
auto vector_cells = state.peek(1).cells().typify<LCT>();
auto matrix_cells = state.peek(0).cells().typify<RCT>();
auto dst_cells = state.stash.create_uninitialized_array<OCT>(self.result_size);
OCT *dst = dst_cells.begin();
const RCT *matrix = matrix_cells.cbegin();
for (size_t i = 0; i < self.result_size; ++i) {
- *dst++ = my_dot_product<LCT,RCT,common_inner>(vector_cells.cbegin(), matrix, self.vector_size, self.result_size);
+ *dst++ = my_dot_product<LCT,RCT,OCT,common_inner>(vector_cells.cbegin(), matrix, self.vector_size, self.result_size);
matrix += (common_inner ? self.vector_size : 1);
}
state.pop_pop_push(state.stash.create<DenseValueView>(self.result_type, TypedCells(dst_cells)));
@@ -100,13 +99,19 @@ const TensorFunction &createDenseXWProduct(const ValueType &res, const TensorFun
}
struct MyXWProductOp {
- template<typename R1, typename R2, typename R3> static auto invoke() {
- if (std::is_same_v<R1,double> && std::is_same_v<R2,double>) {
- return my_cblas_double_xw_product_op<R3::value>;
- } else if (std::is_same_v<R1,float> && std::is_same_v<R2,float>) {
- return my_cblas_float_xw_product_op<R3::value>;
+ template<typename LCM, typename RCM, typename CommonInner> static auto invoke() {
+ constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value).reduce(false);
+ using LCT = CellValueType<LCM::value.cell_type>;
+ using RCT = CellValueType<RCM::value.cell_type>;
+ using OCT = CellValueType<ocm.cell_type>;
+ if (std::is_same_v<LCT,double> && std::is_same_v<RCT,double>) {
+ assert((std::is_same_v<OCT,double>));
+ return my_cblas_double_xw_product_op<CommonInner::value>;
+ } else if (std::is_same_v<LCT,float> && std::is_same_v<RCT,float>) {
+ assert((std::is_same_v<OCT,float>));
+ return my_cblas_float_xw_product_op<CommonInner::value>;
} else {
- return my_xw_product_op<R1, R2, R3::value>;
+ return my_xw_product_op<LCT, RCT, OCT, CommonInner::value>;
}
}
};
@@ -139,9 +144,10 @@ InterpretedFunction::Instruction
DenseXWProductFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
Self &self = stash.create<Self>(result_type(), _vector_size, _result_size);
- using MyTypify = TypifyValue<TypifyCellType,vespalib::TypifyBool>;
- auto op = typify_invoke<3,MyTypify,MyXWProductOp>(lhs().result_type().cell_type(),
- rhs().result_type().cell_type(),
+ assert(self.result_type.cell_meta().is_scalar == false);
+ using MyTypify = TypifyValue<TypifyCellMeta,vespalib::TypifyBool>;
+ auto op = typify_invoke<3,MyTypify,MyXWProductOp>(lhs().result_type().cell_meta().not_scalar(),
+ rhs().result_type().cell_meta().not_scalar(),
_common_inner);
return InterpretedFunction::Instruction(op, wrap_param<DenseXWProductFunction::Self>(self));
}
diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
index d487ab42d26..21c6f945609 100644
--- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
@@ -55,11 +55,10 @@ ArrayRef<OCT> make_dst_cells(ConstArrayRef<PCT> pri_cells, Stash &stash) {
}
}
-template <typename LCT, typename RCT, typename Fun, bool swap, Overlap overlap, bool pri_mut>
+template <typename LCT, typename RCT, typename OCT, typename Fun, bool swap, Overlap overlap, bool pri_mut>
void my_simple_join_op(State &state, uint64_t param) {
using PCT = typename std::conditional<swap,RCT,LCT>::type;
using SCT = typename std::conditional<swap,LCT,RCT>::type;
- using OCT = decltype(unify_cell_types<LCT,RCT>());
using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
const JoinParams &params = unwrap_param<JoinParams>(param);
OP my_op(params.function);
@@ -94,13 +93,18 @@ void my_simple_join_op(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-struct MyGetFun {
- template <typename R1, typename R2, typename R3, typename R4, typename R5, typename R6> static auto invoke() {
- return my_simple_join_op<R1, R2, R3, R4::value, R5::value, R6::value>;
+struct SelectMixedSimpleJoin {
+ template<typename LCM, typename RCM, typename R3, typename R4, typename R5, typename R6>
+ static auto invoke() {
+ constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value);
+ using LCT = CellValueType<LCM::value.cell_type>;
+ using RCT = CellValueType<RCM::value.cell_type>;
+ using OCT = CellValueType<ocm.cell_type>;
+ return my_simple_join_op<LCT, RCT, OCT, R3, R4::value, R5::value, R6::value>;
}
};
-using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool,TypifyOverlap>;
+using MyTypify = TypifyValue<TypifyCellMeta,TypifyOp2,TypifyBool,TypifyOverlap>;
//-----------------------------------------------------------------------------
@@ -197,10 +201,10 @@ Instruction
MixedSimpleJoinFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
const JoinParams &params = stash.create<JoinParams>(result_type(), factor(), function());
- auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(),
- rhs().result_type().cell_type(),
- function(), (_primary == Primary::RHS),
- _overlap, primary_is_mutable());
+ auto op = typify_invoke<6,MyTypify,SelectMixedSimpleJoin>(lhs().result_type().cell_meta().not_scalar(),
+ rhs().result_type().cell_meta().not_scalar(),
+ function(), (_primary == Primary::RHS),
+ _overlap, primary_is_mutable());
return Instruction(op, wrap_param<JoinParams>(params));
}