aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-31 16:24:08 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-31 16:24:08 +0000
commit5e9a28f18304c19d9622a59d044a7ad560cbb48b (patch)
treee75099de72dc74d194e4f067c643a26d5a56d1a1 /eval
parent6b60b5d9e9a97e6db1c7499c9fc338538bac309e (diff)
use interpreted function to evaluate tensor function
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp8
1 files changed, 6 insertions, 2 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index b2df7eddd46..0308b24e742 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -18,8 +18,10 @@ struct EvalCtx {
ErrorValue error;
std::vector<Value::UP> tensors;
std::vector<Value::CREF> params;
+ InterpretedFunction::UP ifun;
+ std::unique_ptr<InterpretedFunction::Context> ictx;
EvalCtx(const TensorEngine &engine_in)
- : engine(engine_in), stash(), error(), tensors() {}
+ : engine(engine_in), stash(), error(), tensors(), params(), ifun(), ictx() {}
~EvalCtx() {}
size_t add_tensor(Value::UP tensor) {
size_t id = params.size();
@@ -32,7 +34,9 @@ struct EvalCtx {
tensors[idx] = std::move(tensor);
}
const Value &eval(const TensorFunction &fun) {
- return fun.eval(engine, SimpleObjectParams(params), stash);
+ ifun = std::make_unique<InterpretedFunction>(engine, fun);
+ ictx = std::make_unique<InterpretedFunction::Context>(*ifun);
+ return ifun->eval(*ictx, SimpleObjectParams(params));
}
const TensorFunction &compile(const tensor_function::Node &expr) {
return engine.optimize(expr, stash);