diff options
author | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-22 12:52:17 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@yahoo-inc.com> | 2017-11-22 12:52:17 +0100 |
commit | a8b35f83c93ef6a524c52ff551254db0a12516ae (patch) | |
tree | e8f120178195a84ce006f13744e6550be54f0bd3 /eval | |
parent | 59ccae15b25d21b50ca310793b1e91d3542b749d (diff) |
Add Elu as a backend ranking function
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/apps/tensor_conformance/generate.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/call_nodes.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/call_nodes.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/interpreted_function.cpp | 5 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/key_gen.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_visitor.h | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operation.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operation.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/eval_spec.cpp | 1 |
12 files changed, 20 insertions, 2 deletions
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp index 993d226c3c6..0aba5276ace 100644 --- a/eval/src/apps/tensor_conformance/generate.cpp +++ b/eval/src/apps/tensor_conformance/generate.cpp @@ -93,6 +93,7 @@ void generate_tensor_map(TestBuilder &dst) { generate_op1_map("isNan(a)", operation::IsNan::f, Mask2Seq(SkipNth(3), 1.0, my_nan), dst); generate_op1_map("relu(a)", operation::Relu::f, Sub2(Div10(N())), dst); generate_op1_map("sigmoid(a)", operation::Sigmoid::f, Sub2(Div10(N())), dst); + generate_op1_map("elu(a)", operation::Elu::f, Sub2(Div10(N())), dst); generate_op1_map("a in [1,5,7,13,42]", MyIn::f, N(), dst); generate_map_expr("map(a,f(a)((a+1)*2))", MyOp::f, Div10(N()), dst); } diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp index 0e54ed183f4..69a9151a2bb 100644 --- a/eval/src/vespa/eval/eval/call_nodes.cpp +++ b/eval/src/vespa/eval/eval/call_nodes.cpp @@ -41,6 +41,7 @@ CallRepo::CallRepo() : _map() { add(nodes::IsNan()); add(nodes::Relu()); add(nodes::Sigmoid()); + add(nodes::Elu()); } } // namespace vespalib::eval::nodes diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h index 4c5611a863a..8210616750e 100644 --- a/eval/src/vespa/eval/eval/call_nodes.h +++ b/eval/src/vespa/eval/eval/call_nodes.h @@ -137,6 +137,7 @@ struct Max : CallHelper<Max> { Max() : Helper("max", 2) {} }; struct IsNan : CallHelper<IsNan> { IsNan() : Helper("isNan", 1) {} }; struct Relu : CallHelper<Relu> { Relu() : Helper("relu", 1) {} }; struct Sigmoid : CallHelper<Sigmoid> { Sigmoid() : Helper("sigmoid", 1) {} }; +struct Elu : CallHelper<Elu> { Elu() : Helper("elu", 1) {} }; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp index 59d6753e142..cfe989e95f8 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.cpp +++ b/eval/src/vespa/eval/eval/interpreted_function.cpp @@ -112,7 +112,7 @@ void op_tensor_concat(State &state, uint64_t param) { //----------------------------------------------------------------------------- template <typename T> -const T &undef_cref() { +const T &undef_cref() { const T *undef = nullptr; assert(undef); return *undef; @@ -423,6 +423,9 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { void visit(const Sigmoid &node) override { make_map_op(node, operation::Sigmoid::f); } + void visit(const Elu &node) override { + make_map_op(node, operation::Elu::f); + } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index e0494e1fe11..86908b331ba 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -81,6 +81,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const IsNan &) override { add_byte(58); } void visit(const Relu &) override { add_byte(59); } void visit(const Sigmoid &) override { add_byte(60); } + void visit(const Elu &) override { add_byte(61); } // traverse bool open(const Node &node) override { node.accept(*this); return true; } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 9355cf7a4e4..f314f8a69cb 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -25,6 +25,7 @@ double vespalib_eval_isnan(double a) { return (std::isnan(a) ? 1.0 : 0.0); } double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal(a, b) ? 1.0 : 0.0); } double vespalib_eval_relu(double a) { return std::max(a, 0.0); } double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } +double vespalib_eval_elu(double a) { return (a < 0) ? std::exp(a) - 1.0 : a; } using vespalib::eval::gbdt::Forest; using resolve_function = double (*)(void *ctx, size_t idx); @@ -586,6 +587,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Sigmoid &) override { make_call_1("vespalib_eval_sigmoid"); } + void visit(const Elu &) override { + make_call_1("vespalib_eval_elu"); + } }; FunctionBuilder::~FunctionBuilder() { } @@ -628,7 +632,7 @@ LLVMWrapper::LLVMWrapper() size_t LLVMWrapper::make_function(size_t num_params, PassParams pass_params, const Node &root, const gbdt::Optimize::Chain &forest_optimizers) -{ +{ std::lock_guard<std::recursive_mutex> guard(_global_llvm_lock); size_t function_id = _functions.size(); FunctionBuilder builder(*_context, *_module, diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h index d3011a54ec0..6860be922f4 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h @@ -18,6 +18,7 @@ extern "C" { double vespalib_eval_approx(double a, double b); double vespalib_eval_relu(double a); double vespalib_eval_sigmoid(double a); + double vespalib_eval_elu(double a); }; namespace vespalib { diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index f86c3e1a84a..0cbc30667f0 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -164,6 +164,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const IsNan &node) override { resolve_op1(node); } void visit(const Relu &node) override { resolve_op1(node); } void visit(const Sigmoid &node) override { resolve_op1(node); } + void visit(const Elu &node) override { resolve_op1(node); } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index c5a6fd51373..10b389db792 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -79,6 +79,7 @@ struct NodeVisitor { virtual void visit(const nodes::IsNan &) = 0; virtual void visit(const nodes::Relu &) = 0; virtual void visit(const nodes::Sigmoid &) = 0; + virtual void visit(const nodes::Elu &) = 0; virtual ~NodeVisitor() {} }; @@ -142,6 +143,7 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::IsNan &) override {} void visit(const nodes::Relu &) override {} void visit(const nodes::Sigmoid &) override {} + void visit(const nodes::Elu &) override {} }; } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index 42b1a110497..d697db40e7b 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -49,6 +49,7 @@ double Max::f(double a, double b) { return std::max(a, b); } double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; } double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } +double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } } // namespace vespalib::eval::operation } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index 12de7c3deb7..52a0fbabd22 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -51,6 +51,7 @@ struct Max { static double f(double a, double b); }; struct IsNan { static double f(double a); }; struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; +struct Elu { static double f(double a); }; } // namespace vespalib::eval::operation } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index d214486cf21..baa3ee989d4 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -150,6 +150,7 @@ EvalSpec::add_function_call_cases() { .add_case({my_nan}, 1.0).add_case({my_inf}, 0.0).add_case({-my_inf}, 0.0); add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); }); add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); }); + add_rule({"a", -1.0, 1.0}, "elu(a)", [](double a){ return (a < 0) ? std::exp(a)-1 : a; }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); }); |