summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-31 11:28:40 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-31 11:28:40 +0000
commit48c161deff1b8d1cb0a546e3516e8ef0a706f4bf (patch)
tree68f43f55726dbda5295b3ed43d93f33afe13d569 /eval
parent4d175c3c37d6ffada13dd15023d575f8e663351e (diff)
allow any intermediate result as input for inner products
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp24
1 files changed, 11 insertions, 13 deletions
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp
index c5ebe837151..48eefa3eed4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp
@@ -46,7 +46,7 @@ bool isDenseXWProduct(const ValueType &res, const ValueType &vec, const ValueTyp
return false;
}
-const TensorFunction &createDenseXWProduct(const ValueType &res, const Inject &vec, const Inject &mat, Stash &stash) {
+const TensorFunction &createDenseXWProduct(const ValueType &res, const TensorFunction &vec, const TensorFunction &mat, Stash &stash) {
bool common_is_inner = (mat.result_type().dimension_index(vec.result_type().dimensions()[0].name) == 1);
return stash.create<DenseXWProductFunction>(res, vec, mat,
vec.result_type().dimensions()[0].size,
@@ -62,18 +62,16 @@ struct InnerProductFunctionOptimizer
const ValueType &result_type = reduce->result_type();
const Join *join = as<Join>(reduce->child());
if (join && (join->function() == Mul::f)) {
- const Inject *lhs = as<Inject>(join->lhs());
- const Inject *rhs = as<Inject>(join->rhs());
- if (lhs && rhs) {
- if (isDenseDotProduct(result_type, lhs->result_type(), rhs->result_type())) {
- return stash.create<DenseDotProductFunction>(*lhs, *rhs);
- }
- if (isDenseXWProduct(result_type, lhs->result_type(), rhs->result_type())) {
- return createDenseXWProduct(result_type, *lhs, *rhs, stash);
- }
- if (isDenseXWProduct(result_type, rhs->result_type(), lhs->result_type())) {
- return createDenseXWProduct(result_type, *rhs, *lhs, stash);
- }
+ const TensorFunction &lhs = join->lhs();
+ const TensorFunction &rhs = join->rhs();
+ if (isDenseDotProduct(result_type, lhs.result_type(), rhs.result_type())) {
+ return stash.create<DenseDotProductFunction>(lhs, rhs);
+ }
+ if (isDenseXWProduct(result_type, lhs.result_type(), rhs.result_type())) {
+ return createDenseXWProduct(result_type, lhs, rhs, stash);
+ }
+ if (isDenseXWProduct(result_type, rhs.result_type(), lhs.result_type())) {
+ return createDenseXWProduct(result_type, rhs, lhs, stash);
}
}
}