diff options
author | Håvard Pettersen <havardpe@oath.com> | 2018-03-23 15:12:54 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2018-03-23 15:12:54 +0000 |
commit | d43c441945a032c0053d2736666e2c418c1f4b1f (patch) | |
tree | c3fc291cbc08dabbd8827f006cccd5d0ce2b132c /eval | |
parent | 686694f6a6788c9ebe25f3278c253b0fe015d331 (diff) |
allow multi-dimensional dot product optimization
Diffstat (limited to 'eval')
3 files changed, 85 insertions, 16 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp index 37f9602565d..60830e4abd7 100644 --- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -123,7 +123,11 @@ EvalFixture::ParamRepo make_params() { .add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any") .add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])") .add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])") - .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(0))); + .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0))) + .add("m02_x2y3", spec({x(2),y(3)}, MyVecSeq(2.0))) + .add("m03_x3y2", spec({x(3),y(2)}, MyVecSeq(3.0))) + .add("m04_xuy3", spec({x(3),y(3)}, MyVecSeq(4.0)), "tensor(x[],y[3])") + .add("m05_x3yu", spec({x(3),y(3)}, MyVecSeq(5.0)), "tensor(x[3],y[])"); } EvalFixture::ParamRepo param_repo = make_params(); @@ -183,6 +187,59 @@ TEST("require that expressions similar to dot product are not optimized") { // TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)")); } +TEST("require that multi-dimensional dot product can be optimized") { + TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x2y3,sum)")); + TEST_DO(assertOptimized("reduce(m02_x2y3*m01_x3y3,sum)")); + TEST_DO(assertOptimized("reduce(m01_x3y3*m04_xuy3,sum)")); + TEST_DO(assertOptimized("reduce(m04_xuy3*m01_x3y3,sum)")); + TEST_DO(assertOptimized("reduce(m04_xuy3*m04_xuy3,sum)")); +} + +TEST("require that result must be double to trigger optimization") { + TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum,x,y)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,x)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)")); +} + +TEST("require that additional dimensions must have matching size") { + TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m03_x3y2,sum)")); + TEST_DO(assertNotOptimized("reduce(m03_x3y2*m01_x3y3,sum)")); + TEST_DO(assertNotOptimized("reduce(m01_x3y3*m05_x3yu,sum)")); + TEST_DO(assertNotOptimized("reduce(m05_x3yu*m01_x3y3,sum)")); +} + +void verify_compatible(const vespalib::string &a, const vespalib::string &b) { + auto a_type = ValueType::from_spec(a); + auto b_type = ValueType::from_spec(b); + EXPECT_TRUE(!a_type.is_error()); + EXPECT_TRUE(!b_type.is_error()); + EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type)); + EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type)); +} + +void verify_not_compatible(const vespalib::string &a, const vespalib::string &b) { + auto a_type = ValueType::from_spec(a); + auto b_type = ValueType::from_spec(b); + EXPECT_TRUE(!a_type.is_error()); + EXPECT_TRUE(!b_type.is_error()); + EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type)); + EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type)); +} + +TEST("require that type compatibility test is appropriate") { + TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])")); + TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])")); + TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])")); + TEST_DO(verify_compatible("tensor(x[])", "tensor(x[3])")); + TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])")); + TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[],y[7],z[9])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[],z[9])", "tensor(x[5],y[7],z[9])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])")); + TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[])", "tensor(x[5],y[7],z[9])")); +} + //----------------------------------------------------------------------------- TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp index ae217935fd9..859a7092ce2 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp @@ -33,17 +33,6 @@ void my_dot_product_op(eval::InterpretedFunction::State &state, uint64_t param) state.pop_pop_push(state.stash.create<eval::DoubleValue>(result)); } -bool is1dDenseTensor(const ValueType &type) { - return (type.is_dense() && (type.dimensions().size() == 1)); -} - -bool isDenseDotProduct(const ValueType &res, const ValueType &lhsType, const ValueType &rhsType) { - return (res.is_double() && - is1dDenseTensor(lhsType) && - is1dDenseTensor(rhsType) && - (lhsType.dimensions()[0].name == rhsType.dimensions()[0].name)); -} - } // namespace vespalib::tensor::<unnamed> DenseDotProductFunction::DenseDotProductFunction(const eval::TensorFunction &lhs_in, @@ -59,16 +48,38 @@ DenseDotProductFunction::compile_self(Stash &) const return eval::InterpretedFunction::Instruction(my_dot_product_op, (uint64_t)(_hwAccelerator.get())); } +bool +DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) +{ + if (!res.is_double() || !lhs.is_dense() || !rhs.is_dense() || + (lhs.dimensions().size() != rhs.dimensions().size()) || + (lhs.dimensions().empty())) + { + return false; + } + for (size_t i = 0; i < lhs.dimensions().size(); ++i) { + const auto &ldim = lhs.dimensions()[i]; + const auto &rdim = rhs.dimensions()[i]; + bool first = (i == 0); + bool name_mismatch = (ldim.name != rdim.name); + bool size_mismatch = ((ldim.size != rdim.size) || !ldim.is_bound()); + if (name_mismatch || (!first && size_mismatch)) { + return false; + } + } + return true; +} + const TensorFunction & DenseDotProductFunction::optimize(const eval::TensorFunction &expr, Stash &stash) { - const Reduce *reduce = as<Reduce>(expr); + auto reduce = as<Reduce>(expr); if (reduce && (reduce->aggr() == Aggr::SUM)) { - const Join *join = as<Join>(reduce->child()); + auto join = as<Join>(reduce->child()); if (join && (join->function() == Mul::f)) { const TensorFunction &lhs = join->lhs(); const TensorFunction &rhs = join->rhs(); - if (isDenseDotProduct(expr.result_type(), lhs.result_type(), rhs.result_type())) { + if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) { return stash.create<DenseDotProductFunction>(lhs, rhs); } } diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h index 46b04a446d4..d6181d33887 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h @@ -14,12 +14,13 @@ class DenseDotProductFunction : public eval::tensor_function::Op2 { private: hwaccelrated::IAccelrated::UP _hwAccelerator; - + using ValueType = eval::ValueType; public: DenseDotProductFunction(const eval::TensorFunction &lhs_in, const eval::TensorFunction &rhs_in); eval::InterpretedFunction::Instruction compile_self(Stash &stash) const override; bool result_is_mutable() const override { return true; } + static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs); static const eval::TensorFunction &optimize(const eval::TensorFunction &expr, Stash &stash); }; |