diff options
author | Håvard Pettersen <havardpe@oath.com> | 2018-01-31 11:28:40 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2018-01-31 11:28:40 +0000 |
commit | 48c161deff1b8d1cb0a546e3516e8ef0a706f4bf (patch) | |
tree | 68f43f55726dbda5295b3ed43d93f33afe13d569 /eval | |
parent | 4d175c3c37d6ffada13dd15023d575f8e663351e (diff) |
allow any intermediate result as input for inner products
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp | 24 |
1 files changed, 11 insertions, 13 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp index c5ebe837151..48eefa3eed4 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp @@ -46,7 +46,7 @@ bool isDenseXWProduct(const ValueType &res, const ValueType &vec, const ValueTyp return false; } -const TensorFunction &createDenseXWProduct(const ValueType &res, const Inject &vec, const Inject &mat, Stash &stash) { +const TensorFunction &createDenseXWProduct(const ValueType &res, const TensorFunction &vec, const TensorFunction &mat, Stash &stash) { bool common_is_inner = (mat.result_type().dimension_index(vec.result_type().dimensions()[0].name) == 1); return stash.create<DenseXWProductFunction>(res, vec, mat, vec.result_type().dimensions()[0].size, @@ -62,18 +62,16 @@ struct InnerProductFunctionOptimizer const ValueType &result_type = reduce->result_type(); const Join *join = as<Join>(reduce->child()); if (join && (join->function() == Mul::f)) { - const Inject *lhs = as<Inject>(join->lhs()); - const Inject *rhs = as<Inject>(join->rhs()); - if (lhs && rhs) { - if (isDenseDotProduct(result_type, lhs->result_type(), rhs->result_type())) { - return stash.create<DenseDotProductFunction>(*lhs, *rhs); - } - if (isDenseXWProduct(result_type, lhs->result_type(), rhs->result_type())) { - return createDenseXWProduct(result_type, *lhs, *rhs, stash); - } - if (isDenseXWProduct(result_type, rhs->result_type(), lhs->result_type())) { - return createDenseXWProduct(result_type, *rhs, *lhs, stash); - } + const TensorFunction &lhs = join->lhs(); + const TensorFunction &rhs = join->rhs(); + if (isDenseDotProduct(result_type, lhs.result_type(), rhs.result_type())) { + return stash.create<DenseDotProductFunction>(lhs, rhs); + } + if (isDenseXWProduct(result_type, lhs.result_type(), rhs.result_type())) { + return createDenseXWProduct(result_type, lhs, rhs, stash); + } + if (isDenseXWProduct(result_type, rhs.result_type(), lhs.result_type())) { + return createDenseXWProduct(result_type, rhs, lhs, stash); } } } |