summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-11-02 12:59:45 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-11-02 12:59:45 +0000
commiteab972dbf7ff39285d6a23da42dc3cafa0ca2721 (patch)
treeb947bd76e242825b07d7148464b4c6b9267c7f26 /eval
parentc70562eb766a5205fba797f62456652919e7cd3d (diff)
handle and test recursive issues with interpreted functions
also disallow map_subspaces in compiled functions
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/compiled_function/compiled_function_test.cpp1
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp30
-rw-r--r--eval/src/vespa/eval/eval/function.cpp8
-rw-r--r--eval/src/vespa/eval/eval/function.h6
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp35
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp3
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp2
7 files changed, 72 insertions, 13 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
index 071e4766629..7b1f9a84b6d 100644
--- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -50,6 +50,7 @@ TEST("require that lazy parameter passing works") {
std::vector<vespalib::string> unsupported = {
"map(",
+ "map_subspaces(",
"join(",
"merge(",
"reduce(",
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index ca35b8db66d..4ba715ea192 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -150,18 +150,20 @@ TEST("require that basic addition works") {
//-----------------------------------------------------------------------------
-TEST("require that functions with non-compilable lambdas cannot be interpreted") {
+TEST("require that functions with non-compilable simple lambdas cannot be interpreted") {
auto good_map = Function::parse("map(a,f(x)(x+1))");
auto good_join = Function::parse("join(a,b,f(x,y)(x+y))");
+ auto good_merge = Function::parse("merge(a,b,f(x,y)(x+y))");
auto bad_map = Function::parse("map(a,f(x)(map(x,f(i)(i+1))))");
auto bad_join = Function::parse("join(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))");
- for (const Function *good: {good_map.get(), good_join.get()}) {
+ auto bad_merge = Function::parse("merge(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))");
+ for (const Function *good: {good_map.get(), good_join.get(), good_merge.get()}) {
if (!EXPECT_TRUE(!good->has_error())) {
fprintf(stderr, "parse error: %s\n", good->get_error().c_str());
}
EXPECT_TRUE(!InterpretedFunction::detect_issues(*good));
}
- for (const Function *bad: {bad_map.get(), bad_join.get()}) {
+ for (const Function *bad: {bad_map.get(), bad_join.get(), bad_merge.get()}) {
if (!EXPECT_TRUE(!bad->has_error())) {
fprintf(stderr, "parse error: %s\n", bad->get_error().c_str());
}
@@ -172,6 +174,28 @@ TEST("require that functions with non-compilable lambdas cannot be interpreted")
<< std::endl;
}
+TEST("require that functions with non-interpretable complex lambdas cannot be interpreted") {
+ auto good_tensor_lambda = Function::parse("tensor(x[5])(map(x,f(y)(y)))");
+ auto good_map_subspaces = Function::parse("map_subspaces(a,f(x)(concat(x,x,y)))");
+ auto bad_tensor_lambda = Function::parse("tensor(x[5])(map(x,f(y)(map(y,f(i)(i+1)))))");
+ auto bad_map_subspaces = Function::parse("map_subspaces(a,f(x)(map(x,f(y)(map(y,f(i)(i+1))))))");
+ for (const Function *good: {good_tensor_lambda.get(), good_map_subspaces.get()}) {
+ if (!EXPECT_TRUE(!good->has_error())) {
+ fprintf(stderr, "parse error: %s\n", good->get_error().c_str());
+ }
+ EXPECT_TRUE(!InterpretedFunction::detect_issues(*good));
+ }
+ for (const Function *bad: {bad_tensor_lambda.get(), bad_map_subspaces.get()}) {
+ if (!EXPECT_TRUE(!bad->has_error())) {
+ fprintf(stderr, "parse error: %s\n", bad->get_error().c_str());
+ }
+ EXPECT_TRUE(InterpretedFunction::detect_issues(*bad));
+ }
+ std::cerr << "Example function issues:" << std::endl
+ << InterpretedFunction::detect_issues(*bad_map_subspaces).list
+ << std::endl;
+}
+
//-----------------------------------------------------------------------------
TEST("require that compilation meta-data can be collected") {
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index 1c4dcc3b5db..a39d8dda228 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -1127,4 +1127,12 @@ Function::unwrap(vespalib::stringref input,
//-----------------------------------------------------------------------------
+void
+Function::Issues::add_nested_issues(const vespalib::string &context, const Issues &issues)
+{
+ for (const auto &issue: issues.list) {
+ list.push_back(context + ": " + issue);
+ }
+}
+
}
diff --git a/eval/src/vespa/eval/eval/function.h b/eval/src/vespa/eval/eval/function.h
index 0f79d66ead6..ae23d0093fb 100644
--- a/eval/src/vespa/eval/eval/function.h
+++ b/eval/src/vespa/eval/eval/function.h
@@ -74,8 +74,10 @@ public:
**/
struct Issues {
std::vector<vespalib::string> list;
- operator bool() const { return !list.empty(); }
- Issues(std::vector<vespalib::string> &&list_in) : list(std::move(list_in)) {}
+ operator bool() const noexcept { return !list.empty(); }
+ Issues() noexcept : list() {}
+ Issues(std::vector<vespalib::string> &&list_in) noexcept : list(std::move(list_in)) {}
+ void add_nested_issues(const vespalib::string &context, const Issues &issues);
};
};
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index e4304049b8e..c0aa7d1703b 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -18,7 +18,7 @@ namespace vespalib::eval {
namespace {
-const Function *get_lambda(const nodes::Node &node) {
+const Function *get_simple_lambda(const nodes::Node &node) {
if (auto ptr = nodes::as<nodes::TensorMap>(node)) {
return &ptr->lambda();
}
@@ -31,6 +31,16 @@ const Function *get_lambda(const nodes::Node &node) {
return nullptr;
}
+const Function *get_complex_lambda(const nodes::Node &node) {
+ if (auto ptr = nodes::as<nodes::TensorLambda>(node)) {
+ return &ptr->lambda();
+ }
+ if (auto ptr = nodes::as<nodes::TensorMapSubspaces>(node)) {
+ return &ptr->lambda();
+ }
+ return nullptr;
+}
+
void my_nop(InterpretedFunction::State &, uint64_t) {}
} // namespace vespalib::<unnamed>
@@ -148,18 +158,29 @@ Function::Issues
InterpretedFunction::detect_issues(const Function &function)
{
struct NotSupported : NodeTraverser {
- std::vector<vespalib::string> issues;
+ Function::Issues issues;
bool open(const nodes::Node &) override { return true; }
void close(const nodes::Node &node) override {
- auto lambda = get_lambda(node);
- if (lambda && CompiledFunction::detect_issues(*lambda)) {
- issues.push_back(make_string("lambda function that cannot be compiled within %s",
- getClassName(node).c_str()));
+ // map/join/merge: simple scalar lambdas must be compilable with llvm
+ if (auto lambda = get_simple_lambda(node)) {
+ auto inner_issues = CompiledFunction::detect_issues(*lambda);
+ if (inner_issues) {
+ auto ctx = make_string("within %s simple lambda", getClassName(node).c_str());
+ issues.add_nested_issues(ctx, inner_issues);
+ }
+ }
+ // tensor lambda/map_subspaces: complex lambdas that may be interpreted and use tensor math
+ if (auto lambda = get_complex_lambda(node)) {
+ auto inner_issues = InterpretedFunction::detect_issues(*lambda);
+ if (inner_issues) {
+ auto ctx = make_string("within %s complex lambda", getClassName(node).c_str());
+ issues.add_nested_issues(ctx, inner_issues);
+ }
}
}
} checker;
function.root().traverse(checker);
- return Function::Issues(std::move(checker.issues));
+ return std::move(checker.issues);
}
InterpretedFunction::EvalSingle::EvalSingle(const ValueBuilderFactory &factory, Instruction op, const LazyParams &params)
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index 50a8f731942..bd52b30b708 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -128,6 +128,7 @@ CompiledFunction::detect_issues(const nodes::Node &node)
bool open(const nodes::Node &) override { return true; }
void close(const nodes::Node &node) override {
if (nodes::check_type<nodes::TensorMap,
+ nodes::TensorMapSubspaces,
nodes::TensorJoin,
nodes::TensorMerge,
nodes::TensorReduce,
@@ -139,7 +140,7 @@ CompiledFunction::detect_issues(const nodes::Node &node)
nodes::TensorPeek>(node))
{
issues.push_back(make_string("unsupported node type: %s",
- getClassName(node).c_str()));
+ getClassName(node).c_str()));
}
}
} checker;
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index af88b2a526a..72664168114 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -198,6 +198,8 @@ void
EvalSpec::add_tensor_operation_cases() {
add_rule({"a", -1.0, 1.0}, "map(a,f(x)(sin(x)))", [](double x){ return std::sin(x); });
add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x*x*3))", [](double x){ return ((x * x) * 3); });
+ add_rule({"a", -1.0, 1.0}, "map_subspaces(a,f(x)(sin(x)))", [](double x){ return std::sin(x); });
+ add_rule({"a", -1.0, 1.0}, "map_subspaces(a,f(x)(x*x*3))", [](double x){ return ((x * x) * 3); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x*y*3))", [](double x, double y){ return ((x * y) * 3); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });