summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-06-16 11:57:57 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-06-18 13:53:21 +0000
commit13d53f151b8fef06e77c82aa380a0d11bf053a79 (patch)
treeab5d74823dd18582d4004817e42c3e9740717d5c /eval
parent1bdde0b8222028ba205f10ca5efaccc791908de3 (diff)
add 'bit(a,b)' math function
Diffstat (limited to 'eval')
-rw-r--r--eval/src/apps/tensor_conformance/generate.cpp30
-rw-r--r--eval/src/tests/eval/inline_operation/inline_operation_test.cpp1
-rw-r--r--eval/src/tests/eval/node_tools/node_tools_test.cpp1
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp1
-rw-r--r--eval/src/vespa/eval/eval/call_nodes.cpp1
-rw-r--r--eval/src/vespa/eval/eval/call_nodes.h1
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp1
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp9
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.h1
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp3
-rw-r--r--eval/src/vespa/eval/eval/node_tools.cpp1
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp1
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h2
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp7
-rw-r--r--eval/src/vespa/eval/eval/operation.h1
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp11
-rw-r--r--eval/src/vespa/eval/eval/test/reference_evaluation.cpp3
-rw-r--r--eval/src/vespa/eval/eval/visit_stuff.cpp1
18 files changed, 76 insertions, 0 deletions
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp
index 9ae33c1234f..01936a879d0 100644
--- a/eval/src/apps/tensor_conformance/generate.cpp
+++ b/eval/src/apps/tensor_conformance/generate.cpp
@@ -232,11 +232,24 @@ void generate_join_expr(const vespalib::string &expr, const Sequence &seq, TestB
}
}
+void generate_join_expr(const vespalib::string &expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) {
+ for (const auto &layouts: join_layouts) {
+ GenSpec a = GenSpec::from_desc(layouts.first).seq(seq_a);
+ GenSpec b = GenSpec::from_desc(layouts.second).seq(seq_b);
+ generate(expr, a, b, dst);
+ }
+}
+
void generate_op2_join(const vespalib::string &op2_expr, const Sequence &seq, TestBuilder &dst) {
generate_join_expr(op2_expr, seq, dst);
generate_join_expr(fmt("join(a,b,f(a,b)(%s))", op2_expr.c_str()), seq, dst);
}
+[[maybe_unused]] void generate_op2_join(const vespalib::string &op2_expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) {
+ generate_join_expr(op2_expr, seq_a, seq_b, dst);
+ generate_join_expr(fmt("join(a,b,f(a,b)(%s))", op2_expr.c_str()), seq_a, seq_b, dst);
+}
+
void generate_join(TestBuilder &dst) {
generate_op2_join("a+b", Div16(N()), dst);
generate_op2_join("a-b", Div16(N()), dst);
@@ -259,6 +272,8 @@ void generate_join(TestBuilder &dst) {
generate_op2_join("fmod(a,b)", Div16(N()), dst);
generate_op2_join("min(a,b)", Div16(N()), dst);
generate_op2_join("max(a,b)", Div16(N()), dst);
+ // TODO: test bit(a,b) when implemented in Java
+ // generate_op2_join("bit(a,b)", Seq({-128, -43, -1, 0, 85, 127}), Seq({0, 1, 2, 3, 4, 5, 6, 7}), dst);
// inverted lambda
generate_join_expr("join(a,b,f(a,b)(b-a))", Div16(N()), dst);
// custom lambda
@@ -276,11 +291,24 @@ void generate_merge_expr(const vespalib::string &expr, const Sequence &seq, Test
}
}
+void generate_merge_expr(const vespalib::string &expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) {
+ for (const auto &layouts: merge_layouts) {
+ GenSpec a = GenSpec::from_desc(layouts.first).seq(seq_a);
+ GenSpec b = GenSpec::from_desc(layouts.second).seq(seq_b);
+ generate(expr, a, b, dst);
+ }
+}
+
void generate_op2_merge(const vespalib::string &op2_expr, const Sequence &seq, TestBuilder &dst) {
generate_merge_expr(op2_expr, seq, dst);
generate_merge_expr(fmt("merge(a,b,f(a,b)(%s))", op2_expr.c_str()), seq, dst);
}
+[[maybe_unused]] void generate_op2_merge(const vespalib::string &op2_expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) {
+ generate_merge_expr(op2_expr, seq_a, seq_b, dst);
+ generate_merge_expr(fmt("merge(a,b,f(a,b)(%s))", op2_expr.c_str()), seq_a, seq_b, dst);
+}
+
void generate_merge(TestBuilder &dst) {
generate_op2_merge("a+b", Div16(N()), dst);
generate_op2_merge("a-b", Div16(N()), dst);
@@ -303,6 +331,8 @@ void generate_merge(TestBuilder &dst) {
generate_op2_merge("fmod(a,b)", Div16(N()), dst);
generate_op2_merge("min(a,b)", Div16(N()), dst);
generate_op2_merge("max(a,b)", Div16(N()), dst);
+ // TODO: test bit(a,b) when implemented in Java
+ // generate_op2_merge("bit(a,b)", Seq({-128, -43, -1, 0, 85, 127}), Seq({0, 1, 2, 3, 4, 5, 6, 7}), dst);
// inverted lambda
generate_merge_expr("merge(a,b,f(a,b)(b-a))", Div16(N()), dst);
// custom lambda
diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
index de5a3fbf395..ae5f503b680 100644
--- a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
+++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
@@ -115,6 +115,7 @@ TEST(InlineOperationTest, op2_lambdas_are_recognized) {
EXPECT_EQ(as_op2("fmod(a,b)"), &Mod::f);
EXPECT_EQ(as_op2("min(a,b)"), &Min::f);
EXPECT_EQ(as_op2("max(a,b)"), &Max::f);
+ EXPECT_EQ(as_op2("bit(a,b)"), &Bit::f);
}
TEST(InlineOperationTest, op2_lambdas_are_recognized_with_different_parameter_names) {
diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp
index 13185065f57..e8296c01d73 100644
--- a/eval/src/tests/eval/node_tools/node_tools_test.cpp
+++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp
@@ -100,6 +100,7 @@ TEST("require that call node types can be copied") {
TEST_DO(verify_copy("sigmoid(a)"));
TEST_DO(verify_copy("elu(a)"));
TEST_DO(verify_copy("erf(a)"));
+ TEST_DO(verify_copy("bit(a,b)"));
}
TEST("require that tensor node types can NOT be copied (yet)") {
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index 504f66ac717..b2373f0d8f5 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -218,6 +218,7 @@ TEST("require that various operations resolve appropriate type") {
TEST_DO(verify_op1("sigmoid(%s)")); // Sigmoid
TEST_DO(verify_op1("elu(%s)")); // Elu
TEST_DO(verify_op1("erf(%s)")); // Erf
+ TEST_DO(verify_op2("bit(%s,%s)")); // Bit
}
TEST("require that map resolves correct type") {
diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp
index 2fc25bdbc77..798583cf89a 100644
--- a/eval/src/vespa/eval/eval/call_nodes.cpp
+++ b/eval/src/vespa/eval/eval/call_nodes.cpp
@@ -43,6 +43,7 @@ CallRepo::CallRepo() : _map() {
add(nodes::Sigmoid());
add(nodes::Elu());
add(nodes::Erf());
+ add(nodes::Bit());
}
} // namespace vespalib::eval::nodes
diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h
index 2a7d4173e64..945aba69596 100644
--- a/eval/src/vespa/eval/eval/call_nodes.h
+++ b/eval/src/vespa/eval/eval/call_nodes.h
@@ -139,6 +139,7 @@ struct Relu : CallHelper<Relu> { Relu() : Helper("relu", 1) {} };
struct Sigmoid : CallHelper<Sigmoid> { Sigmoid() : Helper("sigmoid", 1) {} };
struct Elu : CallHelper<Elu> { Elu() : Helper("elu", 1) {} };
struct Erf : CallHelper<Erf> { Erf() : Helper("erf", 1) {} };
+struct Bit : CallHelper<Bit> { Bit() : Helper("bit", 2) {} };
//-----------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index a8fb205f124..a40a8887119 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -87,6 +87,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const Sigmoid &) override { add_byte(60); }
void visit(const Elu &) override { add_byte(61); }
void visit(const Erf &) override { add_byte(62); }
+ void visit(const Bit &) override { add_byte(63); }
// traverse
bool open(const Node &node) override { node.accept(*this); return true; }
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 42911a56c14..9a99c4fedd7 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -29,6 +29,12 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal
double vespalib_eval_relu(double a) { return std::max(a, 0.0); }
double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double vespalib_eval_elu(double a) { return (a < 0) ? std::exp(a) - 1.0 : a; }
+double vespalib_eval_bit(double a, double b) {
+ // must match Bit::f
+ int8_t value = (int8_t) a;
+ uint32_t n = (uint32_t) b;
+ return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0;
+}
using vespalib::eval::gbdt::Forest;
using resolve_function = double (*)(void *ctx, size_t idx);
@@ -646,6 +652,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const Erf &) override {
make_call_1("erf");
}
+ void visit(const Bit &) override {
+ make_call_2("vespalib_eval_bit");
+ }
};
FunctionBuilder::~FunctionBuilder() { }
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
index 040c0bdb73f..e04b477750d 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h
@@ -19,6 +19,7 @@ extern "C" {
double vespalib_eval_relu(double a);
double vespalib_eval_sigmoid(double a);
double vespalib_eval_elu(double a);
+ double vespalib_eval_bit(double a, double b);
};
namespace vespalib::eval {
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index b65c3d5aaa7..498be2a738b 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -357,6 +357,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const Erf &node) override {
make_map(node, operation::Erf::f);
}
+ void visit(const Bit &node) override {
+ make_join(node, operation::Bit::f);
+ }
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp
index e7341bc1755..fa2d16a2271 100644
--- a/eval/src/vespa/eval/eval/node_tools.cpp
+++ b/eval/src/vespa/eval/eval/node_tools.cpp
@@ -182,6 +182,7 @@ struct CopyNode : NodeTraverser, NodeVisitor {
void visit(const Sigmoid &node) override { copy_call(node); }
void visit(const Elu &node) override { copy_call(node); }
void visit(const Erf &node) override { copy_call(node); }
+ void visit(const Bit &node) override { copy_call(node); }
// traverse nodes
bool open(const Node &) override { return !error; }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 63da6d79c6f..8622fd734f1 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -278,6 +278,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
void visit(const Sigmoid &node) override { resolve_op1(node); }
void visit(const Elu &node) override { resolve_op1(node); }
void visit(const Erf &node) override { resolve_op1(node); }
+ void visit(const Bit &node) override { resolve_op2(node); }
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index 172cd48fe2a..475bbf5405c 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -85,6 +85,7 @@ struct NodeVisitor {
virtual void visit(const nodes::Sigmoid &) = 0;
virtual void visit(const nodes::Elu &) = 0;
virtual void visit(const nodes::Erf &) = 0;
+ virtual void visit(const nodes::Bit &) = 0;
virtual ~NodeVisitor() {}
};
@@ -154,6 +155,7 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::Sigmoid &) override {}
void visit(const nodes::Elu &) override {}
void visit(const nodes::Erf &) override {}
+ void visit(const nodes::Bit &) override {}
};
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index b97ac3f2261..36b922539f4 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -50,6 +50,12 @@ double Relu::f(double a) { return std::max(a, 0.0); }
double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; }
double Erf::f(double a) { return std::erf(a); }
+double Bit::f(double a, double b) {
+ // must match vespalib_eval_bit
+ int8_t value = (int8_t) a;
+ uint32_t n = (uint32_t) b;
+ return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0;
+}
//-----------------------------------------------------------------------------
double Inv::f(double a) { return (1.0 / a); }
double Square::f(double a) { return (a * a); }
@@ -143,6 +149,7 @@ std::map<vespalib::string,op2_t> make_op2_map() {
add_op2(map, "fmod(a,b)", Mod::f);
add_op2(map, "min(a,b)", Min::f);
add_op2(map, "max(a,b)", Max::f);
+ add_op2(map, "bit(a,b)", Bit::f);
return map;
}
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index 3170c868214..438b510b714 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -49,6 +49,7 @@ struct Relu { static double f(double a); };
struct Sigmoid { static double f(double a); };
struct Elu { static double f(double a); };
struct Erf { static double f(double a); };
+struct Bit { static double f(double a, double b); };
//-----------------------------------------------------------------------------
struct Inv { static double f(double a); };
struct Square { static double f(double a); };
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index 63a3a23d9ae..5d51a1d23b5 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -158,6 +158,17 @@ EvalSpec::add_function_call_cases() {
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "fmod(a,b)", [](double a, double b){ return std::fmod(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "min(a,b)", [](double a, double b){ return std::min(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "max(a,b)", [](double a, double b){ return std::max(a, b); });
+ add_expression({"a", "b"}, "bit(a,b)")
+ .add_case({-128, 7}, 1.0).add_case({-128, 6}, 0.0).add_case({-128, 5}, 0.0).add_case({-128, 4}, 0.0)
+ .add_case({-128, 3}, 0.0).add_case({-128, 2}, 0.0).add_case({-128, 1}, 0.0).add_case({-128, 0}, 0.0)
+ .add_case({-43, 7}, 1.0).add_case({-43, 6}, 1.0).add_case({-43, 5}, 0.0).add_case({-43, 4}, 1.0)
+ .add_case({-43, 3}, 0.0).add_case({-43, 2}, 1.0).add_case({-43, 1}, 0.0).add_case({-43, 0}, 1.0)
+ .add_case({0, 7}, 0.0).add_case({0, 6}, 0.0).add_case({0, 5}, 0.0).add_case({0, 4}, 0.0)
+ .add_case({0, 3}, 0.0).add_case({0, 2}, 0.0).add_case({0, 1}, 0.0).add_case({0, 0}, 0.0)
+ .add_case({85, 7}, 0.0).add_case({85, 6}, 1.0).add_case({85, 5}, 0.0).add_case({85, 4}, 1.0)
+ .add_case({85, 3}, 0.0).add_case({85, 2}, 1.0).add_case({85, 1}, 0.0).add_case({85, 0}, 1.0)
+ .add_case({127, 7}, 0.0).add_case({127, 6}, 1.0).add_case({127, 5}, 1.0).add_case({127, 4}, 1.0)
+ .add_case({127, 3}, 1.0).add_case({127, 2}, 1.0).add_case({127, 1}, 1.0).add_case({127, 0}, 1.0);
}
void
diff --git a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
index 4824751bb14..58e4b91f6d9 100644
--- a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
+++ b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
@@ -335,6 +335,9 @@ struct EvalNode : public NodeVisitor {
void visit(const Erf &node) override {
eval_map(node.get_child(0), operation::Erf::f);
}
+ void visit(const Bit &node) override {
+ eval_join(node.get_child(0), node.get_child(1), operation::Bit::f);
+ }
};
TensorSpec eval_node(const Node &node, const std::vector<TensorSpec> &params) {
diff --git a/eval/src/vespa/eval/eval/visit_stuff.cpp b/eval/src/vespa/eval/eval/visit_stuff.cpp
index 9306a720837..786562d823f 100644
--- a/eval/src/vespa/eval/eval/visit_stuff.cpp
+++ b/eval/src/vespa/eval/eval/visit_stuff.cpp
@@ -59,6 +59,7 @@ vespalib::string name_of(join_fun_t fun) {
if (fun == operation::Ldexp::f) return "ldexp";
if (fun == operation::Min::f) return "min";
if (fun == operation::Max::f) return "max";
+ if (fun == operation::Bit::f) return "bit";
return "[other join function]";
}