diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-12-03 15:36:48 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-12-04 09:56:58 +0000 |
commit | 500c3d014b5a0a60008df8760592cbe19e566f09 (patch) | |
tree | 29f94ac4a49c6d148e9d29bb03d0f210230d8924 /eval | |
parent | 9ef0b57cf1e3b7dc2035ff710946b95509bef701 (diff) |
tensor lambda is now syntactic sugar for tensor create
perform constant-value folding for tensor create
Diffstat (limited to 'eval')
5 files changed, 125 insertions, 52 deletions
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index d93b57ab5b2..6932ba46eab 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -808,9 +808,9 @@ TEST("require that tensor rename dimension lists must have equal size") { //----------------------------------------------------------------------------- TEST("require that tensor lambda can be parsed") { - EXPECT_EQUAL("tensor(x[10])(x)", Function::parse({}, "tensor(x[10])(x)").dump()); - EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({}, "tensor(x[10],y[10])(x==y)").dump()); - EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({}, " tensor ( x [ 10 ] , y [ 10 ] ) ( x == y ) ").dump()); + EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)").dump()); + EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:(0==0),{x:0,y:1}:(0==1),{x:1,y:0}:(1==0),{x:1,y:1}:(1==1)}", + Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ").dump()); } TEST("require that tensor lambda requires appropriate tensor type") { @@ -819,8 +819,9 @@ TEST("require that tensor lambda requires appropriate tensor type") { verify_error("tensor()(x==y)", "[tensor()]...[invalid tensor type]...[(x==y)]"); } -TEST("require that tensor lambda can only use dimension names") { - verify_error("tensor(x[10],y[10])(x==z)", "[tensor(x[10],y[10])(x==z]...[unknown symbol: 'z']...[)]"); +TEST("require that tensor lambda can use non-dimension symbols") { + EXPECT_EQUAL("tensor(x[2]):{{x:0}:(0==a),{x:1}:(1==a)}", + Function::parse({"a"}, "tensor(x[2])(x==a)").dump()); } //----------------------------------------------------------------------------- @@ -961,6 +962,16 @@ TEST("require that empty tensor peek is not allowed") { //----------------------------------------------------------------------------- +TEST("require that nested tensor lambda using tensor peek can be parsed") { + vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:(0)}," + "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:(1)}}"); + EXPECT_EQUAL(Function::parse(expect).dump(), expect); + auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})"); + EXPECT_EQUAL(fun.dump(), expect); +} + +//----------------------------------------------------------------------------- + TEST("require that tensor concat can be parsed") { EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, "concat(a,b,d)").dump()); EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, " concat ( a , b , d ) ").dump()); diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 9112a8b1712..c9108ee74ce 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -151,24 +151,22 @@ TEST("require that basic addition works") { TEST("require that functions with non-compilable lambdas cannot be interpreted") { auto good_map = Function::parse("map(a,f(x)(x+1))"); auto good_join = Function::parse("join(a,b,f(x,y)(x+y))"); - auto good_tensor = Function::parse("tensor(a[10],b[10])(a+b)"); auto bad_map = Function::parse("map(a,f(x)(map(x,f(i)(i+1))))"); auto bad_join = Function::parse("join(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))"); - auto bad_tensor = Function::parse("tensor(a[10],b[10])(join(a,b,f(i,j)(i+j)))"); - for (const Function *good: {&good_map, &good_join, &good_tensor}) { + for (const Function *good: {&good_map, &good_join}) { if (!EXPECT_TRUE(!good->has_error())) { fprintf(stderr, "parse error: %s\n", good->get_error().c_str()); } EXPECT_TRUE(!InterpretedFunction::detect_issues(*good)); } - for (const Function *bad: {&bad_map, &bad_join, &bad_tensor}) { + for (const Function *bad: {&bad_map, &bad_join}) { if (!EXPECT_TRUE(!bad->has_error())) { fprintf(stderr, "parse error: %s\n", bad->get_error().c_str()); } EXPECT_TRUE(InterpretedFunction::detect_issues(*bad)); } std::cerr << "Example function issues:" << std::endl - << InterpretedFunction::detect_issues(bad_tensor).list + << InterpretedFunction::detect_issues(bad_join).list << std::endl; } diff --git a/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp b/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp index 0fa09e6c46e..6ae0681d2ff 100644 --- a/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp +++ b/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp @@ -43,7 +43,7 @@ void verify(const vespalib::string &expr, size_t expect_optimized_cnt, size_t ex //----------------------------------------------------------------------------- TEST("require that tensor create can be optimized") { - TEST_DO(verify("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", 1, 0)); + TEST_DO(verify("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", 0, 0)); // NB: const value TEST_DO(verify("tensor(x[3]):{{x:0}:a,{x:1}:b,{x:2}:c}", 1, 0)); TEST_DO(verify("tensor<float>(x[3]):{{x:0}:a,{x:1}:b,{x:2}:c}", 1, 0)); TEST_DO(verify("tensor(x[3]):{{x:0}:a+b,{x:1}:b-c,{x:2}:c*a}", 1, 0)); diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index 56899debc26..a7eaf14e9cd 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -10,6 +10,7 @@ #include <vespa/vespalib/locale/c.h> #include <cctype> #include <map> +#include <set> namespace vespalib::eval { @@ -30,17 +31,28 @@ bool has_duplicates(const std::vector<vespalib::string> &list) { return false; } +bool step_labels(std::vector<size_t> &labels, const ValueType &type) { + for (size_t idx = labels.size(); idx-- > 0; ) { + if (++labels[idx] < type.dimensions()[idx].size) { + return true; + } else { + labels[idx] = 0; + } + } + return false; +} + //----------------------------------------------------------------------------- class Params { private: std::map<vespalib::string,size_t> _params; protected: - size_t lookup(vespalib::stringref token) const { + size_t lookup(const vespalib::string &token) const { auto result = _params.find(token); return (result == _params.end()) ? UNDEF : result->second; } - size_t lookup_add(vespalib::stringref token) { + size_t lookup_add(const vespalib::string &token) { size_t result = lookup(token); if (result == UNDEF) { result = _params.size(); @@ -51,7 +63,7 @@ protected: public: static const size_t UNDEF = -1; virtual bool implicit() const = 0; - virtual size_t resolve(vespalib::stringref token) const = 0; + virtual size_t resolve(const vespalib::string &token) const = 0; std::vector<vespalib::string> extract() const { std::vector<vespalib::string> params_out; params_out.resize(_params.size()); @@ -71,30 +83,27 @@ struct ExplicitParams : Params { } } bool implicit() const override { return false; } - size_t resolve(vespalib::stringref token) const override { + size_t resolve(const vespalib::string &token) const override { return lookup(token); } }; struct ImplicitParams : Params { bool implicit() const override { return true; } - size_t resolve(vespalib::stringref token) const override { + size_t resolve(const vespalib::string &token) const override { return const_cast<ImplicitParams*>(this)->lookup_add(token); } }; //----------------------------------------------------------------------------- -class ResolveContext -{ -private: - const Params &_params; - const SymbolExtractor *_symbol_extractor; -public: - ResolveContext(const Params ¶ms, const SymbolExtractor *symbol_extractor) - : _params(params), _symbol_extractor(symbol_extractor) {} - size_t resolve_param(const vespalib::string &name) const { return _params.resolve(name); } - const SymbolExtractor *symbol_extractor() const { return _symbol_extractor; } +struct ResolveContext { + const Params ¶ms; + const SymbolExtractor *symbol_extractor; + const std::map<vespalib::string,size_t*> *aliases; + ResolveContext(const Params ¶ms_in, const SymbolExtractor *symbol_extractor_in, + const std::map<vespalib::string,size_t*> *aliases_in) + : params(params_in), symbol_extractor(symbol_extractor_in), aliases(aliases_in) {} }; class ParseContext @@ -118,7 +127,7 @@ public: _scratch(), _failure(), _expression_stack(), _operator_stack(), _operator_mark(0), - _resolve_stack({ResolveContext(params, symbol_extractor)}) + _resolve_stack({ResolveContext(params, symbol_extractor, nullptr)}) { if (_pos < _end) { _curr = *_pos; @@ -141,13 +150,34 @@ public: return _resolve_stack.back(); } - void push_resolve_context(const Params ¶ms, const SymbolExtractor *symbol_extractor) { - _resolve_stack.emplace_back(params, symbol_extractor); + void push_resolve_context(const Params ¶ms) { + _resolve_stack.emplace_back(params, nullptr, nullptr); + } + + void push_resolve_context(const std::map<vespalib::string,size_t*> &aliases) { + assert(!_resolve_stack.empty()); + _resolve_stack.emplace_back(resolver().params, resolver().symbol_extractor, &aliases); } void pop_resolve_context() { assert(!_resolve_stack.empty()); _resolve_stack.pop_back(); + assert(!_resolve_stack.empty()); + } + + bool has_alias(const vespalib::string &ident) const { + if (auto aliases = resolver().aliases) { + return (aliases->find(ident) != aliases->end()); + } + return false; + } + + size_t get_alias_value(const vespalib::string &ident) const { + auto aliases = resolver().aliases; + assert(aliases); + auto pos = aliases->find(ident); + assert(pos != aliases->end()); + return *(pos->second); } void fail(const vespalib::string &msg) { @@ -209,19 +239,18 @@ public: } size_t resolve_parameter(const vespalib::string &name) const { - return resolver().resolve_param(name); + return resolver().params.resolve(name); } void extract_symbol(vespalib::string &symbol_out, InputMark before_symbol) { - const SymbolExtractor *symbol_extractor = resolver().symbol_extractor(); - if (symbol_extractor == nullptr) { + if (resolver().symbol_extractor == nullptr) { return; } symbol_out.clear(); restore_input_mark(before_symbol); if (!eos()) { const char *new_pos = nullptr; - symbol_extractor->extract_symbol(_pos, _end, new_pos, symbol_out); + resolver().symbol_extractor->extract_symbol(_pos, _end, new_pos, symbol_out); if ((new_pos != nullptr) && (new_pos > _pos) && (new_pos <= _end)) { _pos = new_pos; _curr = (_pos < _end) ? *_pos : 0; @@ -232,7 +261,7 @@ public: } Node_UP get_result() { - if (!eos() || (num_expressions() != 1) || (num_operators() > 0)) { + if (!eos() || (num_expressions() != 1) || (num_operators() > 0) || (_resolve_stack.size() != 1)) { fail("incomplete parse"); } if (!_failure.empty()) { @@ -515,7 +544,7 @@ Function parse_lambda(ParseContext &ctx, size_t num_params) { ctx.eat('f'); auto param_names = get_ident_list(ctx, true); ExplicitParams params(param_names); - ctx.push_resolve_context(params, nullptr); + ctx.push_resolve_context(params); ctx.skip_spaces(); ctx.eat('('); Node_UP lambda_root = get_expression(ctx); @@ -694,15 +723,30 @@ void parse_tensor_create(ParseContext &ctx, const ValueType &type, } void parse_tensor_lambda(ParseContext &ctx, const ValueType &type) { - auto param_names = type.dimension_names(); - ExplicitParams params(param_names); - ctx.push_resolve_context(params, nullptr); ctx.skip_spaces(); ctx.eat('('); - Function lambda(get_expression(ctx), std::move(param_names)); + ParseContext::InputMark before_expr = ctx.get_input_mark(); + std::vector<size_t> params(type.dimensions().size(), 0); + std::map<vespalib::string,size_t*> my_aliases; + if (auto parent_aliases = ctx.resolver().aliases) { + my_aliases = *parent_aliases; + } + for (size_t i = 0; i < params.size(); ++i) { + my_aliases.emplace(type.dimensions()[i].name, ¶ms[i]); + } + ctx.push_resolve_context(my_aliases); + nodes::TensorCreate::Spec create_spec; + do { + ctx.restore_input_mark(before_expr); + TensorSpec::Address address; + for (size_t i = 0; i < params.size(); ++i) { + address.emplace(type.dimensions()[i].name, params[i]); + } + create_spec.emplace(std::move(address), get_expression(ctx)); + } while (!ctx.failed() && step_labels(params, type)); ctx.eat(')'); ctx.pop_resolve_context(); - ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda))); + ctx.push_expression(std::make_unique<nodes::TensorCreate>(type, std::move(create_spec))); } bool maybe_parse_tensor_generator(ParseContext &ctx) { @@ -806,23 +850,25 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) { return false; } -size_t parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::InputMark before_name) { - ctx.extract_symbol(name, before_name); - return ctx.resolve_parameter(name); -} - void parse_symbol_or_call(ParseContext &ctx) { ParseContext::InputMark before_name = ctx.get_input_mark(); vespalib::string name = get_ident(ctx, true); bool was_tensor_generate = ((name == "tensor") && maybe_parse_tensor_generator(ctx)); if (!was_tensor_generate && !maybe_parse_call(ctx, name)) { - size_t id = parse_symbol(ctx, name, before_name); - if (name.empty()) { - ctx.fail("missing value"); - } else if (id == Params::UNDEF) { - ctx.fail(make_string("unknown symbol: '%s'", name.c_str())); + ctx.extract_symbol(name, before_name); + if (ctx.has_alias(name)) { + ctx.push_expression(Node_UP(new nodes::Number(ctx.get_alias_value(name)))); } else { - ctx.push_expression(Node_UP(new nodes::Symbol(id))); + if (name.empty()) { + ctx.fail("missing value"); + } else { + size_t id = ctx.resolve_parameter(name); + if (id == Params::UNDEF) { + ctx.fail(make_string("unknown symbol: '%s'", name.c_str())); + } else { + ctx.push_expression(Node_UP(new nodes::Symbol(id))); + } + } } } } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 3b74a7a0e23..0087bf5efe7 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -97,6 +97,24 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::concat(a, b, dimension, stash); } + bool maybe_make_const(const Node &node) { + if (auto create = as<TensorCreate>(node)) { + bool is_const = true; + for (size_t i = 0; i < create->num_children(); ++i) { + is_const &= create->get_child(i).is_const(); + } + if (is_const) { + TensorSpec spec(create->type().to_spec()); + for (size_t i = 0; i < create->num_children(); ++i) { + spec.add(create->get_child_address(i), create->get_child(i).get_const_value()); + } + make_const(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec))); + return true; + } + } + return false; + } + void make_create(const TensorCreate &node) { assert(stack.size() >= node.num_children()); std::map<TensorSpec::Address, tensor_function::Node::CREF> spec; @@ -348,7 +366,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { //------------------------------------------------------------------------- - bool open(const Node &) override { return true; } + bool open(const Node &node) override { return !maybe_make_const(node); } void close(const Node &node) override { node.accept(*this); } }; |