diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-06-18 14:17:09 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-06-18 14:17:09 +0000 |
commit | 50924a208ed1682e07a254c93a78987d0d537004 (patch) | |
tree | 411805f4e89a1a7619a17163584a4af958a40ab6 /eval | |
parent | 418716c7441ec9b20cc4b541f6e4d554796fd2a9 (diff) |
added 'erf' function
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/apps/tensor_conformance/generate.cpp | 2 | ||||
-rw-r--r-- | eval/src/tests/eval/inline_operation/inline_operation_test.cpp | 1 | ||||
-rw-r--r-- | eval/src/tests/eval/node_tools/node_tools_test.cpp | 1 | ||||
-rw-r--r-- | eval/src/tests/eval/node_types/node_types_test.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/call_nodes.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/call_nodes.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/key_gen.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 3 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/make_tensor_function.cpp | 3 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_tools.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_types.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/node_visitor.h | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operation.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operation.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/eval_spec.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/tensor_conformance.cpp | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/visit_stuff.cpp | 1 |
17 files changed, 24 insertions, 0 deletions
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp index 7d48307b786..df1c06593cb 100644 --- a/eval/src/apps/tensor_conformance/generate.cpp +++ b/eval/src/apps/tensor_conformance/generate.cpp @@ -100,6 +100,8 @@ void generate_tensor_map(TestBuilder &dst) { generate_op1_map("relu(a)", operation::Relu::f, Sub2(Div16(N())), dst); generate_op1_map("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N())), dst); generate_op1_map("elu(a)", operation::Elu::f, Sub2(Div16(N())), dst); + // TODO(havardpe): add erf when supported by Java + // generate_op1_map("erf(a)", operation::Erf::f, Sub2(Div16(N())), dst); generate_op1_map("a in [1,5,7,13,42]", MyIn::f, N(), dst); generate_map_expr("map(a,f(a)((a+1)*2))", MyOp::f, Div16(N()), dst); } diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp index fe9396398da..de5a3fbf395 100644 --- a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp +++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp @@ -65,6 +65,7 @@ TEST(InlineOperationTest, op1_lambdas_are_recognized) { EXPECT_EQ(as_op1("relu(a)"), &Relu::f); EXPECT_EQ(as_op1("sigmoid(a)"), &Sigmoid::f); EXPECT_EQ(as_op1("elu(a)"), &Elu::f); + EXPECT_EQ(as_op1("erf(a)"), &Erf::f); //------------------------------------------- EXPECT_EQ(as_op1("1/a"), &Inv::f); EXPECT_EQ(as_op1("1.0/a"), &Inv::f); diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp index ca89650127e..13185065f57 100644 --- a/eval/src/tests/eval/node_tools/node_tools_test.cpp +++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp @@ -99,6 +99,7 @@ TEST("require that call node types can be copied") { TEST_DO(verify_copy("relu(a)")); TEST_DO(verify_copy("sigmoid(a)")); TEST_DO(verify_copy("elu(a)")); + TEST_DO(verify_copy("erf(a)")); } TEST("require that tensor node types can NOT be copied (yet)") { diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 7912ec213bc..f595c58ef29 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -191,6 +191,7 @@ TEST("require that various operations resolve appropriate type") { TEST_DO(verify_op1("relu(%s)")); // Relu TEST_DO(verify_op1("sigmoid(%s)")); // Sigmoid TEST_DO(verify_op1("elu(%s)")); // Elu + TEST_DO(verify_op1("erf(%s)")); // Erf } TEST("require that map resolves correct type") { diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp index 69a9151a2bb..2fc25bdbc77 100644 --- a/eval/src/vespa/eval/eval/call_nodes.cpp +++ b/eval/src/vespa/eval/eval/call_nodes.cpp @@ -42,6 +42,7 @@ CallRepo::CallRepo() : _map() { add(nodes::Relu()); add(nodes::Sigmoid()); add(nodes::Elu()); + add(nodes::Erf()); } } // namespace vespalib::eval::nodes diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h index 8210616750e..c5a41756005 100644 --- a/eval/src/vespa/eval/eval/call_nodes.h +++ b/eval/src/vespa/eval/eval/call_nodes.h @@ -138,6 +138,7 @@ struct IsNan : CallHelper<IsNan> { IsNan() : Helper("isNan", 1) {} }; struct Relu : CallHelper<Relu> { Relu() : Helper("relu", 1) {} }; struct Sigmoid : CallHelper<Sigmoid> { Sigmoid() : Helper("sigmoid", 1) {} }; struct Elu : CallHelper<Elu> { Elu() : Helper("elu", 1) {} }; +struct Erf : CallHelper<Erf> { Erf() : Helper("erf", 1) {} }; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index cd31f92f96a..31167be5fe1 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -85,6 +85,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const Relu &) override { add_byte(59); } void visit(const Sigmoid &) override { add_byte(60); } void visit(const Elu &) override { add_byte(61); } + void visit(const Erf &) override { add_byte(62); } // traverse bool open(const Node &node) override { node.accept(*this); return true; } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 89f1789e97b..6f9bee025c9 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -644,6 +644,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Elu &) override { make_call_1("vespalib_eval_elu"); } + void visit(const Erf &) override { + make_call_1("erf"); + } }; FunctionBuilder::~FunctionBuilder() { } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index e80633b5c41..c5b5ca59401 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -344,6 +344,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Elu &node) override { make_map(node, operation::Elu::f); } + void visit(const Erf &node) override { + make_map(node, operation::Erf::f); + } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp index 7bbe095c060..1c194736138 100644 --- a/eval/src/vespa/eval/eval/node_tools.cpp +++ b/eval/src/vespa/eval/eval/node_tools.cpp @@ -180,6 +180,7 @@ struct CopyNode : NodeTraverser, NodeVisitor { void visit(const Relu &node) override { copy_call(node); } void visit(const Sigmoid &node) override { copy_call(node); } void visit(const Elu &node) override { copy_call(node); } + void visit(const Erf &node) override { copy_call(node); } // traverse nodes bool open(const Node &) override { return !error; } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 468b9a58655..cbc96e719e0 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -274,6 +274,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const Relu &node) override { resolve_op1(node); } void visit(const Sigmoid &node) override { resolve_op1(node); } void visit(const Elu &node) override { resolve_op1(node); } + void visit(const Erf &node) override { resolve_op1(node); } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index d3e066c8f53..95a5bec8be7 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -83,6 +83,7 @@ struct NodeVisitor { virtual void visit(const nodes::Relu &) = 0; virtual void visit(const nodes::Sigmoid &) = 0; virtual void visit(const nodes::Elu &) = 0; + virtual void visit(const nodes::Erf &) = 0; virtual ~NodeVisitor() {} }; @@ -150,6 +151,7 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::Relu &) override {} void visit(const nodes::Sigmoid &) override {} void visit(const nodes::Elu &) override {} + void visit(const nodes::Erf &) override {} }; } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index fa8be4d20bc..b97ac3f2261 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -49,6 +49,7 @@ double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; } double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } +double Erf::f(double a) { return std::erf(a); } //----------------------------------------------------------------------------- double Inv::f(double a) { return (1.0 / a); } double Square::f(double a) { return (a * a); } @@ -106,6 +107,7 @@ std::map<vespalib::string,op1_t> make_op1_map() { add_op1(map, "relu(a)", Relu::f); add_op1(map, "sigmoid(a)", Sigmoid::f); add_op1(map, "elu(a)", Elu::f); + add_op1(map, "erf(a)", Erf::f); //------------------------------------- add_op1(map, "1/a", Inv::f); add_op1(map, "a*a", Square::f); diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index 02d3322f867..3170c868214 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -48,6 +48,7 @@ struct IsNan { static double f(double a); }; struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; struct Elu { static double f(double a); }; +struct Erf { static double f(double a); }; //----------------------------------------------------------------------------- struct Inv { static double f(double a); }; struct Square { static double f(double a); }; diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index dbc20dcf606..b1dfa6d3c9c 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -151,6 +151,7 @@ EvalSpec::add_function_call_cases() { add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); }); add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); }); add_rule({"a", -1.0, 1.0}, "elu(a)", [](double a){ return (a < 0) ? std::exp(a)-1 : a; }); + add_rule({"a", -1.0, 1.0}, "erf(a)", [](double a){ return std::erf(a); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); }); diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 41c6dd21e24..95e720cd1a2 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -412,6 +412,7 @@ struct TestContext { TEST_DO(test_map_op("relu(a)", operation::Relu::f, Sub2(Div16(N())))); TEST_DO(test_map_op("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N())))); TEST_DO(test_map_op("elu(a)", operation::Elu::f, Sub2(Div16(N())))); + TEST_DO(test_map_op("erf(a)", operation::Erf::f, Sub2(Div16(N())))); TEST_DO(test_map_op("a in [1,5,7,13,42]", MyIn::f, N())); TEST_DO(test_map_op("(a+1)*2", MyOp::f, Div16(N()))); } diff --git a/eval/src/vespa/eval/eval/visit_stuff.cpp b/eval/src/vespa/eval/eval/visit_stuff.cpp index 821e609ebd0..9306a720837 100644 --- a/eval/src/vespa/eval/eval/visit_stuff.cpp +++ b/eval/src/vespa/eval/eval/visit_stuff.cpp @@ -35,6 +35,7 @@ vespalib::string name_of(map_fun_t fun) { if (fun == operation::Relu::f) return "relu"; if (fun == operation::Sigmoid::f) return "sigmoid"; if (fun == operation::Elu::f) return "elu"; + if (fun == operation::Erf::f) return "erf"; return "[other map function]"; } |