aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/mixed_inner_product_function.h
blob: d5967c2114d7fce97af0ab34dd4ff4b5684590b0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#pragma once

#include <vespa/eval/eval/tensor_function.h>

namespace vespalib::eval {

/**
 * Tensor function for a dot product inside a mixed tensor.
 *
 * Optimized tensor function for dot-product inside a bigger (possibly
 * mixed) tensor.  To trigger this, the function must be in the form
 * reduce((mixed tensor)*(vector),sum,dimension names)
 * with "vector" being a dense tensor with the same dimensions that
 * are reduced, "mixed tensor" must contain all these dimension, and
 * they must also be the innermost (alphabetically last) indexed
 * dimensions in the mixed tensor.
 * Simple example:
 *   mixed: tensor(category{},x[32])
 *   vector: tensor(x[32])
 *   expression: reduce(mixed*vector,sum,x)
 *   result: tensor(category{})
 * More complex example:
 *   mixed: tensor<double>(a{},b[31],c{},d[42],e{},f[5],g{})
 *   vector: tensor<float>(d[42],f[5])
 *   expression: reduce(mixed*vector,sum,d,f)
 *   result: tensor<double>(a{},b[31],c{},e{},g{})
 * Note:
 * if the bigger tensor is dense, other optimizers are likely
 * to pick up the operation, even if this function could also
 * handle them.
 **/
class MixedInnerProductFunction : public tensor_function::Op2
{
public:
    MixedInnerProductFunction(const ValueType &res_type_in,
                              const TensorFunction &mixed_child,
                              const TensorFunction &vector_child);
    InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
    bool result_is_mutable() const override { return true; }
    static bool compatible_types(const ValueType &res, const ValueType &mixed, const ValueType &dense);
    static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);
};

} // namespace