summaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/basic_nodes.cpp
blob: eac0a7d97e812b0d9084d8d4bd229aafe5c399eb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "basic_nodes.h"
#include "node_traverser.h"
#include "node_visitor.h"
#include "interpreted_function.h"
#include "simple_value.h"
#include "fast_value.h"
#include "node_tools.h"
#include <vespa/vespalib/util/stringfmt.h>

namespace vespalib::eval::nodes {

namespace {

struct Frame {
    const Node &node;
    size_t child_idx;
    explicit Frame(const Node &node_in) noexcept : node(node_in), child_idx(0) {}
    bool has_next_child() const { return (child_idx < node.num_children()); }
    const Node &next_child() { return node.get_child(child_idx++); }
};

} // namespace vespalib::eval::nodes::<unnamed>

vespalib::string
Number::dump(DumpContext &) const {
    return make_string("%g", _value);
}

vespalib::string
If::dump(DumpContext &ctx) const {
    vespalib::string str;
    str += "if(";
    str += _cond->dump(ctx);
    str += ",";
    str += _true_expr->dump(ctx);
    str += ",";
    str += _false_expr->dump(ctx);
    if (_p_true != 0.5) {
        str += make_string(",%g", _p_true);
    }
    str += ")";
    return str;
}
double
Node::get_const_double_value() const
{
    assert(is_const_double());
    NodeTypes node_types(*this);
    InterpretedFunction function(SimpleValueBuilderFactory::get(), *this, node_types);
    NoParams no_params;
    InterpretedFunction::Context ctx(function);
    return function.eval(ctx, no_params).as_double();
}

Value::UP
Node::get_const_value() const
{
    if (nodes::as<nodes::Error>(*this)) {
        // cannot get const value for parse error
        return {nullptr};
    }
    if (NodeTools::min_num_params(*this) != 0) {
        // cannot get const value for non-const sub-expression
        return {nullptr};
    }
    NodeTypes node_types(*this);
    InterpretedFunction function(SimpleValueBuilderFactory::get(), *this, node_types);
    NoParams no_params;
    InterpretedFunction::Context ctx(function);
    return FastValueBuilderFactory::get().copy(function.eval(ctx, no_params));
}

void
Node::traverse(NodeTraverser &traverser) const
{
    if (!traverser.open(*this)) {
        return;
    }
    std::vector<Frame> stack({Frame(*this)});
    while (!stack.empty()) {
        if (stack.back().has_next_child()) {
            const Node &next_child = stack.back().next_child();
            if (traverser.open(next_child)) {
                stack.emplace_back(next_child);
            }
        } else {
            traverser.close(stack.back().node);
            stack.pop_back();
        }
    }
}

void Number::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Symbol::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void String::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void In    ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Neg   ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Not   ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void If    ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
void Error ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }

If::If(Node_UP cond_in, Node_UP true_expr_in, Node_UP false_expr_in, double p_true_in)
    : _cond(std::move(cond_in)),
      _true_expr(std::move(true_expr_in)),
      _false_expr(std::move(false_expr_in)),
      _p_true(p_true_in),
      _is_tree(false)
{
    auto less = as<Less>(cond());
    auto in = as<In>(cond());
    auto inverted = as<Not>(cond());
    bool true_is_subtree = (true_expr().is_tree() || true_expr().is_const_double());
    bool false_is_subtree = (false_expr().is_tree() || false_expr().is_const_double());
    if (true_is_subtree && false_is_subtree) {
        if (less) {
            _is_tree = (less->lhs().is_param() && less->rhs().is_const_double());
        } else if (in) {
            _is_tree = in->child().is_param();
        } else if (inverted) {
            if (auto ge = as<GreaterEqual>(inverted->child())) {
                _is_tree = (ge->lhs().is_param() && ge->rhs().is_const_double());
            }
        }
    }
}

}