aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/instruction/join_with_number_function.cpp
blob: c41867242f0967382f3d3c6b1958d4f799547ac7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "join_with_number_function.h"
#include <vespa/vespalib/objects/objectvisitor.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/inline_operation.h>
#include <vespa/vespalib/util/typify.h>

using namespace vespalib::eval::tensor_function;
using namespace vespalib::eval::operation;

namespace vespalib::eval {

using Instruction = InterpretedFunction::Instruction;
using State = InterpretedFunction::State;
using vespalib::eval::tensor_function::unwrap_param;
using vespalib::eval::tensor_function::wrap_param;

namespace {

struct JoinWithNumberParam {
    const ValueType res_type;
    const join_fun_t function;
    JoinWithNumberParam(const ValueType &r, join_fun_t f) : res_type(r), function(f) {}
};

template <typename ICT, typename OCT, bool inplace>
ArrayRef<OCT> make_dst_cells(ConstArrayRef<ICT> src_cells, Stash &stash) {
    if constexpr (inplace) {
        static_assert(std::is_same_v<ICT,OCT>);
        return unconstify(src_cells);
    } else {
        return stash.create_uninitialized_array<OCT>(src_cells.size());
    }
}

template <typename ICT, typename OCT, typename Fun, bool inplace, bool swap>
void my_number_join_op(State &state, uint64_t param_in) {
    const auto &param = unwrap_param<JoinWithNumberParam>(param_in);
    using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
    OP my_op(param.function);
    const Value &tensor = state.peek(swap ? 0 : 1);
    OCT number = state.peek(swap ? 1 : 0).as_double();
    auto src_cells = tensor.cells().typify<ICT>();
    auto dst_cells = make_dst_cells<ICT, OCT, inplace>(src_cells, state.stash);
    apply_op2_vec_num(dst_cells.begin(), src_cells.begin(), number, dst_cells.size(), my_op);
    if (inplace) {
        state.pop_pop_push(tensor);
    } else {
        state.pop_pop_push(state.stash.create<ValueView>(param.res_type, tensor.index(), TypedCells(dst_cells)));
    }
}

struct SelectJoinWithNumberOp {
    template<typename CM, typename Fun,
             typename PrimaryMutable, typename NumberWasLeft>
    static auto invoke() {
        constexpr CellMeta icm = CM::value;
        constexpr CellMeta num(CellType::DOUBLE, true);
        constexpr CellMeta ocm = CellMeta::join(icm, num); 
        using ICT = CellValueType<icm.cell_type>;
        using OCT = CellValueType<ocm.cell_type>;
        constexpr bool inplace = (PrimaryMutable::value && std::is_same_v<ICT,OCT>);
        return my_number_join_op<ICT, OCT, Fun, inplace, NumberWasLeft::value>;
    }
};

} // namespace <unnamed>

JoinWithNumberFunction::JoinWithNumberFunction(const Join &original, bool tensor_was_right)
    : tensor_function::Op2(original.result_type(), original.lhs(), original.rhs()),
      _primary(tensor_was_right ? Primary::RHS : Primary::LHS),
      _function(original.function())
{
}

JoinWithNumberFunction::~JoinWithNumberFunction() = default;

bool
JoinWithNumberFunction::primary_is_mutable() const {
    if (_primary == Primary::LHS) {
        return lhs().result_is_mutable();
    } else {
        return rhs().result_is_mutable();
    }
}

using MyTypify = TypifyValue<TypifyCellMeta,vespalib::TypifyBool,operation::TypifyOp2>;

InterpretedFunction::Instruction
JoinWithNumberFunction::compile_self(const CTFContext &ctx) const
{
    const auto &param = ctx.stash.create<JoinWithNumberParam>(result_type(), _function);
    auto input_type = (_primary == Primary::LHS) ? lhs().result_type() : rhs().result_type();
    assert(result_type() == ValueType::join(input_type, ValueType::double_type()));
    auto op = typify_invoke<4,MyTypify,SelectJoinWithNumberOp>(input_type.cell_meta(),
                                                               _function,
                                                               primary_is_mutable(),
                                                               (_primary == Primary::RHS));
    return Instruction(op, wrap_param<JoinWithNumberParam>(param));
}

void
JoinWithNumberFunction::visit_self(vespalib::ObjectVisitor &visitor) const
{
    Super::visit_self(visitor);
    visitor.visitBool("tensor_was_right", (_primary == Primary::RHS));
    visitor.visitBool("primary_is_mutable", primary_is_mutable());
}

const TensorFunction &
JoinWithNumberFunction::optimize(const TensorFunction &expr, Stash &stash)
{
    if (! expr.result_type().is_double()) {
        if (const auto *join = as<Join>(expr)) {
            const TensorFunction &lhs = join->lhs();
            const TensorFunction &rhs = join->rhs();
            if (lhs.result_type().is_double()) {
                return stash.create<JoinWithNumberFunction>(*join, true);
            }
            if (rhs.result_type().is_double()) {
                return stash.create<JoinWithNumberFunction>(*join, false);
            }
        }
    }
    return expr;
}

} // namespace