aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-03-23 15:12:54 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-03-23 15:12:54 +0000
commitd43c441945a032c0053d2736666e2c418c1f4b1f (patch)
treec3fc291cbc08dabbd8827f006cccd5d0ce2b132c /eval
parent686694f6a6788c9ebe25f3278c253b0fe015d331 (diff)
allow multi-dimensional dot product optimization
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp59
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp39
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h3
3 files changed, 85 insertions, 16 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index 37f9602565d..60830e4abd7 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -123,7 +123,11 @@ EvalFixture::ParamRepo make_params() {
.add("v07_x3_a", spec({x(3)}, MyVecSeq(8.0)), "any")
.add("v08_x3_u", spec({x(3)}, MyVecSeq(9.0)), "tensor(x[])")
.add("v09_x4_u", spec({x(4)}, MyVecSeq(3.0)), "tensor(x[])")
- .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(0)));
+ .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
+ .add("m02_x2y3", spec({x(2),y(3)}, MyVecSeq(2.0)))
+ .add("m03_x3y2", spec({x(3),y(2)}, MyVecSeq(3.0)))
+ .add("m04_xuy3", spec({x(3),y(3)}, MyVecSeq(4.0)), "tensor(x[],y[3])")
+ .add("m05_x3yu", spec({x(3),y(3)}, MyVecSeq(5.0)), "tensor(x[3],y[])");
}
EvalFixture::ParamRepo param_repo = make_params();
@@ -183,6 +187,59 @@ TEST("require that expressions similar to dot product are not optimized") {
// TEST_DO(assertNotOptimized("reduce(join(v02_x3,v03_x3,f(x,y)(y*x)),sum)"));
}
+TEST("require that multi-dimensional dot product can be optimized") {
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m02_x2y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m02_x2y3*m01_x3y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m04_xuy3,sum)"));
+ TEST_DO(assertOptimized("reduce(m04_xuy3*m01_x3y3,sum)"));
+ TEST_DO(assertOptimized("reduce(m04_xuy3*m04_xuy3,sum)"));
+}
+
+TEST("require that result must be double to trigger optimization") {
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum,x,y)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,x)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m01_x3y3,sum,y)"));
+}
+
+TEST("require that additional dimensions must have matching size") {
+ TEST_DO(assertOptimized("reduce(m01_x3y3*m01_x3y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m03_x3y2,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m03_x3y2*m01_x3y3,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m01_x3y3*m05_x3yu,sum)"));
+ TEST_DO(assertNotOptimized("reduce(m05_x3yu*m01_x3y3,sum)"));
+}
+
+void verify_compatible(const vespalib::string &a, const vespalib::string &b) {
+ auto a_type = ValueType::from_spec(a);
+ auto b_type = ValueType::from_spec(b);
+ EXPECT_TRUE(!a_type.is_error());
+ EXPECT_TRUE(!b_type.is_error());
+ EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type));
+ EXPECT_TRUE(DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type));
+}
+
+void verify_not_compatible(const vespalib::string &a, const vespalib::string &b) {
+ auto a_type = ValueType::from_spec(a);
+ auto b_type = ValueType::from_spec(b);
+ EXPECT_TRUE(!a_type.is_error());
+ EXPECT_TRUE(!b_type.is_error());
+ EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), a_type, b_type));
+ EXPECT_TRUE(!DenseDotProductFunction::compatible_types(ValueType::double_type(), b_type, a_type));
+}
+
+TEST("require that type compatibility test is appropriate") {
+ TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])"));
+ TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])"));
+ TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])"));
+ TEST_DO(verify_compatible("tensor(x[])", "tensor(x[3])"));
+ TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])"));
+ TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[],y[7],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[],z[9])", "tensor(x[5],y[7],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])"));
+ TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[])", "tensor(x[5],y[7],z[9])"));
+}
+
//-----------------------------------------------------------------------------
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
index ae217935fd9..859a7092ce2 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
@@ -33,17 +33,6 @@ void my_dot_product_op(eval::InterpretedFunction::State &state, uint64_t param)
state.pop_pop_push(state.stash.create<eval::DoubleValue>(result));
}
-bool is1dDenseTensor(const ValueType &type) {
- return (type.is_dense() && (type.dimensions().size() == 1));
-}
-
-bool isDenseDotProduct(const ValueType &res, const ValueType &lhsType, const ValueType &rhsType) {
- return (res.is_double() &&
- is1dDenseTensor(lhsType) &&
- is1dDenseTensor(rhsType) &&
- (lhsType.dimensions()[0].name == rhsType.dimensions()[0].name));
-}
-
} // namespace vespalib::tensor::<unnamed>
DenseDotProductFunction::DenseDotProductFunction(const eval::TensorFunction &lhs_in,
@@ -59,16 +48,38 @@ DenseDotProductFunction::compile_self(Stash &) const
return eval::InterpretedFunction::Instruction(my_dot_product_op, (uint64_t)(_hwAccelerator.get()));
}
+bool
+DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
+{
+ if (!res.is_double() || !lhs.is_dense() || !rhs.is_dense() ||
+ (lhs.dimensions().size() != rhs.dimensions().size()) ||
+ (lhs.dimensions().empty()))
+ {
+ return false;
+ }
+ for (size_t i = 0; i < lhs.dimensions().size(); ++i) {
+ const auto &ldim = lhs.dimensions()[i];
+ const auto &rdim = rhs.dimensions()[i];
+ bool first = (i == 0);
+ bool name_mismatch = (ldim.name != rdim.name);
+ bool size_mismatch = ((ldim.size != rdim.size) || !ldim.is_bound());
+ if (name_mismatch || (!first && size_mismatch)) {
+ return false;
+ }
+ }
+ return true;
+}
+
const TensorFunction &
DenseDotProductFunction::optimize(const eval::TensorFunction &expr, Stash &stash)
{
- const Reduce *reduce = as<Reduce>(expr);
+ auto reduce = as<Reduce>(expr);
if (reduce && (reduce->aggr() == Aggr::SUM)) {
- const Join *join = as<Join>(reduce->child());
+ auto join = as<Join>(reduce->child());
if (join && (join->function() == Mul::f)) {
const TensorFunction &lhs = join->lhs();
const TensorFunction &rhs = join->rhs();
- if (isDenseDotProduct(expr.result_type(), lhs.result_type(), rhs.result_type())) {
+ if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) {
return stash.create<DenseDotProductFunction>(lhs, rhs);
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
index 46b04a446d4..d6181d33887 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
@@ -14,12 +14,13 @@ class DenseDotProductFunction : public eval::tensor_function::Op2
{
private:
hwaccelrated::IAccelrated::UP _hwAccelerator;
-
+ using ValueType = eval::ValueType;
public:
DenseDotProductFunction(const eval::TensorFunction &lhs_in,
const eval::TensorFunction &rhs_in);
eval::InterpretedFunction::Instruction compile_self(Stash &stash) const override;
bool result_is_mutable() const override { return true; }
+ static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs);
static const eval::TensorFunction &optimize(const eval::TensorFunction &expr, Stash &stash);
};