diff options
-rw-r--r-- | eval/src/vespa/eval/eval/operation.cpp | 129 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operation.h | 86 |
2 files changed, 134 insertions, 81 deletions
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index 2fb8d4e1ac8..688bcef1d0d 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -69,6 +69,10 @@ template <typename T> void Op1<T>::accept(OperationVisitor &visitor) const { visitor.visit(static_cast<const T&>(*this)); } +template <typename T> double Op1<T>::eval(double a) const { + return T::f(a); +} + template <typename T> void Op2<T>::accept(OperationVisitor &visitor) const { visitor.visit(static_cast<const T&>(*this)); } @@ -77,48 +81,93 @@ template <typename T> std::unique_ptr<BinaryOperation> Op2<T>::clone() const { return std::make_unique<T>(); } +template <typename T> double Op2<T>::eval(double a, double b) const { + return T::f(a, b); +} + namespace operation { -double Neg::eval(double a) const { return -a; } -double Not::eval(double a) const { return (a != 0.0) ? 0.0 : 1.0; } -double Add::eval(double a, double b) const { return (a + b); } -double Sub::eval(double a, double b) const { return (a - b); } -double Mul::eval(double a, double b) const { return (a * b); } -double Div::eval(double a, double b) const { return (a / b); } -double Mod::eval(double a, double b) const { return std::fmod(a, b); } -double Pow::eval(double a, double b) const { return std::pow(a, b); } -double Equal::eval(double a, double b) const { return (a == b) ? 1.0 : 0.0; } -double NotEqual::eval(double a, double b) const { return (a != b) ? 1.0 : 0.0; } -double Approx::eval(double a, double b) const { return approx_equal(a, b); } -double Less::eval(double a, double b) const { return (a < b) ? 1.0 : 0.0; } -double LessEqual::eval(double a, double b) const { return (a <= b) ? 1.0 : 0.0; } -double Greater::eval(double a, double b) const { return (a > b) ? 1.0 : 0.0; } -double GreaterEqual::eval(double a, double b) const { return (a >= b) ? 1.0 : 0.0; } -double And::eval(double a, double b) const { return ((a != 0.0) && (b != 0.0)) ? 1.0 : 0.0; } -double Or::eval(double a, double b) const { return ((a != 0.0) || (b != 0.0)) ? 1.0 : 0.0; } -double Cos::eval(double a) const { return std::cos(a); } -double Sin::eval(double a) const { return std::sin(a); } -double Tan::eval(double a) const { return std::tan(a); } -double Cosh::eval(double a) const { return std::cosh(a); } -double Sinh::eval(double a) const { return std::sinh(a); } -double Tanh::eval(double a) const { return std::tanh(a); } -double Acos::eval(double a) const { return std::acos(a); } -double Asin::eval(double a) const { return std::asin(a); } -double Atan::eval(double a) const { return std::atan(a); } -double Exp::eval(double a) const { return std::exp(a); } -double Log10::eval(double a) const { return std::log10(a); } -double Log::eval(double a) const { return std::log(a); } -double Sqrt::eval(double a) const { return std::sqrt(a); } -double Ceil::eval(double a) const { return std::ceil(a); } -double Fabs::eval(double a) const { return std::fabs(a); } -double Floor::eval(double a) const { return std::floor(a); } -double Atan2::eval(double a, double b) const { return std::atan2(a, b); } -double Ldexp::eval(double a, double b) const { return std::ldexp(a, b); } -double Min::eval(double a, double b) const { return std::min(a, b); } -double Max::eval(double a, double b) const { return std::max(a, b); } -double IsNan::eval(double a) const { return std::isnan(a) ? 1.0 : 0.0; } -double Relu::eval(double a) const { return std::max(a, 0.0); } -double Sigmoid::eval(double a) const { return 1.0 / (1.0 + std::exp(-1.0 * a)); } +double Neg::f(double a) { return -a; } +double Not::f(double a) { return (a != 0.0) ? 0.0 : 1.0; } +double Add::f(double a, double b) { return (a + b); } +double Sub::f(double a, double b) { return (a - b); } +double Mul::f(double a, double b) { return (a * b); } +double Div::f(double a, double b) { return (a / b); } +double Mod::f(double a, double b) { return std::fmod(a, b); } +double Pow::f(double a, double b) { return std::pow(a, b); } +double Equal::f(double a, double b) { return (a == b) ? 1.0 : 0.0; } +double NotEqual::f(double a, double b) { return (a != b) ? 1.0 : 0.0; } +double Approx::f(double a, double b) { return approx_equal(a, b); } +double Less::f(double a, double b) { return (a < b) ? 1.0 : 0.0; } +double LessEqual::f(double a, double b) { return (a <= b) ? 1.0 : 0.0; } +double Greater::f(double a, double b) { return (a > b) ? 1.0 : 0.0; } +double GreaterEqual::f(double a, double b) { return (a >= b) ? 1.0 : 0.0; } +double And::f(double a, double b) { return ((a != 0.0) && (b != 0.0)) ? 1.0 : 0.0; } +double Or::f(double a, double b) { return ((a != 0.0) || (b != 0.0)) ? 1.0 : 0.0; } +double Cos::f(double a) { return std::cos(a); } +double Sin::f(double a) { return std::sin(a); } +double Tan::f(double a) { return std::tan(a); } +double Cosh::f(double a) { return std::cosh(a); } +double Sinh::f(double a) { return std::sinh(a); } +double Tanh::f(double a) { return std::tanh(a); } +double Acos::f(double a) { return std::acos(a); } +double Asin::f(double a) { return std::asin(a); } +double Atan::f(double a) { return std::atan(a); } +double Exp::f(double a) { return std::exp(a); } +double Log10::f(double a) { return std::log10(a); } +double Log::f(double a) { return std::log(a); } +double Sqrt::f(double a) { return std::sqrt(a); } +double Ceil::f(double a) { return std::ceil(a); } +double Fabs::f(double a) { return std::fabs(a); } +double Floor::f(double a) { return std::floor(a); } +double Atan2::f(double a, double b) { return std::atan2(a, b); } +double Ldexp::f(double a, double b) { return std::ldexp(a, b); } +double Min::f(double a, double b) { return std::min(a, b); } +double Max::f(double a, double b) { return std::max(a, b); } +double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; } +double Relu::f(double a) { return std::max(a, 0.0); } +double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } } // namespace vespalib::eval::operation +template struct Op1<operation::Neg>; +template struct Op1<operation::Not>; +template struct Op2<operation::Add>; +template struct Op2<operation::Sub>; +template struct Op2<operation::Mul>; +template struct Op2<operation::Div>; +template struct Op2<operation::Mod>; +template struct Op2<operation::Pow>; +template struct Op2<operation::Equal>; +template struct Op2<operation::NotEqual>; +template struct Op2<operation::Approx>; +template struct Op2<operation::Less>; +template struct Op2<operation::LessEqual>; +template struct Op2<operation::Greater>; +template struct Op2<operation::GreaterEqual>; +template struct Op2<operation::And>; +template struct Op2<operation::Or>; +template struct Op1<operation::Cos>; +template struct Op1<operation::Sin>; +template struct Op1<operation::Tan>; +template struct Op1<operation::Cosh>; +template struct Op1<operation::Sinh>; +template struct Op1<operation::Tanh>; +template struct Op1<operation::Acos>; +template struct Op1<operation::Asin>; +template struct Op1<operation::Atan>; +template struct Op1<operation::Exp>; +template struct Op1<operation::Log10>; +template struct Op1<operation::Log>; +template struct Op1<operation::Sqrt>; +template struct Op1<operation::Ceil>; +template struct Op1<operation::Fabs>; +template struct Op1<operation::Floor>; +template struct Op2<operation::Atan2>; +template struct Op2<operation::Ldexp>; +template struct Op2<operation::Min>; +template struct Op2<operation::Max>; +template struct Op1<operation::IsNan>; +template struct Op1<operation::Relu>; +template struct Op1<operation::Sigmoid>; + } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index 52ad9047dd6..456790f9180 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -91,12 +91,14 @@ public: template <typename T> struct Op1 : UnaryOperation { virtual void accept(OperationVisitor &visitor) const override; + virtual double eval(double a) const override; }; template <typename T> struct Op2 : BinaryOperation { virtual void accept(OperationVisitor &visitor) const override; virtual std::unique_ptr<BinaryOperation> clone() const override; + virtual double eval(double a, double b) const final override; }; //----------------------------------------------------------------------------- @@ -105,7 +107,9 @@ struct Op2 : BinaryOperation { * A non-trivial custom unary operation. Typically used for closures * and lambdas. **/ -struct CustomUnaryOperation : Op1<CustomUnaryOperation> {}; +struct CustomUnaryOperation : Op1<CustomUnaryOperation> { + static double f(double); +}; //----------------------------------------------------------------------------- @@ -140,46 +144,46 @@ public: //----------------------------------------------------------------------------- namespace operation { -struct Neg : Op1<Neg> { double eval(double a) const override; }; -struct Not : Op1<Not> { double eval(double a) const override; }; -struct Add : Op2<Add> { double eval(double a, double b) const override; }; -struct Sub : Op2<Sub> { double eval(double a, double b) const override; }; -struct Mul : Op2<Mul> { double eval(double a, double b) const override; }; -struct Div : Op2<Div> { double eval(double a, double b) const override; }; -struct Mod : Op2<Mod> { double eval(double a, double b) const override; }; -struct Pow : Op2<Pow> { double eval(double a, double b) const override; }; -struct Equal : Op2<Equal> { double eval(double a, double b) const override; }; -struct NotEqual : Op2<NotEqual> { double eval(double a, double b) const override; }; -struct Approx : Op2<Approx> { double eval(double a, double b) const override; }; -struct Less : Op2<Less> { double eval(double a, double b) const override; }; -struct LessEqual : Op2<LessEqual> { double eval(double a, double b) const override; }; -struct Greater : Op2<Greater> { double eval(double a, double b) const override; }; -struct GreaterEqual : Op2<GreaterEqual> { double eval(double a, double b) const override; }; -struct And : Op2<And> { double eval(double a, double b) const override; }; -struct Or : Op2<Or> { double eval(double a, double b) const override; }; -struct Cos : Op1<Cos> { double eval(double a) const override; }; -struct Sin : Op1<Sin> { double eval(double a) const override; }; -struct Tan : Op1<Tan> { double eval(double a) const override; }; -struct Cosh : Op1<Cosh> { double eval(double a) const override; }; -struct Sinh : Op1<Sinh> { double eval(double a) const override; }; -struct Tanh : Op1<Tanh> { double eval(double a) const override; }; -struct Acos : Op1<Acos> { double eval(double a) const override; }; -struct Asin : Op1<Asin> { double eval(double a) const override; }; -struct Atan : Op1<Atan> { double eval(double a) const override; }; -struct Exp : Op1<Exp> { double eval(double a) const override; }; -struct Log10 : Op1<Log10> { double eval(double a) const override; }; -struct Log : Op1<Log> { double eval(double a) const override; }; -struct Sqrt : Op1<Sqrt> { double eval(double a) const override; }; -struct Ceil : Op1<Ceil> { double eval(double a) const override; }; -struct Fabs : Op1<Fabs> { double eval(double a) const override; }; -struct Floor : Op1<Floor> { double eval(double a) const override; }; -struct Atan2 : Op2<Atan2> { double eval(double a, double b) const override; }; -struct Ldexp : Op2<Ldexp> { double eval(double a, double b) const override; }; -struct Min : Op2<Min> { double eval(double a, double b) const override; }; -struct Max : Op2<Max> { double eval(double a, double b) const override; }; -struct IsNan : Op1<IsNan> { double eval(double a) const override; }; -struct Relu : Op1<Relu> { double eval(double a) const override; }; -struct Sigmoid : Op1<Sigmoid> { double eval(double a) const override; }; +struct Neg : Op1<Neg> { static double f(double a); }; +struct Not : Op1<Not> { static double f(double a); }; +struct Add : Op2<Add> { static double f(double a, double b); }; +struct Sub : Op2<Sub> { static double f(double a, double b); }; +struct Mul : Op2<Mul> { static double f(double a, double b); }; +struct Div : Op2<Div> { static double f(double a, double b); }; +struct Mod : Op2<Mod> { static double f(double a, double b); }; +struct Pow : Op2<Pow> { static double f(double a, double b); }; +struct Equal : Op2<Equal> { static double f(double a, double b); }; +struct NotEqual : Op2<NotEqual> { static double f(double a, double b); }; +struct Approx : Op2<Approx> { static double f(double a, double b); }; +struct Less : Op2<Less> { static double f(double a, double b); }; +struct LessEqual : Op2<LessEqual> { static double f(double a, double b); }; +struct Greater : Op2<Greater> { static double f(double a, double b); }; +struct GreaterEqual : Op2<GreaterEqual> { static double f(double a, double b); }; +struct And : Op2<And> { static double f(double a, double b); }; +struct Or : Op2<Or> { static double f(double a, double b); }; +struct Cos : Op1<Cos> { static double f(double a); }; +struct Sin : Op1<Sin> { static double f(double a); }; +struct Tan : Op1<Tan> { static double f(double a); }; +struct Cosh : Op1<Cosh> { static double f(double a); }; +struct Sinh : Op1<Sinh> { static double f(double a); }; +struct Tanh : Op1<Tanh> { static double f(double a); }; +struct Acos : Op1<Acos> { static double f(double a); }; +struct Asin : Op1<Asin> { static double f(double a); }; +struct Atan : Op1<Atan> { static double f(double a); }; +struct Exp : Op1<Exp> { static double f(double a); }; +struct Log10 : Op1<Log10> { static double f(double a); }; +struct Log : Op1<Log> { static double f(double a); }; +struct Sqrt : Op1<Sqrt> { static double f(double a); }; +struct Ceil : Op1<Ceil> { static double f(double a); }; +struct Fabs : Op1<Fabs> { static double f(double a); }; +struct Floor : Op1<Floor> { static double f(double a); }; +struct Atan2 : Op2<Atan2> { static double f(double a, double b); }; +struct Ldexp : Op2<Ldexp> { static double f(double a, double b); }; +struct Min : Op2<Min> { static double f(double a, double b); }; +struct Max : Op2<Max> { static double f(double a, double b); }; +struct IsNan : Op1<IsNan> { static double f(double a); }; +struct Relu : Op1<Relu> { static double f(double a); }; +struct Sigmoid : Op1<Sigmoid> { static double f(double a); }; } // namespace vespalib::eval::operation } // namespace vespalib::eval |