summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-06-18 14:17:09 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-06-18 14:17:09 +0000
commit50924a208ed1682e07a254c93a78987d0d537004 (patch)
tree411805f4e89a1a7619a17163584a4af958a40ab6 /eval
parent418716c7441ec9b20cc4b541f6e4d554796fd2a9 (diff)
added 'erf' function
Diffstat (limited to 'eval')
-rw-r--r--eval/src/apps/tensor_conformance/generate.cpp2
-rw-r--r--eval/src/tests/eval/inline_operation/inline_operation_test.cpp1
-rw-r--r--eval/src/tests/eval/node_tools/node_tools_test.cpp1
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp1
-rw-r--r--eval/src/vespa/eval/eval/call_nodes.cpp1
-rw-r--r--eval/src/vespa/eval/eval/call_nodes.h1
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp1
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp3
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp3
-rw-r--r--eval/src/vespa/eval/eval/node_tools.cpp1
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp1
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h2
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp2
-rw-r--r--eval/src/vespa/eval/eval/operation.h1
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp1
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp1
-rw-r--r--eval/src/vespa/eval/eval/visit_stuff.cpp1
17 files changed, 24 insertions, 0 deletions
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp
index 7d48307b786..df1c06593cb 100644
--- a/eval/src/apps/tensor_conformance/generate.cpp
+++ b/eval/src/apps/tensor_conformance/generate.cpp
@@ -100,6 +100,8 @@ void generate_tensor_map(TestBuilder &dst) {
generate_op1_map("relu(a)", operation::Relu::f, Sub2(Div16(N())), dst);
generate_op1_map("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N())), dst);
generate_op1_map("elu(a)", operation::Elu::f, Sub2(Div16(N())), dst);
+ // TODO(havardpe): add erf when supported by Java
+ // generate_op1_map("erf(a)", operation::Erf::f, Sub2(Div16(N())), dst);
generate_op1_map("a in [1,5,7,13,42]", MyIn::f, N(), dst);
generate_map_expr("map(a,f(a)((a+1)*2))", MyOp::f, Div16(N()), dst);
}
diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
index fe9396398da..de5a3fbf395 100644
--- a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
+++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
@@ -65,6 +65,7 @@ TEST(InlineOperationTest, op1_lambdas_are_recognized) {
EXPECT_EQ(as_op1("relu(a)"), &Relu::f);
EXPECT_EQ(as_op1("sigmoid(a)"), &Sigmoid::f);
EXPECT_EQ(as_op1("elu(a)"), &Elu::f);
+ EXPECT_EQ(as_op1("erf(a)"), &Erf::f);
//-------------------------------------------
EXPECT_EQ(as_op1("1/a"), &Inv::f);
EXPECT_EQ(as_op1("1.0/a"), &Inv::f);
diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp
index ca89650127e..13185065f57 100644
--- a/eval/src/tests/eval/node_tools/node_tools_test.cpp
+++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp
@@ -99,6 +99,7 @@ TEST("require that call node types can be copied") {
TEST_DO(verify_copy("relu(a)"));
TEST_DO(verify_copy("sigmoid(a)"));
TEST_DO(verify_copy("elu(a)"));
+ TEST_DO(verify_copy("erf(a)"));
}
TEST("require that tensor node types can NOT be copied (yet)") {
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index 7912ec213bc..f595c58ef29 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -191,6 +191,7 @@ TEST("require that various operations resolve appropriate type") {
TEST_DO(verify_op1("relu(%s)")); // Relu
TEST_DO(verify_op1("sigmoid(%s)")); // Sigmoid
TEST_DO(verify_op1("elu(%s)")); // Elu
+ TEST_DO(verify_op1("erf(%s)")); // Erf
}
TEST("require that map resolves correct type") {
diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp
index 69a9151a2bb..2fc25bdbc77 100644
--- a/eval/src/vespa/eval/eval/call_nodes.cpp
+++ b/eval/src/vespa/eval/eval/call_nodes.cpp
@@ -42,6 +42,7 @@ CallRepo::CallRepo() : _map() {
add(nodes::Relu());
add(nodes::Sigmoid());
add(nodes::Elu());
+ add(nodes::Erf());
}
} // namespace vespalib::eval::nodes
diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h
index 8210616750e..c5a41756005 100644
--- a/eval/src/vespa/eval/eval/call_nodes.h
+++ b/eval/src/vespa/eval/eval/call_nodes.h
@@ -138,6 +138,7 @@ struct IsNan : CallHelper<IsNan> { IsNan() : Helper("isNan", 1) {} };
struct Relu : CallHelper<Relu> { Relu() : Helper("relu", 1) {} };
struct Sigmoid : CallHelper<Sigmoid> { Sigmoid() : Helper("sigmoid", 1) {} };
struct Elu : CallHelper<Elu> { Elu() : Helper("elu", 1) {} };
+struct Erf : CallHelper<Erf> { Erf() : Helper("erf", 1) {} };
//-----------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index cd31f92f96a..31167be5fe1 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -85,6 +85,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const Relu &) override { add_byte(59); }
void visit(const Sigmoid &) override { add_byte(60); }
void visit(const Elu &) override { add_byte(61); }
+ void visit(const Erf &) override { add_byte(62); }
// traverse
bool open(const Node &node) override { node.accept(*this); return true; }
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 89f1789e97b..6f9bee025c9 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -644,6 +644,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const Elu &) override {
make_call_1("vespalib_eval_elu");
}
+ void visit(const Erf &) override {
+ make_call_1("erf");
+ }
};
FunctionBuilder::~FunctionBuilder() { }
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index e80633b5c41..c5b5ca59401 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -344,6 +344,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const Elu &node) override {
make_map(node, operation::Elu::f);
}
+ void visit(const Erf &node) override {
+ make_map(node, operation::Erf::f);
+ }
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp
index 7bbe095c060..1c194736138 100644
--- a/eval/src/vespa/eval/eval/node_tools.cpp
+++ b/eval/src/vespa/eval/eval/node_tools.cpp
@@ -180,6 +180,7 @@ struct CopyNode : NodeTraverser, NodeVisitor {
void visit(const Relu &node) override { copy_call(node); }
void visit(const Sigmoid &node) override { copy_call(node); }
void visit(const Elu &node) override { copy_call(node); }
+ void visit(const Erf &node) override { copy_call(node); }
// traverse nodes
bool open(const Node &) override { return !error; }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 468b9a58655..cbc96e719e0 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -274,6 +274,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
void visit(const Relu &node) override { resolve_op1(node); }
void visit(const Sigmoid &node) override { resolve_op1(node); }
void visit(const Elu &node) override { resolve_op1(node); }
+ void visit(const Erf &node) override { resolve_op1(node); }
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index d3e066c8f53..95a5bec8be7 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -83,6 +83,7 @@ struct NodeVisitor {
virtual void visit(const nodes::Relu &) = 0;
virtual void visit(const nodes::Sigmoid &) = 0;
virtual void visit(const nodes::Elu &) = 0;
+ virtual void visit(const nodes::Erf &) = 0;
virtual ~NodeVisitor() {}
};
@@ -150,6 +151,7 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::Relu &) override {}
void visit(const nodes::Sigmoid &) override {}
void visit(const nodes::Elu &) override {}
+ void visit(const nodes::Erf &) override {}
};
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index fa8be4d20bc..b97ac3f2261 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -49,6 +49,7 @@ double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; }
double Relu::f(double a) { return std::max(a, 0.0); }
double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; }
+double Erf::f(double a) { return std::erf(a); }
//-----------------------------------------------------------------------------
double Inv::f(double a) { return (1.0 / a); }
double Square::f(double a) { return (a * a); }
@@ -106,6 +107,7 @@ std::map<vespalib::string,op1_t> make_op1_map() {
add_op1(map, "relu(a)", Relu::f);
add_op1(map, "sigmoid(a)", Sigmoid::f);
add_op1(map, "elu(a)", Elu::f);
+ add_op1(map, "erf(a)", Erf::f);
//-------------------------------------
add_op1(map, "1/a", Inv::f);
add_op1(map, "a*a", Square::f);
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index 02d3322f867..3170c868214 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -48,6 +48,7 @@ struct IsNan { static double f(double a); };
struct Relu { static double f(double a); };
struct Sigmoid { static double f(double a); };
struct Elu { static double f(double a); };
+struct Erf { static double f(double a); };
//-----------------------------------------------------------------------------
struct Inv { static double f(double a); };
struct Square { static double f(double a); };
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index dbc20dcf606..b1dfa6d3c9c 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -151,6 +151,7 @@ EvalSpec::add_function_call_cases() {
add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); });
add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); });
add_rule({"a", -1.0, 1.0}, "elu(a)", [](double a){ return (a < 0) ? std::exp(a)-1 : a; });
+ add_rule({"a", -1.0, 1.0}, "erf(a)", [](double a){ return std::erf(a); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); });
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 41c6dd21e24..95e720cd1a2 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -412,6 +412,7 @@ struct TestContext {
TEST_DO(test_map_op("relu(a)", operation::Relu::f, Sub2(Div16(N()))));
TEST_DO(test_map_op("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N()))));
TEST_DO(test_map_op("elu(a)", operation::Elu::f, Sub2(Div16(N()))));
+ TEST_DO(test_map_op("erf(a)", operation::Erf::f, Sub2(Div16(N()))));
TEST_DO(test_map_op("a in [1,5,7,13,42]", MyIn::f, N()));
TEST_DO(test_map_op("(a+1)*2", MyOp::f, Div16(N())));
}
diff --git a/eval/src/vespa/eval/eval/visit_stuff.cpp b/eval/src/vespa/eval/eval/visit_stuff.cpp
index 821e609ebd0..9306a720837 100644
--- a/eval/src/vespa/eval/eval/visit_stuff.cpp
+++ b/eval/src/vespa/eval/eval/visit_stuff.cpp
@@ -35,6 +35,7 @@ vespalib::string name_of(map_fun_t fun) {
if (fun == operation::Relu::f) return "relu";
if (fun == operation::Sigmoid::f) return "sigmoid";
if (fun == operation::Elu::f) return "elu";
+ if (fun == operation::Erf::f) return "erf";
return "[other map function]";
}