diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-07-11 10:58:52 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-07-11 10:58:52 +0000 |
commit | e8e0d99e3146881f4e5c328aacf7a24fb9140101 (patch) | |
tree | ba88d3a54bdaec76e5dd7890f59b76cac8d31625 | |
parent | 5a0acd6e0a6aa36e26a5142308a0c85fc20a6b0a (diff) |
enable hw dot product for float cells
3 files changed, 64 insertions, 20 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp index 356625417d8..9bf97f449b3 100644 --- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp @@ -103,6 +103,7 @@ EvalFixture::ParamRepo make_params() { .add("v05_x5", spec({x(5)}, MyVecSeq(6.0))) .add("v06_x5", spec({x(5)}, MyVecSeq(7.0))) .add("v07_x5f", spec(float_cells({x(5)}), MyVecSeq(7.0))) + .add("v08_x5f", spec(float_cells({x(5)}), MyVecSeq(6.0))) .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0))) .add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0))); } @@ -183,8 +184,9 @@ void verify_not_compatible(const vespalib::string &a, const vespalib::string &b) TEST("require that type compatibility test is appropriate") { TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])")); - TEST_DO(verify_not_compatible("tensor(x[5])", "tensor<float>(x[5])")); - TEST_DO(verify_not_compatible("tensor<float>(x[5])", "tensor<float>(x[5])")); + TEST_DO(verify_compatible("tensor(x[5])", "tensor<float>(x[5])")); + TEST_DO(verify_compatible("tensor<float>(x[5])", "tensor(x[5])")); + TEST_DO(verify_compatible("tensor<float>(x[5])", "tensor<float>(x[5])")); TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(x[6])")); TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])")); TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[3],y[7],z[9])")); @@ -192,8 +194,10 @@ TEST("require that type compatibility test is appropriate") { TEST_DO(verify_not_compatible("tensor(x[9],y[7],z[5])", "tensor(x[5],y[7],z[9])")); } -TEST("require that optimization is disabled for tensors with non-double cells") { - TEST_DO(assertNotOptimized("reduce(v05_x5*v07_x5f,sum)")); +TEST("require that optimization also works for tensors with non-double cells") { + TEST_DO(assertOptimized("reduce(v05_x5*v07_x5f,sum)")); + TEST_DO(assertOptimized("reduce(v07_x5f*v05_x5,sum)")); + TEST_DO(assertOptimized("reduce(v07_x5f*v08_x5f,sum)")); } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp index c925f288c4a..9b839e1b12f 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp @@ -18,22 +18,47 @@ using namespace eval::operation; namespace { -TypedCells getCellsRef(const eval::Value &value) { +template <typename T> +ConstArrayRef<T> getCellsRef(const eval::Value &value) { const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value); - return denseTensor.cellsRef(); + return denseTensor.cellsRef().typify<T>(); } +template <typename LCT, typename RCT> +struct HWSupport { + static double call(hwaccelrated::IAccelrated *, const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs) { + double result = 0.0; + for (size_t i = 0; i < lhs.size(); ++i) { + result += (lhs[i] * rhs[i]); + } + return result; + } +}; +template <> struct HWSupport<float, float> { + static double call(hwaccelrated::IAccelrated *hw, const ConstArrayRef<float> &lhs, const ConstArrayRef<float> &rhs) { + return hw->dotProduct(lhs.cbegin(), rhs.cbegin(), lhs.size()); + } +}; +template <> struct HWSupport<double, double> { + static double call(hwaccelrated::IAccelrated *hw, const ConstArrayRef<double> &lhs, const ConstArrayRef<double> &rhs) { + return hw->dotProduct(lhs.cbegin(), rhs.cbegin(), lhs.size()); + } +}; + +template <typename LCT, typename RCT> void my_dot_product_op(eval::InterpretedFunction::State &state, uint64_t param) { - auto *hw_accelerator = (hwaccelrated::IAccelrated *)(param); - TypedCells lhsCells = getCellsRef(state.peek(1)); - TypedCells rhsCells = getCellsRef(state.peek(0)); - size_t numCells = std::min(lhsCells.size, rhsCells.size); - const ConstArrayRef<double> lhs = lhsCells.typify<double>(); - const ConstArrayRef<double> rhs = rhsCells.typify<double>(); - double result = hw_accelerator->dotProduct(lhs.cbegin(), rhs.cbegin(), numCells); + auto *hw = (hwaccelrated::IAccelrated *)(param); + auto lhs = getCellsRef<LCT>(state.peek(1)); + auto rhs = getCellsRef<RCT>(state.peek(0)); + double result = HWSupport<LCT,RCT>::call(hw, lhs, rhs); state.pop_pop_push(state.stash.create<eval::DoubleValue>(result)); } +struct MyDotProductOp { + template <typename LCT, typename RCT> + static auto get_fun() { return my_dot_product_op<LCT,RCT>; } +}; + } // namespace vespalib::tensor::<unnamed> DenseDotProductFunction::DenseDotProductFunction(const eval::TensorFunction &lhs_in, @@ -46,18 +71,15 @@ DenseDotProductFunction::DenseDotProductFunction(const eval::TensorFunction &lhs eval::InterpretedFunction::Instruction DenseDotProductFunction::compile_self(Stash &) const { - return eval::InterpretedFunction::Instruction(my_dot_product_op, (uint64_t)(_hwAccelerator.get())); + auto op = select_2<MyDotProductOp>(lhs().result_type().cell_type(), + rhs().result_type().cell_type()); + return eval::InterpretedFunction::Instruction(op, (uint64_t)(_hwAccelerator.get())); } bool DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { - if (lhs.cell_type() != ValueType::CellType::DOUBLE || - rhs.cell_type() != ValueType::CellType::DOUBLE) - { - return false; // non-double cell types not supported - } - return (res.is_double() && lhs.is_dense() && (rhs == lhs)); + return (res.is_double() && lhs.is_dense() && (rhs.dimensions() == lhs.dimensions())); } const TensorFunction & diff --git a/eval/src/vespa/eval/tensor/dense/typed_cells.h b/eval/src/vespa/eval/tensor/dense/typed_cells.h index d1b6058bfbe..98f95d54d9b 100644 --- a/eval/src/vespa/eval/tensor/dense/typed_cells.h +++ b/eval/src/vespa/eval/tensor/dense/typed_cells.h @@ -93,4 +93,22 @@ auto dispatch_2(A1 &&a, const TypedCells &b, Args &&...args) { abort(); } +template <typename T, typename... Args> +auto select_1(CellType a_type) { + switch(a_type) { + case CellType::DOUBLE: return T::template get_fun<double, Args...>(); + case CellType::FLOAT: return T::template get_fun<float, Args...>(); + } + abort(); +} + +template <typename T> +auto select_2(CellType a_type, CellType b_type) { + switch(b_type) { + case CellType::DOUBLE: return select_1<T, double>(a_type); + case CellType::FLOAT: return select_1<T, float>(a_type); + } + abort(); +} + } // namespace |