summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-07-11 10:58:52 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-07-11 10:58:52 +0000
commite8e0d99e3146881f4e5c328aacf7a24fb9140101 (patch)
treeba88d3a54bdaec76e5dd7890f59b76cac8d31625 /eval
parent5a0acd6e0a6aa36e26a5142308a0c85fc20a6b0a (diff)
enable hw dot product for float cells
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp12
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp54
-rw-r--r--eval/src/vespa/eval/tensor/dense/typed_cells.h18
3 files changed, 64 insertions, 20 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index 356625417d8..9bf97f449b3 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -103,6 +103,7 @@ EvalFixture::ParamRepo make_params() {
.add("v05_x5", spec({x(5)}, MyVecSeq(6.0)))
.add("v06_x5", spec({x(5)}, MyVecSeq(7.0)))
.add("v07_x5f", spec(float_cells({x(5)}), MyVecSeq(7.0)))
+ .add("v08_x5f", spec(float_cells({x(5)}), MyVecSeq(6.0)))
.add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
.add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0)));
}
@@ -183,8 +184,9 @@ void verify_not_compatible(const vespalib::string &a, const vespalib::string &b)
TEST("require that type compatibility test is appropriate") {
TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])"));
- TEST_DO(verify_not_compatible("tensor(x[5])", "tensor<float>(x[5])"));
- TEST_DO(verify_not_compatible("tensor<float>(x[5])", "tensor<float>(x[5])"));
+ TEST_DO(verify_compatible("tensor(x[5])", "tensor<float>(x[5])"));
+ TEST_DO(verify_compatible("tensor<float>(x[5])", "tensor(x[5])"));
+ TEST_DO(verify_compatible("tensor<float>(x[5])", "tensor<float>(x[5])"));
TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(x[6])"));
TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])"));
TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[3],y[7],z[9])"));
@@ -192,8 +194,10 @@ TEST("require that type compatibility test is appropriate") {
TEST_DO(verify_not_compatible("tensor(x[9],y[7],z[5])", "tensor(x[5],y[7],z[9])"));
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(assertNotOptimized("reduce(v05_x5*v07_x5f,sum)"));
+TEST("require that optimization also works for tensors with non-double cells") {
+ TEST_DO(assertOptimized("reduce(v05_x5*v07_x5f,sum)"));
+ TEST_DO(assertOptimized("reduce(v07_x5f*v05_x5,sum)"));
+ TEST_DO(assertOptimized("reduce(v07_x5f*v08_x5f,sum)"));
}
//-----------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
index c925f288c4a..9b839e1b12f 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
@@ -18,22 +18,47 @@ using namespace eval::operation;
namespace {
-TypedCells getCellsRef(const eval::Value &value) {
+template <typename T>
+ConstArrayRef<T> getCellsRef(const eval::Value &value) {
const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
- return denseTensor.cellsRef();
+ return denseTensor.cellsRef().typify<T>();
}
+template <typename LCT, typename RCT>
+struct HWSupport {
+ static double call(hwaccelrated::IAccelrated *, const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs) {
+ double result = 0.0;
+ for (size_t i = 0; i < lhs.size(); ++i) {
+ result += (lhs[i] * rhs[i]);
+ }
+ return result;
+ }
+};
+template <> struct HWSupport<float, float> {
+ static double call(hwaccelrated::IAccelrated *hw, const ConstArrayRef<float> &lhs, const ConstArrayRef<float> &rhs) {
+ return hw->dotProduct(lhs.cbegin(), rhs.cbegin(), lhs.size());
+ }
+};
+template <> struct HWSupport<double, double> {
+ static double call(hwaccelrated::IAccelrated *hw, const ConstArrayRef<double> &lhs, const ConstArrayRef<double> &rhs) {
+ return hw->dotProduct(lhs.cbegin(), rhs.cbegin(), lhs.size());
+ }
+};
+
+template <typename LCT, typename RCT>
void my_dot_product_op(eval::InterpretedFunction::State &state, uint64_t param) {
- auto *hw_accelerator = (hwaccelrated::IAccelrated *)(param);
- TypedCells lhsCells = getCellsRef(state.peek(1));
- TypedCells rhsCells = getCellsRef(state.peek(0));
- size_t numCells = std::min(lhsCells.size, rhsCells.size);
- const ConstArrayRef<double> lhs = lhsCells.typify<double>();
- const ConstArrayRef<double> rhs = rhsCells.typify<double>();
- double result = hw_accelerator->dotProduct(lhs.cbegin(), rhs.cbegin(), numCells);
+ auto *hw = (hwaccelrated::IAccelrated *)(param);
+ auto lhs = getCellsRef<LCT>(state.peek(1));
+ auto rhs = getCellsRef<RCT>(state.peek(0));
+ double result = HWSupport<LCT,RCT>::call(hw, lhs, rhs);
state.pop_pop_push(state.stash.create<eval::DoubleValue>(result));
}
+struct MyDotProductOp {
+ template <typename LCT, typename RCT>
+ static auto get_fun() { return my_dot_product_op<LCT,RCT>; }
+};
+
} // namespace vespalib::tensor::<unnamed>
DenseDotProductFunction::DenseDotProductFunction(const eval::TensorFunction &lhs_in,
@@ -46,18 +71,15 @@ DenseDotProductFunction::DenseDotProductFunction(const eval::TensorFunction &lhs
eval::InterpretedFunction::Instruction
DenseDotProductFunction::compile_self(Stash &) const
{
- return eval::InterpretedFunction::Instruction(my_dot_product_op, (uint64_t)(_hwAccelerator.get()));
+ auto op = select_2<MyDotProductOp>(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type());
+ return eval::InterpretedFunction::Instruction(op, (uint64_t)(_hwAccelerator.get()));
}
bool
DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
{
- if (lhs.cell_type() != ValueType::CellType::DOUBLE ||
- rhs.cell_type() != ValueType::CellType::DOUBLE)
- {
- return false; // non-double cell types not supported
- }
- return (res.is_double() && lhs.is_dense() && (rhs == lhs));
+ return (res.is_double() && lhs.is_dense() && (rhs.dimensions() == lhs.dimensions()));
}
const TensorFunction &
diff --git a/eval/src/vespa/eval/tensor/dense/typed_cells.h b/eval/src/vespa/eval/tensor/dense/typed_cells.h
index d1b6058bfbe..98f95d54d9b 100644
--- a/eval/src/vespa/eval/tensor/dense/typed_cells.h
+++ b/eval/src/vespa/eval/tensor/dense/typed_cells.h
@@ -93,4 +93,22 @@ auto dispatch_2(A1 &&a, const TypedCells &b, Args &&...args) {
abort();
}
+template <typename T, typename... Args>
+auto select_1(CellType a_type) {
+ switch(a_type) {
+ case CellType::DOUBLE: return T::template get_fun<double, Args...>();
+ case CellType::FLOAT: return T::template get_fun<float, Args...>();
+ }
+ abort();
+}
+
+template <typename T>
+auto select_2(CellType a_type, CellType b_type) {
+ switch(b_type) {
+ case CellType::DOUBLE: return select_1<T, double>(a_type);
+ case CellType::FLOAT: return select_1<T, float>(a_type);
+ }
+ abort();
+}
+
} // namespace