diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-06-05 14:45:30 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-06-11 09:42:00 +0000 |
commit | 41df43d2296e910f4b0cec24b040ec51cfc9f7d0 (patch) | |
tree | 6d12616f2b9b0a022094fec1946454084ed70717 /eval | |
parent | 51abe86dad7be6ced30bc3b0a2fcce4359525820 (diff) |
common code for operation inlining
- add common code to make selecting the appropriate template function
easier (vespa/vespalib/util/typify.h)
- enable detection of lambda functions matching all low-level
operations. (lookup_op1, lookup_op2)
- add typifiers to decide which low-level operations should be inlined
(TypifyOp1, TypifyOp2)
- integrate into dense_simple_join as a pilot customer
Diffstat (limited to 'eval')
12 files changed, 391 insertions, 126 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index b68440795d4..67f9fa19dc0 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -17,6 +17,7 @@ vespa_define_module( src/tests/eval/function src/tests/eval/function_speed src/tests/eval/gbdt + src/tests/eval/inline_operation src/tests/eval/interpreted_function src/tests/eval/node_tools src/tests/eval/node_types diff --git a/eval/src/tests/eval/inline_operation/CMakeLists.txt b/eval/src/tests/eval/inline_operation/CMakeLists.txt new file mode 100644 index 00000000000..04cdbca3abf --- /dev/null +++ b/eval/src/tests/eval/inline_operation/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_inline_operation_test_app TEST + SOURCES + inline_operation_test.cpp + DEPENDS + vespaeval + gtest +) +vespa_add_test(NAME eval_inline_operation_test_app COMMAND eval_inline_operation_test_app) diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp new file mode 100644 index 00000000000..4520176e276 --- /dev/null +++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp @@ -0,0 +1,156 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> +#include <vespa/eval/eval/function.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::eval::operation; + +template <typename T> struct IsInlined { constexpr static bool value = true; }; +template <> struct IsInlined<CallOp1> { constexpr static bool value = false; }; +template <> struct IsInlined<CallOp2> { constexpr static bool value = false; }; + +template <typename T> double test_op1(op1_t ref, double a, bool inlined) { + T op(ref); + EXPECT_EQ(IsInlined<T>::value, inlined); + EXPECT_EQ(op(a), ref(a)); + return op(a); +}; + +template <typename T> double test_op2(op2_t ref, double a, double b, bool inlined) { + T op(ref); + EXPECT_EQ(IsInlined<T>::value, inlined); + EXPECT_EQ(op(a,b), ref(a,b)); + return op(a,b); +}; + +op1_t as_op1(const vespalib::string &str) { + auto fun = Function::parse({"a"}, str); + auto res = lookup_op1(*fun); + EXPECT_TRUE(res.has_value()); + return res.value(); +} + +op2_t as_op2(const vespalib::string &str) { + auto fun = Function::parse({"a", "b"}, str); + auto res = lookup_op2(*fun); + EXPECT_TRUE(res.has_value()); + return res.value(); +} + +TEST(InlineOperationTest, op1_lambdas_are_recognized) { + EXPECT_EQ(as_op1("-a"), Neg::f); + EXPECT_EQ(as_op1("!a"), Not::f); + EXPECT_EQ(as_op1("cos(a)"), Cos::f); + EXPECT_EQ(as_op1("sin(a)"), Sin::f); + EXPECT_EQ(as_op1("tan(a)"), Tan::f); + EXPECT_EQ(as_op1("cosh(a)"), Cosh::f); + EXPECT_EQ(as_op1("sinh(a)"), Sinh::f); + EXPECT_EQ(as_op1("tanh(a)"), Tanh::f); + EXPECT_EQ(as_op1("acos(a)"), Acos::f); + EXPECT_EQ(as_op1("asin(a)"), Asin::f); + EXPECT_EQ(as_op1("atan(a)"), Atan::f); + EXPECT_EQ(as_op1("exp(a)"), Exp::f); + EXPECT_EQ(as_op1("log10(a)"), Log10::f); + EXPECT_EQ(as_op1("log(a)"), Log::f); + EXPECT_EQ(as_op1("sqrt(a)"), Sqrt::f); + EXPECT_EQ(as_op1("ceil(a)"), Ceil::f); + EXPECT_EQ(as_op1("fabs(a)"), Fabs::f); + EXPECT_EQ(as_op1("floor(a)"), Floor::f); + EXPECT_EQ(as_op1("isNan(a)"), IsNan::f); + EXPECT_EQ(as_op1("relu(a)"), Relu::f); + EXPECT_EQ(as_op1("sigmoid(a)"), Sigmoid::f); + EXPECT_EQ(as_op1("elu(a)"), Elu::f); +} + +TEST(InlineOperationTest, op1_lambdas_are_recognized_with_different_parameter_names) { + EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "-x")).value(), Neg::f); + EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "!x")).value(), Not::f); +} + +TEST(InlineOperationTest, non_op1_lambdas_are_not_recognized) { + EXPECT_FALSE(lookup_op1(*Function::parse({"a"}, "a*a")).has_value()); + EXPECT_FALSE(lookup_op1(*Function::parse({"a", "b"}, "a+b")).has_value()); +} + +TEST(InlineOperationTest, op2_lambdas_are_recognized) { + EXPECT_EQ(as_op2("a+b"), Add::f); + EXPECT_EQ(as_op2("a-b"), Sub::f); + EXPECT_EQ(as_op2("a*b"), Mul::f); + EXPECT_EQ(as_op2("a/b"), Div::f); + EXPECT_EQ(as_op2("a%b"), Mod::f); + EXPECT_EQ(as_op2("a^b"), Pow::f); + EXPECT_EQ(as_op2("a==b"), Equal::f); + EXPECT_EQ(as_op2("a!=b"), NotEqual::f); + EXPECT_EQ(as_op2("a~=b"), Approx::f); + EXPECT_EQ(as_op2("a<b"), Less::f); + EXPECT_EQ(as_op2("a<=b"), LessEqual::f); + EXPECT_EQ(as_op2("a>b"), Greater::f); + EXPECT_EQ(as_op2("a>=b"), GreaterEqual::f); + EXPECT_EQ(as_op2("a&&b"), And::f); + EXPECT_EQ(as_op2("a||b"), Or::f); + EXPECT_EQ(as_op2("atan2(a,b)"), Atan2::f); + EXPECT_EQ(as_op2("ldexp(a,b)"), Ldexp::f); + EXPECT_EQ(as_op2("pow(a,b)"), Pow::f); + EXPECT_EQ(as_op2("fmod(a,b)"), Mod::f); + EXPECT_EQ(as_op2("min(a,b)"), Min::f); + EXPECT_EQ(as_op2("max(a,b)"), Max::f); +} + +TEST(InlineOperationTest, op2_lambdas_are_recognized_with_different_parameter_names) { + EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x+y")).value(), Add::f); + EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x-y")).value(), Sub::f); +} + +TEST(InlineOperationTest, non_op2_lambdas_are_not_recognized) { + EXPECT_FALSE(lookup_op2(*Function::parse({"a"}, "-a")).has_value()); + EXPECT_FALSE(lookup_op2(*Function::parse({"a", "b"}, "b+a")).has_value()); +} + +TEST(InlineOperationTest, generic_op1_wrapper_works) { + CallOp1 op(Neg::f); + EXPECT_EQ(op(3), -3); + EXPECT_EQ(op(-5), 5); +} + +TEST(InlineOperationTest, generic_op2_wrapper_works) { + CallOp2 op(Add::f); + EXPECT_EQ(op(2,3), 5); + EXPECT_EQ(op(3,7), 10); +} + +TEST(InlineOperationTest, inline_op2_example_works) { + op2_t ignored = nullptr; + InlineOp2<Add> op(ignored); + EXPECT_EQ(op(2,3), 5); + EXPECT_EQ(op(3,7), 10); +} + +TEST(InlineOperationTest, parameter_swap_wrapper_works) { + CallOp2 op(Sub::f); + SwapArgs2<CallOp2> swap_op(Sub::f); + EXPECT_EQ(op(2,3), -1); + EXPECT_EQ(swap_op(2,3), 1); + EXPECT_EQ(op(3,7), -4); + EXPECT_EQ(swap_op(3,7), 4); +} + +TEST(InlineOperationTest, resolved_op1_works) { + auto a = TypifyOp1::resolve(Neg::f, [](auto t){ return test_op1<typename decltype(t)::type>(Neg::f, 2.0, false); }); + // putting the lambda inside the EXPECT does not work + EXPECT_EQ(a, -2.0); +} + +TEST(InlineOperationTest, resolved_op2_works) { + auto a = TypifyOp2::resolve(Add::f, [](auto t){ return test_op2<typename decltype(t)::type>(Add::f, 2.0, 5.0, true); }); + auto b = TypifyOp2::resolve(Mul::f, [](auto t){ return test_op2<typename decltype(t)::type>(Mul::f, 5.0, 3.0, true); }); + auto c = TypifyOp2::resolve(Sub::f, [](auto t){ return test_op2<typename decltype(t)::type>(Sub::f, 8.0, 5.0, false); }); + // putting the lambda inside the EXPECT does not work + EXPECT_EQ(a, 7.0); + EXPECT_EQ(b, 15.0); + EXPECT_EQ(c, 3.0); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp index a571837b8e9..92fdbfade46 100644 --- a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp +++ b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp @@ -67,7 +67,6 @@ TEST("require that matmul can be optimized") { TEST("require that matmul with lambda can be optimized") { TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true)); - TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, true, true)); } TEST("require that expressions similar to matmul are not optimized") { @@ -75,6 +74,7 @@ TEST("require that expressions similar to matmul are not optimized") { TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum,b)")); TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,prod,d)")); TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x+y)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*y)),sum,d)")); diff --git a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp index c0823248538..f9c563c9bf8 100644 --- a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp +++ b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp @@ -78,7 +78,6 @@ TEST("require that single multi matmul can be optimized") { TEST("require that multi matmul with lambda can be optimized") { TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, 6, true, true)); - TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, 6, true, true)); } TEST("require that expressions similar to multi matmul are not optimized") { @@ -86,6 +85,7 @@ TEST("require that expressions similar to multi matmul are not optimized") { TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,b)")); TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,prod,d)")); TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x+y)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*y)),sum,d)")); diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp index 0b924451907..3ecc3f66cda 100644 --- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp @@ -130,13 +130,13 @@ TEST("require that xw product gives same results as reference join/reduce") { TEST("require that various variants of xw product can be optimized") { TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)", 3, 2, true)); } TEST("require that expressions similar to xw product are not optimized") { TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)")); TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)")); TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)")); diff --git a/eval/src/vespa/eval/eval/inline_operation.h b/eval/src/vespa/eval/eval/inline_operation.h new file mode 100644 index 00000000000..493de9ea56c --- /dev/null +++ b/eval/src/vespa/eval/eval/inline_operation.h @@ -0,0 +1,67 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "operation.h" +#include <vespa/vespalib/util/typify.h> + +namespace vespalib::eval::operation { + +//----------------------------------------------------------------------------- + +struct CallOp1 { + op1_t my_op1; + CallOp1(op1_t op1) : my_op1(op1) {} + double operator()(double a) const { return my_op1(a); } +}; + +struct TypifyOp1 { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(op1_t value, F &&f) { + (void) value; + return f(Result<CallOp1>()); + } +}; + +//----------------------------------------------------------------------------- + +struct CallOp2 { + op2_t my_op2; + CallOp2(op2_t op2) : my_op2(op2) {} + op2_t get() const { return my_op2; } + double operator()(double a, double b) const { return my_op2(a, b); } +}; + +template <typename Op2> +struct SwapArgs2 { + Op2 op2; + SwapArgs2(op2_t op2_in) : op2(op2_in) {} + template <typename A, typename B> constexpr auto operator()(A a, B b) const { return op2(b, a); } +}; + +template <typename T> struct InlineOp2; +template <> struct InlineOp2<Add> { + InlineOp2(op2_t) {} + template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a+b); } +}; +template <> struct InlineOp2<Mul> { + InlineOp2(op2_t) {} + template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a*b); } +}; + +struct TypifyOp2 { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(op2_t value, F &&f) { + if (value == Add::f) { + return f(Result<InlineOp2<Add>>()); + } else if (value == Mul::f) { + return f(Result<InlineOp2<Mul>>()); + } else { + return f(Result<CallOp2>()); + } + } +}; + +//----------------------------------------------------------------------------- + +} diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 3a73a3b8784..02d16caae6b 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -15,25 +15,6 @@ namespace vespalib::eval { namespace { using namespace nodes; -using map_fun_t = double (*)(double); -using join_fun_t = double (*)(double, double); - -//----------------------------------------------------------------------------- - -// TODO(havardpe): generic function pointer resolving for all single -// operation lambdas. - -template <typename OP2> -bool is_op2(const Function &lambda) { - if (lambda.num_params() == 2) { - if (auto op2 = as<OP2>(lambda.root())) { - auto sym1 = as<Symbol>(op2->lhs()); - auto sym2 = as<Symbol>(op2->rhs()); - return (sym1 && sym2 && (sym1->id() != sym2->id())); - } - } - return false; -} //----------------------------------------------------------------------------- @@ -63,13 +44,13 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::reduce(a, aggr, dimensions, stash); } - void make_map(const Node &, map_fun_t function) { + void make_map(const Node &, operation::op1_t function) { assert(stack.size() >= 1); const auto &a = stack.back().get(); stack.back() = tensor_function::map(a, function, stash); } - void make_join(const Node &, join_fun_t function) { + void make_join(const Node &, operation::op2_t function) { assert(stack.size() >= 2); const auto &b = stack.back().get(); stack.pop_back(); @@ -77,7 +58,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::join(a, b, function, stash); } - void make_merge(const Node &, join_fun_t function) { + void make_merge(const Node &, operation::op2_t function) { assert(stack.size() >= 2); const auto &b = stack.back().get(); stack.pop_back(); @@ -203,14 +184,16 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { abort(); } void visit(const TensorMap &node) override { - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - make_map(node, token.get()->get().get_function<1>()); + if (auto op1 = operation::lookup_op1(node.lambda())) { + make_map(node, op1.value()); + } else { + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); + make_map(node, token.get()->get().get_function<1>()); + } } void visit(const TensorJoin &node) override { - if (is_op2<Mul>(node.lambda())) { - make_join(node, operation::Mul::f); - } else if (is_op2<Add>(node.lambda())) { - make_join(node, operation::Add::f); + if (auto op2 = operation::lookup_op2(node.lambda())) { + make_join(node, op2.value()); } else { const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); make_join(node, token.get()->get().get_function<2>()); diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index fa0a99de461..581f65c0e31 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "operation.h" +#include "function.h" +#include "key_gen.h" #include <vespa/vespalib/util/approx.h> #include <algorithm> @@ -48,4 +50,97 @@ double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } +namespace { + +template <typename T> +void add_op(std::map<vespalib::string,T> &map, const Function &fun, T op) { + assert(!fun.has_error()); + auto key = gen_key(fun, PassParams::SEPARATE); + auto res = map.emplace(key, op); + assert(res.second); +} + +template <typename T> +std::optional<T> lookup_op(const std::map<vespalib::string,T> &map, const Function &fun) { + auto key = gen_key(fun, PassParams::SEPARATE); + auto pos = map.find(key); + if (pos != map.end()) { + return pos->second; + } + return std::nullopt; +} + +void add_op1(std::map<vespalib::string,op1_t> &map, const vespalib::string &expr, op1_t op) { + add_op(map, *Function::parse({"a"}, expr), op); +} + +void add_op2(std::map<vespalib::string,op2_t> &map, const vespalib::string &expr, op2_t op) { + add_op(map, *Function::parse({"a", "b"}, expr), op); +} + +std::map<vespalib::string,op1_t> make_op1_map() { + std::map<vespalib::string,op1_t> map; + add_op1(map, "-a", Neg::f); + add_op1(map, "!a", Not::f); + add_op1(map, "cos(a)", Cos::f); + add_op1(map, "sin(a)", Sin::f); + add_op1(map, "tan(a)", Tan::f); + add_op1(map, "cosh(a)", Cosh::f); + add_op1(map, "sinh(a)", Sinh::f); + add_op1(map, "tanh(a)", Tanh::f); + add_op1(map, "acos(a)", Acos::f); + add_op1(map, "asin(a)", Asin::f); + add_op1(map, "atan(a)", Atan::f); + add_op1(map, "exp(a)", Exp::f); + add_op1(map, "log10(a)", Log10::f); + add_op1(map, "log(a)", Log::f); + add_op1(map, "sqrt(a)", Sqrt::f); + add_op1(map, "ceil(a)", Ceil::f); + add_op1(map, "fabs(a)", Fabs::f); + add_op1(map, "floor(a)", Floor::f); + add_op1(map, "isNan(a)", IsNan::f); + add_op1(map, "relu(a)", Relu::f); + add_op1(map, "sigmoid(a)", Sigmoid::f); + add_op1(map, "elu(a)", Elu::f); + return map; +} + +std::map<vespalib::string,op2_t> make_op2_map() { + std::map<vespalib::string,op2_t> map; + add_op2(map, "a+b", Add::f); + add_op2(map, "a-b", Sub::f); + add_op2(map, "a*b", Mul::f); + add_op2(map, "a/b", Div::f); + add_op2(map, "a%b", Mod::f); + add_op2(map, "a^b", Pow::f); + add_op2(map, "a==b", Equal::f); + add_op2(map, "a!=b", NotEqual::f); + add_op2(map, "a~=b", Approx::f); + add_op2(map, "a<b", Less::f); + add_op2(map, "a<=b", LessEqual::f); + add_op2(map, "a>b", Greater::f); + add_op2(map, "a>=b", GreaterEqual::f); + add_op2(map, "a&&b", And::f); + add_op2(map, "a||b", Or::f); + add_op2(map, "atan2(a,b)", Atan2::f); + add_op2(map, "ldexp(a,b)", Ldexp::f); + add_op2(map, "pow(a,b)", Pow::f); + add_op2(map, "fmod(a,b)", Mod::f); + add_op2(map, "min(a,b)", Min::f); + add_op2(map, "max(a,b)", Max::f); + return map; +} + +} // namespace <unnamed> + +std::optional<op1_t> lookup_op1(const Function &fun) { + static const std::map<vespalib::string,op1_t> map = make_op1_map(); + return lookup_op(map, fun); +} + +std::optional<op2_t> lookup_op2(const Function &fun) { + static const std::map<vespalib::string,op2_t> map = make_op2_map(); + return lookup_op(map, fun); +} + } diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index fa99f51a308..a80193e704d 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -1,6 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <optional> + +namespace vespalib::eval { class Function; } namespace vespalib::eval::operation { @@ -46,4 +49,10 @@ struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; struct Elu { static double f(double a); }; +using op1_t = double (*)(double); +using op2_t = double (*)(double, double); + +std::optional<op1_t> lookup_op1(const Function &fun); +std::optional<op2_t> lookup_op2(const Function &fun); + } diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 3e91240048b..a8ae9c44bb0 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -2,6 +2,7 @@ #pragma once +#include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/stllike/string.h> #include <vector> @@ -104,4 +105,15 @@ template <typename CT> inline ValueType::CellType get_cell_type(); template <> inline ValueType::CellType get_cell_type<double>() { return ValueType::CellType::DOUBLE; } template <> inline ValueType::CellType get_cell_type<float>() { return ValueType::CellType::FLOAT; } +struct TypifyCellType { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(ValueType::CellType value, F &&f) { + switch(value) { + case ValueType::CellType::DOUBLE: return f(Result<double>()); + case ValueType::CellType::FLOAT: return f(Result<float>()); + } + abort(); + } +}; + } // namespace diff --git a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp index 6b0d65c0743..c358c9d618d 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp @@ -5,6 +5,8 @@ #include <vespa/vespalib/objects/objectvisitor.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> +#include <vespa/vespalib/util/typify.h> #include <optional> #include <algorithm> @@ -16,6 +18,7 @@ using eval::Value; using eval::ValueType; using eval::TensorFunction; using eval::TensorEngine; +using eval::TypifyCellType; using eval::as; using namespace eval::operation; @@ -30,6 +33,18 @@ using State = eval::InterpretedFunction::State; namespace { +struct TypifyOverlap { + template <Overlap VALUE> using Result = TypifyResultValue<Overlap, VALUE>; + template <typename F> static decltype(auto) resolve(Overlap value, F &&f) { + switch (value) { + case Overlap::INNER: return f(Result<Overlap::INNER>()); + case Overlap::OUTER: return f(Result<Overlap::OUTER>()); + case Overlap::FULL: return f(Result<Overlap::FULL>()); + } + abort(); + } +}; + struct JoinParams { const ValueType &result_type; size_t factor; @@ -38,44 +53,17 @@ struct JoinParams { : result_type(result_type_in), factor(factor_in), function(function_in) {} }; -struct CallFun { - join_fun_t function; - CallFun(const JoinParams ¶ms) : function(params.function) {} - double eval(double a, double b) const { return function(a, b); } -}; - -struct AddFun { - AddFun(const JoinParams &) {} - template <typename A, typename B> - auto eval(A a, B b) const { return (a + b); } -}; - -struct MulFun { - MulFun(const JoinParams &) {} - template <typename A, typename B> - auto eval(A a, B b) const { return (a * b); } -}; - -// needed for asymmetric operations like Sub and Div -template <typename Fun> -struct SwapFun { - Fun fun; - SwapFun(const JoinParams ¶ms) : fun(params) {} - template <typename A, typename B> - auto eval(A a, B b) const { return fun.eval(b, a); } -}; - template <typename OCT, typename PCT, typename SCT, typename Fun> void apply_fun_1_to_n(OCT *dst, const PCT *pri, SCT sec, size_t n, const Fun &fun) { for (size_t i = 0; i < n; ++i) { - dst[i] = fun.eval(pri[i], sec); + dst[i] = fun(pri[i], sec); } } template <typename OCT, typename PCT, typename SCT, typename Fun> void apply_fun_n_to_n(OCT *dst, const PCT *pri, const SCT *sec, size_t n, const Fun &fun) { for (size_t i = 0; i < n; ++i) { - dst[i] = fun.eval(pri[i], sec[i]); + dst[i] = fun(pri[i], sec[i]); } } @@ -93,9 +81,9 @@ void my_simple_join_op(State &state, uint64_t param) { using PCT = typename std::conditional<swap,RCT,LCT>::type; using SCT = typename std::conditional<swap,LCT,RCT>::type; using OCT = typename eval::UnifyCellTypes<PCT,SCT>::type; - using OP = typename std::conditional<swap,SwapFun<Fun>,Fun>::type; + using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type; const JoinParams ¶ms = *(JoinParams*)param; - OP my_op(params); + OP my_op(params.function); auto pri_cells = DenseTensorView::typify_cells<PCT>(state.peek(swap ? 0 : 1)); auto sec_cells = DenseTensorView::typify_cells<SCT>(state.peek(swap ? 1 : 0)); auto dst_cells = make_dst_cells<OCT, pri_mut>(pri_cells, state.stash); @@ -122,67 +110,13 @@ void my_simple_join_op(State &state, uint64_t param) { //----------------------------------------------------------------------------- -template <typename Fun, bool swap, Overlap overlap, bool pri_mut> -struct MySimpleJoinOp { - template <typename LCT, typename RCT> - static auto get_fun() { return my_simple_join_op<LCT,RCT,Fun,swap,overlap,pri_mut>; } -}; - -template <bool swap, Overlap overlap, bool pri_mut> -op_function my_select_4(ValueType::CellType lct, - ValueType::CellType rct, - join_fun_t fun_hint) -{ - if (fun_hint == Add::f) { - return select_2<MySimpleJoinOp<AddFun,swap,overlap,pri_mut>>(lct, rct); - } else if (fun_hint == Mul::f) { - return select_2<MySimpleJoinOp<MulFun,swap,overlap,pri_mut>>(lct, rct); - } else { - return select_2<MySimpleJoinOp<CallFun,swap,overlap,pri_mut>>(lct, rct); - } -} - -template <bool swap, Overlap overlap> -op_function my_select_3(ValueType::CellType lct, - ValueType::CellType rct, - bool pri_mut, - join_fun_t fun_hint) -{ - if (pri_mut) { - return my_select_4<swap, overlap, true>(lct, rct, fun_hint); - } else { - return my_select_4<swap, overlap, false>(lct, rct, fun_hint); +struct MyGetFun { + template <typename R1, typename R2, typename R3, typename R4, typename R5, typename R6> static auto invoke() { + return my_simple_join_op<R1, R2, R3, R4::value, R5::value, R6::value>; } -} - -template <bool swap> -op_function my_select_2(ValueType::CellType lct, - ValueType::CellType rct, - Overlap overlap, - bool pri_mut, - join_fun_t fun_hint) -{ - switch (overlap) { - case Overlap::INNER: return my_select_3<swap, Overlap::INNER>(lct, rct, pri_mut, fun_hint); - case Overlap::OUTER: return my_select_3<swap, Overlap::OUTER>(lct, rct, pri_mut, fun_hint); - case Overlap::FULL: return my_select_3<swap, Overlap::FULL>(lct, rct, pri_mut, fun_hint); - } - abort(); -} +}; -op_function my_select(ValueType::CellType lct, - ValueType::CellType rct, - Primary primary, - Overlap overlap, - bool pri_mut, - join_fun_t fun_hint) -{ - switch (primary) { - case Primary::LHS: return my_select_2<false>(lct, rct, overlap, pri_mut, fun_hint); - case Primary::RHS: return my_select_2<true>(lct, rct, overlap, pri_mut, fun_hint); - } - abort(); -} +using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool,TypifyOverlap>; //----------------------------------------------------------------------------- @@ -280,11 +214,10 @@ Instruction DenseSimpleJoinFunction::compile_self(const TensorEngine &, Stash &stash) const { const JoinParams ¶ms = stash.create<JoinParams>(result_type(), factor(), function()); - auto op = my_select(lhs().result_type().cell_type(), - rhs().result_type().cell_type(), - _primary, _overlap, - primary_is_mutable(), - function()); + auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(), + rhs().result_type().cell_type(), + function(), (_primary == Primary::RHS), + _overlap, primary_is_mutable()); static_assert(sizeof(uint64_t) == sizeof(¶ms)); return Instruction(op, (uint64_t)(¶ms)); } |