aboutsummaryrefslogtreecommitdiffstats
path: root/vespalib
diff options
context:
space:
mode:
authorGeir Storli <geirstorli@yahoo.no>2016-11-21 11:43:05 +0100
committerGitHub <noreply@github.com>2016-11-21 11:43:05 +0100
commit37e782eed02a1b4f8aacade1cfeed651f3c0e870 (patch)
treee68a7447cb8e9a37fe065ab2f88baa6b045bcf46 /vespalib
parent81b959e4472035f9c44b3f6c8e54a6cc41ec4259 (diff)
parentd100f7cea6200f055e30738dc2d183a23f31fd9a (diff)
Merge pull request #1142 from yahoo/havardpe/parse-more-tensor-operations
Havardpe/parse more tensor operations
Diffstat (limited to 'vespalib')
-rw-r--r--vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp13
-rw-r--r--vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp31
-rw-r--r--vespalib/src/tests/eval/function/function_test.cpp103
-rw-r--r--vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp79
-rw-r--r--vespalib/src/tests/eval/node_types/node_types_test.cpp25
-rw-r--r--vespalib/src/vespa/vespalib/eval/function.cpp156
-rw-r--r--vespalib/src/vespa/vespalib/eval/interpreted_function.cpp17
-rw-r--r--vespalib/src/vespa/vespalib/eval/key_gen.cpp89
-rw-r--r--vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp5
-rw-r--r--vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp9
-rw-r--r--vespalib/src/vespa/vespalib/eval/node_types.cpp9
-rw-r--r--vespalib/src/vespa/vespalib/eval/node_visitor.h6
-rw-r--r--vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp52
-rw-r--r--vespalib/src/vespa/vespalib/eval/tensor_nodes.h147
-rw-r--r--vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp24
-rw-r--r--vespalib/src/vespa/vespalib/eval/test/eval_spec.h20
16 files changed, 645 insertions, 140 deletions
diff --git a/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp b/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp
index 1bbe9d2fa0b..f80df8090d9 100644
--- a/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp
+++ b/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp
@@ -69,11 +69,14 @@ struct CheckKeys : test::EvalSpec::EvalTest {
virtual void next_expression(const std::vector<vespalib::string> &param_names,
const vespalib::string &expression) override
{
- if (check_key(gen_key(Function::parse(param_names, expression), PassParams::ARRAY)) ||
- check_key(gen_key(Function::parse(param_names, expression), PassParams::SEPARATE)))
- {
- failed = true;
- fprintf(stderr, "key collision for: %s\n", expression.c_str());
+ Function function = Function::parse(param_names, expression);
+ if (!CompiledFunction::detect_issues(function)) {
+ if (check_key(gen_key(function, PassParams::ARRAY)) ||
+ check_key(gen_key(function, PassParams::SEPARATE)))
+ {
+ failed = true;
+ fprintf(stderr, "key collision for: %s\n", expression.c_str());
+ }
}
}
virtual void handle_case(const std::vector<vespalib::string> &,
diff --git a/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp b/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp
index 2dc065fe08e..121bfe7b1ee 100644
--- a/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -40,7 +40,10 @@ TEST("require that array parameter passing works") {
std::vector<vespalib::string> unsupported = {
"sum(",
"map(",
- "join("
+ "join(",
+ "reduce(",
+ "rename(",
+ "tensor("
};
bool is_unsupported(const vespalib::string &expression) {
@@ -59,28 +62,34 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
size_t fail_cnt = 0;
bool print_pass = false;
bool print_fail = false;
- virtual void next_expression(const std::vector<vespalib::string> &,
- const vespalib::string &) override {}
- virtual void handle_case(const std::vector<vespalib::string> &param_names,
- const std::vector<double> &param_values,
- const vespalib::string &expression,
- double expected_result) override
+ virtual void next_expression(const std::vector<vespalib::string> &param_names,
+ const vespalib::string &expression) override
{
Function function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function.has_error());
bool is_supported = !is_unsupported(expression);
bool has_issues = CompiledFunction::detect_issues(function);
if (is_supported == has_issues) {
const char *supported_str = is_supported ? "supported" : "not supported";
const char *issues_str = has_issues ? "has issues" : "does not have issues";
print_fail && fprintf(stderr, "expression %s is %s, but %s\n",
- as_string(param_names, param_values, expression).c_str(),
- supported_str, issues_str);
+ expression.c_str(), supported_str, issues_str);
++fail_cnt;
}
+ }
+ virtual void handle_case(const std::vector<vespalib::string> &param_names,
+ const std::vector<double> &param_values,
+ const vespalib::string &expression,
+ double expected_result) override
+ {
+ Function function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function.has_error());
+ bool is_supported = !is_unsupported(expression);
+ bool has_issues = CompiledFunction::detect_issues(function);
if (is_supported && !has_issues) {
- CompiledFunction cfun(Function::parse(param_names, expression), PassParams::ARRAY);
+ CompiledFunction cfun(function, PassParams::ARRAY);
auto fun = cfun.get_function();
- EXPECT_EQUAL(cfun.num_params(), param_values.size());
+ ASSERT_EQUAL(cfun.num_params(), param_values.size());
double result = fun(&param_values[0]);
if (is_same(expected_result, result)) {
print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
diff --git a/vespalib/src/tests/eval/function/function_test.cpp b/vespalib/src/tests/eval/function/function_test.cpp
index 6e4963050d2..afb0defdd91 100644
--- a/vespalib/src/tests/eval/function/function_test.cpp
+++ b/vespalib/src/tests/eval/function/function_test.cpp
@@ -756,6 +756,8 @@ TEST("require that tensor operations can be nested") {
EXPECT_EQUAL("sum(sum(sum(a)),dim)", Function::parse("sum(sum(sum(a)),dim)").dump());
}
+//-----------------------------------------------------------------------------
+
TEST("require that tensor map can be parsed") {
EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse("map(a,f(x)(x+1))").dump());
EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse(" map ( a , f ( x ) ( x + 1 ) ) ").dump());
@@ -784,4 +786,105 @@ TEST("require that outer let bindings are hidden within a lambda") {
//-----------------------------------------------------------------------------
+TEST("require that tensor reduce can be parsed") {
+ EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, "reduce(x,sum,a,b,c)").dump());
+ EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, " reduce ( x , sum , a , b , c ) ").dump());
+ EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)").dump());
+ EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce( x , sum )").dump());
+ EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce(x,avg)").dump());
+ EXPECT_EQUAL("reduce(x,count)", Function::parse({"x"}, "reduce(x,count)").dump());
+ EXPECT_EQUAL("reduce(x,prod)", Function::parse({"x"}, "reduce(x,prod)").dump());
+ EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)").dump());
+ EXPECT_EQUAL("reduce(x,min)", Function::parse({"x"}, "reduce(x,min)").dump());
+ EXPECT_EQUAL("reduce(x,max)", Function::parse({"x"}, "reduce(x,max)").dump());
+}
+
+TEST("require that tensor reduce with unknown aggregator fails") {
+ verify_error("reduce(x,bogus)", "[reduce(x,bogus]...[unknown aggregator: 'bogus']...[)]");
+}
+
+TEST("require that tensor reduce with duplicate dimensions fails") {
+ verify_error("reduce(x,sum,a,a)", "[reduce(x,sum,a,a]...[duplicate identifiers]...[)]");
+}
+
+//-----------------------------------------------------------------------------
+
+TEST("require that tensor rename can be parsed") {
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,b)").dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),(b))").dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,(b))").dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),b)").dump());
+ EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename(x,(a,b),(b,a))").dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , a , b )").dump());
+ EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , ( a ) , ( b ) )").dump());
+ EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename( x , ( a , b ) , ( b , a ) )").dump());
+}
+
+TEST("require that tensor rename dimension lists cannot be empty") {
+ verify_error("rename(x,,b)", "[rename(x,]...[missing identifier]...[,b)]");
+ verify_error("rename(x,a,)", "[rename(x,a,]...[missing identifier]...[)]");
+ verify_error("rename(x,(),b)", "[rename(x,()]...[missing identifiers]...[,b)]");
+ verify_error("rename(x,a,())", "[rename(x,a,()]...[missing identifiers]...[)]");
+}
+
+TEST("require that tensor rename dimension lists cannot contain duplicates") {
+ verify_error("rename(x,(a,a),(b,a))", "[rename(x,(a,a)]...[duplicate identifiers]...[,(b,a))]");
+ verify_error("rename(x,(a,b),(b,b))", "[rename(x,(a,b),(b,b)]...[duplicate identifiers]...[)]");
+}
+
+TEST("require that tensor rename dimension lists must have equal size") {
+ verify_error("rename(x,(a,b),(b))", "[rename(x,(a,b),(b)]...[dimension list size mismatch]...[)]");
+ verify_error("rename(x,(a),(b,a))", "[rename(x,(a),(b,a)]...[dimension list size mismatch]...[)]");
+}
+
+//-----------------------------------------------------------------------------
+
+TEST("require that tensor lambda can be parsed") {
+ EXPECT_EQUAL("tensor(x[10])(x)", Function::parse({""}, "tensor(x[10])(x)").dump());
+ EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({""}, "tensor(x[10],y[10])(x==y)").dump());
+ EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({""}, " tensor ( x [ 10 ] , y [ 10 ] ) ( x == y ) ").dump());
+}
+
+TEST("require that tensor lambda requires appropriate tensor type") {
+ verify_error("tensor(x[10],y[])(x==y)", "[tensor(x[10],y[])]...[invalid tensor type]...[(x==y)]");
+ verify_error("tensor(x[10],y{})(x==y)", "[tensor(x[10],y{})]...[invalid tensor type]...[(x==y)]");
+ verify_error("tensor()(x==y)", "[tensor()]...[invalid tensor type]...[(x==y)]");
+}
+
+TEST("require that tensor lambda can only use dimension names") {
+ verify_error("tensor(x[10],y[10])(x==z)", "[tensor(x[10],y[10])(x==z]...[unknown symbol: 'z']...[)]");
+}
+
+//-----------------------------------------------------------------------------
+
+struct CheckExpressions : test::EvalSpec::EvalTest {
+ bool failed = false;
+ size_t seen_cnt = 0;
+ virtual void next_expression(const std::vector<vespalib::string> &param_names,
+ const vespalib::string &expression) override
+ {
+ Function function = Function::parse(param_names, expression);
+ if (function.has_error()) {
+ failed = true;
+ fprintf(stderr, "parse error: %s\n", function.get_error().c_str());
+ }
+ ++seen_cnt;
+ }
+ virtual void handle_case(const std::vector<vespalib::string> &,
+ const std::vector<double> &,
+ const vespalib::string &,
+ double) override {}
+};
+
+TEST_FF("require that all conformance test expressions can be parsed",
+ CheckExpressions(), test::EvalSpec())
+{
+ f2.add_all_cases();
+ f2.each_case(f1);
+ EXPECT_TRUE(!f1.failed);
+ EXPECT_GREATER(f1.seen_cnt, 42u);
+}
+
+//-----------------------------------------------------------------------------
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index 51aacfd2272..b12dfdd2df7 100644
--- a/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -15,37 +15,74 @@ using vespalib::Stash;
//-----------------------------------------------------------------------------
+std::vector<vespalib::string> unsupported = {
+ "map(",
+ "join(",
+ "reduce(",
+ "rename(",
+ "tensor("
+};
+
+bool is_unsupported(const vespalib::string &expression) {
+ for (const auto &prefix: unsupported) {
+ if (starts_with(expression, prefix)) {
+ return true;
+ }
+ }
+ return false;
+}
+
+//-----------------------------------------------------------------------------
+
struct MyEvalTest : test::EvalSpec::EvalTest {
size_t pass_cnt = 0;
size_t fail_cnt = 0;
bool print_pass = false;
bool print_fail = false;
- virtual void next_expression(const std::vector<vespalib::string> &,
- const vespalib::string &) override {}
+ virtual void next_expression(const std::vector<vespalib::string> &param_names,
+ const vespalib::string &expression) override
+ {
+ Function function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function.has_error());
+ bool is_supported = !is_unsupported(expression);
+ bool has_issues = InterpretedFunction::detect_issues(function);
+ if (is_supported == has_issues) {
+ const char *supported_str = is_supported ? "supported" : "not supported";
+ const char *issues_str = has_issues ? "has issues" : "does not have issues";
+ print_fail && fprintf(stderr, "expression %s is %s, but %s\n",
+ expression.c_str(), supported_str, issues_str);
+ ++fail_cnt;
+ }
+ }
virtual void handle_case(const std::vector<vespalib::string> &param_names,
const std::vector<double> &param_values,
const vespalib::string &expression,
double expected_result) override
{
- Function fun = Function::parse(param_names, expression);
- EXPECT_EQUAL(fun.num_params(), param_values.size());
- InterpretedFunction ifun(SimpleTensorEngine::ref(), fun, NodeTypes());
- InterpretedFunction::Context ictx;
- for (double param: param_values) {
- ictx.add_param(param);
- }
- const Value &result_value = ifun.eval(ictx);
- double result = result_value.as_double();
- if (result_value.is_double() && is_same(expected_result, result)) {
- print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
- as_string(param_names, param_values, expression).c_str(),
- expected_result);
- ++pass_cnt;
- } else {
- print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n",
- as_string(param_names, param_values, expression).c_str(),
- expected_result, result);
- ++fail_cnt;
+ Function function = Function::parse(param_names, expression);
+ ASSERT_TRUE(!function.has_error());
+ bool is_supported = !is_unsupported(expression);
+ bool has_issues = InterpretedFunction::detect_issues(function);
+ if (is_supported && !has_issues) {
+ InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes());
+ ASSERT_EQUAL(ifun.num_params(), param_values.size());
+ InterpretedFunction::Context ictx;
+ for (double param: param_values) {
+ ictx.add_param(param);
+ }
+ const Value &result_value = ifun.eval(ictx);
+ double result = result_value.as_double();
+ if (result_value.is_double() && is_same(expected_result, result)) {
+ print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n",
+ as_string(param_names, param_values, expression).c_str(),
+ expected_result);
+ ++pass_cnt;
+ } else {
+ print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n",
+ as_string(param_names, param_values, expression).c_str(),
+ expected_result, result);
+ ++fail_cnt;
+ }
}
}
};
diff --git a/vespalib/src/tests/eval/node_types/node_types_test.cpp b/vespalib/src/tests/eval/node_types/node_types_test.cpp
index 1dbb5d765ae..33799f0f989 100644
--- a/vespalib/src/tests/eval/node_types/node_types_test.cpp
+++ b/vespalib/src/tests/eval/node_types/node_types_test.cpp
@@ -7,10 +7,25 @@
using namespace vespalib::eval;
+/**
+ * Hack to avoid parse-conflict between tensor type expressions and
+ * lambda-generated tensors. This will patch leading identifier 'T' to
+ * 't' directly in the input stream after we have concluded that this
+ * is not a lambda-generated tensor in order to parse it out as a
+ * valid tensor type. This may be reverted later if we add support for
+ * parser rollback when we fail to parse a lambda-generated tensor.
+ **/
+void tensor_type_hack(const char *pos_in, const char *end_in) {
+ if ((pos_in < end_in) && (*pos_in == 'T')) {
+ const_cast<char *>(pos_in)[0] = 't';
+ }
+}
+
struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor {
void extract_symbol(const char *pos_in, const char *end_in,
const char *&pos_out, vespalib::string &symbol_out) const override
{
+ tensor_type_hack(pos_in, end_in);
ValueType type = value_type::parse_spec(pos_in, end_in, pos_out);
if (pos_out != nullptr) {
symbol_out = type.to_spec();
@@ -18,7 +33,15 @@ struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor {
}
};
-void verify(const vespalib::string &type_expr, const vespalib::string &type_spec) {
+void verify(const vespalib::string &type_expr_in, const vespalib::string &type_spec) {
+ // replace 'tensor' with 'Tensor' in type expression, see hack above
+ vespalib::string type_expr = type_expr_in;
+ for (size_t idx = type_expr.find("tensor");
+ idx != type_expr.npos;
+ idx = type_expr.find("tensor"))
+ {
+ type_expr[idx] = 'T';
+ }
Function function = Function::parse(type_expr, TypeSpecExtractor());
if (!EXPECT_TRUE(!function.has_error())) {
fprintf(stderr, "parse error: %s\n", function.get_error().c_str());
diff --git a/vespalib/src/vespa/vespalib/eval/function.cpp b/vespalib/src/vespa/vespalib/eval/function.cpp
index bd2e4a84c5d..c26256e69d6 100644
--- a/vespalib/src/vespa/vespalib/eval/function.cpp
+++ b/vespalib/src/vespa/vespalib/eval/function.cpp
@@ -19,6 +19,29 @@ using nodes::Call_UP;
namespace {
+bool has_duplicates(const std::vector<vespalib::string> &list) {
+ for (size_t i = 0; i < list.size(); ++i) {
+ for (size_t j = (i + 1); j < list.size(); ++j) {
+ if (list[i] == list[j]) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+bool check_tensor_lambda_type(const ValueType &type) {
+ if (!type.is_tensor() || type.dimensions().empty()) {
+ return false;
+ }
+ for (const auto &dim: type.dimensions()) {
+ if (!dim.is_indexed() || !dim.is_bound()) {
+ return false;
+ }
+ }
+ return true;
+}
+
//-----------------------------------------------------------------------------
class Params {
@@ -400,7 +423,6 @@ void parse_number(ParseContext &ctx) {
} else {
ctx.fail(make_string("invalid number: '%s'", str.c_str()));
}
- return;
}
// NOTE: using non-standard definition of identifiers
@@ -413,7 +435,7 @@ bool is_ident(char c, bool first) {
(c == '$' && !first));
}
-vespalib::string get_ident(ParseContext &ctx) {
+vespalib::string get_ident(ParseContext &ctx, bool allow_empty) {
ctx.skip_spaces();
vespalib::string ident;
if (is_ident(ctx.get(), true)) {
@@ -422,6 +444,9 @@ vespalib::string get_ident(ParseContext &ctx) {
ident.push_back(ctx.get());
}
}
+ if (!allow_empty && ident.empty()) {
+ ctx.fail("missing identifier");
+ }
return ident;
}
@@ -448,7 +473,7 @@ void parse_if(ParseContext &ctx) {
}
void parse_let(ParseContext &ctx) {
- vespalib::string name = get_ident(ctx);
+ vespalib::string name = get_ident(ctx, false);
ctx.skip_spaces();
ctx.eat(',');
parse_expression(ctx);
@@ -472,25 +497,50 @@ void parse_call(ParseContext &ctx, Call_UP call) {
ctx.push_expression(std::move(call));
}
-// (a,b,c)
-std::vector<vespalib::string> get_ident_list(ParseContext &ctx) {
+// (a,b,c) wrapped
+// ,a,b,c -> ) not wrapped
+std::vector<vespalib::string> get_ident_list(ParseContext &ctx, bool wrapped) {
std::vector<vespalib::string> list;
- ctx.skip_spaces();
- ctx.eat('(');
+ if (wrapped) {
+ ctx.skip_spaces();
+ ctx.eat('(');
+ }
for (ctx.skip_spaces(); !ctx.eos() && (ctx.get() != ')'); ctx.skip_spaces()) {
- if (!list.empty()) {
+ if (!list.empty() || !wrapped) {
ctx.eat(',');
}
- list.push_back(get_ident(ctx));
+ list.push_back(get_ident(ctx, false));
+ }
+ if (wrapped) {
+ ctx.eat(')');
+ }
+ if (has_duplicates(list)) {
+ ctx.fail("duplicate identifiers");
+ }
+ return list;
+}
+
+// a
+// (a,b,c)
+// cannot be empty
+std::vector<vespalib::string> get_idents(ParseContext &ctx) {
+ std::vector<vespalib::string> list;
+ ctx.skip_spaces();
+ if (ctx.get() == '(') {
+ list = get_ident_list(ctx, true);
+ } else {
+ list.push_back(get_ident(ctx, false));
+ }
+ if (list.empty()) {
+ ctx.fail("missing identifiers");
}
- ctx.eat(')');
return list;
}
-Function parse_lambda(ParseContext &ctx) {
+Function parse_lambda(ParseContext &ctx, size_t num_params) {
ctx.skip_spaces();
ctx.eat('f');
- auto param_names = get_ident_list(ctx);
+ auto param_names = get_ident_list(ctx, true);
ExplicitParams params(param_names);
ctx.push_resolve_context(params, nullptr);
ctx.skip_spaces();
@@ -500,6 +550,10 @@ Function parse_lambda(ParseContext &ctx) {
ctx.skip_spaces();
ctx.pop_resolve_context();
Node_UP lambda_root = ctx.pop_expression();
+ if (param_names.size() != num_params) {
+ ctx.fail(make_string("expected lambda with %zu parameter(s), was %zu",
+ num_params, param_names.size()));
+ }
return Function(std::move(lambda_root), std::move(param_names));
}
@@ -507,13 +561,8 @@ void parse_tensor_map(ParseContext &ctx) {
parse_expression(ctx);
Node_UP child = ctx.pop_expression();
ctx.eat(',');
- Function lambda = parse_lambda(ctx);
- if (lambda.num_params() == 1) {
- ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda)));
- } else {
- ctx.fail(make_string("map requires a lambda with 1 parameter, was %zu",
- lambda.num_params()));
- }
+ Function lambda = parse_lambda(ctx, 1);
+ ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda)));
}
void parse_tensor_join(ParseContext &ctx) {
@@ -523,13 +572,62 @@ void parse_tensor_join(ParseContext &ctx) {
parse_expression(ctx);
Node_UP rhs = ctx.pop_expression();
ctx.eat(',');
- Function lambda = parse_lambda(ctx);
- if (lambda.num_params() == 2) {
- ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda)));
+ Function lambda = parse_lambda(ctx, 2);
+ ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda)));
+}
+
+void parse_tensor_reduce(ParseContext &ctx) {
+ parse_expression(ctx);
+ Node_UP child = ctx.pop_expression();
+ ctx.eat(',');
+ auto aggr_name = get_ident(ctx, false);
+ auto maybe_aggr = nodes::AggrNames::from_name(aggr_name);
+ if (!maybe_aggr) {
+ ctx.fail(make_string("unknown aggregator: '%s'", aggr_name.c_str()));
+ return;
+ }
+ auto dimensions = get_ident_list(ctx, false);
+ ctx.push_expression(std::make_unique<nodes::TensorReduce>(std::move(child), *maybe_aggr, std::move(dimensions)));
+}
+
+void parse_tensor_rename(ParseContext &ctx) {
+ parse_expression(ctx);
+ Node_UP child = ctx.pop_expression();
+ ctx.eat(',');
+ auto from = get_idents(ctx);
+ ctx.skip_spaces();
+ ctx.eat(',');
+ auto to = get_idents(ctx);
+ if (from.size() != to.size()) {
+ ctx.fail("dimension list size mismatch");
} else {
- ctx.fail(make_string("join requires a lambda with 2 parameters, was %zu",
- lambda.num_params()));
+ ctx.push_expression(std::make_unique<nodes::TensorRename>(std::move(child), std::move(from), std::move(to)));
+ }
+ ctx.skip_spaces();
+}
+
+void parse_tensor_lambda(ParseContext &ctx) {
+ vespalib::string type_spec("tensor(");
+ while(!ctx.eos() && (ctx.get() != ')')) {
+ type_spec.push_back(ctx.get());
+ ctx.next();
}
+ ctx.eat(')');
+ type_spec.push_back(')');
+ ValueType type = ValueType::from_spec(type_spec);
+ if (!check_tensor_lambda_type(type)) {
+ ctx.fail("invalid tensor type");
+ return;
+ }
+ auto param_names = type.dimension_names();
+ ExplicitParams params(param_names);
+ ctx.push_resolve_context(params, nullptr);
+ ctx.skip_spaces();
+ ctx.eat('(');
+ parse_expression(ctx);
+ ctx.pop_resolve_context();
+ Function lambda(ctx.pop_expression(), std::move(param_names));
+ ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda)));
}
// to be replaced with more generic 'reduce'
@@ -538,7 +636,7 @@ void parse_tensor_sum(ParseContext &ctx) {
Node_UP child = ctx.pop_expression();
if (ctx.get() == ',') {
ctx.next();
- vespalib::string dimension = get_ident(ctx);
+ vespalib::string dimension = get_ident(ctx, false);
ctx.skip_spaces();
ctx.push_expression(Node_UP(new nodes::TensorSum(std::move(child), dimension)));
} else {
@@ -562,6 +660,12 @@ bool try_parse_call(ParseContext &ctx, const vespalib::string &name) {
parse_tensor_map(ctx);
} else if (name == "join") {
parse_tensor_join(ctx);
+ } else if (name == "reduce") {
+ parse_tensor_reduce(ctx);
+ } else if (name == "rename") {
+ parse_tensor_rename(ctx);
+ } else if (name == "tensor") {
+ parse_tensor_lambda(ctx);
} else if (name == "sum") {
parse_tensor_sum(ctx);
} else {
@@ -586,7 +690,7 @@ int parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::InputM
void parse_symbol_or_call(ParseContext &ctx) {
ParseContext::InputMark before_name = ctx.get_input_mark();
- vespalib::string name = get_ident(ctx);
+ vespalib::string name = get_ident(ctx, true);
if (!try_parse_call(ctx, name)) {
int id = parse_symbol(ctx, name, before_name);
if (name.empty()) {
diff --git a/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp b/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp
index a4329f97de2..a726e550382 100644
--- a/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp
+++ b/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp
@@ -274,6 +274,18 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
// TODO(havardpe): add actual evaluation
program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
}
+ virtual void visit(const TensorReduce &) {
+ // TODO(havardpe): add actual evaluation
+ program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
+ }
+ virtual void visit(const TensorRename &) {
+ // TODO(havardpe): add actual evaluation
+ program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
+ }
+ virtual void visit(const TensorLambda &) {
+ // TODO(havardpe): add actual evaluation
+ program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
+ }
virtual void visit(const Add &) {
program.emplace_back(op_binary<operation::Add>);
}
@@ -462,7 +474,10 @@ InterpretedFunction::detect_issues(const Function &function)
bool open(const nodes::Node &) override { return true; }
void close(const nodes::Node &node) override {
if (nodes::check_type<nodes::TensorMap,
- nodes::TensorJoin>(node)) {
+ nodes::TensorJoin,
+ nodes::TensorReduce,
+ nodes::TensorRename,
+ nodes::TensorLambda>(node)) {
issues.push_back(make_string("unsupported node type: %s",
getClassName(node).c_str()));
}
diff --git a/vespalib/src/vespa/vespalib/eval/key_gen.cpp b/vespalib/src/vespa/vespalib/eval/key_gen.cpp
index 96774c62f78..dc137bc2060 100644
--- a/vespalib/src/vespa/vespalib/eval/key_gen.cpp
+++ b/vespalib/src/vespa/vespalib/eval/key_gen.cpp
@@ -34,49 +34,52 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
virtual void visit(const If &node) { add_byte( 7); add_double(node.p_true()); }
virtual void visit(const Let &) { add_byte( 8); }
virtual void visit(const Error &) { add_byte( 9); }
- virtual void visit(const TensorSum &) { add_byte(10); }
- virtual void visit(const TensorMap &) { add_byte(11); }
- virtual void visit(const TensorJoin &) { add_byte(12); }
- virtual void visit(const Add &) { add_byte(13); }
- virtual void visit(const Sub &) { add_byte(14); }
- virtual void visit(const Mul &) { add_byte(15); }
- virtual void visit(const Div &) { add_byte(16); }
- virtual void visit(const Pow &) { add_byte(17); }
- virtual void visit(const Equal &) { add_byte(18); }
- virtual void visit(const NotEqual &) { add_byte(19); }
- virtual void visit(const Approx &) { add_byte(20); }
- virtual void visit(const Less &) { add_byte(21); }
- virtual void visit(const LessEqual &) { add_byte(22); }
- virtual void visit(const Greater &) { add_byte(23); }
- virtual void visit(const GreaterEqual &) { add_byte(24); }
- virtual void visit(const In &) { add_byte(25); }
- virtual void visit(const And &) { add_byte(26); }
- virtual void visit(const Or &) { add_byte(27); }
- virtual void visit(const Cos &) { add_byte(28); }
- virtual void visit(const Sin &) { add_byte(29); }
- virtual void visit(const Tan &) { add_byte(30); }
- virtual void visit(const Cosh &) { add_byte(31); }
- virtual void visit(const Sinh &) { add_byte(32); }
- virtual void visit(const Tanh &) { add_byte(33); }
- virtual void visit(const Acos &) { add_byte(34); }
- virtual void visit(const Asin &) { add_byte(35); }
- virtual void visit(const Atan &) { add_byte(36); }
- virtual void visit(const Exp &) { add_byte(37); }
- virtual void visit(const Log10 &) { add_byte(38); }
- virtual void visit(const Log &) { add_byte(39); }
- virtual void visit(const Sqrt &) { add_byte(40); }
- virtual void visit(const Ceil &) { add_byte(41); }
- virtual void visit(const Fabs &) { add_byte(42); }
- virtual void visit(const Floor &) { add_byte(43); }
- virtual void visit(const Atan2 &) { add_byte(44); }
- virtual void visit(const Ldexp &) { add_byte(45); }
- virtual void visit(const Pow2 &) { add_byte(46); }
- virtual void visit(const Fmod &) { add_byte(47); }
- virtual void visit(const Min &) { add_byte(48); }
- virtual void visit(const Max &) { add_byte(49); }
- virtual void visit(const IsNan &) { add_byte(50); }
- virtual void visit(const Relu &) { add_byte(51); }
- virtual void visit(const Sigmoid &) { add_byte(52); }
+ virtual void visit(const TensorSum &) { add_byte(10); } // dimensions should be part of key
+ virtual void visit(const TensorMap &) { add_byte(11); } // lambda should be part of key
+ virtual void visit(const TensorJoin &) { add_byte(12); } // lambda should be part of key
+ virtual void visit(const TensorReduce &) { add_byte(13); } // aggr/dimensions should be part of key
+ virtual void visit(const TensorRename &) { add_byte(14); } // dimensions should be part of key
+ virtual void visit(const TensorLambda &) { add_byte(15); } // type/lambda should be part of key
+ virtual void visit(const Add &) { add_byte(20); }
+ virtual void visit(const Sub &) { add_byte(21); }
+ virtual void visit(const Mul &) { add_byte(22); }
+ virtual void visit(const Div &) { add_byte(23); }
+ virtual void visit(const Pow &) { add_byte(24); }
+ virtual void visit(const Equal &) { add_byte(25); }
+ virtual void visit(const NotEqual &) { add_byte(26); }
+ virtual void visit(const Approx &) { add_byte(27); }
+ virtual void visit(const Less &) { add_byte(28); }
+ virtual void visit(const LessEqual &) { add_byte(29); }
+ virtual void visit(const Greater &) { add_byte(30); }
+ virtual void visit(const GreaterEqual &) { add_byte(31); }
+ virtual void visit(const In &) { add_byte(32); }
+ virtual void visit(const And &) { add_byte(33); }
+ virtual void visit(const Or &) { add_byte(34); }
+ virtual void visit(const Cos &) { add_byte(35); }
+ virtual void visit(const Sin &) { add_byte(36); }
+ virtual void visit(const Tan &) { add_byte(37); }
+ virtual void visit(const Cosh &) { add_byte(38); }
+ virtual void visit(const Sinh &) { add_byte(39); }
+ virtual void visit(const Tanh &) { add_byte(40); }
+ virtual void visit(const Acos &) { add_byte(41); }
+ virtual void visit(const Asin &) { add_byte(42); }
+ virtual void visit(const Atan &) { add_byte(43); }
+ virtual void visit(const Exp &) { add_byte(44); }
+ virtual void visit(const Log10 &) { add_byte(45); }
+ virtual void visit(const Log &) { add_byte(46); }
+ virtual void visit(const Sqrt &) { add_byte(47); }
+ virtual void visit(const Ceil &) { add_byte(48); }
+ virtual void visit(const Fabs &) { add_byte(49); }
+ virtual void visit(const Floor &) { add_byte(50); }
+ virtual void visit(const Atan2 &) { add_byte(51); }
+ virtual void visit(const Ldexp &) { add_byte(52); }
+ virtual void visit(const Pow2 &) { add_byte(53); }
+ virtual void visit(const Fmod &) { add_byte(54); }
+ virtual void visit(const Min &) { add_byte(55); }
+ virtual void visit(const Max &) { add_byte(56); }
+ virtual void visit(const IsNan &) { add_byte(57); }
+ virtual void visit(const Relu &) { add_byte(58); }
+ virtual void visit(const Sigmoid &) { add_byte(59); }
// traverse
virtual bool open(const Node &node) { node.accept(*this); return true; }
diff --git a/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp b/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp
index 9b120b677e2..e8243c27193 100644
--- a/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp
+++ b/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp
@@ -59,7 +59,10 @@ CompiledFunction::detect_issues(const Function &function)
void close(const nodes::Node &node) override {
if (nodes::check_type<nodes::TensorSum,
nodes::TensorMap,
- nodes::TensorJoin>(node)) {
+ nodes::TensorJoin,
+ nodes::TensorReduce,
+ nodes::TensorRename,
+ nodes::TensorLambda>(node)) {
issues.push_back(make_string("unsupported node type: %s",
getClassName(node).c_str()));
}
diff --git a/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp b/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp
index a03a5a9248a..7dfe44644ef 100644
--- a/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp
+++ b/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp
@@ -356,6 +356,15 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
virtual void visit(const TensorJoin &node) {
make_error(node.num_children());
}
+ virtual void visit(const TensorReduce &node) {
+ make_error(node.num_children());
+ }
+ virtual void visit(const TensorRename &node) {
+ make_error(node.num_children());
+ }
+ virtual void visit(const TensorLambda &node) {
+ make_error(node.num_children());
+ }
// operator nodes
diff --git a/vespalib/src/vespa/vespalib/eval/node_types.cpp b/vespalib/src/vespa/vespalib/eval/node_types.cpp
index faf1e085b1b..194cbeac1d5 100644
--- a/vespalib/src/vespa/vespalib/eval/node_types.cpp
+++ b/vespalib/src/vespa/vespalib/eval/node_types.cpp
@@ -169,6 +169,15 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
}
virtual void visit(const TensorMap &node) { resolve_op1(node); }
virtual void visit(const TensorJoin &node) { resolve_op2(node); }
+ virtual void visit(const TensorReduce &node) {
+ bind_type(ValueType::error_type(), node);
+ }
+ virtual void visit(const TensorRename &node) {
+ bind_type(ValueType::error_type(), node);
+ }
+ virtual void visit(const TensorLambda &node) {
+ bind_type(node.type(), node);
+ }
virtual void visit(const Add &node) { resolve_op2(node); }
virtual void visit(const Sub &node) { resolve_op2(node); }
diff --git a/vespalib/src/vespa/vespalib/eval/node_visitor.h b/vespalib/src/vespa/vespalib/eval/node_visitor.h
index d7fa5feff77..c890ffcd034 100644
--- a/vespalib/src/vespa/vespalib/eval/node_visitor.h
+++ b/vespalib/src/vespa/vespalib/eval/node_visitor.h
@@ -33,6 +33,9 @@ struct NodeVisitor {
virtual void visit(const nodes::TensorSum &) = 0;
virtual void visit(const nodes::TensorMap &) = 0;
virtual void visit(const nodes::TensorJoin &) = 0;
+ virtual void visit(const nodes::TensorReduce &) = 0;
+ virtual void visit(const nodes::TensorRename &) = 0;
+ virtual void visit(const nodes::TensorLambda &) = 0;
// operator nodes
virtual void visit(const nodes::Add &) = 0;
@@ -98,6 +101,9 @@ struct EmptyNodeVisitor : NodeVisitor {
virtual void visit(const nodes::TensorSum &) {}
virtual void visit(const nodes::TensorMap &) {}
virtual void visit(const nodes::TensorJoin &) {}
+ virtual void visit(const nodes::TensorReduce &) {}
+ virtual void visit(const nodes::TensorRename &) {}
+ virtual void visit(const nodes::TensorLambda &) {}
virtual void visit(const nodes::Add &) {}
virtual void visit(const nodes::Sub &) {}
virtual void visit(const nodes::Mul &) {}
diff --git a/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp b/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp
index a63e3e37347..beeb7583e0b 100644
--- a/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp
+++ b/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp
@@ -8,9 +8,55 @@ namespace vespalib {
namespace eval {
namespace nodes {
-void TensorSum ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorJoin::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorSum ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorLambda::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+
+const AggrNames AggrNames::_instance;
+
+void
+AggrNames::add(Aggr aggr, const vespalib::string &name)
+{
+ _name_aggr_map[name] = aggr;
+ _aggr_name_map[aggr] = name;
+}
+
+AggrNames::AggrNames()
+ : _name_aggr_map(),
+ _aggr_name_map()
+{
+ add(Aggr::AVG, "avg");
+ add(Aggr::COUNT, "count");
+ add(Aggr::PROD, "prod");
+ add(Aggr::SUM, "sum");
+ add(Aggr::MAX, "max");
+ add(Aggr::MIN, "min");
+}
+
+const vespalib::string *
+AggrNames::name_of(Aggr aggr)
+{
+ const auto &map = _instance._aggr_name_map;
+ auto result = map.find(aggr);
+ if (result == map.end()) {
+ return nullptr;
+ }
+ return &(result->second);
+}
+
+const Aggr *
+AggrNames::from_name(const vespalib::string &name)
+{
+ const auto &map = _instance._name_aggr_map;
+ auto result = map.find(name);
+ if (result == map.end()) {
+ return nullptr;
+ }
+ return &(result->second);
+}
} // namespace vespalib::eval::nodes
} // namespace vespalib::eval
diff --git a/vespalib/src/vespa/vespalib/eval/tensor_nodes.h b/vespalib/src/vespa/vespalib/eval/tensor_nodes.h
index a8368c7eec3..28fbfc796c6 100644
--- a/vespalib/src/vespa/vespalib/eval/tensor_nodes.h
+++ b/vespalib/src/vespa/vespalib/eval/tensor_nodes.h
@@ -20,7 +20,7 @@ public:
TensorSum(Node_UP child, const vespalib::string &dimension_in)
: _child(std::move(child)), _dimension(dimension_in) {}
const vespalib::string &dimension() const { return _dimension; }
- virtual vespalib::string dump(DumpContext &ctx) const {
+ vespalib::string dump(DumpContext &ctx) const override {
vespalib::string str;
str += "sum(";
str += _child->dump(ctx);
@@ -31,14 +31,14 @@ public:
str += ")";
return str;
}
- virtual void accept(NodeVisitor &visitor) const;
- virtual size_t num_children() const { return 1; }
- virtual const Node &get_child(size_t idx) const {
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 1; }
+ const Node &get_child(size_t idx) const override {
(void) idx;
assert(idx == 0);
return *_child;
}
- virtual void detach_children(NodeHandler &handler) {
+ void detach_children(NodeHandler &handler) override {
handler.handle(std::move(_child));
}
};
@@ -51,7 +51,7 @@ public:
TensorMap(Node_UP child, Function lambda)
: _child(std::move(child)), _lambda(std::move(lambda)) {}
const Function &lambda() const { return _lambda; }
- virtual vespalib::string dump(DumpContext &ctx) const {
+ vespalib::string dump(DumpContext &ctx) const override {
vespalib::string str;
str += "map(";
str += _child->dump(ctx);
@@ -60,14 +60,14 @@ public:
str += ")";
return str;
}
- virtual void accept(NodeVisitor &visitor) const;
- virtual size_t num_children() const { return 1; }
- virtual const Node &get_child(size_t idx) const {
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 1; }
+ const Node &get_child(size_t idx) const override {
(void) idx;
assert(idx == 0);
return *_child;
}
- virtual void detach_children(NodeHandler &handler) {
+ void detach_children(NodeHandler &handler) override {
handler.handle(std::move(_child));
}
};
@@ -81,7 +81,7 @@ public:
TensorJoin(Node_UP lhs, Node_UP rhs, Function lambda)
: _lhs(std::move(lhs)), _rhs(std::move(rhs)), _lambda(std::move(lambda)) {}
const Function &lambda() const { return _lambda; }
- virtual vespalib::string dump(DumpContext &ctx) const {
+ vespalib::string dump(DumpContext &ctx) const override {
vespalib::string str;
str += "join(";
str += _lhs->dump(ctx);
@@ -92,18 +92,135 @@ public:
str += ")";
return str;
}
- virtual void accept(NodeVisitor &visitor) const;
- virtual size_t num_children() const { return 2; }
- virtual const Node &get_child(size_t idx) const {
+ void accept(NodeVisitor &visitor) const override ;
+ size_t num_children() const override { return 2; }
+ const Node &get_child(size_t idx) const override {
assert(idx < 2);
return (idx == 0) ? *_lhs : *_rhs;
}
- virtual void detach_children(NodeHandler &handler) {
+ void detach_children(NodeHandler &handler) override {
handler.handle(std::move(_lhs));
handler.handle(std::move(_rhs));
}
};
+enum class Aggr { AVG, COUNT, PROD, SUM, MAX, MIN };
+class AggrNames {
+private:
+ static const AggrNames _instance;
+ std::map<vespalib::string,Aggr> _name_aggr_map;
+ std::map<Aggr,vespalib::string> _aggr_name_map;
+ void add(Aggr aggr, const vespalib::string &name);
+ AggrNames();
+public:
+ static const vespalib::string *name_of(Aggr aggr);
+ static const Aggr *from_name(const vespalib::string &name);
+};
+
+class TensorReduce : public Node {
+private:
+ Node_UP _child;
+ Aggr _aggr;
+ std::vector<vespalib::string> _dimensions;
+public:
+ TensorReduce(Node_UP child, Aggr aggr_in, std::vector<vespalib::string> dimensions_in)
+ : _child(std::move(child)), _aggr(aggr_in), _dimensions(std::move(dimensions_in)) {}
+ const std::vector<vespalib::string> &dimensions() const { return _dimensions; }
+ Aggr aggr() const { return _aggr; }
+ vespalib::string dump(DumpContext &ctx) const override {
+ vespalib::string str;
+ str += "reduce(";
+ str += _child->dump(ctx);
+ str += ",";
+ str += *AggrNames::name_of(_aggr);
+ for (const auto &dimension: _dimensions) {
+ str += ",";
+ str += dimension;
+ }
+ str += ")";
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 1; }
+ const Node &get_child(size_t idx) const override {
+ assert(idx == 0);
+ return *_child;
+ }
+ void detach_children(NodeHandler &handler) override {
+ handler.handle(std::move(_child));
+ }
+};
+
+class TensorRename : public Node {
+private:
+ Node_UP _child;
+ std::vector<vespalib::string> _from;
+ std::vector<vespalib::string> _to;
+ static vespalib::string flatten(const std::vector<vespalib::string> &list) {
+ if (list.size() == 1) {
+ return list[0];
+ }
+ vespalib::string str = "(";
+ for (size_t i = 0; i < list.size(); ++i) {
+ if (i > 0) {
+ str += ",";
+ }
+ str += list[i];
+ }
+ str += ")";
+ return str;
+ }
+public:
+ TensorRename(Node_UP child, std::vector<vespalib::string> from_in, std::vector<vespalib::string> to_in)
+ : _child(std::move(child)), _from(std::move(from_in)), _to(std::move(to_in)) {}
+ const std::vector<vespalib::string> &from() const { return _from; }
+ const std::vector<vespalib::string> &to() const { return _to; }
+ vespalib::string dump(DumpContext &ctx) const override {
+ vespalib::string str;
+ str += "rename(";
+ str += _child->dump(ctx);
+ str += ",";
+ str += flatten(_from);
+ str += ",";
+ str += flatten(_to);
+ str += ")";
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 1; }
+ const Node &get_child(size_t idx) const override {
+ assert(idx == 0);
+ return *_child;
+ }
+ void detach_children(NodeHandler &handler) override {
+ handler.handle(std::move(_child));
+ }
+};
+
+class TensorLambda : public Leaf {
+private:
+ ValueType _type;
+ Function _lambda;
+public:
+ TensorLambda(ValueType type_in, Function lambda)
+ : _type(std::move(type_in)), _lambda(std::move(lambda)) {}
+ const ValueType &type() const { return _type; }
+ const Function &lambda() const { return _lambda; }
+ vespalib::string dump(DumpContext &) const override {
+ vespalib::string str = _type.to_spec();
+ vespalib::string expr = _lambda.dump();
+ if (starts_with(expr, "(")) {
+ str += expr;
+ } else {
+ str += "(";
+ str += expr;
+ str += ")";
+ }
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override;
+};
+
} // namespace vespalib::eval::nodes
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp b/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp
index f132cc7acbc..cd48aa4881f 100644
--- a/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp
+++ b/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp
@@ -10,9 +10,10 @@ namespace vespalib {
namespace eval {
namespace test {
-const double my_nan = std::numeric_limits<double>::quiet_NaN();
-const double my_inf = std::numeric_limits<double>::infinity();
-
+constexpr double my_nan = std::numeric_limits<double>::quiet_NaN();
+constexpr double my_inf = std::numeric_limits<double>::infinity();
+constexpr double my_error_value = 31212.0;
+
vespalib::string
EvalSpec::EvalTest::as_string(const std::vector<vespalib::string> &param_names,
const std::vector<double> &param_values,
@@ -112,7 +113,6 @@ EvalSpec::add_function_call_cases() {
.add_case({my_nan}, 1.0).add_case({my_inf}, 0.0).add_case({-my_inf}, 0.0);
add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); });
add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); });
- add_rule({"a", -1.0, 1.0}, "sum(a)", [](double a){ return a; });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); });
@@ -122,6 +122,22 @@ EvalSpec::add_function_call_cases() {
}
void
+EvalSpec::add_tensor_operation_cases() {
+ add_rule({"a", -1.0, 1.0}, "sum(a)", [](double a){ return a; });
+ add_rule({"a", -1.0, 1.0}, "map(a,f(x)(sin(x)))", [](double x){ return std::sin(x); });
+ add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x+x*3))", [](double x){ return (x + (x * 3)); });
+ add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });
+ add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); });
+ add_rule({"a", -1.0, 1.0}, "reduce(a,sum)", [](double a){ return a; });
+ add_rule({"a", -1.0, 1.0}, "reduce(a,prod)", [](double a){ return a; });
+ add_rule({"a", -1.0, 1.0}, "reduce(a,count)", [](double){ return 1.0; });
+ add_rule({"a", -1.0, 1.0}, "rename(a,x,y)", [](double){ return my_error_value; });
+ add_rule({"a", -1.0, 1.0}, "rename(a,(x,y),(y,x))", [](double){ return my_error_value; });
+ add_expression({}, "tensor(x[10])(x)");
+ add_expression({}, "tensor(x[10],y[10])(x==y)");
+}
+
+void
EvalSpec::add_comparison_cases() {
add_expression({"a", "b"}, "(a==b)")
.add_case({my_nan, 2.0}, 0.0)
diff --git a/vespalib/src/vespa/vespalib/eval/test/eval_spec.h b/vespalib/src/vespa/vespalib/eval/test/eval_spec.h
index af1821c1e8d..582c3b1c1e5 100644
--- a/vespalib/src/vespa/vespalib/eval/test/eval_spec.h
+++ b/vespalib/src/vespa/vespalib/eval/test/eval_spec.h
@@ -120,20 +120,22 @@ public:
virtual ~EvalTest() {}
};
//-------------------------------------------------------------------------
- void add_terminal_cases(); // a, 1.0
- void add_arithmetic_cases(); // a + b, a ^ b
- void add_function_call_cases(); // cos(a), max(a, b)
- void add_comparison_cases(); // a < b, c != d
- void add_set_membership_cases(); // a in [x, y, z]
- void add_boolean_cases(); // 1.0 && 0.0
- void add_if_cases(); // if (a < b, a, b)
- void add_let_cases(); // let (a, b + 1, a * a)
- void add_complex_cases(); // ...
+ void add_terminal_cases(); // a, 1.0
+ void add_arithmetic_cases(); // a + b, a ^ b
+ void add_function_call_cases(); // cos(a), max(a, b)
+ void add_tensor_operation_cases(); // map(a,f(x)(sin(x)))
+ void add_comparison_cases(); // a < b, c != d
+ void add_set_membership_cases(); // a in [x, y, z]
+ void add_boolean_cases(); // 1.0 && 0.0
+ void add_if_cases(); // if (a < b, a, b)
+ void add_let_cases(); // let (a, b + 1, a * a)
+ void add_complex_cases(); // ...
//-------------------------------------------------------------------------
void add_all_cases() {
add_terminal_cases();
add_arithmetic_cases();
add_function_call_cases();
+ add_tensor_operation_cases();
add_comparison_cases();
add_set_membership_cases();
add_boolean_cases();