diff options
author | Geir Storli <geirstorli@yahoo.no> | 2016-11-21 11:43:05 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-21 11:43:05 +0100 |
commit | 37e782eed02a1b4f8aacade1cfeed651f3c0e870 (patch) | |
tree | e68a7447cb8e9a37fe065ab2f88baa6b045bcf46 /vespalib | |
parent | 81b959e4472035f9c44b3f6c8e54a6cc41ec4259 (diff) | |
parent | d100f7cea6200f055e30738dc2d183a23f31fd9a (diff) |
Merge pull request #1142 from yahoo/havardpe/parse-more-tensor-operations
Havardpe/parse more tensor operations
Diffstat (limited to 'vespalib')
16 files changed, 645 insertions, 140 deletions
diff --git a/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp b/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp index 1bbe9d2fa0b..f80df8090d9 100644 --- a/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp +++ b/vespalib/src/tests/eval/compile_cache/compile_cache_test.cpp @@ -69,11 +69,14 @@ struct CheckKeys : test::EvalSpec::EvalTest { virtual void next_expression(const std::vector<vespalib::string> ¶m_names, const vespalib::string &expression) override { - if (check_key(gen_key(Function::parse(param_names, expression), PassParams::ARRAY)) || - check_key(gen_key(Function::parse(param_names, expression), PassParams::SEPARATE))) - { - failed = true; - fprintf(stderr, "key collision for: %s\n", expression.c_str()); + Function function = Function::parse(param_names, expression); + if (!CompiledFunction::detect_issues(function)) { + if (check_key(gen_key(function, PassParams::ARRAY)) || + check_key(gen_key(function, PassParams::SEPARATE))) + { + failed = true; + fprintf(stderr, "key collision for: %s\n", expression.c_str()); + } } } virtual void handle_case(const std::vector<vespalib::string> &, diff --git a/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp b/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp index 2dc065fe08e..121bfe7b1ee 100644 --- a/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp +++ b/vespalib/src/tests/eval/compiled_function/compiled_function_test.cpp @@ -40,7 +40,10 @@ TEST("require that array parameter passing works") { std::vector<vespalib::string> unsupported = { "sum(", "map(", - "join(" + "join(", + "reduce(", + "rename(", + "tensor(" }; bool is_unsupported(const vespalib::string &expression) { @@ -59,28 +62,34 @@ struct MyEvalTest : test::EvalSpec::EvalTest { size_t fail_cnt = 0; bool print_pass = false; bool print_fail = false; - virtual void next_expression(const std::vector<vespalib::string> &, - const vespalib::string &) override {} - virtual void handle_case(const std::vector<vespalib::string> ¶m_names, - const std::vector<double> ¶m_values, - const vespalib::string &expression, - double expected_result) override + virtual void next_expression(const std::vector<vespalib::string> ¶m_names, + const vespalib::string &expression) override { Function function = Function::parse(param_names, expression); + ASSERT_TRUE(!function.has_error()); bool is_supported = !is_unsupported(expression); bool has_issues = CompiledFunction::detect_issues(function); if (is_supported == has_issues) { const char *supported_str = is_supported ? "supported" : "not supported"; const char *issues_str = has_issues ? "has issues" : "does not have issues"; print_fail && fprintf(stderr, "expression %s is %s, but %s\n", - as_string(param_names, param_values, expression).c_str(), - supported_str, issues_str); + expression.c_str(), supported_str, issues_str); ++fail_cnt; } + } + virtual void handle_case(const std::vector<vespalib::string> ¶m_names, + const std::vector<double> ¶m_values, + const vespalib::string &expression, + double expected_result) override + { + Function function = Function::parse(param_names, expression); + ASSERT_TRUE(!function.has_error()); + bool is_supported = !is_unsupported(expression); + bool has_issues = CompiledFunction::detect_issues(function); if (is_supported && !has_issues) { - CompiledFunction cfun(Function::parse(param_names, expression), PassParams::ARRAY); + CompiledFunction cfun(function, PassParams::ARRAY); auto fun = cfun.get_function(); - EXPECT_EQUAL(cfun.num_params(), param_values.size()); + ASSERT_EQUAL(cfun.num_params(), param_values.size()); double result = fun(¶m_values[0]); if (is_same(expected_result, result)) { print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", diff --git a/vespalib/src/tests/eval/function/function_test.cpp b/vespalib/src/tests/eval/function/function_test.cpp index 6e4963050d2..afb0defdd91 100644 --- a/vespalib/src/tests/eval/function/function_test.cpp +++ b/vespalib/src/tests/eval/function/function_test.cpp @@ -756,6 +756,8 @@ TEST("require that tensor operations can be nested") { EXPECT_EQUAL("sum(sum(sum(a)),dim)", Function::parse("sum(sum(sum(a)),dim)").dump()); } +//----------------------------------------------------------------------------- + TEST("require that tensor map can be parsed") { EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse("map(a,f(x)(x+1))").dump()); EXPECT_EQUAL("map(a,f(x)(x+1))", Function::parse(" map ( a , f ( x ) ( x + 1 ) ) ").dump()); @@ -784,4 +786,105 @@ TEST("require that outer let bindings are hidden within a lambda") { //----------------------------------------------------------------------------- +TEST("require that tensor reduce can be parsed") { + EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, "reduce(x,sum,a,b,c)").dump()); + EXPECT_EQUAL("reduce(x,sum,a,b,c)", Function::parse({"x"}, " reduce ( x , sum , a , b , c ) ").dump()); + EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)").dump()); + EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce( x , sum )").dump()); + EXPECT_EQUAL("reduce(x,avg)", Function::parse({"x"}, "reduce(x,avg)").dump()); + EXPECT_EQUAL("reduce(x,count)", Function::parse({"x"}, "reduce(x,count)").dump()); + EXPECT_EQUAL("reduce(x,prod)", Function::parse({"x"}, "reduce(x,prod)").dump()); + EXPECT_EQUAL("reduce(x,sum)", Function::parse({"x"}, "reduce(x,sum)").dump()); + EXPECT_EQUAL("reduce(x,min)", Function::parse({"x"}, "reduce(x,min)").dump()); + EXPECT_EQUAL("reduce(x,max)", Function::parse({"x"}, "reduce(x,max)").dump()); +} + +TEST("require that tensor reduce with unknown aggregator fails") { + verify_error("reduce(x,bogus)", "[reduce(x,bogus]...[unknown aggregator: 'bogus']...[)]"); +} + +TEST("require that tensor reduce with duplicate dimensions fails") { + verify_error("reduce(x,sum,a,a)", "[reduce(x,sum,a,a]...[duplicate identifiers]...[)]"); +} + +//----------------------------------------------------------------------------- + +TEST("require that tensor rename can be parsed") { + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,b)").dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),(b))").dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,a,(b))").dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename(x,(a),b)").dump()); + EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename(x,(a,b),(b,a))").dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , a , b )").dump()); + EXPECT_EQUAL("rename(x,a,b)", Function::parse({"x"}, "rename( x , ( a ) , ( b ) )").dump()); + EXPECT_EQUAL("rename(x,(a,b),(b,a))", Function::parse({"x"}, "rename( x , ( a , b ) , ( b , a ) )").dump()); +} + +TEST("require that tensor rename dimension lists cannot be empty") { + verify_error("rename(x,,b)", "[rename(x,]...[missing identifier]...[,b)]"); + verify_error("rename(x,a,)", "[rename(x,a,]...[missing identifier]...[)]"); + verify_error("rename(x,(),b)", "[rename(x,()]...[missing identifiers]...[,b)]"); + verify_error("rename(x,a,())", "[rename(x,a,()]...[missing identifiers]...[)]"); +} + +TEST("require that tensor rename dimension lists cannot contain duplicates") { + verify_error("rename(x,(a,a),(b,a))", "[rename(x,(a,a)]...[duplicate identifiers]...[,(b,a))]"); + verify_error("rename(x,(a,b),(b,b))", "[rename(x,(a,b),(b,b)]...[duplicate identifiers]...[)]"); +} + +TEST("require that tensor rename dimension lists must have equal size") { + verify_error("rename(x,(a,b),(b))", "[rename(x,(a,b),(b)]...[dimension list size mismatch]...[)]"); + verify_error("rename(x,(a),(b,a))", "[rename(x,(a),(b,a)]...[dimension list size mismatch]...[)]"); +} + +//----------------------------------------------------------------------------- + +TEST("require that tensor lambda can be parsed") { + EXPECT_EQUAL("tensor(x[10])(x)", Function::parse({""}, "tensor(x[10])(x)").dump()); + EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({""}, "tensor(x[10],y[10])(x==y)").dump()); + EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({""}, " tensor ( x [ 10 ] , y [ 10 ] ) ( x == y ) ").dump()); +} + +TEST("require that tensor lambda requires appropriate tensor type") { + verify_error("tensor(x[10],y[])(x==y)", "[tensor(x[10],y[])]...[invalid tensor type]...[(x==y)]"); + verify_error("tensor(x[10],y{})(x==y)", "[tensor(x[10],y{})]...[invalid tensor type]...[(x==y)]"); + verify_error("tensor()(x==y)", "[tensor()]...[invalid tensor type]...[(x==y)]"); +} + +TEST("require that tensor lambda can only use dimension names") { + verify_error("tensor(x[10],y[10])(x==z)", "[tensor(x[10],y[10])(x==z]...[unknown symbol: 'z']...[)]"); +} + +//----------------------------------------------------------------------------- + +struct CheckExpressions : test::EvalSpec::EvalTest { + bool failed = false; + size_t seen_cnt = 0; + virtual void next_expression(const std::vector<vespalib::string> ¶m_names, + const vespalib::string &expression) override + { + Function function = Function::parse(param_names, expression); + if (function.has_error()) { + failed = true; + fprintf(stderr, "parse error: %s\n", function.get_error().c_str()); + } + ++seen_cnt; + } + virtual void handle_case(const std::vector<vespalib::string> &, + const std::vector<double> &, + const vespalib::string &, + double) override {} +}; + +TEST_FF("require that all conformance test expressions can be parsed", + CheckExpressions(), test::EvalSpec()) +{ + f2.add_all_cases(); + f2.each_case(f1); + EXPECT_TRUE(!f1.failed); + EXPECT_GREATER(f1.seen_cnt, 42u); +} + +//----------------------------------------------------------------------------- + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 51aacfd2272..b12dfdd2df7 100644 --- a/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/vespalib/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -15,37 +15,74 @@ using vespalib::Stash; //----------------------------------------------------------------------------- +std::vector<vespalib::string> unsupported = { + "map(", + "join(", + "reduce(", + "rename(", + "tensor(" +}; + +bool is_unsupported(const vespalib::string &expression) { + for (const auto &prefix: unsupported) { + if (starts_with(expression, prefix)) { + return true; + } + } + return false; +} + +//----------------------------------------------------------------------------- + struct MyEvalTest : test::EvalSpec::EvalTest { size_t pass_cnt = 0; size_t fail_cnt = 0; bool print_pass = false; bool print_fail = false; - virtual void next_expression(const std::vector<vespalib::string> &, - const vespalib::string &) override {} + virtual void next_expression(const std::vector<vespalib::string> ¶m_names, + const vespalib::string &expression) override + { + Function function = Function::parse(param_names, expression); + ASSERT_TRUE(!function.has_error()); + bool is_supported = !is_unsupported(expression); + bool has_issues = InterpretedFunction::detect_issues(function); + if (is_supported == has_issues) { + const char *supported_str = is_supported ? "supported" : "not supported"; + const char *issues_str = has_issues ? "has issues" : "does not have issues"; + print_fail && fprintf(stderr, "expression %s is %s, but %s\n", + expression.c_str(), supported_str, issues_str); + ++fail_cnt; + } + } virtual void handle_case(const std::vector<vespalib::string> ¶m_names, const std::vector<double> ¶m_values, const vespalib::string &expression, double expected_result) override { - Function fun = Function::parse(param_names, expression); - EXPECT_EQUAL(fun.num_params(), param_values.size()); - InterpretedFunction ifun(SimpleTensorEngine::ref(), fun, NodeTypes()); - InterpretedFunction::Context ictx; - for (double param: param_values) { - ictx.add_param(param); - } - const Value &result_value = ifun.eval(ictx); - double result = result_value.as_double(); - if (result_value.is_double() && is_same(expected_result, result)) { - print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", - as_string(param_names, param_values, expression).c_str(), - expected_result); - ++pass_cnt; - } else { - print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n", - as_string(param_names, param_values, expression).c_str(), - expected_result, result); - ++fail_cnt; + Function function = Function::parse(param_names, expression); + ASSERT_TRUE(!function.has_error()); + bool is_supported = !is_unsupported(expression); + bool has_issues = InterpretedFunction::detect_issues(function); + if (is_supported && !has_issues) { + InterpretedFunction ifun(SimpleTensorEngine::ref(), function, NodeTypes()); + ASSERT_EQUAL(ifun.num_params(), param_values.size()); + InterpretedFunction::Context ictx; + for (double param: param_values) { + ictx.add_param(param); + } + const Value &result_value = ifun.eval(ictx); + double result = result_value.as_double(); + if (result_value.is_double() && is_same(expected_result, result)) { + print_pass && fprintf(stderr, "verifying: %s -> %g ... PASS\n", + as_string(param_names, param_values, expression).c_str(), + expected_result); + ++pass_cnt; + } else { + print_fail && fprintf(stderr, "verifying: %s -> %g ... FAIL: got %g\n", + as_string(param_names, param_values, expression).c_str(), + expected_result, result); + ++fail_cnt; + } } } }; diff --git a/vespalib/src/tests/eval/node_types/node_types_test.cpp b/vespalib/src/tests/eval/node_types/node_types_test.cpp index 1dbb5d765ae..33799f0f989 100644 --- a/vespalib/src/tests/eval/node_types/node_types_test.cpp +++ b/vespalib/src/tests/eval/node_types/node_types_test.cpp @@ -7,10 +7,25 @@ using namespace vespalib::eval; +/** + * Hack to avoid parse-conflict between tensor type expressions and + * lambda-generated tensors. This will patch leading identifier 'T' to + * 't' directly in the input stream after we have concluded that this + * is not a lambda-generated tensor in order to parse it out as a + * valid tensor type. This may be reverted later if we add support for + * parser rollback when we fail to parse a lambda-generated tensor. + **/ +void tensor_type_hack(const char *pos_in, const char *end_in) { + if ((pos_in < end_in) && (*pos_in == 'T')) { + const_cast<char *>(pos_in)[0] = 't'; + } +} + struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor { void extract_symbol(const char *pos_in, const char *end_in, const char *&pos_out, vespalib::string &symbol_out) const override { + tensor_type_hack(pos_in, end_in); ValueType type = value_type::parse_spec(pos_in, end_in, pos_out); if (pos_out != nullptr) { symbol_out = type.to_spec(); @@ -18,7 +33,15 @@ struct TypeSpecExtractor : public vespalib::eval::SymbolExtractor { } }; -void verify(const vespalib::string &type_expr, const vespalib::string &type_spec) { +void verify(const vespalib::string &type_expr_in, const vespalib::string &type_spec) { + // replace 'tensor' with 'Tensor' in type expression, see hack above + vespalib::string type_expr = type_expr_in; + for (size_t idx = type_expr.find("tensor"); + idx != type_expr.npos; + idx = type_expr.find("tensor")) + { + type_expr[idx] = 'T'; + } Function function = Function::parse(type_expr, TypeSpecExtractor()); if (!EXPECT_TRUE(!function.has_error())) { fprintf(stderr, "parse error: %s\n", function.get_error().c_str()); diff --git a/vespalib/src/vespa/vespalib/eval/function.cpp b/vespalib/src/vespa/vespalib/eval/function.cpp index bd2e4a84c5d..c26256e69d6 100644 --- a/vespalib/src/vespa/vespalib/eval/function.cpp +++ b/vespalib/src/vespa/vespalib/eval/function.cpp @@ -19,6 +19,29 @@ using nodes::Call_UP; namespace { +bool has_duplicates(const std::vector<vespalib::string> &list) { + for (size_t i = 0; i < list.size(); ++i) { + for (size_t j = (i + 1); j < list.size(); ++j) { + if (list[i] == list[j]) { + return true; + } + } + } + return false; +} + +bool check_tensor_lambda_type(const ValueType &type) { + if (!type.is_tensor() || type.dimensions().empty()) { + return false; + } + for (const auto &dim: type.dimensions()) { + if (!dim.is_indexed() || !dim.is_bound()) { + return false; + } + } + return true; +} + //----------------------------------------------------------------------------- class Params { @@ -400,7 +423,6 @@ void parse_number(ParseContext &ctx) { } else { ctx.fail(make_string("invalid number: '%s'", str.c_str())); } - return; } // NOTE: using non-standard definition of identifiers @@ -413,7 +435,7 @@ bool is_ident(char c, bool first) { (c == '$' && !first)); } -vespalib::string get_ident(ParseContext &ctx) { +vespalib::string get_ident(ParseContext &ctx, bool allow_empty) { ctx.skip_spaces(); vespalib::string ident; if (is_ident(ctx.get(), true)) { @@ -422,6 +444,9 @@ vespalib::string get_ident(ParseContext &ctx) { ident.push_back(ctx.get()); } } + if (!allow_empty && ident.empty()) { + ctx.fail("missing identifier"); + } return ident; } @@ -448,7 +473,7 @@ void parse_if(ParseContext &ctx) { } void parse_let(ParseContext &ctx) { - vespalib::string name = get_ident(ctx); + vespalib::string name = get_ident(ctx, false); ctx.skip_spaces(); ctx.eat(','); parse_expression(ctx); @@ -472,25 +497,50 @@ void parse_call(ParseContext &ctx, Call_UP call) { ctx.push_expression(std::move(call)); } -// (a,b,c) -std::vector<vespalib::string> get_ident_list(ParseContext &ctx) { +// (a,b,c) wrapped +// ,a,b,c -> ) not wrapped +std::vector<vespalib::string> get_ident_list(ParseContext &ctx, bool wrapped) { std::vector<vespalib::string> list; - ctx.skip_spaces(); - ctx.eat('('); + if (wrapped) { + ctx.skip_spaces(); + ctx.eat('('); + } for (ctx.skip_spaces(); !ctx.eos() && (ctx.get() != ')'); ctx.skip_spaces()) { - if (!list.empty()) { + if (!list.empty() || !wrapped) { ctx.eat(','); } - list.push_back(get_ident(ctx)); + list.push_back(get_ident(ctx, false)); + } + if (wrapped) { + ctx.eat(')'); + } + if (has_duplicates(list)) { + ctx.fail("duplicate identifiers"); + } + return list; +} + +// a +// (a,b,c) +// cannot be empty +std::vector<vespalib::string> get_idents(ParseContext &ctx) { + std::vector<vespalib::string> list; + ctx.skip_spaces(); + if (ctx.get() == '(') { + list = get_ident_list(ctx, true); + } else { + list.push_back(get_ident(ctx, false)); + } + if (list.empty()) { + ctx.fail("missing identifiers"); } - ctx.eat(')'); return list; } -Function parse_lambda(ParseContext &ctx) { +Function parse_lambda(ParseContext &ctx, size_t num_params) { ctx.skip_spaces(); ctx.eat('f'); - auto param_names = get_ident_list(ctx); + auto param_names = get_ident_list(ctx, true); ExplicitParams params(param_names); ctx.push_resolve_context(params, nullptr); ctx.skip_spaces(); @@ -500,6 +550,10 @@ Function parse_lambda(ParseContext &ctx) { ctx.skip_spaces(); ctx.pop_resolve_context(); Node_UP lambda_root = ctx.pop_expression(); + if (param_names.size() != num_params) { + ctx.fail(make_string("expected lambda with %zu parameter(s), was %zu", + num_params, param_names.size())); + } return Function(std::move(lambda_root), std::move(param_names)); } @@ -507,13 +561,8 @@ void parse_tensor_map(ParseContext &ctx) { parse_expression(ctx); Node_UP child = ctx.pop_expression(); ctx.eat(','); - Function lambda = parse_lambda(ctx); - if (lambda.num_params() == 1) { - ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda))); - } else { - ctx.fail(make_string("map requires a lambda with 1 parameter, was %zu", - lambda.num_params())); - } + Function lambda = parse_lambda(ctx, 1); + ctx.push_expression(std::make_unique<nodes::TensorMap>(std::move(child), std::move(lambda))); } void parse_tensor_join(ParseContext &ctx) { @@ -523,13 +572,62 @@ void parse_tensor_join(ParseContext &ctx) { parse_expression(ctx); Node_UP rhs = ctx.pop_expression(); ctx.eat(','); - Function lambda = parse_lambda(ctx); - if (lambda.num_params() == 2) { - ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda))); + Function lambda = parse_lambda(ctx, 2); + ctx.push_expression(std::make_unique<nodes::TensorJoin>(std::move(lhs), std::move(rhs), std::move(lambda))); +} + +void parse_tensor_reduce(ParseContext &ctx) { + parse_expression(ctx); + Node_UP child = ctx.pop_expression(); + ctx.eat(','); + auto aggr_name = get_ident(ctx, false); + auto maybe_aggr = nodes::AggrNames::from_name(aggr_name); + if (!maybe_aggr) { + ctx.fail(make_string("unknown aggregator: '%s'", aggr_name.c_str())); + return; + } + auto dimensions = get_ident_list(ctx, false); + ctx.push_expression(std::make_unique<nodes::TensorReduce>(std::move(child), *maybe_aggr, std::move(dimensions))); +} + +void parse_tensor_rename(ParseContext &ctx) { + parse_expression(ctx); + Node_UP child = ctx.pop_expression(); + ctx.eat(','); + auto from = get_idents(ctx); + ctx.skip_spaces(); + ctx.eat(','); + auto to = get_idents(ctx); + if (from.size() != to.size()) { + ctx.fail("dimension list size mismatch"); } else { - ctx.fail(make_string("join requires a lambda with 2 parameters, was %zu", - lambda.num_params())); + ctx.push_expression(std::make_unique<nodes::TensorRename>(std::move(child), std::move(from), std::move(to))); + } + ctx.skip_spaces(); +} + +void parse_tensor_lambda(ParseContext &ctx) { + vespalib::string type_spec("tensor("); + while(!ctx.eos() && (ctx.get() != ')')) { + type_spec.push_back(ctx.get()); + ctx.next(); } + ctx.eat(')'); + type_spec.push_back(')'); + ValueType type = ValueType::from_spec(type_spec); + if (!check_tensor_lambda_type(type)) { + ctx.fail("invalid tensor type"); + return; + } + auto param_names = type.dimension_names(); + ExplicitParams params(param_names); + ctx.push_resolve_context(params, nullptr); + ctx.skip_spaces(); + ctx.eat('('); + parse_expression(ctx); + ctx.pop_resolve_context(); + Function lambda(ctx.pop_expression(), std::move(param_names)); + ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda))); } // to be replaced with more generic 'reduce' @@ -538,7 +636,7 @@ void parse_tensor_sum(ParseContext &ctx) { Node_UP child = ctx.pop_expression(); if (ctx.get() == ',') { ctx.next(); - vespalib::string dimension = get_ident(ctx); + vespalib::string dimension = get_ident(ctx, false); ctx.skip_spaces(); ctx.push_expression(Node_UP(new nodes::TensorSum(std::move(child), dimension))); } else { @@ -562,6 +660,12 @@ bool try_parse_call(ParseContext &ctx, const vespalib::string &name) { parse_tensor_map(ctx); } else if (name == "join") { parse_tensor_join(ctx); + } else if (name == "reduce") { + parse_tensor_reduce(ctx); + } else if (name == "rename") { + parse_tensor_rename(ctx); + } else if (name == "tensor") { + parse_tensor_lambda(ctx); } else if (name == "sum") { parse_tensor_sum(ctx); } else { @@ -586,7 +690,7 @@ int parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::InputM void parse_symbol_or_call(ParseContext &ctx) { ParseContext::InputMark before_name = ctx.get_input_mark(); - vespalib::string name = get_ident(ctx); + vespalib::string name = get_ident(ctx, true); if (!try_parse_call(ctx, name)) { int id = parse_symbol(ctx, name, before_name); if (name.empty()) { diff --git a/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp b/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp index a4329f97de2..a726e550382 100644 --- a/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp +++ b/vespalib/src/vespa/vespalib/eval/interpreted_function.cpp @@ -274,6 +274,18 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { // TODO(havardpe): add actual evaluation program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); } + virtual void visit(const TensorReduce &) { + // TODO(havardpe): add actual evaluation + program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); + } + virtual void visit(const TensorRename &) { + // TODO(havardpe): add actual evaluation + program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); + } + virtual void visit(const TensorLambda &) { + // TODO(havardpe): add actual evaluation + program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); + } virtual void visit(const Add &) { program.emplace_back(op_binary<operation::Add>); } @@ -462,7 +474,10 @@ InterpretedFunction::detect_issues(const Function &function) bool open(const nodes::Node &) override { return true; } void close(const nodes::Node &node) override { if (nodes::check_type<nodes::TensorMap, - nodes::TensorJoin>(node)) { + nodes::TensorJoin, + nodes::TensorReduce, + nodes::TensorRename, + nodes::TensorLambda>(node)) { issues.push_back(make_string("unsupported node type: %s", getClassName(node).c_str())); } diff --git a/vespalib/src/vespa/vespalib/eval/key_gen.cpp b/vespalib/src/vespa/vespalib/eval/key_gen.cpp index 96774c62f78..dc137bc2060 100644 --- a/vespalib/src/vespa/vespalib/eval/key_gen.cpp +++ b/vespalib/src/vespa/vespalib/eval/key_gen.cpp @@ -34,49 +34,52 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { virtual void visit(const If &node) { add_byte( 7); add_double(node.p_true()); } virtual void visit(const Let &) { add_byte( 8); } virtual void visit(const Error &) { add_byte( 9); } - virtual void visit(const TensorSum &) { add_byte(10); } - virtual void visit(const TensorMap &) { add_byte(11); } - virtual void visit(const TensorJoin &) { add_byte(12); } - virtual void visit(const Add &) { add_byte(13); } - virtual void visit(const Sub &) { add_byte(14); } - virtual void visit(const Mul &) { add_byte(15); } - virtual void visit(const Div &) { add_byte(16); } - virtual void visit(const Pow &) { add_byte(17); } - virtual void visit(const Equal &) { add_byte(18); } - virtual void visit(const NotEqual &) { add_byte(19); } - virtual void visit(const Approx &) { add_byte(20); } - virtual void visit(const Less &) { add_byte(21); } - virtual void visit(const LessEqual &) { add_byte(22); } - virtual void visit(const Greater &) { add_byte(23); } - virtual void visit(const GreaterEqual &) { add_byte(24); } - virtual void visit(const In &) { add_byte(25); } - virtual void visit(const And &) { add_byte(26); } - virtual void visit(const Or &) { add_byte(27); } - virtual void visit(const Cos &) { add_byte(28); } - virtual void visit(const Sin &) { add_byte(29); } - virtual void visit(const Tan &) { add_byte(30); } - virtual void visit(const Cosh &) { add_byte(31); } - virtual void visit(const Sinh &) { add_byte(32); } - virtual void visit(const Tanh &) { add_byte(33); } - virtual void visit(const Acos &) { add_byte(34); } - virtual void visit(const Asin &) { add_byte(35); } - virtual void visit(const Atan &) { add_byte(36); } - virtual void visit(const Exp &) { add_byte(37); } - virtual void visit(const Log10 &) { add_byte(38); } - virtual void visit(const Log &) { add_byte(39); } - virtual void visit(const Sqrt &) { add_byte(40); } - virtual void visit(const Ceil &) { add_byte(41); } - virtual void visit(const Fabs &) { add_byte(42); } - virtual void visit(const Floor &) { add_byte(43); } - virtual void visit(const Atan2 &) { add_byte(44); } - virtual void visit(const Ldexp &) { add_byte(45); } - virtual void visit(const Pow2 &) { add_byte(46); } - virtual void visit(const Fmod &) { add_byte(47); } - virtual void visit(const Min &) { add_byte(48); } - virtual void visit(const Max &) { add_byte(49); } - virtual void visit(const IsNan &) { add_byte(50); } - virtual void visit(const Relu &) { add_byte(51); } - virtual void visit(const Sigmoid &) { add_byte(52); } + virtual void visit(const TensorSum &) { add_byte(10); } // dimensions should be part of key + virtual void visit(const TensorMap &) { add_byte(11); } // lambda should be part of key + virtual void visit(const TensorJoin &) { add_byte(12); } // lambda should be part of key + virtual void visit(const TensorReduce &) { add_byte(13); } // aggr/dimensions should be part of key + virtual void visit(const TensorRename &) { add_byte(14); } // dimensions should be part of key + virtual void visit(const TensorLambda &) { add_byte(15); } // type/lambda should be part of key + virtual void visit(const Add &) { add_byte(20); } + virtual void visit(const Sub &) { add_byte(21); } + virtual void visit(const Mul &) { add_byte(22); } + virtual void visit(const Div &) { add_byte(23); } + virtual void visit(const Pow &) { add_byte(24); } + virtual void visit(const Equal &) { add_byte(25); } + virtual void visit(const NotEqual &) { add_byte(26); } + virtual void visit(const Approx &) { add_byte(27); } + virtual void visit(const Less &) { add_byte(28); } + virtual void visit(const LessEqual &) { add_byte(29); } + virtual void visit(const Greater &) { add_byte(30); } + virtual void visit(const GreaterEqual &) { add_byte(31); } + virtual void visit(const In &) { add_byte(32); } + virtual void visit(const And &) { add_byte(33); } + virtual void visit(const Or &) { add_byte(34); } + virtual void visit(const Cos &) { add_byte(35); } + virtual void visit(const Sin &) { add_byte(36); } + virtual void visit(const Tan &) { add_byte(37); } + virtual void visit(const Cosh &) { add_byte(38); } + virtual void visit(const Sinh &) { add_byte(39); } + virtual void visit(const Tanh &) { add_byte(40); } + virtual void visit(const Acos &) { add_byte(41); } + virtual void visit(const Asin &) { add_byte(42); } + virtual void visit(const Atan &) { add_byte(43); } + virtual void visit(const Exp &) { add_byte(44); } + virtual void visit(const Log10 &) { add_byte(45); } + virtual void visit(const Log &) { add_byte(46); } + virtual void visit(const Sqrt &) { add_byte(47); } + virtual void visit(const Ceil &) { add_byte(48); } + virtual void visit(const Fabs &) { add_byte(49); } + virtual void visit(const Floor &) { add_byte(50); } + virtual void visit(const Atan2 &) { add_byte(51); } + virtual void visit(const Ldexp &) { add_byte(52); } + virtual void visit(const Pow2 &) { add_byte(53); } + virtual void visit(const Fmod &) { add_byte(54); } + virtual void visit(const Min &) { add_byte(55); } + virtual void visit(const Max &) { add_byte(56); } + virtual void visit(const IsNan &) { add_byte(57); } + virtual void visit(const Relu &) { add_byte(58); } + virtual void visit(const Sigmoid &) { add_byte(59); } // traverse virtual bool open(const Node &node) { node.accept(*this); return true; } diff --git a/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp b/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp index 9b120b677e2..e8243c27193 100644 --- a/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp +++ b/vespalib/src/vespa/vespalib/eval/llvm/compiled_function.cpp @@ -59,7 +59,10 @@ CompiledFunction::detect_issues(const Function &function) void close(const nodes::Node &node) override { if (nodes::check_type<nodes::TensorSum, nodes::TensorMap, - nodes::TensorJoin>(node)) { + nodes::TensorJoin, + nodes::TensorReduce, + nodes::TensorRename, + nodes::TensorLambda>(node)) { issues.push_back(make_string("unsupported node type: %s", getClassName(node).c_str())); } diff --git a/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp b/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp index a03a5a9248a..7dfe44644ef 100644 --- a/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp +++ b/vespalib/src/vespa/vespalib/eval/llvm/llvm_wrapper.cpp @@ -356,6 +356,15 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { virtual void visit(const TensorJoin &node) { make_error(node.num_children()); } + virtual void visit(const TensorReduce &node) { + make_error(node.num_children()); + } + virtual void visit(const TensorRename &node) { + make_error(node.num_children()); + } + virtual void visit(const TensorLambda &node) { + make_error(node.num_children()); + } // operator nodes diff --git a/vespalib/src/vespa/vespalib/eval/node_types.cpp b/vespalib/src/vespa/vespalib/eval/node_types.cpp index faf1e085b1b..194cbeac1d5 100644 --- a/vespalib/src/vespa/vespalib/eval/node_types.cpp +++ b/vespalib/src/vespa/vespalib/eval/node_types.cpp @@ -169,6 +169,15 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { } virtual void visit(const TensorMap &node) { resolve_op1(node); } virtual void visit(const TensorJoin &node) { resolve_op2(node); } + virtual void visit(const TensorReduce &node) { + bind_type(ValueType::error_type(), node); + } + virtual void visit(const TensorRename &node) { + bind_type(ValueType::error_type(), node); + } + virtual void visit(const TensorLambda &node) { + bind_type(node.type(), node); + } virtual void visit(const Add &node) { resolve_op2(node); } virtual void visit(const Sub &node) { resolve_op2(node); } diff --git a/vespalib/src/vespa/vespalib/eval/node_visitor.h b/vespalib/src/vespa/vespalib/eval/node_visitor.h index d7fa5feff77..c890ffcd034 100644 --- a/vespalib/src/vespa/vespalib/eval/node_visitor.h +++ b/vespalib/src/vespa/vespalib/eval/node_visitor.h @@ -33,6 +33,9 @@ struct NodeVisitor { virtual void visit(const nodes::TensorSum &) = 0; virtual void visit(const nodes::TensorMap &) = 0; virtual void visit(const nodes::TensorJoin &) = 0; + virtual void visit(const nodes::TensorReduce &) = 0; + virtual void visit(const nodes::TensorRename &) = 0; + virtual void visit(const nodes::TensorLambda &) = 0; // operator nodes virtual void visit(const nodes::Add &) = 0; @@ -98,6 +101,9 @@ struct EmptyNodeVisitor : NodeVisitor { virtual void visit(const nodes::TensorSum &) {} virtual void visit(const nodes::TensorMap &) {} virtual void visit(const nodes::TensorJoin &) {} + virtual void visit(const nodes::TensorReduce &) {} + virtual void visit(const nodes::TensorRename &) {} + virtual void visit(const nodes::TensorLambda &) {} virtual void visit(const nodes::Add &) {} virtual void visit(const nodes::Sub &) {} virtual void visit(const nodes::Mul &) {} diff --git a/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp b/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp index a63e3e37347..beeb7583e0b 100644 --- a/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp +++ b/vespalib/src/vespa/vespalib/eval/tensor_nodes.cpp @@ -8,9 +8,55 @@ namespace vespalib { namespace eval { namespace nodes { -void TensorSum ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorJoin::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorSum ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorLambda::accept(NodeVisitor &visitor) const { visitor.visit(*this); } + +const AggrNames AggrNames::_instance; + +void +AggrNames::add(Aggr aggr, const vespalib::string &name) +{ + _name_aggr_map[name] = aggr; + _aggr_name_map[aggr] = name; +} + +AggrNames::AggrNames() + : _name_aggr_map(), + _aggr_name_map() +{ + add(Aggr::AVG, "avg"); + add(Aggr::COUNT, "count"); + add(Aggr::PROD, "prod"); + add(Aggr::SUM, "sum"); + add(Aggr::MAX, "max"); + add(Aggr::MIN, "min"); +} + +const vespalib::string * +AggrNames::name_of(Aggr aggr) +{ + const auto &map = _instance._aggr_name_map; + auto result = map.find(aggr); + if (result == map.end()) { + return nullptr; + } + return &(result->second); +} + +const Aggr * +AggrNames::from_name(const vespalib::string &name) +{ + const auto &map = _instance._name_aggr_map; + auto result = map.find(name); + if (result == map.end()) { + return nullptr; + } + return &(result->second); +} } // namespace vespalib::eval::nodes } // namespace vespalib::eval diff --git a/vespalib/src/vespa/vespalib/eval/tensor_nodes.h b/vespalib/src/vespa/vespalib/eval/tensor_nodes.h index a8368c7eec3..28fbfc796c6 100644 --- a/vespalib/src/vespa/vespalib/eval/tensor_nodes.h +++ b/vespalib/src/vespa/vespalib/eval/tensor_nodes.h @@ -20,7 +20,7 @@ public: TensorSum(Node_UP child, const vespalib::string &dimension_in) : _child(std::move(child)), _dimension(dimension_in) {} const vespalib::string &dimension() const { return _dimension; } - virtual vespalib::string dump(DumpContext &ctx) const { + vespalib::string dump(DumpContext &ctx) const override { vespalib::string str; str += "sum("; str += _child->dump(ctx); @@ -31,14 +31,14 @@ public: str += ")"; return str; } - virtual void accept(NodeVisitor &visitor) const; - virtual size_t num_children() const { return 1; } - virtual const Node &get_child(size_t idx) const { + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 1; } + const Node &get_child(size_t idx) const override { (void) idx; assert(idx == 0); return *_child; } - virtual void detach_children(NodeHandler &handler) { + void detach_children(NodeHandler &handler) override { handler.handle(std::move(_child)); } }; @@ -51,7 +51,7 @@ public: TensorMap(Node_UP child, Function lambda) : _child(std::move(child)), _lambda(std::move(lambda)) {} const Function &lambda() const { return _lambda; } - virtual vespalib::string dump(DumpContext &ctx) const { + vespalib::string dump(DumpContext &ctx) const override { vespalib::string str; str += "map("; str += _child->dump(ctx); @@ -60,14 +60,14 @@ public: str += ")"; return str; } - virtual void accept(NodeVisitor &visitor) const; - virtual size_t num_children() const { return 1; } - virtual const Node &get_child(size_t idx) const { + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 1; } + const Node &get_child(size_t idx) const override { (void) idx; assert(idx == 0); return *_child; } - virtual void detach_children(NodeHandler &handler) { + void detach_children(NodeHandler &handler) override { handler.handle(std::move(_child)); } }; @@ -81,7 +81,7 @@ public: TensorJoin(Node_UP lhs, Node_UP rhs, Function lambda) : _lhs(std::move(lhs)), _rhs(std::move(rhs)), _lambda(std::move(lambda)) {} const Function &lambda() const { return _lambda; } - virtual vespalib::string dump(DumpContext &ctx) const { + vespalib::string dump(DumpContext &ctx) const override { vespalib::string str; str += "join("; str += _lhs->dump(ctx); @@ -92,18 +92,135 @@ public: str += ")"; return str; } - virtual void accept(NodeVisitor &visitor) const; - virtual size_t num_children() const { return 2; } - virtual const Node &get_child(size_t idx) const { + void accept(NodeVisitor &visitor) const override ; + size_t num_children() const override { return 2; } + const Node &get_child(size_t idx) const override { assert(idx < 2); return (idx == 0) ? *_lhs : *_rhs; } - virtual void detach_children(NodeHandler &handler) { + void detach_children(NodeHandler &handler) override { handler.handle(std::move(_lhs)); handler.handle(std::move(_rhs)); } }; +enum class Aggr { AVG, COUNT, PROD, SUM, MAX, MIN }; +class AggrNames { +private: + static const AggrNames _instance; + std::map<vespalib::string,Aggr> _name_aggr_map; + std::map<Aggr,vespalib::string> _aggr_name_map; + void add(Aggr aggr, const vespalib::string &name); + AggrNames(); +public: + static const vespalib::string *name_of(Aggr aggr); + static const Aggr *from_name(const vespalib::string &name); +}; + +class TensorReduce : public Node { +private: + Node_UP _child; + Aggr _aggr; + std::vector<vespalib::string> _dimensions; +public: + TensorReduce(Node_UP child, Aggr aggr_in, std::vector<vespalib::string> dimensions_in) + : _child(std::move(child)), _aggr(aggr_in), _dimensions(std::move(dimensions_in)) {} + const std::vector<vespalib::string> &dimensions() const { return _dimensions; } + Aggr aggr() const { return _aggr; } + vespalib::string dump(DumpContext &ctx) const override { + vespalib::string str; + str += "reduce("; + str += _child->dump(ctx); + str += ","; + str += *AggrNames::name_of(_aggr); + for (const auto &dimension: _dimensions) { + str += ","; + str += dimension; + } + str += ")"; + return str; + } + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 1; } + const Node &get_child(size_t idx) const override { + assert(idx == 0); + return *_child; + } + void detach_children(NodeHandler &handler) override { + handler.handle(std::move(_child)); + } +}; + +class TensorRename : public Node { +private: + Node_UP _child; + std::vector<vespalib::string> _from; + std::vector<vespalib::string> _to; + static vespalib::string flatten(const std::vector<vespalib::string> &list) { + if (list.size() == 1) { + return list[0]; + } + vespalib::string str = "("; + for (size_t i = 0; i < list.size(); ++i) { + if (i > 0) { + str += ","; + } + str += list[i]; + } + str += ")"; + return str; + } +public: + TensorRename(Node_UP child, std::vector<vespalib::string> from_in, std::vector<vespalib::string> to_in) + : _child(std::move(child)), _from(std::move(from_in)), _to(std::move(to_in)) {} + const std::vector<vespalib::string> &from() const { return _from; } + const std::vector<vespalib::string> &to() const { return _to; } + vespalib::string dump(DumpContext &ctx) const override { + vespalib::string str; + str += "rename("; + str += _child->dump(ctx); + str += ","; + str += flatten(_from); + str += ","; + str += flatten(_to); + str += ")"; + return str; + } + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 1; } + const Node &get_child(size_t idx) const override { + assert(idx == 0); + return *_child; + } + void detach_children(NodeHandler &handler) override { + handler.handle(std::move(_child)); + } +}; + +class TensorLambda : public Leaf { +private: + ValueType _type; + Function _lambda; +public: + TensorLambda(ValueType type_in, Function lambda) + : _type(std::move(type_in)), _lambda(std::move(lambda)) {} + const ValueType &type() const { return _type; } + const Function &lambda() const { return _lambda; } + vespalib::string dump(DumpContext &) const override { + vespalib::string str = _type.to_spec(); + vespalib::string expr = _lambda.dump(); + if (starts_with(expr, "(")) { + str += expr; + } else { + str += "("; + str += expr; + str += ")"; + } + return str; + } + void accept(NodeVisitor &visitor) const override; +}; + } // namespace vespalib::eval::nodes } // namespace vespalib::eval } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp b/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp index f132cc7acbc..cd48aa4881f 100644 --- a/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp +++ b/vespalib/src/vespa/vespalib/eval/test/eval_spec.cpp @@ -10,9 +10,10 @@ namespace vespalib { namespace eval { namespace test { -const double my_nan = std::numeric_limits<double>::quiet_NaN(); -const double my_inf = std::numeric_limits<double>::infinity(); - +constexpr double my_nan = std::numeric_limits<double>::quiet_NaN(); +constexpr double my_inf = std::numeric_limits<double>::infinity(); +constexpr double my_error_value = 31212.0; + vespalib::string EvalSpec::EvalTest::as_string(const std::vector<vespalib::string> ¶m_names, const std::vector<double> ¶m_values, @@ -112,7 +113,6 @@ EvalSpec::add_function_call_cases() { .add_case({my_nan}, 1.0).add_case({my_inf}, 0.0).add_case({-my_inf}, 0.0); add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); }); add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); }); - add_rule({"a", -1.0, 1.0}, "sum(a)", [](double a){ return a; }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); }); @@ -122,6 +122,22 @@ EvalSpec::add_function_call_cases() { } void +EvalSpec::add_tensor_operation_cases() { + add_rule({"a", -1.0, 1.0}, "sum(a)", [](double a){ return a; }); + add_rule({"a", -1.0, 1.0}, "map(a,f(x)(sin(x)))", [](double x){ return std::sin(x); }); + add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x+x*3))", [](double x){ return (x + (x * 3)); }); + add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); }); + add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); }); + add_rule({"a", -1.0, 1.0}, "reduce(a,sum)", [](double a){ return a; }); + add_rule({"a", -1.0, 1.0}, "reduce(a,prod)", [](double a){ return a; }); + add_rule({"a", -1.0, 1.0}, "reduce(a,count)", [](double){ return 1.0; }); + add_rule({"a", -1.0, 1.0}, "rename(a,x,y)", [](double){ return my_error_value; }); + add_rule({"a", -1.0, 1.0}, "rename(a,(x,y),(y,x))", [](double){ return my_error_value; }); + add_expression({}, "tensor(x[10])(x)"); + add_expression({}, "tensor(x[10],y[10])(x==y)"); +} + +void EvalSpec::add_comparison_cases() { add_expression({"a", "b"}, "(a==b)") .add_case({my_nan, 2.0}, 0.0) diff --git a/vespalib/src/vespa/vespalib/eval/test/eval_spec.h b/vespalib/src/vespa/vespalib/eval/test/eval_spec.h index af1821c1e8d..582c3b1c1e5 100644 --- a/vespalib/src/vespa/vespalib/eval/test/eval_spec.h +++ b/vespalib/src/vespa/vespalib/eval/test/eval_spec.h @@ -120,20 +120,22 @@ public: virtual ~EvalTest() {} }; //------------------------------------------------------------------------- - void add_terminal_cases(); // a, 1.0 - void add_arithmetic_cases(); // a + b, a ^ b - void add_function_call_cases(); // cos(a), max(a, b) - void add_comparison_cases(); // a < b, c != d - void add_set_membership_cases(); // a in [x, y, z] - void add_boolean_cases(); // 1.0 && 0.0 - void add_if_cases(); // if (a < b, a, b) - void add_let_cases(); // let (a, b + 1, a * a) - void add_complex_cases(); // ... + void add_terminal_cases(); // a, 1.0 + void add_arithmetic_cases(); // a + b, a ^ b + void add_function_call_cases(); // cos(a), max(a, b) + void add_tensor_operation_cases(); // map(a,f(x)(sin(x))) + void add_comparison_cases(); // a < b, c != d + void add_set_membership_cases(); // a in [x, y, z] + void add_boolean_cases(); // 1.0 && 0.0 + void add_if_cases(); // if (a < b, a, b) + void add_let_cases(); // let (a, b + 1, a * a) + void add_complex_cases(); // ... //------------------------------------------------------------------------- void add_all_cases() { add_terminal_cases(); add_arithmetic_cases(); add_function_call_cases(); + add_tensor_operation_cases(); add_comparison_cases(); add_set_membership_cases(); add_boolean_cases(); |