diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-01-24 13:28:51 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-01-24 13:28:51 +0000 |
commit | 8408c3a5abaa1380aa68ff9e1d2c6051dc4b3ec9 (patch) | |
tree | c295ac512cf28818974b29434bfbaabcd9eb66a9 /eval | |
parent | 0935ce05d877d763e8a87c82f40c5441172a5ef6 (diff) |
added support for tensor lambdas
and some conformance testing
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/interpreted_function.cpp | 34 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/tensor_conformance.cpp | 18 |
2 files changed, 49 insertions, 3 deletions
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp index 47bd483bba4..9ad98ce6579 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.cpp +++ b/eval/src/vespa/eval/eval/interpreted_function.cpp @@ -12,6 +12,7 @@ #include "tensor_spec.h" #include "simple_tensor_engine.h" #include <vespa/vespalib/util/classname.h> +#include <vespa/eval/eval/llvm/compile_cache.h> namespace vespalib { namespace eval { @@ -164,6 +165,20 @@ void op_tensor_function_arg_arg(State &state, uint64_t param) { //----------------------------------------------------------------------------- +bool step_labels(std::vector<double> &labels, const ValueType &type) { + for (size_t idx = labels.size(); idx-- > 0; ) { + labels[idx] += 1.0; + if (size_t(labels[idx]) < type.dimensions()[idx].size) { + return true; + } else { + labels[idx] = 0.0; + } + } + return false; +} + +//----------------------------------------------------------------------------- + struct ProgramBuilder : public NodeVisitor, public NodeTraverser { std::vector<Instruction> &program; Stash &stash; @@ -282,9 +297,22 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { // TODO(havardpe): add actual evaluation program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); } - virtual void visit(const TensorLambda &) { - // TODO(havardpe): add actual evaluation - program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); + virtual void visit(const TensorLambda &node) { + const auto &type = node.type(); + TensorSpec spec(type.to_spec()); + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::ARRAY)); + auto fun = token.get()->get().get_function(); + std::vector<double> params(type.dimensions().size(), 0.0); + assert(token.get()->get().num_params() == params.size()); + do { + TensorSpec::Address addr; + for (size_t i = 0; i < params.size(); ++i) { + addr.emplace(type.dimensions()[i].name, size_t(params[i])); + } + spec.add(addr, fun(¶ms[0])); + } while (step_labels(params, type)); + auto tensor = tensor_engine.create(spec); + program.emplace_back(op_load_const, wrap_param<Value>(stash.create<TensorValue>(std::move(tensor)))); } virtual void visit(const TensorConcat &) { // TODO(havardpe): add actual evaluation diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index ae9b3e98a15..646fb19cc21 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -1098,6 +1098,23 @@ struct TestContext { //------------------------------------------------------------------------- + void test_tensor_lambda(const vespalib::string &expr, const TensorSpec &expect) { + EXPECT_EQUAL(Expr_V(expr).eval(engine).tensor(), expect); + } + + void test_tensor_lambda() { + TEST_DO(test_tensor_lambda("tensor(x[10])(x+1)", spec(x(10), N()))); + TEST_DO(test_tensor_lambda("tensor(x[5],y[4])(x*4+(y+1))", spec({x(5),y(4)}, N()))); + TEST_DO(test_tensor_lambda("tensor(x[5],y[4])(x==y)", spec({x(5),y(4)}, + Seq({ 1.0, 0.0, 0.0, 0.0, + 0.0, 1.0, 0.0, 0.0, + 0.0, 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, 1.0, + 0.0, 0.0, 0.0, 0.0})))); + } + + //------------------------------------------------------------------------- + void run_tests() { TEST_DO(test_tensor_create_type()); TEST_DO(test_tensor_equality()); @@ -1108,6 +1125,7 @@ struct TestContext { TEST_DO(test_dot_product()); TEST_DO(test_concat()); TEST_DO(test_rename()); + TEST_DO(test_tensor_lambda()); } }; |