From 13d53f151b8fef06e77c82aa380a0d11bf053a79 Mon Sep 17 00:00:00 2001 From: Håvard Pettersen Date: Wed, 16 Jun 2021 11:57:57 +0000 Subject: add 'bit(a,b)' math function --- eval/src/apps/tensor_conformance/generate.cpp | 30 ++++++++++++++++++++++ .../inline_operation/inline_operation_test.cpp | 1 + eval/src/tests/eval/node_tools/node_tools_test.cpp | 1 + eval/src/tests/eval/node_types/node_types_test.cpp | 1 + eval/src/vespa/eval/eval/call_nodes.cpp | 1 + eval/src/vespa/eval/eval/call_nodes.h | 1 + eval/src/vespa/eval/eval/key_gen.cpp | 1 + eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 9 +++++++ eval/src/vespa/eval/eval/llvm/llvm_wrapper.h | 1 + eval/src/vespa/eval/eval/make_tensor_function.cpp | 3 +++ eval/src/vespa/eval/eval/node_tools.cpp | 1 + eval/src/vespa/eval/eval/node_types.cpp | 1 + eval/src/vespa/eval/eval/node_visitor.h | 2 ++ eval/src/vespa/eval/eval/operation.cpp | 7 +++++ eval/src/vespa/eval/eval/operation.h | 1 + eval/src/vespa/eval/eval/test/eval_spec.cpp | 11 ++++++++ .../vespa/eval/eval/test/reference_evaluation.cpp | 3 +++ eval/src/vespa/eval/eval/visit_stuff.cpp | 1 + 18 files changed, 76 insertions(+) (limited to 'eval') diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp index 9ae33c1234f..01936a879d0 100644 --- a/eval/src/apps/tensor_conformance/generate.cpp +++ b/eval/src/apps/tensor_conformance/generate.cpp @@ -232,11 +232,24 @@ void generate_join_expr(const vespalib::string &expr, const Sequence &seq, TestB } } +void generate_join_expr(const vespalib::string &expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) { + for (const auto &layouts: join_layouts) { + GenSpec a = GenSpec::from_desc(layouts.first).seq(seq_a); + GenSpec b = GenSpec::from_desc(layouts.second).seq(seq_b); + generate(expr, a, b, dst); + } +} + void generate_op2_join(const vespalib::string &op2_expr, const Sequence &seq, TestBuilder &dst) { generate_join_expr(op2_expr, seq, dst); generate_join_expr(fmt("join(a,b,f(a,b)(%s))", op2_expr.c_str()), seq, dst); } +[[maybe_unused]] void generate_op2_join(const vespalib::string &op2_expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) { + generate_join_expr(op2_expr, seq_a, seq_b, dst); + generate_join_expr(fmt("join(a,b,f(a,b)(%s))", op2_expr.c_str()), seq_a, seq_b, dst); +} + void generate_join(TestBuilder &dst) { generate_op2_join("a+b", Div16(N()), dst); generate_op2_join("a-b", Div16(N()), dst); @@ -259,6 +272,8 @@ void generate_join(TestBuilder &dst) { generate_op2_join("fmod(a,b)", Div16(N()), dst); generate_op2_join("min(a,b)", Div16(N()), dst); generate_op2_join("max(a,b)", Div16(N()), dst); + // TODO: test bit(a,b) when implemented in Java + // generate_op2_join("bit(a,b)", Seq({-128, -43, -1, 0, 85, 127}), Seq({0, 1, 2, 3, 4, 5, 6, 7}), dst); // inverted lambda generate_join_expr("join(a,b,f(a,b)(b-a))", Div16(N()), dst); // custom lambda @@ -276,11 +291,24 @@ void generate_merge_expr(const vespalib::string &expr, const Sequence &seq, Test } } +void generate_merge_expr(const vespalib::string &expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) { + for (const auto &layouts: merge_layouts) { + GenSpec a = GenSpec::from_desc(layouts.first).seq(seq_a); + GenSpec b = GenSpec::from_desc(layouts.second).seq(seq_b); + generate(expr, a, b, dst); + } +} + void generate_op2_merge(const vespalib::string &op2_expr, const Sequence &seq, TestBuilder &dst) { generate_merge_expr(op2_expr, seq, dst); generate_merge_expr(fmt("merge(a,b,f(a,b)(%s))", op2_expr.c_str()), seq, dst); } +[[maybe_unused]] void generate_op2_merge(const vespalib::string &op2_expr, const Sequence &seq_a, const Sequence &seq_b, TestBuilder &dst) { + generate_merge_expr(op2_expr, seq_a, seq_b, dst); + generate_merge_expr(fmt("merge(a,b,f(a,b)(%s))", op2_expr.c_str()), seq_a, seq_b, dst); +} + void generate_merge(TestBuilder &dst) { generate_op2_merge("a+b", Div16(N()), dst); generate_op2_merge("a-b", Div16(N()), dst); @@ -303,6 +331,8 @@ void generate_merge(TestBuilder &dst) { generate_op2_merge("fmod(a,b)", Div16(N()), dst); generate_op2_merge("min(a,b)", Div16(N()), dst); generate_op2_merge("max(a,b)", Div16(N()), dst); + // TODO: test bit(a,b) when implemented in Java + // generate_op2_merge("bit(a,b)", Seq({-128, -43, -1, 0, 85, 127}), Seq({0, 1, 2, 3, 4, 5, 6, 7}), dst); // inverted lambda generate_merge_expr("merge(a,b,f(a,b)(b-a))", Div16(N()), dst); // custom lambda diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp index de5a3fbf395..ae5f503b680 100644 --- a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp +++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp @@ -115,6 +115,7 @@ TEST(InlineOperationTest, op2_lambdas_are_recognized) { EXPECT_EQ(as_op2("fmod(a,b)"), &Mod::f); EXPECT_EQ(as_op2("min(a,b)"), &Min::f); EXPECT_EQ(as_op2("max(a,b)"), &Max::f); + EXPECT_EQ(as_op2("bit(a,b)"), &Bit::f); } TEST(InlineOperationTest, op2_lambdas_are_recognized_with_different_parameter_names) { diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp index 13185065f57..e8296c01d73 100644 --- a/eval/src/tests/eval/node_tools/node_tools_test.cpp +++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp @@ -100,6 +100,7 @@ TEST("require that call node types can be copied") { TEST_DO(verify_copy("sigmoid(a)")); TEST_DO(verify_copy("elu(a)")); TEST_DO(verify_copy("erf(a)")); + TEST_DO(verify_copy("bit(a,b)")); } TEST("require that tensor node types can NOT be copied (yet)") { diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 504f66ac717..b2373f0d8f5 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -218,6 +218,7 @@ TEST("require that various operations resolve appropriate type") { TEST_DO(verify_op1("sigmoid(%s)")); // Sigmoid TEST_DO(verify_op1("elu(%s)")); // Elu TEST_DO(verify_op1("erf(%s)")); // Erf + TEST_DO(verify_op2("bit(%s,%s)")); // Bit } TEST("require that map resolves correct type") { diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp index 2fc25bdbc77..798583cf89a 100644 --- a/eval/src/vespa/eval/eval/call_nodes.cpp +++ b/eval/src/vespa/eval/eval/call_nodes.cpp @@ -43,6 +43,7 @@ CallRepo::CallRepo() : _map() { add(nodes::Sigmoid()); add(nodes::Elu()); add(nodes::Erf()); + add(nodes::Bit()); } } // namespace vespalib::eval::nodes diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h index 2a7d4173e64..945aba69596 100644 --- a/eval/src/vespa/eval/eval/call_nodes.h +++ b/eval/src/vespa/eval/eval/call_nodes.h @@ -139,6 +139,7 @@ struct Relu : CallHelper { Relu() : Helper("relu", 1) {} }; struct Sigmoid : CallHelper { Sigmoid() : Helper("sigmoid", 1) {} }; struct Elu : CallHelper { Elu() : Helper("elu", 1) {} }; struct Erf : CallHelper { Erf() : Helper("erf", 1) {} }; +struct Bit : CallHelper { Bit() : Helper("bit", 2) {} }; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index a8fb205f124..a40a8887119 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -87,6 +87,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const Sigmoid &) override { add_byte(60); } void visit(const Elu &) override { add_byte(61); } void visit(const Erf &) override { add_byte(62); } + void visit(const Bit &) override { add_byte(63); } // traverse bool open(const Node &node) override { node.accept(*this); return true; } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 42911a56c14..9a99c4fedd7 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -29,6 +29,12 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal double vespalib_eval_relu(double a) { return std::max(a, 0.0); } double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double vespalib_eval_elu(double a) { return (a < 0) ? std::exp(a) - 1.0 : a; } +double vespalib_eval_bit(double a, double b) { + // must match Bit::f + int8_t value = (int8_t) a; + uint32_t n = (uint32_t) b; + return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; +} using vespalib::eval::gbdt::Forest; using resolve_function = double (*)(void *ctx, size_t idx); @@ -646,6 +652,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Erf &) override { make_call_1("erf"); } + void visit(const Bit &) override { + make_call_2("vespalib_eval_bit"); + } }; FunctionBuilder::~FunctionBuilder() { } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h index 040c0bdb73f..e04b477750d 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.h @@ -19,6 +19,7 @@ extern "C" { double vespalib_eval_relu(double a); double vespalib_eval_sigmoid(double a); double vespalib_eval_elu(double a); + double vespalib_eval_bit(double a, double b); }; namespace vespalib::eval { diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index b65c3d5aaa7..498be2a738b 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -357,6 +357,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Erf &node) override { make_map(node, operation::Erf::f); } + void visit(const Bit &node) override { + make_join(node, operation::Bit::f); + } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp index e7341bc1755..fa2d16a2271 100644 --- a/eval/src/vespa/eval/eval/node_tools.cpp +++ b/eval/src/vespa/eval/eval/node_tools.cpp @@ -182,6 +182,7 @@ struct CopyNode : NodeTraverser, NodeVisitor { void visit(const Sigmoid &node) override { copy_call(node); } void visit(const Elu &node) override { copy_call(node); } void visit(const Erf &node) override { copy_call(node); } + void visit(const Bit &node) override { copy_call(node); } // traverse nodes bool open(const Node &) override { return !error; } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 63da6d79c6f..8622fd734f1 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -278,6 +278,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const Sigmoid &node) override { resolve_op1(node); } void visit(const Elu &node) override { resolve_op1(node); } void visit(const Erf &node) override { resolve_op1(node); } + void visit(const Bit &node) override { resolve_op2(node); } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index 172cd48fe2a..475bbf5405c 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -85,6 +85,7 @@ struct NodeVisitor { virtual void visit(const nodes::Sigmoid &) = 0; virtual void visit(const nodes::Elu &) = 0; virtual void visit(const nodes::Erf &) = 0; + virtual void visit(const nodes::Bit &) = 0; virtual ~NodeVisitor() {} }; @@ -154,6 +155,7 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::Sigmoid &) override {} void visit(const nodes::Elu &) override {} void visit(const nodes::Erf &) override {} + void visit(const nodes::Bit &) override {} }; } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index b97ac3f2261..36b922539f4 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -50,6 +50,12 @@ double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } double Erf::f(double a) { return std::erf(a); } +double Bit::f(double a, double b) { + // must match vespalib_eval_bit + int8_t value = (int8_t) a; + uint32_t n = (uint32_t) b; + return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; +} //----------------------------------------------------------------------------- double Inv::f(double a) { return (1.0 / a); } double Square::f(double a) { return (a * a); } @@ -143,6 +149,7 @@ std::map make_op2_map() { add_op2(map, "fmod(a,b)", Mod::f); add_op2(map, "min(a,b)", Min::f); add_op2(map, "max(a,b)", Max::f); + add_op2(map, "bit(a,b)", Bit::f); return map; } diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index 3170c868214..438b510b714 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -49,6 +49,7 @@ struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; struct Elu { static double f(double a); }; struct Erf { static double f(double a); }; +struct Bit { static double f(double a, double b); }; //----------------------------------------------------------------------------- struct Inv { static double f(double a); }; struct Square { static double f(double a); }; diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index 63a3a23d9ae..5d51a1d23b5 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -158,6 +158,17 @@ EvalSpec::add_function_call_cases() { add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "fmod(a,b)", [](double a, double b){ return std::fmod(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "min(a,b)", [](double a, double b){ return std::min(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "max(a,b)", [](double a, double b){ return std::max(a, b); }); + add_expression({"a", "b"}, "bit(a,b)") + .add_case({-128, 7}, 1.0).add_case({-128, 6}, 0.0).add_case({-128, 5}, 0.0).add_case({-128, 4}, 0.0) + .add_case({-128, 3}, 0.0).add_case({-128, 2}, 0.0).add_case({-128, 1}, 0.0).add_case({-128, 0}, 0.0) + .add_case({-43, 7}, 1.0).add_case({-43, 6}, 1.0).add_case({-43, 5}, 0.0).add_case({-43, 4}, 1.0) + .add_case({-43, 3}, 0.0).add_case({-43, 2}, 1.0).add_case({-43, 1}, 0.0).add_case({-43, 0}, 1.0) + .add_case({0, 7}, 0.0).add_case({0, 6}, 0.0).add_case({0, 5}, 0.0).add_case({0, 4}, 0.0) + .add_case({0, 3}, 0.0).add_case({0, 2}, 0.0).add_case({0, 1}, 0.0).add_case({0, 0}, 0.0) + .add_case({85, 7}, 0.0).add_case({85, 6}, 1.0).add_case({85, 5}, 0.0).add_case({85, 4}, 1.0) + .add_case({85, 3}, 0.0).add_case({85, 2}, 1.0).add_case({85, 1}, 0.0).add_case({85, 0}, 1.0) + .add_case({127, 7}, 0.0).add_case({127, 6}, 1.0).add_case({127, 5}, 1.0).add_case({127, 4}, 1.0) + .add_case({127, 3}, 1.0).add_case({127, 2}, 1.0).add_case({127, 1}, 1.0).add_case({127, 0}, 1.0); } void diff --git a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp index 4824751bb14..58e4b91f6d9 100644 --- a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp +++ b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp @@ -335,6 +335,9 @@ struct EvalNode : public NodeVisitor { void visit(const Erf &node) override { eval_map(node.get_child(0), operation::Erf::f); } + void visit(const Bit &node) override { + eval_join(node.get_child(0), node.get_child(1), operation::Bit::f); + } }; TensorSpec eval_node(const Node &node, const std::vector ¶ms) { diff --git a/eval/src/vespa/eval/eval/visit_stuff.cpp b/eval/src/vespa/eval/eval/visit_stuff.cpp index 9306a720837..786562d823f 100644 --- a/eval/src/vespa/eval/eval/visit_stuff.cpp +++ b/eval/src/vespa/eval/eval/visit_stuff.cpp @@ -59,6 +59,7 @@ vespalib::string name_of(join_fun_t fun) { if (fun == operation::Ldexp::f) return "ldexp"; if (fun == operation::Min::f) return "min"; if (fun == operation::Max::f) return "max"; + if (fun == operation::Bit::f) return "bit"; return "[other join function]"; } -- cgit v1.2.3 From 4bf43f52a153cc2dea91aae8e48cf0f782f511f6 Mon Sep 17 00:00:00 2001 From: Håvard Pettersen Date: Mon, 21 Jun 2021 08:52:05 +0000 Subject: use common implementation for 'bit' function --- eval/src/vespa/eval/eval/extract_bit.h | 13 +++++++++++++ eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 8 ++------ eval/src/vespa/eval/eval/operation.cpp | 8 ++------ 3 files changed, 17 insertions(+), 12 deletions(-) create mode 100644 eval/src/vespa/eval/eval/extract_bit.h (limited to 'eval') diff --git a/eval/src/vespa/eval/eval/extract_bit.h b/eval/src/vespa/eval/eval/extract_bit.h new file mode 100644 index 00000000000..ecf56b33b02 --- /dev/null +++ b/eval/src/vespa/eval/eval/extract_bit.h @@ -0,0 +1,13 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib::eval { + +inline double extract_bit(double a, double b) { + int8_t value = (int8_t) a; + uint32_t n = (uint32_t) b; + return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; +} + +} diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 9a99c4fedd7..2a9b7815aa8 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -4,6 +4,7 @@ #include "llvm_wrapper.h" #include #include +#include #include #include #include @@ -29,12 +30,7 @@ double vespalib_eval_approx(double a, double b) { return (vespalib::approx_equal double vespalib_eval_relu(double a) { return std::max(a, 0.0); } double vespalib_eval_sigmoid(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double vespalib_eval_elu(double a) { return (a < 0) ? std::exp(a) - 1.0 : a; } -double vespalib_eval_bit(double a, double b) { - // must match Bit::f - int8_t value = (int8_t) a; - uint32_t n = (uint32_t) b; - return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; -} +double vespalib_eval_bit(double a, double b) { return vespalib::eval::extract_bit(a, b); } using vespalib::eval::gbdt::Forest; using resolve_function = double (*)(void *ctx, size_t idx); diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index 36b922539f4..a82a79e6bc4 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -3,6 +3,7 @@ #include "operation.h" #include "function.h" #include "key_gen.h" +#include "extract_bit.h" #include #include @@ -50,12 +51,7 @@ double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } double Erf::f(double a) { return std::erf(a); } -double Bit::f(double a, double b) { - // must match vespalib_eval_bit - int8_t value = (int8_t) a; - uint32_t n = (uint32_t) b; - return ((n < 8) && bool(value & (1 << n))) ? 1.0 : 0.0; -} +double Bit::f(double a, double b) { return extract_bit(a, b); } //----------------------------------------------------------------------------- double Inv::f(double a) { return (1.0 / a); } double Square::f(double a) { return (a * a); } -- cgit v1.2.3