summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/eval/inline_operation/CMakeLists.txt9
-rw-r--r--eval/src/tests/eval/inline_operation/inline_operation_test.cpp156
-rw-r--r--eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp2
-rw-r--r--eval/src/vespa/eval/eval/inline_operation.h67
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp39
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp95
-rw-r--r--eval/src/vespa/eval/eval/operation.h9
-rw-r--r--eval/src/vespa/eval/eval/value_type.h12
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp123
-rw-r--r--vespalib/CMakeLists.txt3
-rw-r--r--vespalib/src/tests/typify/CMakeLists.txt9
-rw-r--r--vespalib/src/tests/typify/typify_test.cpp124
-rw-r--r--vespalib/src/vespa/vespalib/util/typify.h96
16 files changed, 622 insertions, 127 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index b68440795d4..67f9fa19dc0 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -17,6 +17,7 @@ vespa_define_module(
src/tests/eval/function
src/tests/eval/function_speed
src/tests/eval/gbdt
+ src/tests/eval/inline_operation
src/tests/eval/interpreted_function
src/tests/eval/node_tools
src/tests/eval/node_types
diff --git a/eval/src/tests/eval/inline_operation/CMakeLists.txt b/eval/src/tests/eval/inline_operation/CMakeLists.txt
new file mode 100644
index 00000000000..04cdbca3abf
--- /dev/null
+++ b/eval/src/tests/eval/inline_operation/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_inline_operation_test_app TEST
+ SOURCES
+ inline_operation_test.cpp
+ DEPENDS
+ vespaeval
+ gtest
+)
+vespa_add_test(NAME eval_inline_operation_test_app COMMAND eval_inline_operation_test_app)
diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
new file mode 100644
index 00000000000..4520176e276
--- /dev/null
+++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
@@ -0,0 +1,156 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
+#include <vespa/eval/eval/function.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::operation;
+
+template <typename T> struct IsInlined { constexpr static bool value = true; };
+template <> struct IsInlined<CallOp1> { constexpr static bool value = false; };
+template <> struct IsInlined<CallOp2> { constexpr static bool value = false; };
+
+template <typename T> double test_op1(op1_t ref, double a, bool inlined) {
+ T op(ref);
+ EXPECT_EQ(IsInlined<T>::value, inlined);
+ EXPECT_EQ(op(a), ref(a));
+ return op(a);
+};
+
+template <typename T> double test_op2(op2_t ref, double a, double b, bool inlined) {
+ T op(ref);
+ EXPECT_EQ(IsInlined<T>::value, inlined);
+ EXPECT_EQ(op(a,b), ref(a,b));
+ return op(a,b);
+};
+
+op1_t as_op1(const vespalib::string &str) {
+ auto fun = Function::parse({"a"}, str);
+ auto res = lookup_op1(*fun);
+ EXPECT_TRUE(res.has_value());
+ return res.value();
+}
+
+op2_t as_op2(const vespalib::string &str) {
+ auto fun = Function::parse({"a", "b"}, str);
+ auto res = lookup_op2(*fun);
+ EXPECT_TRUE(res.has_value());
+ return res.value();
+}
+
+TEST(InlineOperationTest, op1_lambdas_are_recognized) {
+ EXPECT_EQ(as_op1("-a"), Neg::f);
+ EXPECT_EQ(as_op1("!a"), Not::f);
+ EXPECT_EQ(as_op1("cos(a)"), Cos::f);
+ EXPECT_EQ(as_op1("sin(a)"), Sin::f);
+ EXPECT_EQ(as_op1("tan(a)"), Tan::f);
+ EXPECT_EQ(as_op1("cosh(a)"), Cosh::f);
+ EXPECT_EQ(as_op1("sinh(a)"), Sinh::f);
+ EXPECT_EQ(as_op1("tanh(a)"), Tanh::f);
+ EXPECT_EQ(as_op1("acos(a)"), Acos::f);
+ EXPECT_EQ(as_op1("asin(a)"), Asin::f);
+ EXPECT_EQ(as_op1("atan(a)"), Atan::f);
+ EXPECT_EQ(as_op1("exp(a)"), Exp::f);
+ EXPECT_EQ(as_op1("log10(a)"), Log10::f);
+ EXPECT_EQ(as_op1("log(a)"), Log::f);
+ EXPECT_EQ(as_op1("sqrt(a)"), Sqrt::f);
+ EXPECT_EQ(as_op1("ceil(a)"), Ceil::f);
+ EXPECT_EQ(as_op1("fabs(a)"), Fabs::f);
+ EXPECT_EQ(as_op1("floor(a)"), Floor::f);
+ EXPECT_EQ(as_op1("isNan(a)"), IsNan::f);
+ EXPECT_EQ(as_op1("relu(a)"), Relu::f);
+ EXPECT_EQ(as_op1("sigmoid(a)"), Sigmoid::f);
+ EXPECT_EQ(as_op1("elu(a)"), Elu::f);
+}
+
+TEST(InlineOperationTest, op1_lambdas_are_recognized_with_different_parameter_names) {
+ EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "-x")).value(), Neg::f);
+ EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "!x")).value(), Not::f);
+}
+
+TEST(InlineOperationTest, non_op1_lambdas_are_not_recognized) {
+ EXPECT_FALSE(lookup_op1(*Function::parse({"a"}, "a*a")).has_value());
+ EXPECT_FALSE(lookup_op1(*Function::parse({"a", "b"}, "a+b")).has_value());
+}
+
+TEST(InlineOperationTest, op2_lambdas_are_recognized) {
+ EXPECT_EQ(as_op2("a+b"), Add::f);
+ EXPECT_EQ(as_op2("a-b"), Sub::f);
+ EXPECT_EQ(as_op2("a*b"), Mul::f);
+ EXPECT_EQ(as_op2("a/b"), Div::f);
+ EXPECT_EQ(as_op2("a%b"), Mod::f);
+ EXPECT_EQ(as_op2("a^b"), Pow::f);
+ EXPECT_EQ(as_op2("a==b"), Equal::f);
+ EXPECT_EQ(as_op2("a!=b"), NotEqual::f);
+ EXPECT_EQ(as_op2("a~=b"), Approx::f);
+ EXPECT_EQ(as_op2("a<b"), Less::f);
+ EXPECT_EQ(as_op2("a<=b"), LessEqual::f);
+ EXPECT_EQ(as_op2("a>b"), Greater::f);
+ EXPECT_EQ(as_op2("a>=b"), GreaterEqual::f);
+ EXPECT_EQ(as_op2("a&&b"), And::f);
+ EXPECT_EQ(as_op2("a||b"), Or::f);
+ EXPECT_EQ(as_op2("atan2(a,b)"), Atan2::f);
+ EXPECT_EQ(as_op2("ldexp(a,b)"), Ldexp::f);
+ EXPECT_EQ(as_op2("pow(a,b)"), Pow::f);
+ EXPECT_EQ(as_op2("fmod(a,b)"), Mod::f);
+ EXPECT_EQ(as_op2("min(a,b)"), Min::f);
+ EXPECT_EQ(as_op2("max(a,b)"), Max::f);
+}
+
+TEST(InlineOperationTest, op2_lambdas_are_recognized_with_different_parameter_names) {
+ EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x+y")).value(), Add::f);
+ EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x-y")).value(), Sub::f);
+}
+
+TEST(InlineOperationTest, non_op2_lambdas_are_not_recognized) {
+ EXPECT_FALSE(lookup_op2(*Function::parse({"a"}, "-a")).has_value());
+ EXPECT_FALSE(lookup_op2(*Function::parse({"a", "b"}, "b+a")).has_value());
+}
+
+TEST(InlineOperationTest, generic_op1_wrapper_works) {
+ CallOp1 op(Neg::f);
+ EXPECT_EQ(op(3), -3);
+ EXPECT_EQ(op(-5), 5);
+}
+
+TEST(InlineOperationTest, generic_op2_wrapper_works) {
+ CallOp2 op(Add::f);
+ EXPECT_EQ(op(2,3), 5);
+ EXPECT_EQ(op(3,7), 10);
+}
+
+TEST(InlineOperationTest, inline_op2_example_works) {
+ op2_t ignored = nullptr;
+ InlineOp2<Add> op(ignored);
+ EXPECT_EQ(op(2,3), 5);
+ EXPECT_EQ(op(3,7), 10);
+}
+
+TEST(InlineOperationTest, parameter_swap_wrapper_works) {
+ CallOp2 op(Sub::f);
+ SwapArgs2<CallOp2> swap_op(Sub::f);
+ EXPECT_EQ(op(2,3), -1);
+ EXPECT_EQ(swap_op(2,3), 1);
+ EXPECT_EQ(op(3,7), -4);
+ EXPECT_EQ(swap_op(3,7), 4);
+}
+
+TEST(InlineOperationTest, resolved_op1_works) {
+ auto a = TypifyOp1::resolve(Neg::f, [](auto t){ return test_op1<typename decltype(t)::type>(Neg::f, 2.0, false); });
+ // putting the lambda inside the EXPECT does not work
+ EXPECT_EQ(a, -2.0);
+}
+
+TEST(InlineOperationTest, resolved_op2_works) {
+ auto a = TypifyOp2::resolve(Add::f, [](auto t){ return test_op2<typename decltype(t)::type>(Add::f, 2.0, 5.0, true); });
+ auto b = TypifyOp2::resolve(Mul::f, [](auto t){ return test_op2<typename decltype(t)::type>(Mul::f, 5.0, 3.0, true); });
+ auto c = TypifyOp2::resolve(Sub::f, [](auto t){ return test_op2<typename decltype(t)::type>(Sub::f, 8.0, 5.0, false); });
+ // putting the lambda inside the EXPECT does not work
+ EXPECT_EQ(a, 7.0);
+ EXPECT_EQ(b, 15.0);
+ EXPECT_EQ(c, 3.0);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
index a571837b8e9..92fdbfade46 100644
--- a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
+++ b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
@@ -67,7 +67,6 @@ TEST("require that matmul can be optimized") {
TEST("require that matmul with lambda can be optimized") {
TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true));
- TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, true, true));
}
TEST("require that expressions similar to matmul are not optimized") {
@@ -75,6 +74,7 @@ TEST("require that expressions similar to matmul are not optimized") {
TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum,b)"));
TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,prod,d)"));
TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x+y)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*y)),sum,d)"));
diff --git a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
index c0823248538..f9c563c9bf8 100644
--- a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
+++ b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
@@ -78,7 +78,6 @@ TEST("require that single multi matmul can be optimized") {
TEST("require that multi matmul with lambda can be optimized") {
TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, 6, true, true));
- TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, 6, true, true));
}
TEST("require that expressions similar to multi matmul are not optimized") {
@@ -86,6 +85,7 @@ TEST("require that expressions similar to multi matmul are not optimized") {
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,b)"));
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,prod,d)"));
TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x+y)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*x)),sum,d)"));
TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*y)),sum,d)"));
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 0b924451907..3ecc3f66cda 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -130,13 +130,13 @@ TEST("require that xw product gives same results as reference join/reduce") {
TEST("require that various variants of xw product can be optimized") {
TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
- TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)", 3, 2, true));
}
TEST("require that expressions similar to xw product are not optimized") {
TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)"));
TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)"));
TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)"));
+ TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)"));
TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)"));
diff --git a/eval/src/vespa/eval/eval/inline_operation.h b/eval/src/vespa/eval/eval/inline_operation.h
new file mode 100644
index 00000000000..493de9ea56c
--- /dev/null
+++ b/eval/src/vespa/eval/eval/inline_operation.h
@@ -0,0 +1,67 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "operation.h"
+#include <vespa/vespalib/util/typify.h>
+
+namespace vespalib::eval::operation {
+
+//-----------------------------------------------------------------------------
+
+struct CallOp1 {
+ op1_t my_op1;
+ CallOp1(op1_t op1) : my_op1(op1) {}
+ double operator()(double a) const { return my_op1(a); }
+};
+
+struct TypifyOp1 {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(op1_t value, F &&f) {
+ (void) value;
+ return f(Result<CallOp1>());
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+struct CallOp2 {
+ op2_t my_op2;
+ CallOp2(op2_t op2) : my_op2(op2) {}
+ op2_t get() const { return my_op2; }
+ double operator()(double a, double b) const { return my_op2(a, b); }
+};
+
+template <typename Op2>
+struct SwapArgs2 {
+ Op2 op2;
+ SwapArgs2(op2_t op2_in) : op2(op2_in) {}
+ template <typename A, typename B> constexpr auto operator()(A a, B b) const { return op2(b, a); }
+};
+
+template <typename T> struct InlineOp2;
+template <> struct InlineOp2<Add> {
+ InlineOp2(op2_t) {}
+ template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a+b); }
+};
+template <> struct InlineOp2<Mul> {
+ InlineOp2(op2_t) {}
+ template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a*b); }
+};
+
+struct TypifyOp2 {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(op2_t value, F &&f) {
+ if (value == Add::f) {
+ return f(Result<InlineOp2<Add>>());
+ } else if (value == Mul::f) {
+ return f(Result<InlineOp2<Mul>>());
+ } else {
+ return f(Result<CallOp2>());
+ }
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index 3a73a3b8784..02d16caae6b 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -15,25 +15,6 @@ namespace vespalib::eval {
namespace {
using namespace nodes;
-using map_fun_t = double (*)(double);
-using join_fun_t = double (*)(double, double);
-
-//-----------------------------------------------------------------------------
-
-// TODO(havardpe): generic function pointer resolving for all single
-// operation lambdas.
-
-template <typename OP2>
-bool is_op2(const Function &lambda) {
- if (lambda.num_params() == 2) {
- if (auto op2 = as<OP2>(lambda.root())) {
- auto sym1 = as<Symbol>(op2->lhs());
- auto sym2 = as<Symbol>(op2->rhs());
- return (sym1 && sym2 && (sym1->id() != sym2->id()));
- }
- }
- return false;
-}
//-----------------------------------------------------------------------------
@@ -63,13 +44,13 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::reduce(a, aggr, dimensions, stash);
}
- void make_map(const Node &, map_fun_t function) {
+ void make_map(const Node &, operation::op1_t function) {
assert(stack.size() >= 1);
const auto &a = stack.back().get();
stack.back() = tensor_function::map(a, function, stash);
}
- void make_join(const Node &, join_fun_t function) {
+ void make_join(const Node &, operation::op2_t function) {
assert(stack.size() >= 2);
const auto &b = stack.back().get();
stack.pop_back();
@@ -77,7 +58,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::join(a, b, function, stash);
}
- void make_merge(const Node &, join_fun_t function) {
+ void make_merge(const Node &, operation::op2_t function) {
assert(stack.size() >= 2);
const auto &b = stack.back().get();
stack.pop_back();
@@ -203,14 +184,16 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
abort();
}
void visit(const TensorMap &node) override {
- const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
- make_map(node, token.get()->get().get_function<1>());
+ if (auto op1 = operation::lookup_op1(node.lambda())) {
+ make_map(node, op1.value());
+ } else {
+ const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
+ make_map(node, token.get()->get().get_function<1>());
+ }
}
void visit(const TensorJoin &node) override {
- if (is_op2<Mul>(node.lambda())) {
- make_join(node, operation::Mul::f);
- } else if (is_op2<Add>(node.lambda())) {
- make_join(node, operation::Add::f);
+ if (auto op2 = operation::lookup_op2(node.lambda())) {
+ make_join(node, op2.value());
} else {
const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
make_join(node, token.get()->get().get_function<2>());
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index fa0a99de461..581f65c0e31 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "operation.h"
+#include "function.h"
+#include "key_gen.h"
#include <vespa/vespalib/util/approx.h>
#include <algorithm>
@@ -48,4 +50,97 @@ double Relu::f(double a) { return std::max(a, 0.0); }
double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; }
+namespace {
+
+template <typename T>
+void add_op(std::map<vespalib::string,T> &map, const Function &fun, T op) {
+ assert(!fun.has_error());
+ auto key = gen_key(fun, PassParams::SEPARATE);
+ auto res = map.emplace(key, op);
+ assert(res.second);
+}
+
+template <typename T>
+std::optional<T> lookup_op(const std::map<vespalib::string,T> &map, const Function &fun) {
+ auto key = gen_key(fun, PassParams::SEPARATE);
+ auto pos = map.find(key);
+ if (pos != map.end()) {
+ return pos->second;
+ }
+ return std::nullopt;
+}
+
+void add_op1(std::map<vespalib::string,op1_t> &map, const vespalib::string &expr, op1_t op) {
+ add_op(map, *Function::parse({"a"}, expr), op);
+}
+
+void add_op2(std::map<vespalib::string,op2_t> &map, const vespalib::string &expr, op2_t op) {
+ add_op(map, *Function::parse({"a", "b"}, expr), op);
+}
+
+std::map<vespalib::string,op1_t> make_op1_map() {
+ std::map<vespalib::string,op1_t> map;
+ add_op1(map, "-a", Neg::f);
+ add_op1(map, "!a", Not::f);
+ add_op1(map, "cos(a)", Cos::f);
+ add_op1(map, "sin(a)", Sin::f);
+ add_op1(map, "tan(a)", Tan::f);
+ add_op1(map, "cosh(a)", Cosh::f);
+ add_op1(map, "sinh(a)", Sinh::f);
+ add_op1(map, "tanh(a)", Tanh::f);
+ add_op1(map, "acos(a)", Acos::f);
+ add_op1(map, "asin(a)", Asin::f);
+ add_op1(map, "atan(a)", Atan::f);
+ add_op1(map, "exp(a)", Exp::f);
+ add_op1(map, "log10(a)", Log10::f);
+ add_op1(map, "log(a)", Log::f);
+ add_op1(map, "sqrt(a)", Sqrt::f);
+ add_op1(map, "ceil(a)", Ceil::f);
+ add_op1(map, "fabs(a)", Fabs::f);
+ add_op1(map, "floor(a)", Floor::f);
+ add_op1(map, "isNan(a)", IsNan::f);
+ add_op1(map, "relu(a)", Relu::f);
+ add_op1(map, "sigmoid(a)", Sigmoid::f);
+ add_op1(map, "elu(a)", Elu::f);
+ return map;
+}
+
+std::map<vespalib::string,op2_t> make_op2_map() {
+ std::map<vespalib::string,op2_t> map;
+ add_op2(map, "a+b", Add::f);
+ add_op2(map, "a-b", Sub::f);
+ add_op2(map, "a*b", Mul::f);
+ add_op2(map, "a/b", Div::f);
+ add_op2(map, "a%b", Mod::f);
+ add_op2(map, "a^b", Pow::f);
+ add_op2(map, "a==b", Equal::f);
+ add_op2(map, "a!=b", NotEqual::f);
+ add_op2(map, "a~=b", Approx::f);
+ add_op2(map, "a<b", Less::f);
+ add_op2(map, "a<=b", LessEqual::f);
+ add_op2(map, "a>b", Greater::f);
+ add_op2(map, "a>=b", GreaterEqual::f);
+ add_op2(map, "a&&b", And::f);
+ add_op2(map, "a||b", Or::f);
+ add_op2(map, "atan2(a,b)", Atan2::f);
+ add_op2(map, "ldexp(a,b)", Ldexp::f);
+ add_op2(map, "pow(a,b)", Pow::f);
+ add_op2(map, "fmod(a,b)", Mod::f);
+ add_op2(map, "min(a,b)", Min::f);
+ add_op2(map, "max(a,b)", Max::f);
+ return map;
+}
+
+} // namespace <unnamed>
+
+std::optional<op1_t> lookup_op1(const Function &fun) {
+ static const std::map<vespalib::string,op1_t> map = make_op1_map();
+ return lookup_op(map, fun);
+}
+
+std::optional<op2_t> lookup_op2(const Function &fun) {
+ static const std::map<vespalib::string,op2_t> map = make_op2_map();
+ return lookup_op(map, fun);
+}
+
}
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index fa99f51a308..a80193e704d 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -1,6 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
+#include <optional>
+
+namespace vespalib::eval { class Function; }
namespace vespalib::eval::operation {
@@ -46,4 +49,10 @@ struct Relu { static double f(double a); };
struct Sigmoid { static double f(double a); };
struct Elu { static double f(double a); };
+using op1_t = double (*)(double);
+using op2_t = double (*)(double, double);
+
+std::optional<op1_t> lookup_op1(const Function &fun);
+std::optional<op2_t> lookup_op2(const Function &fun);
+
}
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 3e91240048b..a8ae9c44bb0 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -2,6 +2,7 @@
#pragma once
+#include <vespa/vespalib/util/typify.h>
#include <vespa/vespalib/stllike/string.h>
#include <vector>
@@ -104,4 +105,15 @@ template <typename CT> inline ValueType::CellType get_cell_type();
template <> inline ValueType::CellType get_cell_type<double>() { return ValueType::CellType::DOUBLE; }
template <> inline ValueType::CellType get_cell_type<float>() { return ValueType::CellType::FLOAT; }
+struct TypifyCellType {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(ValueType::CellType value, F &&f) {
+ switch(value) {
+ case ValueType::CellType::DOUBLE: return f(Result<double>());
+ case ValueType::CellType::FLOAT: return f(Result<float>());
+ }
+ abort();
+ }
+};
+
} // namespace
diff --git a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp
index 6b0d65c0743..c358c9d618d 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp
@@ -5,6 +5,8 @@
#include <vespa/vespalib/objects/objectvisitor.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/operation.h>
+#include <vespa/eval/eval/inline_operation.h>
+#include <vespa/vespalib/util/typify.h>
#include <optional>
#include <algorithm>
@@ -16,6 +18,7 @@ using eval::Value;
using eval::ValueType;
using eval::TensorFunction;
using eval::TensorEngine;
+using eval::TypifyCellType;
using eval::as;
using namespace eval::operation;
@@ -30,6 +33,18 @@ using State = eval::InterpretedFunction::State;
namespace {
+struct TypifyOverlap {
+ template <Overlap VALUE> using Result = TypifyResultValue<Overlap, VALUE>;
+ template <typename F> static decltype(auto) resolve(Overlap value, F &&f) {
+ switch (value) {
+ case Overlap::INNER: return f(Result<Overlap::INNER>());
+ case Overlap::OUTER: return f(Result<Overlap::OUTER>());
+ case Overlap::FULL: return f(Result<Overlap::FULL>());
+ }
+ abort();
+ }
+};
+
struct JoinParams {
const ValueType &result_type;
size_t factor;
@@ -38,44 +53,17 @@ struct JoinParams {
: result_type(result_type_in), factor(factor_in), function(function_in) {}
};
-struct CallFun {
- join_fun_t function;
- CallFun(const JoinParams &params) : function(params.function) {}
- double eval(double a, double b) const { return function(a, b); }
-};
-
-struct AddFun {
- AddFun(const JoinParams &) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return (a + b); }
-};
-
-struct MulFun {
- MulFun(const JoinParams &) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return (a * b); }
-};
-
-// needed for asymmetric operations like Sub and Div
-template <typename Fun>
-struct SwapFun {
- Fun fun;
- SwapFun(const JoinParams &params) : fun(params) {}
- template <typename A, typename B>
- auto eval(A a, B b) const { return fun.eval(b, a); }
-};
-
template <typename OCT, typename PCT, typename SCT, typename Fun>
void apply_fun_1_to_n(OCT *dst, const PCT *pri, SCT sec, size_t n, const Fun &fun) {
for (size_t i = 0; i < n; ++i) {
- dst[i] = fun.eval(pri[i], sec);
+ dst[i] = fun(pri[i], sec);
}
}
template <typename OCT, typename PCT, typename SCT, typename Fun>
void apply_fun_n_to_n(OCT *dst, const PCT *pri, const SCT *sec, size_t n, const Fun &fun) {
for (size_t i = 0; i < n; ++i) {
- dst[i] = fun.eval(pri[i], sec[i]);
+ dst[i] = fun(pri[i], sec[i]);
}
}
@@ -93,9 +81,9 @@ void my_simple_join_op(State &state, uint64_t param) {
using PCT = typename std::conditional<swap,RCT,LCT>::type;
using SCT = typename std::conditional<swap,LCT,RCT>::type;
using OCT = typename eval::UnifyCellTypes<PCT,SCT>::type;
- using OP = typename std::conditional<swap,SwapFun<Fun>,Fun>::type;
+ using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
const JoinParams &params = *(JoinParams*)param;
- OP my_op(params);
+ OP my_op(params.function);
auto pri_cells = DenseTensorView::typify_cells<PCT>(state.peek(swap ? 0 : 1));
auto sec_cells = DenseTensorView::typify_cells<SCT>(state.peek(swap ? 1 : 0));
auto dst_cells = make_dst_cells<OCT, pri_mut>(pri_cells, state.stash);
@@ -122,67 +110,13 @@ void my_simple_join_op(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-template <typename Fun, bool swap, Overlap overlap, bool pri_mut>
-struct MySimpleJoinOp {
- template <typename LCT, typename RCT>
- static auto get_fun() { return my_simple_join_op<LCT,RCT,Fun,swap,overlap,pri_mut>; }
-};
-
-template <bool swap, Overlap overlap, bool pri_mut>
-op_function my_select_4(ValueType::CellType lct,
- ValueType::CellType rct,
- join_fun_t fun_hint)
-{
- if (fun_hint == Add::f) {
- return select_2<MySimpleJoinOp<AddFun,swap,overlap,pri_mut>>(lct, rct);
- } else if (fun_hint == Mul::f) {
- return select_2<MySimpleJoinOp<MulFun,swap,overlap,pri_mut>>(lct, rct);
- } else {
- return select_2<MySimpleJoinOp<CallFun,swap,overlap,pri_mut>>(lct, rct);
- }
-}
-
-template <bool swap, Overlap overlap>
-op_function my_select_3(ValueType::CellType lct,
- ValueType::CellType rct,
- bool pri_mut,
- join_fun_t fun_hint)
-{
- if (pri_mut) {
- return my_select_4<swap, overlap, true>(lct, rct, fun_hint);
- } else {
- return my_select_4<swap, overlap, false>(lct, rct, fun_hint);
+struct MyGetFun {
+ template <typename R1, typename R2, typename R3, typename R4, typename R5, typename R6> static auto invoke() {
+ return my_simple_join_op<R1, R2, R3, R4::value, R5::value, R6::value>;
}
-}
-
-template <bool swap>
-op_function my_select_2(ValueType::CellType lct,
- ValueType::CellType rct,
- Overlap overlap,
- bool pri_mut,
- join_fun_t fun_hint)
-{
- switch (overlap) {
- case Overlap::INNER: return my_select_3<swap, Overlap::INNER>(lct, rct, pri_mut, fun_hint);
- case Overlap::OUTER: return my_select_3<swap, Overlap::OUTER>(lct, rct, pri_mut, fun_hint);
- case Overlap::FULL: return my_select_3<swap, Overlap::FULL>(lct, rct, pri_mut, fun_hint);
- }
- abort();
-}
+};
-op_function my_select(ValueType::CellType lct,
- ValueType::CellType rct,
- Primary primary,
- Overlap overlap,
- bool pri_mut,
- join_fun_t fun_hint)
-{
- switch (primary) {
- case Primary::LHS: return my_select_2<false>(lct, rct, overlap, pri_mut, fun_hint);
- case Primary::RHS: return my_select_2<true>(lct, rct, overlap, pri_mut, fun_hint);
- }
- abort();
-}
+using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool,TypifyOverlap>;
//-----------------------------------------------------------------------------
@@ -280,11 +214,10 @@ Instruction
DenseSimpleJoinFunction::compile_self(const TensorEngine &, Stash &stash) const
{
const JoinParams &params = stash.create<JoinParams>(result_type(), factor(), function());
- auto op = my_select(lhs().result_type().cell_type(),
- rhs().result_type().cell_type(),
- _primary, _overlap,
- primary_is_mutable(),
- function());
+ auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type(),
+ function(), (_primary == Primary::RHS),
+ _overlap, primary_is_mutable());
static_assert(sizeof(uint64_t) == sizeof(&params));
return Instruction(op, (uint64_t)(&params));
}
diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt
index 2675bc16bf2..1ca9816a921 100644
--- a/vespalib/CMakeLists.txt
+++ b/vespalib/CMakeLists.txt
@@ -24,8 +24,8 @@ vespa_define_module(
src/tests/assert
src/tests/barrier
src/tests/benchmark_timer
- src/tests/btree
src/tests/box
+ src/tests/btree
src/tests/closure
src/tests/component
src/tests/compress
@@ -128,6 +128,7 @@ vespa_define_module(
src/tests/tutorial/minimal
src/tests/tutorial/simple
src/tests/tutorial/threads
+ src/tests/typify
src/tests/util/generationhandler
src/tests/util/generationhandler_stress
src/tests/util/md5
diff --git a/vespalib/src/tests/typify/CMakeLists.txt b/vespalib/src/tests/typify/CMakeLists.txt
new file mode 100644
index 00000000000..29e95af1988
--- /dev/null
+++ b/vespalib/src/tests/typify/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(vespalib_typify_test_app TEST
+ SOURCES
+ typify_test.cpp
+ DEPENDS
+ vespalib
+ gtest
+)
+vespa_add_test(NAME vespalib_typify_test_app COMMAND vespalib_typify_test_app)
diff --git a/vespalib/src/tests/typify/typify_test.cpp b/vespalib/src/tests/typify/typify_test.cpp
new file mode 100644
index 00000000000..4c3f1c512ca
--- /dev/null
+++ b/vespalib/src/tests/typify/typify_test.cpp
@@ -0,0 +1,124 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/util/typify.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib;
+
+struct A { static constexpr int value_from_type = 1; };
+struct B { static constexpr int value_from_type = 2; };
+
+struct MyIntA { int value; };
+struct MyIntB { int value; };
+struct MyIntC { int value; }; // no typifier for this type
+
+// MyIntA -> A or B
+struct TypifyMyIntA {
+ template <typename T> using Result = TypifyResultType<T>;
+ template <typename F> static decltype(auto) resolve(MyIntA value, F &&f) {
+ if (value.value == 1) {
+ return f(Result<A>());
+ } else if (value.value == 2) {
+ return f(Result<B>());
+ }
+ abort();
+ }
+};
+
+// MyIntB -> TypifyResultValue<int,1> or TypifyResultValue<int,2>
+struct TypifyMyIntB {
+ template <int VALUE> using Result = TypifyResultValue<int,VALUE>;
+ template <typename F> static decltype(auto) resolve(MyIntB value, F &&f) {
+ if (value.value == 1) {
+ return f(Result<1>());
+ } else if (value.value == 2) {
+ return f(Result<2>());
+ }
+ abort();
+ }
+};
+
+using TX = TypifyValue<TypifyBool, TypifyMyIntA, TypifyMyIntB>;
+
+//-----------------------------------------------------------------------------
+
+struct GetFromType {
+ template <typename T> static int invoke() { return T::value_from_type; }
+};
+
+TEST(TypifyTest, simple_type_typification_works) {
+ auto res1 = typify_invoke<1,TX,GetFromType>(MyIntA{1});
+ auto res2 = typify_invoke<1,TX,GetFromType>(MyIntA{2});
+ EXPECT_EQ(res1, 1);
+ EXPECT_EQ(res2, 2);
+}
+
+struct GetFromValue {
+ template <typename R> static int invoke() { return R::value; }
+};
+
+TEST(TypifyTest, simple_value_typification_works) {
+ auto res1 = typify_invoke<1,TX,GetFromValue>(MyIntB{1});
+ auto res2 = typify_invoke<1,TX,GetFromValue>(MyIntB{2});
+ EXPECT_EQ(res1, 1);
+ EXPECT_EQ(res2, 2);
+}
+
+struct MaybeSum {
+ template <typename F1, typename V1, typename F2, typename V2> static int invoke(MyIntC v3) {
+ int res = 0;
+ if (F1::value) {
+ res += V1::value_from_type;
+ }
+ if (F2::value) {
+ res += V2::value;
+ }
+ res += v3.value;
+ return res;
+ }
+};
+
+TEST(TypifyTest, complex_typification_works) {
+ auto res1 = typify_invoke<4,TX,MaybeSum>(false, MyIntA{2}, false, MyIntB{1}, MyIntC{4});
+ auto res2 = typify_invoke<4,TX,MaybeSum>(false, MyIntA{2}, true, MyIntB{1}, MyIntC{4});
+ auto res3 = typify_invoke<4,TX,MaybeSum>(true, MyIntA{2}, false, MyIntB{1}, MyIntC{4});
+ auto res4 = typify_invoke<4,TX,MaybeSum>(true, MyIntA{2}, true, MyIntB{1}, MyIntC{4});
+ EXPECT_EQ(res1, 4);
+ EXPECT_EQ(res2, 5);
+ EXPECT_EQ(res3, 6);
+ EXPECT_EQ(res4, 7);
+}
+
+struct Singleton {
+ virtual int get() const = 0;
+ virtual ~Singleton() {}
+};
+
+template <int A, int B>
+struct MySingleton : Singleton {
+ MySingleton() = default;
+ MySingleton(const MySingleton &) = delete;
+ MySingleton &operator=(const MySingleton &) = delete;
+ int get() const override { return A + B; }
+};
+
+struct GetSingleton {
+ template <typename A, typename B>
+ static const Singleton &invoke() {
+ static MySingleton<A::value, B::value> obj;
+ return obj;
+ }
+};
+
+TEST(TypifyTest, typify_invoke_can_return_object_reference) {
+ const Singleton &s1 = typify_invoke<2,TX,GetSingleton>(MyIntB{1}, MyIntB{1});
+ const Singleton &s2 = typify_invoke<2,TX,GetSingleton>(MyIntB{2}, MyIntB{2});
+ const Singleton &s3 = typify_invoke<2,TX,GetSingleton>(MyIntB{2}, MyIntB{2});
+ EXPECT_EQ(s1.get(), 2);
+ EXPECT_EQ(s2.get(), 4);
+ EXPECT_EQ(s3.get(), 4);
+ EXPECT_NE(&s1, &s2);
+ EXPECT_EQ(&s2, &s3);
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/vespa/vespalib/util/typify.h b/vespalib/src/vespa/vespalib/util/typify.h
new file mode 100644
index 00000000000..0ee624a95b6
--- /dev/null
+++ b/vespalib/src/vespa/vespalib/util/typify.h
@@ -0,0 +1,96 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <stddef.h>
+#include <utility>
+
+namespace vespalib {
+
+//-----------------------------------------------------------------------------
+
+/**
+ * Typification result for values resolving into actual types.
+ **/
+template <typename T> struct TypifyResultType {
+ static constexpr bool is_type = true;
+ using type = T;
+};
+
+/**
+ * Typification result for values resolving into compile-time values
+ * which are also types as long as they are kept inside their result
+ * wrappers.
+ **/
+template <typename T, T VALUE> struct TypifyResultValue {
+ static constexpr bool is_type = false;
+ static constexpr T value = VALUE;
+};
+
+/**
+ * A Typifier is able to take a run-time value and resolve it into a
+ * type. The resolve result is passed to the specified function in the
+ * form of a thin result wrapper.
+ **/
+struct TypifyBool {
+ template <bool VALUE> using Result = TypifyResultValue<bool, VALUE>;
+ template <typename F> static decltype(auto) resolve(bool value, F &&f) {
+ if (value) {
+ return f(Result<true>());
+ } else {
+ return f(Result<false>());
+ }
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+/**
+ * Template used to combine individual typifiers into a typifier able
+ * to resolve multiple types.
+ **/
+template <typename ...Ts> struct TypifyValue : Ts... { using Ts::resolve...; };
+
+//-----------------------------------------------------------------------------
+
+template <size_t N, typename Typify, typename Target, typename ...Rs> struct TypifyInvoke {
+ static decltype(auto) select() {
+ static_assert(sizeof...(Rs) == N);
+ return Target::template invoke<Rs...>();
+ }
+ template <typename T, typename ...Args> static decltype(auto) select(T &&value, Args &&...args) {
+ if constexpr (N == sizeof...(Rs)) {
+ return Target::template invoke<Rs...>(std::forward<T>(value), std::forward<Args>(args)...);
+ } else {
+ return Typify::resolve(value, [&](auto t)->decltype(auto)
+ {
+ using X = decltype(t);
+ if constexpr (X::is_type) {
+ return TypifyInvoke<N, Typify, Target, Rs..., typename X::type>::select(std::forward<Args>(args)...);
+ } else {
+ return TypifyInvoke<N, Typify, Target, Rs..., X>::select(std::forward<Args>(args)...);
+ }
+ });
+ }
+ }
+};
+
+/**
+ * Typify the N first parameters using Typify (typically an
+ * instantiation of the TypifyValue template) and forward the
+ * remaining parameters to the Target::invoke template function with
+ * the typification results from the N first parameters as template
+ * parameters. Note that typification results that are types are
+ * unwrapped before being used as template parameters while
+ * typification results that are compile-time values are kept in their
+ * wrappers when passed as template parameters. Please refer to the
+ * unit test for examples.
+ **/
+template <size_t N, typename Typify, typename Target, typename ...Args> decltype(auto) typify_invoke(Args && ...args) {
+ static_assert(N > 0);
+ return TypifyInvoke<N,Typify,Target>::select(std::forward<Args>(args)...);
+}
+
+//-----------------------------------------------------------------------------
+
+}