// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once #include "operation.h" #include #include #include namespace vespalib::eval::operation { //----------------------------------------------------------------------------- struct CallOp1 { op1_t my_op1; CallOp1(op1_t op1) : my_op1(op1) {} double operator()(double a) const { return my_op1(a); } }; template struct InlineOp1; template <> struct InlineOp1 { InlineOp1(op1_t) {} template constexpr auto operator()(A a) const { return (a * a * a); } }; template <> struct InlineOp1 { InlineOp1(op1_t) {} template constexpr auto operator()(A a) const { return exp(a); } }; template <> struct InlineOp1 { InlineOp1(op1_t) {} template constexpr auto operator()(A a) const { return (A{1}/a); } }; template <> struct InlineOp1 { InlineOp1(op1_t) {} template constexpr auto operator()(A a) const { return std::sqrt(a); } }; template <> struct InlineOp1 { InlineOp1(op1_t) {} template constexpr auto operator()(A a) const { return (a * a); } }; template <> struct InlineOp1 { InlineOp1(op1_t) {} template constexpr auto operator()(A a) const { return std::tanh(a); } }; struct TypifyOp1 { template using Result = TypifyResultType; template static decltype(auto) resolve(op1_t value, F &&f) { if (value == Cube::f) { return f(Result>()); } else if (value == Exp::f) { return f(Result>()); } else if (value == Inv::f) { return f(Result>()); } else if (value == Sqrt::f) { return f(Result>()); } else if (value == Square::f) { return f(Result>()); } else if (value == Tanh::f) { return f(Result>()); } else { return f(Result()); } } }; //----------------------------------------------------------------------------- struct CallOp2 { op2_t my_op2; CallOp2(op2_t op2) : my_op2(op2) {} op2_t get() const { return my_op2; } double operator()(double a, double b) const { return my_op2(a, b); } }; template struct SwapArgs2 { Op2 op2; SwapArgs2(op2_t op2_in) : op2(op2_in) {} template constexpr auto operator()(A a, B b) const { return op2(b, a); } }; template struct InlineOp2; template <> struct InlineOp2 { InlineOp2(op2_t) {} template constexpr auto operator()(A a, B b) const { return (a+b); } }; template <> struct InlineOp2
{ InlineOp2(op2_t) {} template constexpr auto operator()(A a, B b) const { return (a/b); } }; template <> struct InlineOp2 { InlineOp2(op2_t) {} template constexpr auto operator()(A a, B b) const { return (a*b); } }; template <> struct InlineOp2 { InlineOp2(op2_t) {} constexpr float operator()(float a, float b) const { return std::pow(a,b); } constexpr double operator()(float a, double b) const { return std::pow(a,b); } constexpr double operator()(double a, float b) const { return std::pow(a,b); } constexpr double operator()(double a, double b) const { return std::pow(a,b); } }; template <> struct InlineOp2 { InlineOp2(op2_t) {} template constexpr auto operator()(A a, B b) const { return (a-b); } }; struct TypifyOp2 { template using Result = TypifyResultType; template static decltype(auto) resolve(op2_t value, F &&f) { if (value == Add::f) { return f(Result>()); } else if (value == Div::f) { return f(Result>()); } else if (value == Mul::f) { return f(Result>()); } else if (value == Pow::f) { return f(Result>()); } else if (value == Sub::f) { return f(Result>()); } else { return f(Result()); } } }; //----------------------------------------------------------------------------- template void apply_op1_vec(A *dst, const A *src, size_t n, OP1 &&f) { for (size_t i = 0; i < n; ++i) { dst[i] = f(src[i]); } } template void apply_op2_vec_num(D *dst, const A *a, B b, size_t n, OP2 &&f) { for (size_t i = 0; i < n; ++i) { dst[i] = f(a[i], b); } } template void apply_op2_vec_vec(D *dst, const A *a, const B *b, size_t n, OP2 &&f) { for (size_t i = 0; i < n; ++i) { dst[i] = f(a[i], b[i]); } } //----------------------------------------------------------------------------- template struct DotProduct { static double apply(const LCT * lhs, const RCT * rhs, size_t count) { double result = 0.0; for (size_t i = 0; i < count; ++i) { result += lhs[i] * rhs[i]; } return result; } }; template <> struct DotProduct { static float apply(const float * lhs, const float * rhs, size_t count) { return cblas_sdot(count, lhs, 1, rhs, 1); } }; template <> struct DotProduct { static double apply(const double * lhs, const double * rhs, size_t count) { return cblas_ddot(count, lhs, 1, rhs, 1); } }; //----------------------------------------------------------------------------- }