summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@yahoo-inc.com>2018-02-08 15:16:27 +0000
committerArne Juul <arnej@yahoo-inc.com>2018-02-09 10:46:45 +0000
commite27b4915529572bd568e2ca4bb81307d33f8123f (patch)
tree00714d9485896425742f46e6af73511e8ce7f22b /eval
parent0efd07d305c109f95593e5363ef9bb34d8e7ed3b (diff)
refactor dot product test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp220
1 files changed, 118 insertions, 102 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index 71bbacc7806..fb48e445180 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -7,6 +7,8 @@
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_builder.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
+#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/eval_fixture.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>
@@ -15,128 +17,68 @@ LOG_SETUP("dense_dot_product_function_test");
using namespace vespalib;
using namespace vespalib::eval;
+using namespace vespalib::eval::test;
using namespace vespalib::tensor;
-tensor::Tensor::UP
-makeTensor(size_t numCells, double cellBias)
-{
- DenseTensorBuilder builder;
- DenseTensorBuilder::Dimension dim = builder.defineDimension("x", numCells);
- for (size_t i = 0; i < numCells; ++i) {
- builder.addLabel(dim, i).addCell(i + cellBias);
- }
- return builder.build();
+const TensorEngine &prod_engine = DefaultTensorEngine::ref();
+
+struct MyVecSeq : Sequence {
+ double bias;
+ double operator[](size_t i) const override { return (i + bias); }
+ MyVecSeq(double cellBias) : bias(cellBias) {}
+};
+
+TensorSpec makeTensor(size_t numCells, double cellBias) {
+ return spec({x(numCells)}, MyVecSeq(cellBias));
}
-double
-calcDotProduct(const DenseTensor &lhs, const DenseTensor &rhs)
-{
- size_t numCells = std::min(lhs.cellsRef().size(), rhs.cellsRef().size());
+const double leftBias = 3.0;
+const double rightBias = 5.0;
+
+double calcDotProduct(size_t numCells) {
double result = 0;
for (size_t i = 0; i < numCells; ++i) {
- result += (lhs.cellsRef()[i] * rhs.cellsRef()[i]);
+ result += (i + leftBias) * (i + rightBias);
}
return result;
}
-const DenseTensor &
-asDenseTensor(const tensor::Tensor &tensor)
-{
- return dynamic_cast<const DenseTensor &>(tensor);
-}
-
-class FunctionInput
-{
-private:
- tensor::Tensor::UP _lhsTensor;
- tensor::Tensor::UP _rhsTensor;
- const DenseTensor &_lhsDenseTensor;
- const DenseTensor &_rhsDenseTensor;
- std::vector<Value::CREF> _params;
-
-public:
- FunctionInput(size_t lhsNumCells, size_t rhsNumCells)
- : _lhsTensor(makeTensor(lhsNumCells, 3.0)),
- _rhsTensor(makeTensor(rhsNumCells, 5.0)),
- _lhsDenseTensor(asDenseTensor(*_lhsTensor)),
- _rhsDenseTensor(asDenseTensor(*_rhsTensor))
- {
- _params.emplace_back(_lhsDenseTensor);
- _params.emplace_back(_rhsDenseTensor);
- }
- SimpleObjectParams get() const { return SimpleObjectParams(_params); }
- const Value &param(size_t idx) const { return _params[idx]; }
- double expectedDotProduct() const {
- return calcDotProduct(_lhsDenseTensor, _rhsDenseTensor);
- }
+void check_gen_with_result(size_t l, size_t r, double wanted) {
+ EvalFixture::ParamRepo param_repo;
+ param_repo.add("a", makeTensor(l, leftBias));
+ param_repo.add("b", makeTensor(r, rightBias));
+ vespalib::string expr = "reduce(a*b,sum,x)";
+ EvalFixture evaluator(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(spec(wanted), evaluator.result());
+ EXPECT_EQUAL(evaluator.result(), EvalFixture::ref(expr, param_repo));
+ auto info = evaluator.find_all<DenseDotProductFunction>();
+ EXPECT_EQUAL(info.size(), 1u);
};
-struct Fixture
-{
- FunctionInput input;
- tensor_function::Inject a;
- tensor_function::Inject b;
- DenseDotProductFunction function;
- Fixture(size_t lhsNumCells, size_t rhsNumCells);
- ~Fixture();
- double eval() const {
- InterpretedFunction ifun(DefaultTensorEngine::ref(), function);
- InterpretedFunction::Context ictx(ifun);
- const Value &result = ifun.eval(ictx, input.get());
- ASSERT_TRUE(result.is_double());
- LOG(info, "eval(): (%s) * (%s) = %f",
- input.param(0).type().to_spec().c_str(),
- input.param(1).type().to_spec().c_str(),
- result.as_double());
- return result.as_double();
- }
-};
-
-Fixture::Fixture(size_t lhsNumCells, size_t rhsNumCells)
- : input(lhsNumCells, rhsNumCells),
- a(input.param(0).type(), 0),
- b(input.param(1).type(), 1),
- function(a, b)
-{ }
-
-Fixture::~Fixture() { }
-
-void
-assertDotProduct(size_t numCells)
-{
- Fixture f(numCells, numCells);
- EXPECT_EQUAL(f.input.expectedDotProduct(), f.eval());
-}
+// this should not be possible to set up:
+// TEST("require that empty dot product is correct")
-void
-assertDotProduct(size_t lhsNumCells, size_t rhsNumCells)
-{
- Fixture f(lhsNumCells, rhsNumCells);
- EXPECT_EQUAL(f.input.expectedDotProduct(), f.eval());
+TEST("require that basic dot product with equal sizes is correct") {
+ check_gen_with_result(2, 2, (3.0 * 5.0) + (4.0 * 6.0));
}
-TEST_F("require that empty dot product is correct", Fixture(0, 0))
-{
- EXPECT_EQUAL(0.0, f.eval());
+TEST("require that basic dot product with un-equal sizes is correct") {
+ check_gen_with_result(2, 3, (3.0 * 5.0) + (4.0 * 6.0));
+ check_gen_with_result(3, 2, (3.0 * 5.0) + (4.0 * 6.0));
}
-TEST_F("require that basic dot product with equal sizes is correct", Fixture(2, 2))
-{
- EXPECT_EQUAL((3.0 * 5.0) + (4.0 * 6.0), f.eval());
-}
+//-----------------------------------------------------------------------------
-TEST_F("require that basic dot product with un-equal sizes is correct", Fixture(2, 3))
-{
- EXPECT_EQUAL((3.0 * 5.0) + (4.0 * 6.0), f.eval());
+void assertDotProduct(size_t numCells) {
+ check_gen_with_result(numCells, numCells, calcDotProduct(numCells));
}
-TEST_F("require that basic dot product with un-equal sizes is correct", Fixture(3, 2))
-{
- EXPECT_EQUAL((3.0 * 5.0) + (4.0 * 6.0), f.eval());
+void assertDotProduct(size_t lhsNumCells, size_t rhsNumCells) {
+ size_t numCells = std::min(lhsNumCells, rhsNumCells);
+ check_gen_with_result(lhsNumCells, rhsNumCells, calcDotProduct(numCells));
}
-TEST("require that dot product with equal sizes is correct")
-{
+TEST("require that dot product with equal sizes is correct") {
TEST_DO(assertDotProduct(8));
TEST_DO(assertDotProduct(16));
TEST_DO(assertDotProduct(32));
@@ -156,9 +98,9 @@ TEST("require that dot product with equal sizes is correct")
TEST_DO(assertDotProduct(1024 + 3));
}
-TEST("require that dot product with un-equal sizes is correct")
-{
+TEST("require that dot product with un-equal sizes is correct") {
TEST_DO(assertDotProduct(8, 8 + 3));
+ TEST_DO(assertDotProduct(8 + 3, 8));
TEST_DO(assertDotProduct(16, 16 + 3));
TEST_DO(assertDotProduct(32, 32 + 3));
TEST_DO(assertDotProduct(64, 64 + 3));
@@ -168,4 +110,78 @@ TEST("require that dot product with un-equal sizes is correct")
TEST_DO(assertDotProduct(1024, 1024 + 3));
}
+//-----------------------------------------------------------------------------
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("v01_x1", spec({x(1)}, MyVecSeq(2.0)))
+ .add("v02_x3", spec({x(3)}, MyVecSeq(4.0)))
+ .add("v03_x3", spec({x(3)}, MyVecSeq(5.0)))
+ .add("v04_y3", spec({y(3)}, MyVecSeq(10)))
+ .add("v05_x5", spec({x(5)}, MyVecSeq(6.0)))
+ .add("v06_x5", spec({x(5)}, MyVecSeq(7.0)))
+ .add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any")
+ .add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])")
+ .add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])")
+ .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(0)));
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void assertOptimized(const vespalib::string &expr) {
+ EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ auto info = fixture.find_all<DenseDotProductFunction>();
+ EXPECT_EQUAL(info.size(), 1u);
+}
+
+void assertNotOptimized(const vespalib::string &expr) {
+ EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
+ auto info = fixture.find_all<DenseDotProductFunction>();
+ EXPECT_TRUE(info.empty());
+}
+
+TEST("require that dot product is not optimized for unknown types") {
+ TEST_DO(assertNotOptimized("reduce(v02_x3*v07_x3_a,sum)"));
+ TEST_DO(assertNotOptimized("reduce(v07_x3_a*v03_x3,sum)"));
+}
+
+TEST("require that dot product works with tensor function") {
+ TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)"));
+ TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum,x)"));
+ TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum)"));
+ TEST_DO(assertOptimized("reduce(join(v05_x5,v06_x5,f(x,y)(x*y)),sum,x)"));
+}
+
+TEST("require that dot product with compatible dimensions is optimized") {
+ TEST_DO(assertOptimized("reduce(v01_x1*v01_x1,sum)"));
+ TEST_DO(assertOptimized("reduce(v02_x3*v03_x3,sum)"));
+ TEST_DO(assertOptimized("reduce(v05_x5*v06_x5,sum)"));
+
+ TEST_DO(assertOptimized("reduce(v02_x3*v06_x5,sum)"));
+ TEST_DO(assertOptimized("reduce(v05_x5*v03_x3,sum)"));
+ TEST_DO(assertOptimized("reduce(v08_x3_u*v05_x5,sum)"));
+ TEST_DO(assertOptimized("reduce(v05_x5*v08_x3_u,sum)"));
+}
+
+TEST("require that dot product with incompatible dimensions is NOT optimized") {
+ TEST_DO(assertNotOptimized("reduce(v02_x3*v04_y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(v04_y3*v02_x3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(v08_x3_u*v04_y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(v04_y3*v08_x3_u,sum)"));
+ TEST_DO(assertNotOptimized("reduce(v02_x3*m01_x3y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*v02_x3,sum)"));
+}
+
+TEST("require that expressions similar to dot product are not optimized") {
+ TEST_DO(assertNotOptimized("reduce(v02_x3*v03_x3,prod)"));
+ TEST_DO(assertNotOptimized("reduce(v02_x3+v03_x3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x+y)),sum)"));
+ TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(x*x)),sum)"));
+ TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*y)),sum)"));
+ // TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)"));
+}
+
+//-----------------------------------------------------------------------------
+
TEST_MAIN() { TEST_RUN_ALL(); }