aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2017-10-21 09:21:58 +0200
committerGitHub <noreply@github.com>2017-10-21 09:21:58 +0200
commit535c1ae687415540b2d5e727957665f02f898de7 (patch)
tree6cd7d6c7ccf7f9ba712a781536be5aabda07175e
parent10aab525559a7d4649cd181e8adca0010c3041dc (diff)
parent81a24b587e6271fb9275a8333bf4b7ce3172f4a3 (diff)
Merge pull request #3836 from vespa-engine/havardpe/static-functions-for-operators
use static functions for low-level operation eval
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp129
-rw-r--r--eval/src/vespa/eval/eval/operation.h86
2 files changed, 134 insertions, 81 deletions
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index 2fb8d4e1ac8..688bcef1d0d 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -69,6 +69,10 @@ template <typename T> void Op1<T>::accept(OperationVisitor &visitor) const {
visitor.visit(static_cast<const T&>(*this));
}
+template <typename T> double Op1<T>::eval(double a) const {
+ return T::f(a);
+}
+
template <typename T> void Op2<T>::accept(OperationVisitor &visitor) const {
visitor.visit(static_cast<const T&>(*this));
}
@@ -77,48 +81,93 @@ template <typename T> std::unique_ptr<BinaryOperation> Op2<T>::clone() const {
return std::make_unique<T>();
}
+template <typename T> double Op2<T>::eval(double a, double b) const {
+ return T::f(a, b);
+}
+
namespace operation {
-double Neg::eval(double a) const { return -a; }
-double Not::eval(double a) const { return (a != 0.0) ? 0.0 : 1.0; }
-double Add::eval(double a, double b) const { return (a + b); }
-double Sub::eval(double a, double b) const { return (a - b); }
-double Mul::eval(double a, double b) const { return (a * b); }
-double Div::eval(double a, double b) const { return (a / b); }
-double Mod::eval(double a, double b) const { return std::fmod(a, b); }
-double Pow::eval(double a, double b) const { return std::pow(a, b); }
-double Equal::eval(double a, double b) const { return (a == b) ? 1.0 : 0.0; }
-double NotEqual::eval(double a, double b) const { return (a != b) ? 1.0 : 0.0; }
-double Approx::eval(double a, double b) const { return approx_equal(a, b); }
-double Less::eval(double a, double b) const { return (a < b) ? 1.0 : 0.0; }
-double LessEqual::eval(double a, double b) const { return (a <= b) ? 1.0 : 0.0; }
-double Greater::eval(double a, double b) const { return (a > b) ? 1.0 : 0.0; }
-double GreaterEqual::eval(double a, double b) const { return (a >= b) ? 1.0 : 0.0; }
-double And::eval(double a, double b) const { return ((a != 0.0) && (b != 0.0)) ? 1.0 : 0.0; }
-double Or::eval(double a, double b) const { return ((a != 0.0) || (b != 0.0)) ? 1.0 : 0.0; }
-double Cos::eval(double a) const { return std::cos(a); }
-double Sin::eval(double a) const { return std::sin(a); }
-double Tan::eval(double a) const { return std::tan(a); }
-double Cosh::eval(double a) const { return std::cosh(a); }
-double Sinh::eval(double a) const { return std::sinh(a); }
-double Tanh::eval(double a) const { return std::tanh(a); }
-double Acos::eval(double a) const { return std::acos(a); }
-double Asin::eval(double a) const { return std::asin(a); }
-double Atan::eval(double a) const { return std::atan(a); }
-double Exp::eval(double a) const { return std::exp(a); }
-double Log10::eval(double a) const { return std::log10(a); }
-double Log::eval(double a) const { return std::log(a); }
-double Sqrt::eval(double a) const { return std::sqrt(a); }
-double Ceil::eval(double a) const { return std::ceil(a); }
-double Fabs::eval(double a) const { return std::fabs(a); }
-double Floor::eval(double a) const { return std::floor(a); }
-double Atan2::eval(double a, double b) const { return std::atan2(a, b); }
-double Ldexp::eval(double a, double b) const { return std::ldexp(a, b); }
-double Min::eval(double a, double b) const { return std::min(a, b); }
-double Max::eval(double a, double b) const { return std::max(a, b); }
-double IsNan::eval(double a) const { return std::isnan(a) ? 1.0 : 0.0; }
-double Relu::eval(double a) const { return std::max(a, 0.0); }
-double Sigmoid::eval(double a) const { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
+double Neg::f(double a) { return -a; }
+double Not::f(double a) { return (a != 0.0) ? 0.0 : 1.0; }
+double Add::f(double a, double b) { return (a + b); }
+double Sub::f(double a, double b) { return (a - b); }
+double Mul::f(double a, double b) { return (a * b); }
+double Div::f(double a, double b) { return (a / b); }
+double Mod::f(double a, double b) { return std::fmod(a, b); }
+double Pow::f(double a, double b) { return std::pow(a, b); }
+double Equal::f(double a, double b) { return (a == b) ? 1.0 : 0.0; }
+double NotEqual::f(double a, double b) { return (a != b) ? 1.0 : 0.0; }
+double Approx::f(double a, double b) { return approx_equal(a, b); }
+double Less::f(double a, double b) { return (a < b) ? 1.0 : 0.0; }
+double LessEqual::f(double a, double b) { return (a <= b) ? 1.0 : 0.0; }
+double Greater::f(double a, double b) { return (a > b) ? 1.0 : 0.0; }
+double GreaterEqual::f(double a, double b) { return (a >= b) ? 1.0 : 0.0; }
+double And::f(double a, double b) { return ((a != 0.0) && (b != 0.0)) ? 1.0 : 0.0; }
+double Or::f(double a, double b) { return ((a != 0.0) || (b != 0.0)) ? 1.0 : 0.0; }
+double Cos::f(double a) { return std::cos(a); }
+double Sin::f(double a) { return std::sin(a); }
+double Tan::f(double a) { return std::tan(a); }
+double Cosh::f(double a) { return std::cosh(a); }
+double Sinh::f(double a) { return std::sinh(a); }
+double Tanh::f(double a) { return std::tanh(a); }
+double Acos::f(double a) { return std::acos(a); }
+double Asin::f(double a) { return std::asin(a); }
+double Atan::f(double a) { return std::atan(a); }
+double Exp::f(double a) { return std::exp(a); }
+double Log10::f(double a) { return std::log10(a); }
+double Log::f(double a) { return std::log(a); }
+double Sqrt::f(double a) { return std::sqrt(a); }
+double Ceil::f(double a) { return std::ceil(a); }
+double Fabs::f(double a) { return std::fabs(a); }
+double Floor::f(double a) { return std::floor(a); }
+double Atan2::f(double a, double b) { return std::atan2(a, b); }
+double Ldexp::f(double a, double b) { return std::ldexp(a, b); }
+double Min::f(double a, double b) { return std::min(a, b); }
+double Max::f(double a, double b) { return std::max(a, b); }
+double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; }
+double Relu::f(double a) { return std::max(a, 0.0); }
+double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
} // namespace vespalib::eval::operation
+template struct Op1<operation::Neg>;
+template struct Op1<operation::Not>;
+template struct Op2<operation::Add>;
+template struct Op2<operation::Sub>;
+template struct Op2<operation::Mul>;
+template struct Op2<operation::Div>;
+template struct Op2<operation::Mod>;
+template struct Op2<operation::Pow>;
+template struct Op2<operation::Equal>;
+template struct Op2<operation::NotEqual>;
+template struct Op2<operation::Approx>;
+template struct Op2<operation::Less>;
+template struct Op2<operation::LessEqual>;
+template struct Op2<operation::Greater>;
+template struct Op2<operation::GreaterEqual>;
+template struct Op2<operation::And>;
+template struct Op2<operation::Or>;
+template struct Op1<operation::Cos>;
+template struct Op1<operation::Sin>;
+template struct Op1<operation::Tan>;
+template struct Op1<operation::Cosh>;
+template struct Op1<operation::Sinh>;
+template struct Op1<operation::Tanh>;
+template struct Op1<operation::Acos>;
+template struct Op1<operation::Asin>;
+template struct Op1<operation::Atan>;
+template struct Op1<operation::Exp>;
+template struct Op1<operation::Log10>;
+template struct Op1<operation::Log>;
+template struct Op1<operation::Sqrt>;
+template struct Op1<operation::Ceil>;
+template struct Op1<operation::Fabs>;
+template struct Op1<operation::Floor>;
+template struct Op2<operation::Atan2>;
+template struct Op2<operation::Ldexp>;
+template struct Op2<operation::Min>;
+template struct Op2<operation::Max>;
+template struct Op1<operation::IsNan>;
+template struct Op1<operation::Relu>;
+template struct Op1<operation::Sigmoid>;
+
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index 52ad9047dd6..456790f9180 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -91,12 +91,14 @@ public:
template <typename T>
struct Op1 : UnaryOperation {
virtual void accept(OperationVisitor &visitor) const override;
+ virtual double eval(double a) const override;
};
template <typename T>
struct Op2 : BinaryOperation {
virtual void accept(OperationVisitor &visitor) const override;
virtual std::unique_ptr<BinaryOperation> clone() const override;
+ virtual double eval(double a, double b) const final override;
};
//-----------------------------------------------------------------------------
@@ -105,7 +107,9 @@ struct Op2 : BinaryOperation {
* A non-trivial custom unary operation. Typically used for closures
* and lambdas.
**/
-struct CustomUnaryOperation : Op1<CustomUnaryOperation> {};
+struct CustomUnaryOperation : Op1<CustomUnaryOperation> {
+ static double f(double);
+};
//-----------------------------------------------------------------------------
@@ -140,46 +144,46 @@ public:
//-----------------------------------------------------------------------------
namespace operation {
-struct Neg : Op1<Neg> { double eval(double a) const override; };
-struct Not : Op1<Not> { double eval(double a) const override; };
-struct Add : Op2<Add> { double eval(double a, double b) const override; };
-struct Sub : Op2<Sub> { double eval(double a, double b) const override; };
-struct Mul : Op2<Mul> { double eval(double a, double b) const override; };
-struct Div : Op2<Div> { double eval(double a, double b) const override; };
-struct Mod : Op2<Mod> { double eval(double a, double b) const override; };
-struct Pow : Op2<Pow> { double eval(double a, double b) const override; };
-struct Equal : Op2<Equal> { double eval(double a, double b) const override; };
-struct NotEqual : Op2<NotEqual> { double eval(double a, double b) const override; };
-struct Approx : Op2<Approx> { double eval(double a, double b) const override; };
-struct Less : Op2<Less> { double eval(double a, double b) const override; };
-struct LessEqual : Op2<LessEqual> { double eval(double a, double b) const override; };
-struct Greater : Op2<Greater> { double eval(double a, double b) const override; };
-struct GreaterEqual : Op2<GreaterEqual> { double eval(double a, double b) const override; };
-struct And : Op2<And> { double eval(double a, double b) const override; };
-struct Or : Op2<Or> { double eval(double a, double b) const override; };
-struct Cos : Op1<Cos> { double eval(double a) const override; };
-struct Sin : Op1<Sin> { double eval(double a) const override; };
-struct Tan : Op1<Tan> { double eval(double a) const override; };
-struct Cosh : Op1<Cosh> { double eval(double a) const override; };
-struct Sinh : Op1<Sinh> { double eval(double a) const override; };
-struct Tanh : Op1<Tanh> { double eval(double a) const override; };
-struct Acos : Op1<Acos> { double eval(double a) const override; };
-struct Asin : Op1<Asin> { double eval(double a) const override; };
-struct Atan : Op1<Atan> { double eval(double a) const override; };
-struct Exp : Op1<Exp> { double eval(double a) const override; };
-struct Log10 : Op1<Log10> { double eval(double a) const override; };
-struct Log : Op1<Log> { double eval(double a) const override; };
-struct Sqrt : Op1<Sqrt> { double eval(double a) const override; };
-struct Ceil : Op1<Ceil> { double eval(double a) const override; };
-struct Fabs : Op1<Fabs> { double eval(double a) const override; };
-struct Floor : Op1<Floor> { double eval(double a) const override; };
-struct Atan2 : Op2<Atan2> { double eval(double a, double b) const override; };
-struct Ldexp : Op2<Ldexp> { double eval(double a, double b) const override; };
-struct Min : Op2<Min> { double eval(double a, double b) const override; };
-struct Max : Op2<Max> { double eval(double a, double b) const override; };
-struct IsNan : Op1<IsNan> { double eval(double a) const override; };
-struct Relu : Op1<Relu> { double eval(double a) const override; };
-struct Sigmoid : Op1<Sigmoid> { double eval(double a) const override; };
+struct Neg : Op1<Neg> { static double f(double a); };
+struct Not : Op1<Not> { static double f(double a); };
+struct Add : Op2<Add> { static double f(double a, double b); };
+struct Sub : Op2<Sub> { static double f(double a, double b); };
+struct Mul : Op2<Mul> { static double f(double a, double b); };
+struct Div : Op2<Div> { static double f(double a, double b); };
+struct Mod : Op2<Mod> { static double f(double a, double b); };
+struct Pow : Op2<Pow> { static double f(double a, double b); };
+struct Equal : Op2<Equal> { static double f(double a, double b); };
+struct NotEqual : Op2<NotEqual> { static double f(double a, double b); };
+struct Approx : Op2<Approx> { static double f(double a, double b); };
+struct Less : Op2<Less> { static double f(double a, double b); };
+struct LessEqual : Op2<LessEqual> { static double f(double a, double b); };
+struct Greater : Op2<Greater> { static double f(double a, double b); };
+struct GreaterEqual : Op2<GreaterEqual> { static double f(double a, double b); };
+struct And : Op2<And> { static double f(double a, double b); };
+struct Or : Op2<Or> { static double f(double a, double b); };
+struct Cos : Op1<Cos> { static double f(double a); };
+struct Sin : Op1<Sin> { static double f(double a); };
+struct Tan : Op1<Tan> { static double f(double a); };
+struct Cosh : Op1<Cosh> { static double f(double a); };
+struct Sinh : Op1<Sinh> { static double f(double a); };
+struct Tanh : Op1<Tanh> { static double f(double a); };
+struct Acos : Op1<Acos> { static double f(double a); };
+struct Asin : Op1<Asin> { static double f(double a); };
+struct Atan : Op1<Atan> { static double f(double a); };
+struct Exp : Op1<Exp> { static double f(double a); };
+struct Log10 : Op1<Log10> { static double f(double a); };
+struct Log : Op1<Log> { static double f(double a); };
+struct Sqrt : Op1<Sqrt> { static double f(double a); };
+struct Ceil : Op1<Ceil> { static double f(double a); };
+struct Fabs : Op1<Fabs> { static double f(double a); };
+struct Floor : Op1<Floor> { static double f(double a); };
+struct Atan2 : Op2<Atan2> { static double f(double a, double b); };
+struct Ldexp : Op2<Ldexp> { static double f(double a, double b); };
+struct Min : Op2<Min> { static double f(double a, double b); };
+struct Max : Op2<Max> { static double f(double a, double b); };
+struct IsNan : Op1<IsNan> { static double f(double a); };
+struct Relu : Op1<Relu> { static double f(double a); };
+struct Sigmoid : Op1<Sigmoid> { static double f(double a); };
} // namespace vespalib::eval::operation
} // namespace vespalib::eval