summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-01-24 13:28:51 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-01-24 13:28:51 +0000
commit8408c3a5abaa1380aa68ff9e1d2c6051dc4b3ec9 (patch)
treec295ac512cf28818974b29434bfbaabcd9eb66a9 /eval
parent0935ce05d877d763e8a87c82f40c5441172a5ef6 (diff)
added support for tensor lambdas
and some conformance testing
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp34
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp18
2 files changed, 49 insertions, 3 deletions
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index 47bd483bba4..9ad98ce6579 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -12,6 +12,7 @@
#include "tensor_spec.h"
#include "simple_tensor_engine.h"
#include <vespa/vespalib/util/classname.h>
+#include <vespa/eval/eval/llvm/compile_cache.h>
namespace vespalib {
namespace eval {
@@ -164,6 +165,20 @@ void op_tensor_function_arg_arg(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
+bool step_labels(std::vector<double> &labels, const ValueType &type) {
+ for (size_t idx = labels.size(); idx-- > 0; ) {
+ labels[idx] += 1.0;
+ if (size_t(labels[idx]) < type.dimensions()[idx].size) {
+ return true;
+ } else {
+ labels[idx] = 0.0;
+ }
+ }
+ return false;
+}
+
+//-----------------------------------------------------------------------------
+
struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
std::vector<Instruction> &program;
Stash &stash;
@@ -282,9 +297,22 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
// TODO(havardpe): add actual evaluation
program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
}
- virtual void visit(const TensorLambda &) {
- // TODO(havardpe): add actual evaluation
- program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
+ virtual void visit(const TensorLambda &node) {
+ const auto &type = node.type();
+ TensorSpec spec(type.to_spec());
+ const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::ARRAY));
+ auto fun = token.get()->get().get_function();
+ std::vector<double> params(type.dimensions().size(), 0.0);
+ assert(token.get()->get().num_params() == params.size());
+ do {
+ TensorSpec::Address addr;
+ for (size_t i = 0; i < params.size(); ++i) {
+ addr.emplace(type.dimensions()[i].name, size_t(params[i]));
+ }
+ spec.add(addr, fun(&params[0]));
+ } while (step_labels(params, type));
+ auto tensor = tensor_engine.create(spec);
+ program.emplace_back(op_load_const, wrap_param<Value>(stash.create<TensorValue>(std::move(tensor))));
}
virtual void visit(const TensorConcat &) {
// TODO(havardpe): add actual evaluation
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index ae9b3e98a15..646fb19cc21 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -1098,6 +1098,23 @@ struct TestContext {
//-------------------------------------------------------------------------
+ void test_tensor_lambda(const vespalib::string &expr, const TensorSpec &expect) {
+ EXPECT_EQUAL(Expr_V(expr).eval(engine).tensor(), expect);
+ }
+
+ void test_tensor_lambda() {
+ TEST_DO(test_tensor_lambda("tensor(x[10])(x+1)", spec(x(10), N())));
+ TEST_DO(test_tensor_lambda("tensor(x[5],y[4])(x*4+(y+1))", spec({x(5),y(4)}, N())));
+ TEST_DO(test_tensor_lambda("tensor(x[5],y[4])(x==y)", spec({x(5),y(4)},
+ Seq({ 1.0, 0.0, 0.0, 0.0,
+ 0.0, 1.0, 0.0, 0.0,
+ 0.0, 0.0, 1.0, 0.0,
+ 0.0, 0.0, 0.0, 1.0,
+ 0.0, 0.0, 0.0, 0.0}))));
+ }
+
+ //-------------------------------------------------------------------------
+
void run_tests() {
TEST_DO(test_tensor_create_type());
TEST_DO(test_tensor_equality());
@@ -1108,6 +1125,7 @@ struct TestContext {
TEST_DO(test_dot_product());
TEST_DO(test_concat());
TEST_DO(test_rename());
+ TEST_DO(test_tensor_lambda());
}
};