aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/expression/numericfunctionnode.cpp
blob: d9c664e5cde37245a97d6c222d013bb8c43af2b5 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "numericfunctionnode.h"
#include <stdexcept>

namespace search::expression {

IMPLEMENT_ABSTRACT_EXPRESSIONNODE(NumericFunctionNode,  MultiArgFunctionNode);

NumericFunctionNode::NumericFunctionNode() = default;
NumericFunctionNode::~NumericFunctionNode() = default;

NumericFunctionNode::NumericFunctionNode(const NumericFunctionNode & rhs) :
    MultiArgFunctionNode(rhs),
    _handler()
{
}

NumericFunctionNode & NumericFunctionNode::operator = (const NumericFunctionNode & rhs)
{
    if (this != &rhs) {
        MultiArgFunctionNode::operator =(rhs);
        _handler.reset();
    }
    return *this;
}

void NumericFunctionNode::onPrepare(bool preserveAccurateTypes)
{
    MultiArgFunctionNode::onPrepare(preserveAccurateTypes);
    if (getNumArgs() == 1) {
        if (getArg(0).getResult()->getClass().inherits(IntegerResultNodeVector::classId)) {
            _handler.reset(new FlattenIntegerHandler(*this));
        } else if (getArg(0).getResult()->getClass().inherits(FloatResultNodeVector::classId)) {
            _handler.reset(new FlattenFloatHandler(*this));
        } else if (getArg(0).getResult()->getClass().inherits(StringResultNodeVector::classId)) {
            _handler.reset(new FlattenStringHandler(*this));
        } else {
            throw std::runtime_error(vespalib::string("No FlattenHandler for ") + getArg(0).getResult()->getClass().name());
        }
    } else {
        if (getResult()->getClass().inherits(IntegerResultNodeVector::classId)) {
            _handler.reset(new VectorIntegerHandler(*this));
        } else if (getResult()->getClass().inherits(FloatResultNodeVector::classId)) {
            _handler.reset(new VectorFloatHandler(*this));
        } else if (getResult()->getClass().inherits(StringResultNodeVector::classId)) {
            _handler.reset(new VectorStringHandler(*this));
        } else if (getResult()->getClass().inherits(IntegerResultNode::classId)) {
            _handler.reset(new ScalarIntegerHandler(*this));
        } else if (getResult()->getClass().inherits(FloatResultNode::classId)) {
            _handler.reset(new ScalarFloatHandler(*this));
        } else if (getResult()->getClass().inherits(StringResultNode::classId)) {
            _handler.reset(new ScalarStringHandler(*this));
        } else if (getResult()->getClass().inherits(RawResultNode::classId)) {
            _handler.reset(new ScalarRawHandler(*this));
        } else {
            throw std::runtime_error(vespalib::make_string("NumericFunctionNode::onPrepare does not handle results of type %s", getResult()->getClass().name()));
        }
    }
}

bool NumericFunctionNode::onCalculate(const ExpressionNodeVector & args, ResultNode & result) const
{
    bool retval(true);
    (void) result;
    _handler->handleFirst(*args[0]->getResult());
    for (size_t i(1), m(args.size()); i < m; i++) {
        _handler->handle(*args[i]->getResult());
    }
    return retval;
}

template <typename T>
void NumericFunctionNode::VectorHandler<T>::handle(const ResultNode & arg)
{
    typename T::Vector & result = _result.getVector();
    if (arg.getClass().inherits(ResultNodeVector::classId)) {
        const ResultNodeVector & av = static_cast<const ResultNodeVector &> (arg);
        const size_t argSize(av.size());
        const size_t oldRSize(result.size());
        if (argSize > oldRSize) {
            result.resize(argSize);
            for (size_t i(oldRSize); i < argSize; i++) {
                result[i] = result[i%oldRSize];
            }
        }
        for (size_t i(0), m(result.size()), isize(argSize); i < m; i++) {
            function().executeIterative(av.get(i%isize), result[i]);
        }
    } else {
        for (size_t i(0), m(result.size()); i < m; i++) {
            function().executeIterative(arg, result[i]);
        }
    }
}

template <typename T>
void NumericFunctionNode::VectorHandler<T>::handleFirst(const ResultNode & arg)
{
    typename T::Vector & result = _result.getVector();
    if (arg.getClass().inherits(ResultNodeVector::classId)) {
        const ResultNodeVector & av = static_cast<const ResultNodeVector &> (arg);
        result.resize(av.size());
        for (size_t i(0), m(result.size()); i < m; i++) {
            result[i].set(av.get(i));
        }
    } else {
        result.resize(1);
        result[0].set(arg);
    }
}


void NumericFunctionNode::ScalarIntegerHandler::handle(const ResultNode & arg)
{
    function().executeIterative(arg, _result);
}

void NumericFunctionNode::ScalarFloatHandler::handle(const ResultNode & arg)
{
    function().executeIterative(arg, _result);
}

void NumericFunctionNode::ScalarStringHandler::handle(const ResultNode & arg)
{
    function().executeIterative(arg, _result);
}

void NumericFunctionNode::ScalarRawHandler::handle(const ResultNode & arg)
{
    function().executeIterative(arg, _result);
}

void NumericFunctionNode::FlattenIntegerHandler::handle(const ResultNode & arg)
{
    _result.set(_initial);
    function().flatten(static_cast<const ResultNodeVector &> (arg), _result);
}

void NumericFunctionNode::FlattenFloatHandler::handle(const ResultNode & arg)
{
    _result.set(_initial);
    function().flatten(static_cast<const ResultNodeVector &> (arg), _result);
}

void NumericFunctionNode::FlattenStringHandler::handle(const ResultNode & arg)
{
    _result.set(_initial);
    function().flatten(static_cast<const ResultNodeVector &> (arg), _result);
}

}

// this function was added by ../../forcelink.sh

// this function was added by ../../forcelink.sh
void forcelink_file_searchlib_expression_numericfunctionnode() {}