summaryrefslogtreecommitdiffstats
path: root/eval/src/tests
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-18 13:31:57 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-18 13:31:57 +0000
commit04ed2c4337ed3398bb110e9a845efcd2eb954577 (patch)
tree8c3b1779721a1ff61dd3e5255fce3159ad80bd45 /eval/src/tests
parent15b416446c7a6e674a14b5de0e0243fdab83f340 (diff)
run cross-language tensor conformance tests using tensor functions
also pass tensor engine to tensor function eval function
Diffstat (limited to 'eval/src/tests')
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp2
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp3
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp4
4 files changed, 6 insertions, 5 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index a4222df6e00..802f9555360 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -98,7 +98,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest {
Stash stash;
NodeTypes node_types = NodeTypes(function, std::vector<ValueType>(params.params.size(), ValueType::double_type()));
const auto &tfun = make_tensor_function(engine, function.root(), node_types, stash);
- const Value &result_value = tfun.eval(params, stash);
+ const Value &result_value = tfun.eval(engine, params, stash);
report_result(result_value.is_double(), result_value.as_double(), expected_result, description);
}
};
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 076dcbc8e28..4e52fd8b47b 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -32,7 +32,7 @@ struct EvalCtx {
tensors[idx] = std::move(tensor);
}
const Value &eval(const TensorFunction &fun) {
- return fun.eval(SimpleObjectParams(params), stash);
+ return fun.eval(engine, SimpleObjectParams(params), stash);
}
const TensorFunction &compile(const tensor_function::Node &expr) {
return engine.compile(expr, stash);
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index 3463bee3447..8b4b1497243 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -2,6 +2,7 @@
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_dot_product_function.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_builder.h>
@@ -78,7 +79,7 @@ struct Fixture
~Fixture();
double eval() const {
Stash stash;
- const Value &result = function.eval(input.get(), stash);
+ const Value &result = function.eval(DefaultTensorEngine::ref(), input.get(), stash);
ASSERT_TRUE(result.is_double());
LOG(info, "eval(): (%s) * (%s) = %f",
input.param(0).type().to_spec().c_str(),
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index f27a2073159..5e18df10921 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -52,14 +52,14 @@ void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) {
prod_vec->type().dimensions()[0].size,
expect.type().dimensions()[0].size,
happy);
- const Value &actual1 = fun1.eval(wrap({*prod_vec, *prod_mat}), stash);
+ const Value &actual1 = fun1.eval(prod_engine, wrap({*prod_vec, *prod_mat}), stash);
TEST_DO(verify_equal(expect, actual1));
DenseXWProductFunction fun2(expect.type(), 1, 0,
prod_vec->type().dimensions()[0].size,
expect.type().dimensions()[0].size,
happy);
- const Value &actual2 = fun2.eval(wrap({*prod_mat, *prod_vec}), stash);
+ const Value &actual2 = fun2.eval(prod_engine, wrap({*prod_mat, *prod_vec}), stash);
TEST_DO(verify_equal(expect, actual2));
}