From e27b4915529572bd568e2ca4bb81307d33f8123f Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 8 Feb 2018 15:16:27 +0000 Subject: refactor dot product test --- .../dense_dot_product_function_test.cpp | 220 +++++++++++---------- 1 file changed, 118 insertions(+), 102 deletions(-) (limited to 'eval') diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp index 71bbacc7806..fb48e445180 100644 --- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include @@ -15,128 +17,68 @@ LOG_SETUP("dense_dot_product_function_test"); using namespace vespalib; using namespace vespalib::eval; +using namespace vespalib::eval::test; using namespace vespalib::tensor; -tensor::Tensor::UP -makeTensor(size_t numCells, double cellBias) -{ - DenseTensorBuilder builder; - DenseTensorBuilder::Dimension dim = builder.defineDimension("x", numCells); - for (size_t i = 0; i < numCells; ++i) { - builder.addLabel(dim, i).addCell(i + cellBias); - } - return builder.build(); +const TensorEngine &prod_engine = DefaultTensorEngine::ref(); + +struct MyVecSeq : Sequence { + double bias; + double operator[](size_t i) const override { return (i + bias); } + MyVecSeq(double cellBias) : bias(cellBias) {} +}; + +TensorSpec makeTensor(size_t numCells, double cellBias) { + return spec({x(numCells)}, MyVecSeq(cellBias)); } -double -calcDotProduct(const DenseTensor &lhs, const DenseTensor &rhs) -{ - size_t numCells = std::min(lhs.cellsRef().size(), rhs.cellsRef().size()); +const double leftBias = 3.0; +const double rightBias = 5.0; + +double calcDotProduct(size_t numCells) { double result = 0; for (size_t i = 0; i < numCells; ++i) { - result += (lhs.cellsRef()[i] * rhs.cellsRef()[i]); + result += (i + leftBias) * (i + rightBias); } return result; } -const DenseTensor & -asDenseTensor(const tensor::Tensor &tensor) -{ - return dynamic_cast(tensor); -} - -class FunctionInput -{ -private: - tensor::Tensor::UP _lhsTensor; - tensor::Tensor::UP _rhsTensor; - const DenseTensor &_lhsDenseTensor; - const DenseTensor &_rhsDenseTensor; - std::vector _params; - -public: - FunctionInput(size_t lhsNumCells, size_t rhsNumCells) - : _lhsTensor(makeTensor(lhsNumCells, 3.0)), - _rhsTensor(makeTensor(rhsNumCells, 5.0)), - _lhsDenseTensor(asDenseTensor(*_lhsTensor)), - _rhsDenseTensor(asDenseTensor(*_rhsTensor)) - { - _params.emplace_back(_lhsDenseTensor); - _params.emplace_back(_rhsDenseTensor); - } - SimpleObjectParams get() const { return SimpleObjectParams(_params); } - const Value ¶m(size_t idx) const { return _params[idx]; } - double expectedDotProduct() const { - return calcDotProduct(_lhsDenseTensor, _rhsDenseTensor); - } +void check_gen_with_result(size_t l, size_t r, double wanted) { + EvalFixture::ParamRepo param_repo; + param_repo.add("a", makeTensor(l, leftBias)); + param_repo.add("b", makeTensor(r, rightBias)); + vespalib::string expr = "reduce(a*b,sum,x)"; + EvalFixture evaluator(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(spec(wanted), evaluator.result()); + EXPECT_EQUAL(evaluator.result(), EvalFixture::ref(expr, param_repo)); + auto info = evaluator.find_all(); + EXPECT_EQUAL(info.size(), 1u); }; -struct Fixture -{ - FunctionInput input; - tensor_function::Inject a; - tensor_function::Inject b; - DenseDotProductFunction function; - Fixture(size_t lhsNumCells, size_t rhsNumCells); - ~Fixture(); - double eval() const { - InterpretedFunction ifun(DefaultTensorEngine::ref(), function); - InterpretedFunction::Context ictx(ifun); - const Value &result = ifun.eval(ictx, input.get()); - ASSERT_TRUE(result.is_double()); - LOG(info, "eval(): (%s) * (%s) = %f", - input.param(0).type().to_spec().c_str(), - input.param(1).type().to_spec().c_str(), - result.as_double()); - return result.as_double(); - } -}; - -Fixture::Fixture(size_t lhsNumCells, size_t rhsNumCells) - : input(lhsNumCells, rhsNumCells), - a(input.param(0).type(), 0), - b(input.param(1).type(), 1), - function(a, b) -{ } - -Fixture::~Fixture() { } - -void -assertDotProduct(size_t numCells) -{ - Fixture f(numCells, numCells); - EXPECT_EQUAL(f.input.expectedDotProduct(), f.eval()); -} +// this should not be possible to set up: +// TEST("require that empty dot product is correct") -void -assertDotProduct(size_t lhsNumCells, size_t rhsNumCells) -{ - Fixture f(lhsNumCells, rhsNumCells); - EXPECT_EQUAL(f.input.expectedDotProduct(), f.eval()); +TEST("require that basic dot product with equal sizes is correct") { + check_gen_with_result(2, 2, (3.0 * 5.0) + (4.0 * 6.0)); } -TEST_F("require that empty dot product is correct", Fixture(0, 0)) -{ - EXPECT_EQUAL(0.0, f.eval()); +TEST("require that basic dot product with un-equal sizes is correct") { + check_gen_with_result(2, 3, (3.0 * 5.0) + (4.0 * 6.0)); + check_gen_with_result(3, 2, (3.0 * 5.0) + (4.0 * 6.0)); } -TEST_F("require that basic dot product with equal sizes is correct", Fixture(2, 2)) -{ - EXPECT_EQUAL((3.0 * 5.0) + (4.0 * 6.0), f.eval()); -} +//----------------------------------------------------------------------------- -TEST_F("require that basic dot product with un-equal sizes is correct", Fixture(2, 3)) -{ - EXPECT_EQUAL((3.0 * 5.0) + (4.0 * 6.0), f.eval()); +void assertDotProduct(size_t numCells) { + check_gen_with_result(numCells, numCells, calcDotProduct(numCells)); } -TEST_F("require that basic dot product with un-equal sizes is correct", Fixture(3, 2)) -{ - EXPECT_EQUAL((3.0 * 5.0) + (4.0 * 6.0), f.eval()); +void assertDotProduct(size_t lhsNumCells, size_t rhsNumCells) { + size_t numCells = std::min(lhsNumCells, rhsNumCells); + check_gen_with_result(lhsNumCells, rhsNumCells, calcDotProduct(numCells)); } -TEST("require that dot product with equal sizes is correct") -{ +TEST("require that dot product with equal sizes is correct") { TEST_DO(assertDotProduct(8)); TEST_DO(assertDotProduct(16)); TEST_DO(assertDotProduct(32)); @@ -156,9 +98,9 @@ TEST("require that dot product with equal sizes is correct") TEST_DO(assertDotProduct(1024 + 3)); } -TEST("require that dot product with un-equal sizes is correct") -{ +TEST("require that dot product with un-equal sizes is correct") { TEST_DO(assertDotProduct(8, 8 + 3)); + TEST_DO(assertDotProduct(8 + 3, 8)); TEST_DO(assertDotProduct(16, 16 + 3)); TEST_DO(assertDotProduct(32, 32 + 3)); TEST_DO(assertDotProduct(64, 64 + 3)); @@ -168,4 +110,78 @@ TEST("require that dot product with un-equal sizes is correct") TEST_DO(assertDotProduct(1024, 1024 + 3)); } +//----------------------------------------------------------------------------- + +EvalFixture::ParamRepo make_params() { + return EvalFixture::ParamRepo() + .add("v01_x1", spec({x(1)}, MyVecSeq(2.0))) + .add("v02_x3", spec({x(3)}, MyVecSeq(4.0))) + .add("v03_x3", spec({x(3)}, MyVecSeq(5.0))) + .add("v04_y3", spec({y(3)}, MyVecSeq(10))) + .add("v05_x5", spec({x(5)}, MyVecSeq(6.0))) + .add("v06_x5", spec({x(5)}, MyVecSeq(7.0))) + .add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any") + .add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])") + .add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])") + .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(0))); +} +EvalFixture::ParamRepo param_repo = make_params(); + +void assertOptimized(const vespalib::string &expr) { + EvalFixture fixture(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + auto info = fixture.find_all(); + EXPECT_EQUAL(info.size(), 1u); +} + +void assertNotOptimized(const vespalib::string &expr) { + EvalFixture fixture(prod_engine, expr, param_repo, true); + EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo)); + auto info = fixture.find_all(); + EXPECT_TRUE(info.empty()); +} + +TEST("require that dot product is not optimized for unknown types") { + TEST_DO(assertNotOptimized("reduce(v02_x3*v07_x3_a,sum)")); + TEST_DO(assertNotOptimized("reduce(v07_x3_a*v03_x3,sum)")); +} + +TEST("require that dot product works with tensor function") { + TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)")); + TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum,x)")); + TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum)")); + TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum,x)")); +} + +TEST("require that dot product with compatible dimensions is optimized") { + TEST_DO(assertOptimized("reduce(v01_x1*v01_x1,sum)")); + TEST_DO(assertOptimized("reduce(v02_x3*v03_x3,sum)")); + TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)")); + + TEST_DO(assertOptimized("reduce(v02_x3*v06_x5,sum)")); + TEST_DO(assertOptimized("reduce(v05_x5*v03_x3,sum)")); + TEST_DO(assertOptimized("reduce(v08_x3_u*v05_x5,sum)")); + TEST_DO(assertOptimized("reduce(v05_x5*v08_x3_u,sum)")); +} + +TEST("require that dot product with incompatible dimensions is NOT optimized") { + TEST_DO(assertNotOptimized("reduce(v02_x3*v04_y3,sum)")); + TEST_DO(assertNotOptimized("reduce(v04_y3*v02_x3,sum)")); + TEST_DO(assertNotOptimized("reduce(v08_x3_u*v04_y3,sum)")); + TEST_DO(assertNotOptimized("reduce(v04_y3*v08_x3_u,sum)")); + TEST_DO(assertNotOptimized("reduce(v02_x3*m01_x3y3,sum)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*v02_x3,sum)")); +} + +TEST("require that expressions similar to dot product are not optimized") { + TEST_DO(assertNotOptimized("reduce(v02_x3*v03_x3,prod)")); + TEST_DO(assertNotOptimized("reduce(v02_x3+v03_x3,sum)")); + TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x+y)),sum)")); + TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x*x)),sum)")); + TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*y)),sum)")); + // TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)")); +} + +//----------------------------------------------------------------------------- + TEST_MAIN() { TEST_RUN_ALL(); } -- cgit v1.2.3