diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-01-27 12:14:10 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-01-27 12:14:10 +0000 |
commit | b5d3f3cd43b8dbc5d44fab4e2002a304cb51a9cd (patch) | |
tree | 111fe33923ebd183d55e6ec193106c05da77772e /eval | |
parent | 409aaeeb2cb87e0dab48b0f8ffdef513a67e62f3 (diff) |
ensure number results are always double values (not tensors)
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/test/tensor_conformance.cpp | 55 |
1 files changed, 35 insertions, 20 deletions
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index d0d7818003a..65ed40b3619 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -241,6 +241,10 @@ TensorSpec spec(const vespalib::string &type, return spec; } +double as_double(const TensorSpec &spec) { + return spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value; +} + // abstract evaluation wrapper struct Eval { // typed result wrapper @@ -260,7 +264,7 @@ struct Eval { _type = Type::TENSOR; _tensor = value.as_tensor()->engine().to_spec(*value.as_tensor()); if (_tensor.type() == "double") { - _number = _tensor.cells().empty() ? 0.0 : _tensor.cells().begin()->second.value; + _number = as_double(_tensor); } } } @@ -268,11 +272,11 @@ struct Eval { bool is_number() const { return (_type == Type::NUMBER); } bool is_tensor() const { return (_type == Type::TENSOR); } double number() const { - // EXPECT_TRUE(is_number()); + EXPECT_TRUE(is_number()); return _number; } const TensorSpec &tensor() const { - // EXPECT_TRUE(is_tensor()); + EXPECT_TRUE(is_tensor()); return _tensor; } }; @@ -378,7 +382,7 @@ struct Expr_TT : Eval { const Value &make_value(const TensorEngine &engine, const TensorSpec &spec, Stash &stash) { if (spec.type() == "double") { - double number = spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value; + double number = as_double(spec); return stash.create<DoubleValue>(number); } return stash.create<TensorValue>(engine.create(spec)); @@ -528,6 +532,23 @@ const double X = error_value; // NaN value const double my_nan = std::numeric_limits<double>::quiet_NaN(); +void verify_result(const Eval::Result &result, const Eval::Result &expect) { + if (expect.is_number()) { + EXPECT_EQUAL(result.number(), expect.number()); + } else if (expect.is_tensor()) { + EXPECT_EQUAL(result.tensor(), expect.tensor()); + } else { + TEST_FATAL("expected result should be valid"); + } +} + +void verify_result(const Eval::Result &result, const TensorSpec &expect) { + if (expect.type() == "double") { + EXPECT_EQUAL(result.number(), as_double(expect)); + } else { + EXPECT_EQUAL(result.tensor(), expect); + } +} // Test wrapper to avoid passing global test parameters around struct TestContext { @@ -638,13 +659,7 @@ struct TestContext { //------------------------------------------------------------------------- void verify_reduce_result(const Eval &eval, const TensorSpec &a, const Eval::Result &expect) { - if (expect.is_tensor()) { - EXPECT_EQUAL(eval.eval(engine, a).tensor(), expect.tensor()); - } else if (expect.is_number()) { - EXPECT_EQUAL(eval.eval(engine, a).number(), expect.number()); - } else { - TEST_FATAL("expected result should be valid"); - } + TEST_DO(verify_result(eval.eval(engine, a), expect)); } void test_reduce_op(const vespalib::string &name, const BinaryOperation &op, const Sequence &seq) { @@ -711,7 +726,7 @@ struct TestContext { layouts.push_back({x({"a","b","c"}),y(5),z({"i","j","k","l"})}); } for (const Layout &layout: layouts) { - EXPECT_EQUAL(eval.eval(engine, spec(layout, seq)).tensor(), spec(layout, OpSeq(seq, ref_op))); + TEST_DO(verify_result(eval.eval(engine, spec(layout, seq)), spec(layout, OpSeq(seq, ref_op)))); } } @@ -752,7 +767,7 @@ struct TestContext { const TensorSpec &expect, const TensorSpec &lhs, const TensorSpec &rhs) { - EXPECT_EQUAL(safe(eval).eval(engine, lhs, rhs).tensor(), expect); + TEST_DO(verify_result(safe(eval).eval(engine, lhs, rhs), expect)); } void test_fixed_sparse_cases_apply_op(const Eval &eval, @@ -917,7 +932,7 @@ struct TestContext { const BinaryOperation &op) { TEST_DO(test_apply_op(eval, - spec(op.eval(0,0)), spec(0.0), spec(0.0))); + spec(op.eval(0.1,0.2)), spec(0.1), spec(0.2))); TEST_DO(test_apply_op(eval, spec(x(1), Seq({ op.eval(3,5) })), spec(x(1), Seq({ 3 })), @@ -986,8 +1001,8 @@ struct TestContext { TEST_STATE(make_string("lhs shape: %s, rhs shape: %s", lhs_input.type().c_str(), rhs_input.type().c_str()).c_str()); - TensorSpec expect = ImmediateApply(op).eval(ref_engine, lhs_input, rhs_input).tensor(); - EXPECT_EQUAL(safe(eval).eval(engine, lhs_input, rhs_input).tensor(), expect); + Eval::Result expect = ImmediateApply(op).eval(ref_engine, lhs_input, rhs_input); + TEST_DO(verify_result(safe(eval).eval(engine, lhs_input, rhs_input), expect)); } TEST_DO(test_fixed_sparse_cases_apply_op(eval, op)); TEST_DO(test_fixed_dense_cases_apply_op(eval, op)); @@ -1029,7 +1044,7 @@ struct TestContext { const TensorSpec &rhs) { Expr_TT eval("sum(a*b)"); - EXPECT_EQUAL(expect, safe(eval).eval(engine, lhs, rhs).number()); + TEST_DO(verify_result(safe(eval).eval(engine, lhs, rhs), spec(expect))); } void test_dot_product() { @@ -1052,7 +1067,7 @@ struct TestContext { const vespalib::string &dimension) { ImmediateConcat eval(dimension); - EXPECT_EQUAL(eval.eval(engine, a, b).tensor(), expect); + TEST_DO(verify_result(eval.eval(engine, a, b), expect)); } void test_concat() { @@ -1084,7 +1099,7 @@ struct TestContext { const std::vector<vespalib::string> &to) { ImmediateRename eval(from, to); - EXPECT_EQUAL(eval.eval(engine, input).tensor(), expect); + TEST_DO(verify_result(eval.eval(engine, input), expect)); } void test_rename() { @@ -1099,7 +1114,7 @@ struct TestContext { //------------------------------------------------------------------------- void test_tensor_lambda(const vespalib::string &expr, const TensorSpec &expect) { - EXPECT_EQUAL(Expr_V(expr).eval(engine).tensor(), expect); + TEST_DO(verify_result(Expr_V(expr).eval(engine), expect)); } void test_tensor_lambda() { |