summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-06-05 14:45:30 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-06-11 09:42:00 +0000
commit41df43d2296e910f4b0cec24b040ec51cfc9f7d0 (patch)
tree6d12616f2b9b0a022094fec1946454084ed70717 /eval
parent51abe86dad7be6ced30bc3b0a2fcce4359525820 (diff)
common code for operation inlining
- add common code to make selecting the appropriate template function easier (vespa/vespalib/util/typify.h) - enable detection of lambda functions matching all low-level operations. (lookup_op1, lookup_op2) - add typifiers to decide which low-level operations should be inlined (TypifyOp1, TypifyOp2) - integrate into dense_simple_join as a pilot customer
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/eval/inline_operation/CMakeLists.txt9
-rw-r--r--eval/src/tests/eval/inline_operation/inline_operation_test.cpp156
-rw-r--r--eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp2
-rw-r--r--eval/src/vespa/eval/eval/inline_operation.h67
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp39
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp95
-rw-r--r--eval/src/vespa/eval/eval/operation.h9
-rw-r--r--eval/src/vespa/eval/eval/value_type.h12
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp123
12 files changed, 391 insertions, 126 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index b68440795d4..67f9fa19dc0 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -17,6 +17,7 @@ vespa_define_module(
src/tests/eval/function
src/tests/eval/function_speed
src/tests/eval/gbdt
+ src/tests/eval/inline_operation
src/tests/eval/interpreted_function
src/tests/eval/node_tools
src/tests/eval/node_types
diff --git a/eval/src/tests/eval/inline_operation/CMakeLists.txt b/eval/src/tests/eval/inline_operation/CMakeLists.txt
new file mode 100644
index 00000000000..04cdbca3abf
--- /dev/null
+++ b/eval/src/tests/eval/inline_operation/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_inline_operation_test_app TEST
+ SOURCES
+ inline_operation_test.cpp
+ DEPENDS
+ vespaeval
+ gtest
+)
+vespa_add_test(NAME eval_inline_operation_test_app COMMAND eval_inline_operation_test_app)
diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
new file mode 100644
index 00000000000..4520176e276
--- /dev/null
+++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
@@ -0,0 +1,156 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
+#include <vespa/eval/eval/function.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::operation;
+
+template <typename T> struct IsInlined { constexpr static bool value = true; };
+template <> struct IsInlined<CallOp1> { constexpr static bool value = false; };
+template <> struct IsInlined<CallOp2> { constexpr static bool value = false; };
+
+template <typename T> double test_op1(op1_t ref, double a, bool inlined) {
+ T op(ref);
+ EXPECT_EQ(IsInlined<T>::value, inlined);
+ EXPECT_EQ(op(a), ref(a));
+ return op(a);
+};
+
+template <typename T> double test_op2(op2_t ref, double a, double b, bool inlined) {
+ T op(ref);
+ EXPECT_EQ(IsInlined<T>::value, inlined);
+ EXPECT_EQ(op(a,b), ref(a,b));
+ return op(a,b);
+};
+
+op1_t as_op1(const vespalib::string &str) {
+ auto fun = Function::parse({"a"}, str);
+ auto res = lookup_op1(*fun);
+ EXPECT_TRUE(res.has_value());
+ return res.value();
+}
+
+op2_t as_op2(const vespalib::string &str) {
+ auto fun = Function::parse({"a", "b"}, str);
+ auto res = lookup_op2(*fun);
+ EXPECT_TRUE(res.has_value());
+ return res.value();
+}
+
+TEST(InlineOperationTest, op1_lambdas_are_recognized) {
+ EXPECT_EQ(as_op1("-a"), Neg::f);
+ EXPECT_EQ(as_op1("!a"), Not::f);
+ EXPECT_EQ(as_op1("cos(a)"), Cos::f);
+ EXPECT_EQ(as_op1("sin(a)"), Sin::f);
+ EXPECT_EQ(as_op1("tan(a)"), Tan::f);
+ EXPECT_EQ(as_op1("cosh(a)"), Cosh::f);
+ EXPECT_EQ(as_op1("sinh(a)"), Sinh::f);
+ EXPECT_EQ(as_op1("tanh(a)"), Tanh::f);
+ EXPECT_EQ(as_op1("acos(a)"), Acos::f);
+ EXPECT_EQ(as_op1("asin(a)"), Asin::f);
+ EXPECT_EQ(as_op1("atan(a)"), Atan::f);
+ EXPECT_EQ(as_op1("exp(a)"), Exp::f);
+ EXPECT_EQ(as_op1("log10(a)"), Log10::f);
+ EXPECT_EQ(as_op1("log(a)"), Log::f);
+ EXPECT_EQ(as_op1("sqrt(a)"), Sqrt::f);
+ EXPECT_EQ(as_op1("ceil(a)"), Ceil::f);
+ EXPECT_EQ(as_op1("fabs(a)"), Fabs::f);
+ EXPECT_EQ(as_op1("floor(a)"), Floor::f);
+ EXPECT_EQ(as_op1("isNan(a)"), IsNan::f);
+ EXPECT_EQ(as_op1("relu(a)"), Relu::f);
+ EXPECT_EQ(as_op1("sigmoid(a)"), Sigmoid::f);
+ EXPECT_EQ(as_op1("elu(a)"), Elu::f);
+}
+
+TEST(InlineOperationTest, op1_lambdas_are_recognized_with_different_parameter_names) {
+ EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "-x")).value(), Neg::f);
+ EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "!x")).value(), Not::f);
+}
+
+TEST(InlineOperationTest, non_op1_lambdas_are_not_recognized) {
+ EXPECT_FALSE(lookup_op1(*Function::parse({"a"}, "a*a")).has_value());
+ EXPECT_FALSE(lookup_op1(*Function::parse({"a", "b"}, "a+b")).has_value());
+}
+
+TEST(InlineOperationTest, op2_lambdas_are_recognized) {
+ EXPECT_EQ(as_op2("a+b"), Add::f);
+ EXPECT_EQ(as_op2("a-b"), Sub::f);
+ EXPECT_EQ(as_op2("a*b"), Mul::f);
+ EXPECT_EQ(as_op2("a/b"), Div::f);
+ EXPECT_EQ(as_op2("a%b"), Mod::f);
+ EXPECT_EQ(as_op2("a^b"), Pow::f);
+ EXPECT_EQ(as_op2("a==b"), Equal::f);
+ EXPECT_EQ(as_op2("a!=b"), NotEqual::f);
+ EXPECT_EQ(as_op2("a~=b"), Approx::f);
+ EXPECT_EQ(as_op2("a<b"), Less::f);
+ EXPECT_EQ(as_op2("a<=b"), LessEqual::f);
+ EXPECT_EQ(as_op2("a>b"), Greater::f);
+ EXPECT_EQ(as_op2("a>=b"), GreaterEqual::f);
+ EXPECT_EQ(as_op2("a&&b"), And::f);
+ EXPECT_EQ(as_op2("a||b"), Or::f);
+ EXPECT_EQ(as_op2("atan2(a,b)"), Atan2::f);
+ EXPECT_EQ(as_op2("ldexp(a,b)"), Ldexp::f);
+ EXPECT_EQ(as_op2("pow(a,b)"), Pow::f);
+ EXPECT_EQ(as_op2("fmod(a,b)"), Mod::f);
+ EXPECT_EQ(as_op2("min(a,b)"), Min::f);
+ EXPECT_EQ(as_op2("max(a,b)"), Max::f);
+}
+
+TEST(InlineOperationTest, op2_lambdas_are_recognized_with_different_parameter_names) {
+ EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x+y")).value(), Add::f);
+ EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x-y")).value(), Sub::f);
+}
+
+TEST(InlineOperationTest, non_op2_lambdas_are_not_recognized) {
+ EXPECT_FALSE(lookup_op2(*Function::parse({"a"}, "-a")).has_value());
+ EXPECT_FALSE(lookup_op2(*Function::parse({"a", "b"}, "b+a")).has_value());
+}
+
+TEST(InlineOperationTest, generic_op1_wrapper_works) {
+ CallOp1 op(Neg::f);
+ EXPECT_EQ(op(3), -3);
+ EXPECT_EQ(op(-5), 5);
+}
+
+TEST(InlineOperationTest, generic_op2_wrapper_works) {
+ CallOp2 op(Add::f);
+ EXPECT_EQ(op(2,3), 5);
+ EXPECT_EQ(op(3,7), 10);
+}
+
+TEST(InlineOperationTest, inline_op2_example_works) {
+ op2_t ignored = nullptr;
+ InlineOp2<Add> op(ignored);
+ EXPECT_EQ(op(2,3), 5);
+ EXPECT_EQ(op(3,7), 10);
+}
+
+TEST(InlineOperationTest, parameter_swap_wrapper_works) {
+ CallOp2 op(Sub::f);
+ SwapArgs2<CallOp2> swap_op(Sub::f);
+ EXPECT_EQ(op(2,3), -1);
+ EXPECT_EQ(swap_op(2,3), 1);
+ EXPECT_EQ(op(3,7), -4);
+ EXPECT_EQ(swap_op(3,7), 4);
+}
+
+TEST(InlineOperationTest, resolved_op1_works) {
+ auto a = TypifyOp1::resolve(Neg::f, [](auto t){ return test_op1<typename decltype(t)::type>(Neg::f, 2.0, false); });
+ // putting the lambda inside the EXPECT does not work
+ EXPECT_EQ(a, -2.0);
+}
+
+TEST(InlineOperationTest, resolved_op2_works) {
+ auto a = TypifyOp2::resolve(Add::f, [](auto t){ return test_op2<typename decltype(t)::type>(Add::f, 2.0, 5.0, true); });
+ auto b = TypifyOp2::resolve(Mul::f, [](auto t){ return test_op2<typename decltype(t)::type>(Mul::f, 5.0, 3.0, true); });
+ auto c = TypifyOp2::resolve(Sub::f, [](auto t){ return test_op2<typename decltype(t)::type>(Sub::f, 8.0, 5.0, false); });
+ // putting the lambda inside the EXPECT does not work
+ EXPECT_EQ(a, 7.0);
+ EXPECT_EQ(b, 15.0);
+ EXPECT_EQ(c, 3.0);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
index a571837b8e9..92fdbfade46 100644
--- a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
+++ b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
@@ -67,7 +67,6 @@ TEST("require that matmul can be optimized") {
TEST("require that matmul with lambda can be optimized") {
TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true));
- TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, true, true));
}
TEST("require that expressions similar to matmul are not optimized") {
@@ -75,6 +74,7 @@ TEST("require that expressions similar to matmul are not optimized") {
TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum,b)"));
TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,prod,d)"));
TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x+y)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*y)),sum,d)"));
diff --git a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
index c0823248538..f9c563c9bf8 100644
--- a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
+++ b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
@@ -78,7 +78,6 @@ TEST("require that single multi matmul can be optimized") {
TEST("require that multi matmul with lambda can be optimized") {
TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, 6, true, true));
- TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, 6, true, true));
}
TEST("require that expressions similar to multi matmul are not optimized") {
@@ -86,6 +85,7 @@ TEST("require that expressions similar to multi matmul are not optimized") {
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,b)"));
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,prod,d)"));
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x+y)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*y)),sum,d)"));
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 0b924451907..3ecc3f66cda 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -130,13 +130,13 @@ TEST("require that xw product gives same results as reference join/reduce") {
TEST("require that various variants of xw product can be optimized") {
TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)", 3, 2, true));
}
TEST("require that expressions similar to xw product are not optimized") {
TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)"));
TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)"));
TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)"));
diff --git a/eval/src/vespa/eval/eval/inline_operation.h b/eval/src/vespa/eval/eval/inline_operation.h
new file mode 100644
index 00000000000..493de9ea56c
--- /dev/null
+++ b/eval/src/vespa/eval/eval/inline_operation.h
@@ -0,0 +1,67 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "operation.h"
+#include <vespa/vespalib/util/typify.h>
+
+namespace vespalib::eval::operation {
+
+//-----------------------------------------------------------------------------
+
+struct CallOp1 {
+ op1_t my_op1;
+ CallOp1(op1_t op1) : my_op1(op1) {}
+ double operator()(double a) const { return my_op1(a); }
+};
+
+struct TypifyOp1 {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(op1_t value, F &&f) {
+ (void) value;
+ return f(Result<CallOp1>());
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+struct CallOp2 {
+ op2_t my_op2;
+ CallOp2(op2_t op2) : my_op2(op2) {}
+ op2_t get() const { return my_op2; }
+ double operator()(double a, double b) const { return my_op2(a, b); }
+};
+
+template <typename Op2>
+struct SwapArgs2 {
+ Op2 op2;
+ SwapArgs2(op2_t op2_in) : op2(op2_in) {}
+ template <typename A, typename B> constexpr auto operator()(A a, B b) const { return op2(b, a); }
+};
+
+template <typename T> struct InlineOp2;
+template <> struct InlineOp2<Add> {
+ InlineOp2(op2_t) {}
+ template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a+b); }
+};
+template <> struct InlineOp2<Mul> {
+ InlineOp2(op2_t) {}
+ template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a*b); }
+};
+
+struct TypifyOp2 {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(op2_t value, F &&f) {
+ if (value == Add::f) {
+ return f(Result<InlineOp2<Add>>());
+ } else if (value == Mul::f) {
+ return f(Result<InlineOp2<Mul>>());
+ } else {
+ return f(Result<CallOp2>());
+ }
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index 3a73a3b8784..02d16caae6b 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -15,25 +15,6 @@ namespace vespalib::eval {
namespace {
using namespace nodes;
-using map_fun_t = double (*)(double);
-using join_fun_t = double (*)(double, double);
-
-//-----------------------------------------------------------------------------
-
-// TODO(havardpe): generic function pointer resolving for all single
-// operation lambdas.
-
-template <typename OP2>
-bool is_op2(const Function &lambda) {
- if (lambda.num_params() == 2) {
- if (auto op2 = as<OP2>(lambda.root())) {
- auto sym1 = as<Symbol>(op2->lhs());
- auto sym2 = as<Symbol>(op2->rhs());
- return (sym1 && sym2 && (sym1->id() != sym2->id()));
- }
- }
- return false;
-}
//-----------------------------------------------------------------------------
@@ -63,13 +44,13 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::reduce(a, aggr, dimensions, stash);
}
- void make_map(const Node &, map_fun_t function) {
+ void make_map(const Node &, operation::op1_t function) {
assert(stack.size() >= 1);
const auto &a = stack.back().get();
stack.back() = tensor_function::map(a, function, stash);
}
- void make_join(const Node &, join_fun_t function) {
+ void make_join(const Node &, operation::op2_t function) {
assert(stack.size() >= 2);
const auto &b = stack.back().get();
stack.pop_back();
@@ -77,7 +58,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::join(a, b, function, stash);
}
- void make_merge(const Node &, join_fun_t function) {
+ void make_merge(const Node &, operation::op2_t function) {
assert(stack.size() >= 2);
const auto &b = stack.back().get();
stack.pop_back();
@@ -203,14 +184,16 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
abort();
}
void visit(const TensorMap &node) override {
- const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
- make_map(node, token.get()->get().get_function<1>());
+ if (auto op1 = operation::lookup_op1(node.lambda())) {
+ make_map(node, op1.value());
+ } else {
+ const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
+ make_map(node, token.get()->get().get_function<1>());
+ }
}
void visit(const TensorJoin &node) override {
- if (is_op2<Mul>(node.lambda())) {
- make_join(node, operation::Mul::f);
- } else if (is_op2<Add>(node.lambda())) {
- make_join(node, operation::Add::f);
+ if (auto op2 = operation::lookup_op2(node.lambda())) {
+ make_join(node, op2.value());
} else {
const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
make_join(node, token.get()->get().get_function<2>());
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index fa0a99de461..581f65c0e31 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "operation.h"
+#include "function.h"
+#include "key_gen.h"
#include <vespa/vespalib/util/approx.h>
#include <algorithm>
@@ -48,4 +50,97 @@ double Relu::f(double a) { return std::max(a, 0.0); }
double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; }
+namespace {
+
+template <typename T>
+void add_op(std::map<vespalib::string,T> &map, const Function &fun, T op) {
+ assert(!fun.has_error());
+ auto key = gen_key(fun, PassParams::SEPARATE);
+ auto res = map.emplace(key, op);
+ assert(res.second);
+}
+
+template <typename T>
+std::optional<T> lookup_op(const std::map<vespalib::string,T> &map, const Function &fun) {
+ auto key = gen_key(fun, PassParams::SEPARATE);
+ auto pos = map.find(key);
+ if (pos != map.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+void add_op1(std::map<vespalib::string,op1_t> &map, const vespalib::string &expr, op1_t op) {
+ add_op(map, *Function::parse({"a"}, expr), op);
+}
+
+void add_op2(std::map<vespalib::string,op2_t> &map, const vespalib::string &expr, op2_t op) {
+ add_op(map, *Function::parse({"a", "b"}, expr), op);
+}
+
+std::map<vespalib::string,op1_t> make_op1_map() {
+ std::map<vespalib::string,op1_t> map;
+ add_op1(map, "-a", Neg::f);
+ add_op1(map, "!a", Not::f);
+ add_op1(map, "cos(a)", Cos::f);
+ add_op1(map, "sin(a)", Sin::f);
+ add_op1(map, "tan(a)", Tan::f);
+ add_op1(map, "cosh(a)", Cosh::f);
+ add_op1(map, "sinh(a)", Sinh::f);
+ add_op1(map, "tanh(a)", Tanh::f);
+ add_op1(map, "acos(a)", Acos::f);
+ add_op1(map, "asin(a)", Asin::f);
+ add_op1(map, "atan(a)", Atan::f);
+ add_op1(map, "exp(a)", Exp::f);
+ add_op1(map, "log10(a)", Log10::f);
+ add_op1(map, "log(a)", Log::f);
+ add_op1(map, "sqrt(a)", Sqrt::f);
+ add_op1(map, "ceil(a)", Ceil::f);
+ add_op1(map, "fabs(a)", Fabs::f);
+ add_op1(map, "floor(a)", Floor::f);
+ add_op1(map, "isNan(a)", IsNan::f);
+ add_op1(map, "relu(a)", Relu::f);
+ add_op1(map, "sigmoid(a)", Sigmoid::f);
+ add_op1(map, "elu(a)", Elu::f);
+ return map;
+}
+
+std::map<vespalib::string,op2_t> make_op2_map() {
+ std::map<vespalib::string,op2_t> map;
+ add_op2(map, "a+b", Add::f);
+ add_op2(map, "a-b", Sub::f);
+ add_op2(map, "a*b", Mul::f);
+ add_op2(map, "a/b", Div::f);
+ add_op2(map, "a%b", Mod::f);
+ add_op2(map, "a^b", Pow::f);
+ add_op2(map, "a==b", Equal::f);
+ add_op2(map, "a!=b", NotEqual::f);
+ add_op2(map, "a~=b", Approx::f);
+ add_op2(map, "a<b", Less::f);
+ add_op2(map, "a<=b", LessEqual::f);
+ add_op2(map, "a>b", Greater::f);
+ add_op2(map, "a>=b", GreaterEqual::f);
+ add_op2(map, "a&&b", And::f);
+ add_op2(map, "a||b", Or::f);
+ add_op2(map, "atan2(a,b)", Atan2::f);
+ add_op2(map, "ldexp(a,b)", Ldexp::f);
+ add_op2(map, "pow(a,b)", Pow::f);
+ add_op2(map, "fmod(a,b)", Mod::f);
+ add_op2(map, "min(a,b)", Min::f);
+ add_op2(map, "max(a,b)", Max::f);
+ return map;
+}
+
+} // namespace <unnamed>
+
+std::optional<op1_t> lookup_op1(const Function &fun) {
+ static const std::map<vespalib::string,op1_t> map = make_op1_map();
+ return lookup_op(map, fun);
+}
+
+std::optional<op2_t> lookup_op2(const Function &fun) {
+ static const std::map<vespalib::string,op2_t> map = make_op2_map();
+ return lookup_op(map, fun);
+}
+
}
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index fa99f51a308..a80193e704d 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -1,6 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
+#include <optional>
+
+namespace vespalib::eval { class Function; }
namespace vespalib::eval::operation {
@@ -46,4 +49,10 @@ struct Relu { static double f(double a); };
struct Sigmoid { static double f(double a); };
struct Elu { static double f(double a); };
+using op1_t = double (*)(double);
+using op2_t = double (*)(double, double);
+
+std::optional<op1_t> lookup_op1(const Function &fun);
+std::optional<op2_t> lookup_op2(const Function &fun);
+
}
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 3e91240048b..a8ae9c44bb0 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -2,6 +2,7 @@
#pragma once
+#include <vespa/vespalib/util/typify.h>
#include <vespa/vespalib/stllike/string.h>
#include <vector>
@@ -104,4 +105,15 @@ template <typename CT> inline ValueType::CellType get_cell_type();
template <> inline ValueType::CellType get_cell_type<double>() { return ValueType::CellType::DOUBLE; }
template <> inline ValueType::CellType get_cell_type<float>() { return ValueType::CellType::FLOAT; }
+struct TypifyCellType {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(ValueType::CellType value, F &&f) {
+ switch(value) {
+ case ValueType::CellType::DOUBLE: return f(Result<double>());
+ case ValueType::CellType::FLOAT: return f(Result<float>());
+ }
+ abort();
+ }
+};
+
} // namespace
diff --git a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp
index 6b0d65c0743..c358c9d618d 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp
@@ -5,6 +5,8 @@
#include <vespa/vespalib/objects/objectvisitor.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
+#include <vespa/vespalib/util/typify.h>
#include <optional>
#include <algorithm>
@@ -16,6 +18,7 @@ using eval::Value;
using eval::ValueType;
using eval::TensorFunction;
using eval::TensorEngine;
+using eval::TypifyCellType;
using eval::as;
using namespace eval::operation;
@@ -30,6 +33,18 @@ using State = eval::InterpretedFunction::State;
namespace {
+struct TypifyOverlap {
+ template <Overlap VALUE> using Result = TypifyResultValue<Overlap, VALUE>;
+ template <typename F> static decltype(auto) resolve(Overlap value, F &&f) {
+ switch (value) {
+ case Overlap::INNER: return f(Result<Overlap::INNER>());
+ case Overlap::OUTER: return f(Result<Overlap::OUTER>());
+ case Overlap::FULL: return f(Result<Overlap::FULL>());
+ }
+ abort();
+ }
+};
+
struct JoinParams {
const ValueType &result_type;
size_t factor;
@@ -38,44 +53,17 @@ struct JoinParams {
: result_type(result_type_in), factor(factor_in), function(function_in) {}
};
-struct CallFun {
- join_fun_t function;
- CallFun(const JoinParams &params) : function(params.function) {}
- double eval(double a, double b) const { return function(a, b); }
-};
-
-struct AddFun {
- AddFun(const JoinParams &) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return (a + b); }
-};
-
-struct MulFun {
- MulFun(const JoinParams &) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return (a * b); }
-};
-
-// needed for asymmetric operations like Sub and Div
-template <typename Fun>
-struct SwapFun {
- Fun fun;
- SwapFun(const JoinParams &params) : fun(params) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return fun.eval(b, a); }
-};
-
template <typename OCT, typename PCT, typename SCT, typename Fun>
void apply_fun_1_to_n(OCT *dst, const PCT *pri, SCT sec, size_t n, const Fun &fun) {
for (size_t i = 0; i < n; ++i) {
- dst[i] = fun.eval(pri[i], sec);
+ dst[i] = fun(pri[i], sec);
}
}
template <typename OCT, typename PCT, typename SCT, typename Fun>
void apply_fun_n_to_n(OCT *dst, const PCT *pri, const SCT *sec, size_t n, const Fun &fun) {
for (size_t i = 0; i < n; ++i) {
- dst[i] = fun.eval(pri[i], sec[i]);
+ dst[i] = fun(pri[i], sec[i]);
}
}
@@ -93,9 +81,9 @@ void my_simple_join_op(State &state, uint64_t param) {
using PCT = typename std::conditional<swap,RCT,LCT>::type;
using SCT = typename std::conditional<swap,LCT,RCT>::type;
using OCT = typename eval::UnifyCellTypes<PCT,SCT>::type;
- using OP = typename std::conditional<swap,SwapFun<Fun>,Fun>::type;
+ using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
const JoinParams &params = *(JoinParams*)param;
- OP my_op(params);
+ OP my_op(params.function);
auto pri_cells = DenseTensorView::typify_cells<PCT>(state.peek(swap ? 0 : 1));
auto sec_cells = DenseTensorView::typify_cells<SCT>(state.peek(swap ? 1 : 0));
auto dst_cells = make_dst_cells<OCT, pri_mut>(pri_cells, state.stash);
@@ -122,67 +110,13 @@ void my_simple_join_op(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-template <typename Fun, bool swap, Overlap overlap, bool pri_mut>
-struct MySimpleJoinOp {
- template <typename LCT, typename RCT>
- static auto get_fun() { return my_simple_join_op<LCT,RCT,Fun,swap,overlap,pri_mut>; }
-};
-
-template <bool swap, Overlap overlap, bool pri_mut>
-op_function my_select_4(ValueType::CellType lct,
- ValueType::CellType rct,
- join_fun_t fun_hint)
-{
- if (fun_hint == Add::f) {
- return select_2<MySimpleJoinOp<AddFun,swap,overlap,pri_mut>>(lct, rct);
- } else if (fun_hint == Mul::f) {
- return select_2<MySimpleJoinOp<MulFun,swap,overlap,pri_mut>>(lct, rct);
- } else {
- return select_2<MySimpleJoinOp<CallFun,swap,overlap,pri_mut>>(lct, rct);
- }
-}
-
-template <bool swap, Overlap overlap>
-op_function my_select_3(ValueType::CellType lct,
- ValueType::CellType rct,
- bool pri_mut,
- join_fun_t fun_hint)
-{
- if (pri_mut) {
- return my_select_4<swap, overlap, true>(lct, rct, fun_hint);
- } else {
- return my_select_4<swap, overlap, false>(lct, rct, fun_hint);
+struct MyGetFun {
+ template <typename R1, typename R2, typename R3, typename R4, typename R5, typename R6> static auto invoke() {
+ return my_simple_join_op<R1, R2, R3, R4::value, R5::value, R6::value>;
}
-}
-
-template <bool swap>
-op_function my_select_2(ValueType::CellType lct,
- ValueType::CellType rct,
- Overlap overlap,
- bool pri_mut,
- join_fun_t fun_hint)
-{
- switch (overlap) {
- case Overlap::INNER: return my_select_3<swap, Overlap::INNER>(lct, rct, pri_mut, fun_hint);
- case Overlap::OUTER: return my_select_3<swap, Overlap::OUTER>(lct, rct, pri_mut, fun_hint);
- case Overlap::FULL: return my_select_3<swap, Overlap::FULL>(lct, rct, pri_mut, fun_hint);
- }
- abort();
-}
+};
-op_function my_select(ValueType::CellType lct,
- ValueType::CellType rct,
- Primary primary,
- Overlap overlap,
- bool pri_mut,
- join_fun_t fun_hint)
-{
- switch (primary) {
- case Primary::LHS: return my_select_2<false>(lct, rct, overlap, pri_mut, fun_hint);
- case Primary::RHS: return my_select_2<true>(lct, rct, overlap, pri_mut, fun_hint);
- }
- abort();
-}
+using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool,TypifyOverlap>;
//-----------------------------------------------------------------------------
@@ -280,11 +214,10 @@ Instruction
DenseSimpleJoinFunction::compile_self(const TensorEngine &, Stash &stash) const
{
const JoinParams &params = stash.create<JoinParams>(result_type(), factor(), function());
- auto op = my_select(lhs().result_type().cell_type(),
- rhs().result_type().cell_type(),
- _primary, _overlap,
- primary_is_mutable(),
- function());
+ auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type(),
+ function(), (_primary == Primary::RHS),
+ _overlap, primary_is_mutable());
static_assert(sizeof(uint64_t) == sizeof(&params));
return Instruction(op, (uint64_t)(&params));
}