aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-12-03 15:36:48 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-12-04 09:56:58 +0000
commit500c3d014b5a0a60008df8760592cbe19e566f09 (patch)
tree29f94ac4a49c6d148e9d29bb03d0f210230d8924 /eval
parent9ef0b57cf1e3b7dc2035ff710946b95509bef701 (diff)
tensor lambda is now syntactic sugar for tensor create
perform constant-value folding for tensor create
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/function/function_test.cpp21
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp8
-rw-r--r--eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp2
-rw-r--r--eval/src/vespa/eval/eval/function.cpp126
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp20
5 files changed, 125 insertions, 52 deletions
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp
index d93b57ab5b2..6932ba46eab 100644
--- a/eval/src/tests/eval/function/function_test.cpp
+++ b/eval/src/tests/eval/function/function_test.cpp
@@ -808,9 +808,9 @@ TEST("require that tensor rename dimension lists must have equal size") {
//-----------------------------------------------------------------------------
TEST("require that tensor lambda can be parsed") {
- EXPECT_EQUAL("tensor(x[10])(x)", Function::parse({}, "tensor(x[10])(x)").dump());
- EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({}, "tensor(x[10],y[10])(x==y)").dump());
- EXPECT_EQUAL("tensor(x[10],y[10])(x==y)", Function::parse({}, " tensor ( x [ 10 ] , y [ 10 ] ) ( x == y ) ").dump());
+ EXPECT_EQUAL("tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}", Function::parse({}, "tensor(x[3])(x)").dump());
+ EXPECT_EQUAL("tensor(x[2],y[2]):{{x:0,y:0}:(0==0),{x:0,y:1}:(0==1),{x:1,y:0}:(1==0),{x:1,y:1}:(1==1)}",
+ Function::parse({}, " tensor ( x [ 2 ] , y [ 2 ] ) ( x == y ) ").dump());
}
TEST("require that tensor lambda requires appropriate tensor type") {
@@ -819,8 +819,9 @@ TEST("require that tensor lambda requires appropriate tensor type") {
verify_error("tensor()(x==y)", "[tensor()]...[invalid tensor type]...[(x==y)]");
}
-TEST("require that tensor lambda can only use dimension names") {
- verify_error("tensor(x[10],y[10])(x==z)", "[tensor(x[10],y[10])(x==z]...[unknown symbol: 'z']...[)]");
+TEST("require that tensor lambda can use non-dimension symbols") {
+ EXPECT_EQUAL("tensor(x[2]):{{x:0}:(0==a),{x:1}:(1==a)}",
+ Function::parse({"a"}, "tensor(x[2])(x==a)").dump());
}
//-----------------------------------------------------------------------------
@@ -961,6 +962,16 @@ TEST("require that empty tensor peek is not allowed") {
//-----------------------------------------------------------------------------
+TEST("require that nested tensor lambda using tensor peek can be parsed") {
+ vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:(0)},"
+ "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:(1)}}");
+ EXPECT_EQUAL(Function::parse(expect).dump(), expect);
+ auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})");
+ EXPECT_EQUAL(fun.dump(), expect);
+}
+
+//-----------------------------------------------------------------------------
+
TEST("require that tensor concat can be parsed") {
EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, "concat(a,b,d)").dump());
EXPECT_EQUAL("concat(a,b,d)", Function::parse({"a", "b"}, " concat ( a , b , d ) ").dump());
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index 9112a8b1712..c9108ee74ce 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -151,24 +151,22 @@ TEST("require that basic addition works") {
TEST("require that functions with non-compilable lambdas cannot be interpreted") {
auto good_map = Function::parse("map(a,f(x)(x+1))");
auto good_join = Function::parse("join(a,b,f(x,y)(x+y))");
- auto good_tensor = Function::parse("tensor(a[10],b[10])(a+b)");
auto bad_map = Function::parse("map(a,f(x)(map(x,f(i)(i+1))))");
auto bad_join = Function::parse("join(a,b,f(x,y)(join(x,y,f(i,j)(i+j))))");
- auto bad_tensor = Function::parse("tensor(a[10],b[10])(join(a,b,f(i,j)(i+j)))");
- for (const Function *good: {&good_map, &good_join, &good_tensor}) {
+ for (const Function *good: {&good_map, &good_join}) {
if (!EXPECT_TRUE(!good->has_error())) {
fprintf(stderr, "parse error: %s\n", good->get_error().c_str());
}
EXPECT_TRUE(!InterpretedFunction::detect_issues(*good));
}
- for (const Function *bad: {&bad_map, &bad_join, &bad_tensor}) {
+ for (const Function *bad: {&bad_map, &bad_join}) {
if (!EXPECT_TRUE(!bad->has_error())) {
fprintf(stderr, "parse error: %s\n", bad->get_error().c_str());
}
EXPECT_TRUE(InterpretedFunction::detect_issues(*bad));
}
std::cerr << "Example function issues:" << std::endl
- << InterpretedFunction::detect_issues(bad_tensor).list
+ << InterpretedFunction::detect_issues(bad_join).list
<< std::endl;
}
diff --git a/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp b/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp
index 0fa09e6c46e..6ae0681d2ff 100644
--- a/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp
+++ b/eval/src/tests/tensor/dense_tensor_create_function/dense_tensor_create_function_test.cpp
@@ -43,7 +43,7 @@ void verify(const vespalib::string &expr, size_t expect_optimized_cnt, size_t ex
//-----------------------------------------------------------------------------
TEST("require that tensor create can be optimized") {
- TEST_DO(verify("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", 1, 0));
+ TEST_DO(verify("tensor(x[3]):{{x:0}:1,{x:1}:2,{x:2}:3}", 0, 0)); // NB: const value
TEST_DO(verify("tensor(x[3]):{{x:0}:a,{x:1}:b,{x:2}:c}", 1, 0));
TEST_DO(verify("tensor<float>(x[3]):{{x:0}:a,{x:1}:b,{x:2}:c}", 1, 0));
TEST_DO(verify("tensor(x[3]):{{x:0}:a+b,{x:1}:b-c,{x:2}:c*a}", 1, 0));
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index 56899debc26..a7eaf14e9cd 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -10,6 +10,7 @@
#include <vespa/vespalib/locale/c.h>
#include <cctype>
#include <map>
+#include <set>
namespace vespalib::eval {
@@ -30,17 +31,28 @@ bool has_duplicates(const std::vector<vespalib::string> &list) {
return false;
}
+bool step_labels(std::vector<size_t> &labels, const ValueType &type) {
+ for (size_t idx = labels.size(); idx-- > 0; ) {
+ if (++labels[idx] < type.dimensions()[idx].size) {
+ return true;
+ } else {
+ labels[idx] = 0;
+ }
+ }
+ return false;
+}
+
//-----------------------------------------------------------------------------
class Params {
private:
std::map<vespalib::string,size_t> _params;
protected:
- size_t lookup(vespalib::stringref token) const {
+ size_t lookup(const vespalib::string &token) const {
auto result = _params.find(token);
return (result == _params.end()) ? UNDEF : result->second;
}
- size_t lookup_add(vespalib::stringref token) {
+ size_t lookup_add(const vespalib::string &token) {
size_t result = lookup(token);
if (result == UNDEF) {
result = _params.size();
@@ -51,7 +63,7 @@ protected:
public:
static const size_t UNDEF = -1;
virtual bool implicit() const = 0;
- virtual size_t resolve(vespalib::stringref token) const = 0;
+ virtual size_t resolve(const vespalib::string &token) const = 0;
std::vector<vespalib::string> extract() const {
std::vector<vespalib::string> params_out;
params_out.resize(_params.size());
@@ -71,30 +83,27 @@ struct ExplicitParams : Params {
}
}
bool implicit() const override { return false; }
- size_t resolve(vespalib::stringref token) const override {
+ size_t resolve(const vespalib::string &token) const override {
return lookup(token);
}
};
struct ImplicitParams : Params {
bool implicit() const override { return true; }
- size_t resolve(vespalib::stringref token) const override {
+ size_t resolve(const vespalib::string &token) const override {
return const_cast<ImplicitParams*>(this)->lookup_add(token);
}
};
//-----------------------------------------------------------------------------
-class ResolveContext
-{
-private:
- const Params &_params;
- const SymbolExtractor *_symbol_extractor;
-public:
- ResolveContext(const Params &params, const SymbolExtractor *symbol_extractor)
- : _params(params), _symbol_extractor(symbol_extractor) {}
- size_t resolve_param(const vespalib::string &name) const { return _params.resolve(name); }
- const SymbolExtractor *symbol_extractor() const { return _symbol_extractor; }
+struct ResolveContext {
+ const Params &params;
+ const SymbolExtractor *symbol_extractor;
+ const std::map<vespalib::string,size_t*> *aliases;
+ ResolveContext(const Params &params_in, const SymbolExtractor *symbol_extractor_in,
+ const std::map<vespalib::string,size_t*> *aliases_in)
+ : params(params_in), symbol_extractor(symbol_extractor_in), aliases(aliases_in) {}
};
class ParseContext
@@ -118,7 +127,7 @@ public:
_scratch(), _failure(),
_expression_stack(), _operator_stack(),
_operator_mark(0),
- _resolve_stack({ResolveContext(params, symbol_extractor)})
+ _resolve_stack({ResolveContext(params, symbol_extractor, nullptr)})
{
if (_pos < _end) {
_curr = *_pos;
@@ -141,13 +150,34 @@ public:
return _resolve_stack.back();
}
- void push_resolve_context(const Params &params, const SymbolExtractor *symbol_extractor) {
- _resolve_stack.emplace_back(params, symbol_extractor);
+ void push_resolve_context(const Params &params) {
+ _resolve_stack.emplace_back(params, nullptr, nullptr);
+ }
+
+ void push_resolve_context(const std::map<vespalib::string,size_t*> &aliases) {
+ assert(!_resolve_stack.empty());
+ _resolve_stack.emplace_back(resolver().params, resolver().symbol_extractor, &aliases);
}
void pop_resolve_context() {
assert(!_resolve_stack.empty());
_resolve_stack.pop_back();
+ assert(!_resolve_stack.empty());
+ }
+
+ bool has_alias(const vespalib::string &ident) const {
+ if (auto aliases = resolver().aliases) {
+ return (aliases->find(ident) != aliases->end());
+ }
+ return false;
+ }
+
+ size_t get_alias_value(const vespalib::string &ident) const {
+ auto aliases = resolver().aliases;
+ assert(aliases);
+ auto pos = aliases->find(ident);
+ assert(pos != aliases->end());
+ return *(pos->second);
}
void fail(const vespalib::string &msg) {
@@ -209,19 +239,18 @@ public:
}
size_t resolve_parameter(const vespalib::string &name) const {
- return resolver().resolve_param(name);
+ return resolver().params.resolve(name);
}
void extract_symbol(vespalib::string &symbol_out, InputMark before_symbol) {
- const SymbolExtractor *symbol_extractor = resolver().symbol_extractor();
- if (symbol_extractor == nullptr) {
+ if (resolver().symbol_extractor == nullptr) {
return;
}
symbol_out.clear();
restore_input_mark(before_symbol);
if (!eos()) {
const char *new_pos = nullptr;
- symbol_extractor->extract_symbol(_pos, _end, new_pos, symbol_out);
+ resolver().symbol_extractor->extract_symbol(_pos, _end, new_pos, symbol_out);
if ((new_pos != nullptr) && (new_pos > _pos) && (new_pos <= _end)) {
_pos = new_pos;
_curr = (_pos < _end) ? *_pos : 0;
@@ -232,7 +261,7 @@ public:
}
Node_UP get_result() {
- if (!eos() || (num_expressions() != 1) || (num_operators() > 0)) {
+ if (!eos() || (num_expressions() != 1) || (num_operators() > 0) || (_resolve_stack.size() != 1)) {
fail("incomplete parse");
}
if (!_failure.empty()) {
@@ -515,7 +544,7 @@ Function parse_lambda(ParseContext &ctx, size_t num_params) {
ctx.eat('f');
auto param_names = get_ident_list(ctx, true);
ExplicitParams params(param_names);
- ctx.push_resolve_context(params, nullptr);
+ ctx.push_resolve_context(params);
ctx.skip_spaces();
ctx.eat('(');
Node_UP lambda_root = get_expression(ctx);
@@ -694,15 +723,30 @@ void parse_tensor_create(ParseContext &ctx, const ValueType &type,
}
void parse_tensor_lambda(ParseContext &ctx, const ValueType &type) {
- auto param_names = type.dimension_names();
- ExplicitParams params(param_names);
- ctx.push_resolve_context(params, nullptr);
ctx.skip_spaces();
ctx.eat('(');
- Function lambda(get_expression(ctx), std::move(param_names));
+ ParseContext::InputMark before_expr = ctx.get_input_mark();
+ std::vector<size_t> params(type.dimensions().size(), 0);
+ std::map<vespalib::string,size_t*> my_aliases;
+ if (auto parent_aliases = ctx.resolver().aliases) {
+ my_aliases = *parent_aliases;
+ }
+ for (size_t i = 0; i < params.size(); ++i) {
+ my_aliases.emplace(type.dimensions()[i].name, &params[i]);
+ }
+ ctx.push_resolve_context(my_aliases);
+ nodes::TensorCreate::Spec create_spec;
+ do {
+ ctx.restore_input_mark(before_expr);
+ TensorSpec::Address address;
+ for (size_t i = 0; i < params.size(); ++i) {
+ address.emplace(type.dimensions()[i].name, params[i]);
+ }
+ create_spec.emplace(std::move(address), get_expression(ctx));
+ } while (!ctx.failed() && step_labels(params, type));
ctx.eat(')');
ctx.pop_resolve_context();
- ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda)));
+ ctx.push_expression(std::make_unique<nodes::TensorCreate>(type, std::move(create_spec)));
}
bool maybe_parse_tensor_generator(ParseContext &ctx) {
@@ -806,23 +850,25 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) {
return false;
}
-size_t parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::InputMark before_name) {
- ctx.extract_symbol(name, before_name);
- return ctx.resolve_parameter(name);
-}
-
void parse_symbol_or_call(ParseContext &ctx) {
ParseContext::InputMark before_name = ctx.get_input_mark();
vespalib::string name = get_ident(ctx, true);
bool was_tensor_generate = ((name == "tensor") && maybe_parse_tensor_generator(ctx));
if (!was_tensor_generate && !maybe_parse_call(ctx, name)) {
- size_t id = parse_symbol(ctx, name, before_name);
- if (name.empty()) {
- ctx.fail("missing value");
- } else if (id == Params::UNDEF) {
- ctx.fail(make_string("unknown symbol: '%s'", name.c_str()));
+ ctx.extract_symbol(name, before_name);
+ if (ctx.has_alias(name)) {
+ ctx.push_expression(Node_UP(new nodes::Number(ctx.get_alias_value(name))));
} else {
- ctx.push_expression(Node_UP(new nodes::Symbol(id)));
+ if (name.empty()) {
+ ctx.fail("missing value");
+ } else {
+ size_t id = ctx.resolve_parameter(name);
+ if (id == Params::UNDEF) {
+ ctx.fail(make_string("unknown symbol: '%s'", name.c_str()));
+ } else {
+ ctx.push_expression(Node_UP(new nodes::Symbol(id)));
+ }
+ }
}
}
}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index 3b74a7a0e23..0087bf5efe7 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -97,6 +97,24 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::concat(a, b, dimension, stash);
}
+ bool maybe_make_const(const Node &node) {
+ if (auto create = as<TensorCreate>(node)) {
+ bool is_const = true;
+ for (size_t i = 0; i < create->num_children(); ++i) {
+ is_const &= create->get_child(i).is_const();
+ }
+ if (is_const) {
+ TensorSpec spec(create->type().to_spec());
+ for (size_t i = 0; i < create->num_children(); ++i) {
+ spec.add(create->get_child_address(i), create->get_child(i).get_const_value());
+ }
+ make_const(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec)));
+ return true;
+ }
+ }
+ return false;
+ }
+
void make_create(const TensorCreate &node) {
assert(stack.size() >= node.num_children());
std::map<TensorSpec::Address, tensor_function::Node::CREF> spec;
@@ -348,7 +366,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
//-------------------------------------------------------------------------
- bool open(const Node &) override { return true; }
+ bool open(const Node &node) override { return !maybe_make_const(node); }
void close(const Node &node) override { node.accept(*this); }
};