summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@yahoo-inc.com>2017-12-01 13:33:51 +0000
committerArne Juul <arnej@yahoo-inc.com>2017-12-01 13:33:51 +0000
commitfa25907d34e991707eecb29aa0df3471244d9a40 (patch)
treee662be97d0570e1670c9a05ed5aed92f19d91cfe /eval
parent0ab21f333b201f25f532e4911894488b8c52ed11 (diff)
consolidate tests
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp16
1 files changed, 4 insertions, 12 deletions
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 3ab7bd63f16..5c62d319dc3 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -40,10 +40,8 @@ std::vector<eval::Value::CREF> wrap(std::vector<eval::Value::CREF> params) {
void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) {
Stash stash;
-
Value::UP ref_vec = ref_engine.from_spec(v);
Value::UP ref_mat = ref_engine.from_spec(m);
-
const Value &joined = ref_engine.join(*ref_vec, *ref_mat, operation::Mul::f, stash);
const Value &expect = ref_engine.reduce(joined, Aggr::SUM, {"x"}, stash);
@@ -54,7 +52,6 @@ void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) {
prod_vec->type().dimensions()[0].size,
expect.type().dimensions()[0].size,
happy);
-
const Value &actual1 = fun1.eval(wrap({*prod_vec, *prod_mat}), stash);
TEST_DO(verify_equal(expect, actual1));
@@ -66,15 +63,6 @@ void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) {
TEST_DO(verify_equal(expect, actual2));
}
-TensorSpec vec_x1 = TensorSpec("tensor(x[1])").add({{"x", 0}}, 3.0);
-TensorSpec m_x1y1 = TensorSpec("tensor(x[1],y[1])").add({{"x",0},{"y",0}}, 5);
-TensorSpec m_o1x1 = TensorSpec("tensor(o[1],x[1])").add({{"x",0},{"o",0}}, 7);
-
-TEST("require that basic product with size 1 is correct") {
- verify_result(vec_x1, m_o1x1, true);
- verify_result(vec_x1, m_x1y1, false);
-}
-
TensorSpec make_vector(const vespalib::string &name, size_t sz) {
TensorSpec ret(make_string("tensor(%s[%zu])", name.c_str(), sz));
for (size_t i = 0; i < sz; ++i) {
@@ -82,6 +70,7 @@ TensorSpec make_vector(const vespalib::string &name, size_t sz) {
}
return ret;
}
+
TensorSpec make_matrix(const vespalib::string &d1name, size_t d1sz,
const vespalib::string &d2name, size_t d2sz)
{
@@ -97,6 +86,9 @@ TensorSpec make_matrix(const vespalib::string &d1name, size_t d1sz,
}
TEST("require that xw product gives same results as reference join/reduce") {
+ verify_result(make_vector("x", 1), make_matrix("o", 1, "x", 1), true);
+ verify_result(make_vector("x", 1), make_matrix("x", 1, "y", 1), false);
+
verify_result(make_vector("x", 3), make_matrix("o", 2, "x", 3), true);
verify_result(make_vector("x", 3), make_matrix("x", 3, "y", 2), false);