diff options
-rw-r--r-- | eval/src/tests/eval/function/function_test.cpp | 95 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/basic_nodes.cpp | 33 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/basic_nodes.h | 28 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/call_nodes.h | 10 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/fast_forest.cpp | 12 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/key_gen.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp | 8 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/make_tensor_function.cpp | 10 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operator_nodes.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/operator_nodes.h | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/vm_forest.cpp | 14 | ||||
-rw-r--r-- | searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp | 10 | ||||
-rw-r--r-- | searchlib/src/tests/features/constant/constant_test.cpp | 46 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/constant_feature.cpp | 27 |
14 files changed, 215 insertions, 88 deletions
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index 4dff2934873..aca19e2ccc9 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -3,12 +3,15 @@ #include <vespa/eval/eval/function.h> #include <vespa/eval/eval/operator_nodes.h> #include <vespa/eval/eval/node_traverser.h> +#include <vespa/eval/eval/value_codec.h> #include <set> #include <vespa/eval/eval/test/eval_spec.h> +#include <vespa/eval/eval/test/gen_spec.h> #include <vespa/eval/eval/check_type.h> using namespace vespalib::eval; using namespace vespalib::eval::nodes; +using vespalib::eval::test::GenSpec; std::vector<vespalib::string> params({"x", "y", "z", "w"}); @@ -351,7 +354,7 @@ TEST("require that Not child can be accessed") { const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(1u, root.num_children()); - EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); + EXPECT_EQUAL(1.0, root.get_child(0).get_const_double_value()); } TEST("require that If children can be accessed") { @@ -359,9 +362,9 @@ TEST("require that If children can be accessed") { const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(3u, root.num_children()); - EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); - EXPECT_EQUAL(2.0, root.get_child(1).get_const_value()); - EXPECT_EQUAL(3.0, root.get_child(2).get_const_value()); + EXPECT_EQUAL(1.0, root.get_child(0).get_const_double_value()); + EXPECT_EQUAL(2.0, root.get_child(1).get_const_double_value()); + EXPECT_EQUAL(3.0, root.get_child(2).get_const_double_value()); } TEST("require that Operator children can be accessed") { @@ -369,8 +372,8 @@ TEST("require that Operator children can be accessed") { const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(2u, root.num_children()); - EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); - EXPECT_EQUAL(2.0, root.get_child(1).get_const_value()); + EXPECT_EQUAL(1.0, root.get_child(0).get_const_double_value()); + EXPECT_EQUAL(2.0, root.get_child(1).get_const_double_value()); } TEST("require that Call children can be accessed") { @@ -378,8 +381,8 @@ TEST("require that Call children can be accessed") { const Node &root = f->root(); EXPECT_TRUE(!root.is_leaf()); ASSERT_EQUAL(2u, root.num_children()); - EXPECT_EQUAL(1.0, root.get_child(0).get_const_value()); - EXPECT_EQUAL(2.0, root.get_child(1).get_const_value()); + EXPECT_EQUAL(1.0, root.get_child(0).get_const_double_value()); + EXPECT_EQUAL(2.0, root.get_child(1).get_const_double_value()); } struct MyNodeHandler : public NodeHandler { @@ -498,7 +501,7 @@ TEST("require that node types can be checked") { TEST("require that parameter is param, but not const") { EXPECT_TRUE(Function::parse("x")->root().is_param()); - EXPECT_TRUE(!Function::parse("x")->root().is_const()); + EXPECT_TRUE(!Function::parse("x")->root().is_const_double()); } TEST("require that inverted parameter is not param") { @@ -506,43 +509,43 @@ TEST("require that inverted parameter is not param") { } TEST("require that number is const, but not param") { - EXPECT_TRUE(Function::parse("123")->root().is_const()); + EXPECT_TRUE(Function::parse("123")->root().is_const_double()); EXPECT_TRUE(!Function::parse("123")->root().is_param()); } TEST("require that string is const") { - EXPECT_TRUE(Function::parse("\"x\"")->root().is_const()); + EXPECT_TRUE(Function::parse("\"x\"")->root().is_const_double()); } TEST("require that neg is const if sub-expression is const") { - EXPECT_TRUE(Function::parse("-123")->root().is_const()); - EXPECT_TRUE(!Function::parse("-x")->root().is_const()); + EXPECT_TRUE(Function::parse("-123")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("-x")->root().is_const_double()); } TEST("require that not is const if sub-expression is const") { - EXPECT_TRUE(Function::parse("!1")->root().is_const()); - EXPECT_TRUE(!Function::parse("!x")->root().is_const()); + EXPECT_TRUE(Function::parse("!1")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("!x")->root().is_const_double()); } TEST("require that operators are cost if both children are const") { - EXPECT_TRUE(!Function::parse("x+y")->root().is_const()); - EXPECT_TRUE(!Function::parse("1+y")->root().is_const()); - EXPECT_TRUE(!Function::parse("x+2")->root().is_const()); - EXPECT_TRUE(Function::parse("1+2")->root().is_const()); + EXPECT_TRUE(!Function::parse("x+y")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("1+y")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("x+2")->root().is_const_double()); + EXPECT_TRUE(Function::parse("1+2")->root().is_const_double()); } TEST("require that set membership is never tagged as const (NB: avoids jit recursion)") { - EXPECT_TRUE(!Function::parse("x in [x,y,z]")->root().is_const()); - EXPECT_TRUE(!Function::parse("1 in [x,y,z]")->root().is_const()); - EXPECT_TRUE(!Function::parse("1 in [1,y,z]")->root().is_const()); - EXPECT_TRUE(!Function::parse("1 in [1,2,3]")->root().is_const()); + EXPECT_TRUE(!Function::parse("x in [x,y,z]")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("1 in [x,y,z]")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("1 in [1,y,z]")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("1 in [1,2,3]")->root().is_const_double()); } TEST("require that calls are cost if all parameters are const") { - EXPECT_TRUE(!Function::parse("max(x,y)")->root().is_const()); - EXPECT_TRUE(!Function::parse("max(1,y)")->root().is_const()); - EXPECT_TRUE(!Function::parse("max(x,2)")->root().is_const()); - EXPECT_TRUE(Function::parse("max(1,2)")->root().is_const()); + EXPECT_TRUE(!Function::parse("max(x,y)")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("max(1,y)")->root().is_const_double()); + EXPECT_TRUE(!Function::parse("max(x,2)")->root().is_const_double()); + EXPECT_TRUE(Function::parse("max(1,2)")->root().is_const_double()); } //----------------------------------------------------------------------------- @@ -1059,4 +1062,42 @@ TEST_FF("require that all conformance test expressions can be parsed", //----------------------------------------------------------------------------- +TEST("require that constant double value can be (pre-)calculated") { + auto expect = GenSpec(42).gen(); + auto f = Function::parse("21+21"); + ASSERT_TRUE(!f->has_error()); + const Node &root = f->root(); + auto value = root.get_const_value(); + ASSERT_TRUE(value); + EXPECT_EQUAL(spec_from_value(*value), expect); +} + +TEST("require that constant tensor value can be (pre-)calculated") { + auto expect = GenSpec().idx("x", 10).gen(); + auto f = Function::parse("concat(tensor(x[4])(x+1),tensor(x[6])(x+5),x)"); + ASSERT_TRUE(!f->has_error()); + const Node &root = f->root(); + auto value = root.get_const_value(); + ASSERT_TRUE(value); + EXPECT_EQUAL(spec_from_value(*value), expect); +} + +TEST("require that non-const value cannot be (pre-)calculated") { + auto f = Function::parse("a+b"); + ASSERT_TRUE(!f->has_error()); + const Node &root = f->root(); + auto value = root.get_const_value(); + EXPECT_TRUE(value.get() == nullptr); +} + +TEST("require that parse error does not produce a const value") { + auto f = Function::parse("this is a parse error"); + EXPECT_TRUE(f->has_error()); + const Node &root = f->root(); + auto value = root.get_const_value(); + EXPECT_TRUE(value.get() == nullptr); +} + +//----------------------------------------------------------------------------- + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/basic_nodes.cpp b/eval/src/vespa/eval/eval/basic_nodes.cpp index d7fa76bf1cc..98ce50b2543 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.cpp +++ b/eval/src/vespa/eval/eval/basic_nodes.cpp @@ -5,6 +5,8 @@ #include "node_visitor.h" #include "interpreted_function.h" #include "simple_value.h" +#include "fast_value.h" +#include "node_tools.h" namespace vespalib::eval::nodes { @@ -21,8 +23,9 @@ struct Frame { } // namespace vespalib::eval::nodes::<unnamed> double -Node::get_const_value() const { - assert(is_const()); +Node::get_const_double_value() const +{ + assert(is_const_double()); NodeTypes node_types(*this); InterpretedFunction function(SimpleValueBuilderFactory::get(), *this, node_types); NoParams no_params; @@ -30,6 +33,24 @@ Node::get_const_value() const { return function.eval(ctx, no_params).as_double(); } +Value::UP +Node::get_const_value() const +{ + if (nodes::as<nodes::Error>(*this)) { + // cannot get const value for parse error + return {nullptr}; + } + if (NodeTools::min_num_params(*this) != 0) { + // cannot get const value for non-const sub-expression + return {nullptr}; + } + NodeTypes node_types(*this); + InterpretedFunction function(SimpleValueBuilderFactory::get(), *this, node_types); + NoParams no_params; + InterpretedFunction::Context ctx(function); + return FastValueBuilderFactory::get().copy(function.eval(ctx, no_params)); +} + void Node::traverse(NodeTraverser &traverser) const { @@ -69,16 +90,16 @@ If::If(Node_UP cond_in, Node_UP true_expr_in, Node_UP false_expr_in, double p_tr auto less = as<Less>(cond()); auto in = as<In>(cond()); auto inverted = as<Not>(cond()); - bool true_is_subtree = (true_expr().is_tree() || true_expr().is_const()); - bool false_is_subtree = (false_expr().is_tree() || false_expr().is_const()); + bool true_is_subtree = (true_expr().is_tree() || true_expr().is_const_double()); + bool false_is_subtree = (false_expr().is_tree() || false_expr().is_const_double()); if (true_is_subtree && false_is_subtree) { if (less) { - _is_tree = (less->lhs().is_param() && less->rhs().is_const()); + _is_tree = (less->lhs().is_param() && less->rhs().is_const_double()); } else if (in) { _is_tree = in->child().is_param(); } else if (inverted) { if (auto ge = as<GreaterEqual>(inverted->child())) { - _is_tree = (ge->lhs().is_param() && ge->rhs().is_const()); + _is_tree = (ge->lhs().is_param() && ge->rhs().is_const_double()); } } } diff --git a/eval/src/vespa/eval/eval/basic_nodes.h b/eval/src/vespa/eval/eval/basic_nodes.h index c1192585f7c..c6b19f6ce12 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.h +++ b/eval/src/vespa/eval/eval/basic_nodes.h @@ -2,6 +2,7 @@ #pragma once +#include "value.h" #include "string_stuff.h" #include <vespa/vespalib/util/hdr_abort.h> #include <vespa/vespalib/stllike/string.h> @@ -48,9 +49,10 @@ struct DumpContext { struct Node { virtual bool is_forest() const { return false; } virtual bool is_tree() const { return false; } - virtual bool is_const() const { return false; } + virtual bool is_const_double() const { return false; } virtual bool is_param() const { return false; } - virtual double get_const_value() const; + virtual double get_const_double_value() const; + Value::UP get_const_value() const; void traverse(NodeTraverser &traverser) const; virtual vespalib::string dump(DumpContext &ctx) const = 0; virtual void accept(NodeVisitor &visitor) const = 0; @@ -92,8 +94,8 @@ private: double _value; public: Number(double value_in) : _value(value_in) {} - virtual bool is_const() const override { return true; } - virtual double get_const_value() const override { return value(); } + virtual bool is_const_double() const override { return true; } + double get_const_double_value() const override { return value(); } double value() const { return _value; } vespalib::string dump(DumpContext &) const override { return make_string("%g", _value); @@ -120,8 +122,8 @@ private: vespalib::string _value; public: String(const vespalib::string &value_in) : _value(value_in) {} - bool is_const() const override { return true; } - double get_const_value() const override { return hash(); } + bool is_const_double() const override { return true; } + double get_const_double_value() const override { return hash(); } const vespalib::string &value() const { return _value; } uint32_t hash() const { return hash_code(_value.data(), _value.size()); } vespalib::string dump(DumpContext &) const override { @@ -137,7 +139,7 @@ private: public: In(Node_UP child) : _child(std::move(child)), _entries() {} void add_entry(Node_UP entry) { - assert(entry->is_const()); + assert(entry->is_const_double()); _entries.push_back(std::move(entry)); } size_t num_entries() const { return _entries.size(); } @@ -171,10 +173,10 @@ public: class Neg : public Node { private: Node_UP _child; - bool _is_const; + bool _is_const_double; public: - Neg(Node_UP child_in) : _child(std::move(child_in)), _is_const(_child->is_const()) {} - bool is_const() const override { return _is_const; } + Neg(Node_UP child_in) : _child(std::move(child_in)), _is_const_double(_child->is_const_double()) {} + bool is_const_double() const override { return _is_const_double; } const Node &child() const { return *_child; } size_t num_children() const override { return _child ? 1 : 0; } const Node &get_child(size_t idx) const override { @@ -198,10 +200,10 @@ public: class Not : public Node { private: Node_UP _child; - bool _is_const; + bool _is_const_double; public: - Not(Node_UP child_in) : _child(std::move(child_in)), _is_const(_child->is_const()) {} - bool is_const() const override { return _is_const; } + Not(Node_UP child_in) : _child(std::move(child_in)), _is_const_double(_child->is_const_double()) {} + bool is_const_double() const override { return _is_const_double; } const Node &child() const { return *_child; } size_t num_children() const override { return _child ? 1 : 0; } const Node &get_child(size_t idx) const override { diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h index c5a41756005..2a7d4173e64 100644 --- a/eval/src/vespa/eval/eval/call_nodes.h +++ b/eval/src/vespa/eval/eval/call_nodes.h @@ -25,12 +25,12 @@ private: vespalib::string _name; size_t _num_params; std::vector<Node_UP> _args; - bool _is_const; + bool _is_const_double; public: Call(const vespalib::string &name_in, size_t num_params_in) - : _name(name_in), _num_params(num_params_in), _is_const(false) {} + : _name(name_in), _num_params(num_params_in), _is_const_double(false) {} ~Call(); - bool is_const() const override { return _is_const; } + bool is_const_double() const override { return _is_const_double; } const vespalib::string &name() const { return _name; } size_t num_params() const { return _num_params; } size_t num_args() const { return _args.size(); } @@ -45,9 +45,9 @@ public: } virtual void bind_next(Node_UP arg_in) { if (_args.empty()) { - _is_const = arg_in->is_const(); + _is_const_double = arg_in->is_const_double(); } else { - _is_const = (_is_const && arg_in->is_const()); + _is_const_double = (_is_const_double && arg_in->is_const_double()); } _args.push_back(std::move(arg_in)); } diff --git a/eval/src/vespa/eval/eval/fast_forest.cpp b/eval/src/vespa/eval/eval/fast_forest.cpp index 47932ff14fb..4eea4a5cce7 100644 --- a/eval/src/vespa/eval/eval/fast_forest.cpp +++ b/eval/src/vespa/eval/eval/fast_forest.cpp @@ -85,26 +85,26 @@ State::encode_node(uint32_t tree_id, const nodes::Node &node) if (less) { auto symbol = nodes::as<nodes::Symbol>(less->lhs()); assert(symbol); - assert(less->rhs().is_const()); + assert(less->rhs().is_const_double()); size_t feature = symbol->id(); assert(feature < cmp_nodes.size()); - cmp_nodes[feature].emplace_back(less->rhs().get_const_value(), tree_id, true_leafs, true); + cmp_nodes[feature].emplace_back(less->rhs().get_const_double_value(), tree_id, true_leafs, true); } else { assert(inverted); auto ge = nodes::as<nodes::GreaterEqual>(inverted->child()); assert(ge); auto symbol = nodes::as<nodes::Symbol>(ge->lhs()); assert(symbol); - assert(ge->rhs().is_const()); + assert(ge->rhs().is_const_double()); size_t feature = symbol->id(); assert(feature < cmp_nodes.size()); - cmp_nodes[feature].emplace_back(ge->rhs().get_const_value(), tree_id, true_leafs, false); + cmp_nodes[feature].emplace_back(ge->rhs().get_const_double_value(), tree_id, true_leafs, false); } return BitRange::join(true_leafs, false_leafs); } else { - assert(node.is_const()); + assert(node.is_const_double()); BitRange leaf_range(leafs[tree_id].size()); - leafs[tree_id].push_back(node.get_const_value()); + leafs[tree_id].push_back(node.get_const_double_value()); return leaf_range; } } diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index 80803d3b2a2..a8fb205f124 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -28,7 +28,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const In &node) override { add_byte( 4); add_size(node.num_entries()); for (size_t i = 0; i < node.num_entries(); ++i) { - add_double(node.get_entry(i).get_const_value()); + add_double(node.get_entry(i).get_const_double_value()); } } void visit(const Neg &) override { add_byte( 5); } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index fce9abb7316..42911a56c14 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -61,7 +61,7 @@ struct SetMemberHash : PluginState { vespalib::hash_set<double> members; explicit SetMemberHash(const In &in) : members(in.num_entries() * 3) { for (size_t i = 0; i < in.num_entries(); ++i) { - members.insert(in.get_entry(i).get_const_value()); + members.insert(in.get_entry(i).get_const_double_value()); } } static bool check_membership(const PluginState *state, double value) { @@ -260,8 +260,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { //------------------------------------------------------------------------- bool open(const Node &node) override { - if (node.is_const()) { - push_double(node.get_const_value()); + if (node.is_const_double()) { + push_double(node.get_const_double_value()); return false; } if (!inside_forest && (pass_params != PassParams::SEPARATE) && node.is_forest()) { @@ -412,7 +412,7 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { // build explicit code to check all set members llvm::Value *found = builder.getFalse(); for (size_t i = 0; i < item.num_entries(); ++i) { - llvm::Value *elem = llvm::ConstantFP::get(builder.getDoubleTy(), item.get_entry(i).get_const_value()); + llvm::Value *elem = llvm::ConstantFP::get(builder.getDoubleTy(), item.get_entry(i).get_const_double_value()); llvm::Value *elem_eq = builder.CreateFCmpOEQ(lhs, elem, "elem_eq"); found = builder.CreateOr(found, elem_eq, "found"); } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 15f188db51a..b65c3d5aaa7 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -83,14 +83,14 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { bool maybe_make_const(const Node &node) { if (auto create = as<TensorCreate>(node)) { - bool is_const = true; + bool is_const_double = true; for (size_t i = 0; i < create->num_children(); ++i) { - is_const &= create->get_child(i).is_const(); + is_const_double &= create->get_child(i).is_const_double(); } - if (is_const) { + if (is_const_double) { TensorSpec spec(create->type().to_spec()); for (size_t i = 0; i < create->num_children(); ++i) { - spec.add(create->get_child_address(i), create->get_child(i).get_const_value()); + spec.add(create->get_child_address(i), create->get_child(i).get_const_double_value()); } make_const(node, *stash.create<Value::UP>(value_from_spec(spec, factory))); return true; @@ -172,7 +172,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const In &node) override { auto my_in = std::make_unique<In>(std::make_unique<Symbol>(0)); for (size_t i = 0; i < node.num_entries(); ++i) { - my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_value())); + my_in->add_entry(std::make_unique<Number>(node.get_entry(i).get_const_double_value())); } auto my_fun = Function::create(std::move(my_in), {"x"}); const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(*my_fun, PassParams::SEPARATE)); diff --git a/eval/src/vespa/eval/eval/operator_nodes.cpp b/eval/src/vespa/eval/eval/operator_nodes.cpp index 4c66268dfa2..98072b99324 100644 --- a/eval/src/vespa/eval/eval/operator_nodes.cpp +++ b/eval/src/vespa/eval/eval/operator_nodes.cpp @@ -13,7 +13,7 @@ Operator::Operator(const vespalib::string &op_str_in, int priority_in, Order ord _order(order_in), _lhs(), _rhs(), - _is_const(false) + _is_const_double(false) {} Operator::~Operator() { } diff --git a/eval/src/vespa/eval/eval/operator_nodes.h b/eval/src/vespa/eval/eval/operator_nodes.h index eafd817d42c..5562659f0f7 100644 --- a/eval/src/vespa/eval/eval/operator_nodes.h +++ b/eval/src/vespa/eval/eval/operator_nodes.h @@ -32,7 +32,7 @@ private: Order _order; Node_UP _lhs; Node_UP _rhs; - bool _is_const; + bool _is_const_double; public: Operator(const vespalib::string &op_str_in, int priority_in, Order order_in); @@ -42,7 +42,7 @@ public: Order order() const { return _order; } const Node &lhs() const { return *_lhs; } const Node &rhs() const { return *_rhs; } - bool is_const() const override { return _is_const; } + bool is_const_double() const override { return _is_const_double; } size_t num_children() const override { return (_lhs && _rhs) ? 2 : 0; } const Node &get_child(size_t idx) const override { assert(idx < 2); @@ -67,7 +67,7 @@ public: virtual void bind(Node_UP lhs_in, Node_UP rhs_in) { _lhs = std::move(lhs_in); _rhs = std::move(rhs_in); - _is_const = (_lhs->is_const() && _rhs->is_const()); + _is_const_double = (_lhs->is_const_double() && _rhs->is_const_double()); } vespalib::string dump(DumpContext &ctx) const override { diff --git a/eval/src/vespa/eval/eval/vm_forest.cpp b/eval/src/vespa/eval/eval/vm_forest.cpp index e0fac9405ce..a31c5f502ac 100644 --- a/eval/src/vespa/eval/eval/vm_forest.cpp +++ b/eval/src/vespa/eval/eval/vm_forest.cpp @@ -136,8 +136,8 @@ void encode_less(const nodes::Less &less, auto symbol = nodes::as<nodes::Symbol>(less.lhs()); assert(symbol); model_out.push_back(uint32_t(symbol->id()) << 12); - assert(less.rhs().is_const()); - encode_const(less.rhs().get_const_value(), model_out); + assert(less.rhs().is_const_double()); + encode_const(less.rhs().get_const_double_value(), model_out); size_t skip_idx = model_out.size(); model_out.push_back(0); // left child size placeholder uint32_t left_type = encode_node(left_child, model_out); @@ -157,7 +157,7 @@ void encode_in(const nodes::In &in, size_t set_size_idx = model_out.size(); model_out.push_back(in.num_entries()); for (size_t i = 0; i < in.num_entries(); ++i) { - encode_large_const(in.get_entry(i).get_const_value(), model_out); + encode_large_const(in.get_entry(i).get_const_double_value(), model_out); } size_t left_idx = model_out.size(); uint32_t left_type = encode_node(left_child, model_out); @@ -176,8 +176,8 @@ void encode_inverted(const nodes::Not &inverted, auto symbol = nodes::as<nodes::Symbol>(ge->lhs()); assert(symbol); model_out.push_back(uint32_t(symbol->id()) << 12); - assert(ge->rhs().is_const()); - encode_const(ge->rhs().get_const_value(), model_out); + assert(ge->rhs().is_const_double()); + encode_const(ge->rhs().get_const_double_value(), model_out); size_t skip_idx = model_out.size(); model_out.push_back(0); // left child size placeholder uint32_t left_type = encode_node(left_child, model_out); @@ -204,8 +204,8 @@ uint32_t encode_node(const nodes::Node &node_in, std::vector<uint32_t> &model_ou return INVERTED; } } else { - assert(node_in.is_const()); - encode_const(node_in.get_const_value(), model_out); + assert(node_in.is_const_double()); + encode_const(node_in.get_const_double_value(), model_out); return LEAF; } } diff --git a/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp b/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp index f2a9ee2932f..3445d64c477 100644 --- a/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp +++ b/searchlib/src/apps/vespa-ranking-expression-analyzer/vespa-ranking-expression-analyzer.cpp @@ -95,11 +95,11 @@ struct FunctionInfo { if (node) { auto lhs_symbol = as<Symbol>(node->lhs()); auto rhs_symbol = as<Symbol>(node->rhs()); - if (lhs_symbol && node->rhs().is_const()) { - inputs[lhs_symbol->id()].cmp_with.push_back(node->rhs().get_const_value()); + if (lhs_symbol && node->rhs().is_const_double()) { + inputs[lhs_symbol->id()].cmp_with.push_back(node->rhs().get_const_double_value()); } - if (node->lhs().is_const() && rhs_symbol) { - inputs[rhs_symbol->id()].cmp_with.push_back(node->lhs().get_const_value()); + if (node->lhs().is_const_double() && rhs_symbol) { + inputs[rhs_symbol->id()].cmp_with.push_back(node->lhs().get_const_double_value()); } } } @@ -108,7 +108,7 @@ struct FunctionInfo { if (node) { if (auto symbol = as<Symbol>(node->child())) { for (size_t i = 0; i < node->num_entries(); ++i) { - inputs[symbol->id()].cmp_with.push_back(node->get_entry(i).get_const_value()); + inputs[symbol->id()].cmp_with.push_back(node->get_entry(i).get_const_double_value()); } } } diff --git a/searchlib/src/tests/features/constant/constant_test.cpp b/searchlib/src/tests/features/constant/constant_test.cpp index 9c8480c1da2..1ef985f9e36 100644 --- a/searchlib/src/tests/features/constant/constant_test.cpp +++ b/searchlib/src/tests/features/constant/constant_test.cpp @@ -1,4 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + #include <vespa/vespalib/testkit/test_kit.h> #include <iostream> #include <vespa/searchlib/features/setup.h> @@ -7,9 +8,11 @@ #include <vespa/searchlib/fef/test/indexenvironment.h> #include <vespa/eval/eval/function.h> #include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/eval/node_types.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/test/value_compare.h> +#include <vespa/vespalib/util/stringfmt.h> using search::feature_t; using namespace search::fef; @@ -19,9 +22,11 @@ using namespace search::features; using vespalib::eval::DoubleValue; using vespalib::eval::Function; using vespalib::eval::SimpleValue; +using vespalib::eval::NodeTypes; using vespalib::eval::TensorSpec; using vespalib::eval::Value; using vespalib::eval::ValueType; +using vespalib::make_string_short::fmt; namespace { @@ -68,12 +73,18 @@ struct ExecFixture std::move(type), std::move(tensor)); } - void addDouble(const vespalib::string &name, const double value) { test.getIndexEnv().addConstantValue(name, ValueType::double_type(), std::make_unique<DoubleValue>(value)); } + void addTypeValue(const vespalib::string &name, const vespalib::string &type, const vespalib::string &value) { + auto &props = test.getIndexEnv().getProperties(); + auto type_prop = fmt("constant(%s).type", name.c_str()); + auto value_prop = fmt("constant(%s).value", name.c_str()); + props.add(type_prop, type); + props.add(value_prop, value); + } }; TEST_F("require that missing constant is detected", @@ -108,5 +119,38 @@ TEST_F("require that existing double constant is detected", EXPECT_EQUAL(42.0, f.executeDouble()); } +//----------------------------------------------------------------------------- + +TEST_F("require that constants can be functional", ExecFixture("constant(foo)")) { + f.addTypeValue("foo", "tensor(x{})", "tensor(x{}):{a:3,b:5,c:7}"); + EXPECT_TRUE(f.setup()); + auto expect = make_tensor(TensorSpec("tensor(x{})") + .add({{"x","b"}}, 5) + .add({{"x","c"}}, 7) + .add({{"x","a"}}, 3)); + EXPECT_EQUAL(*expect, f.executeTensor()); +} + +TEST_F("require that functional constant type must match the expression result", ExecFixture("constant(foo)")) { + f.addTypeValue("foo", "tensor<float>(x{})", "tensor(x{}):{a:3,b:5,c:7}"); + EXPECT_TRUE(!f.setup()); +} + +TEST_F("require that functional constant must parse without errors", ExecFixture("constant(foo)")) { + f.addTypeValue("foo", "double", "this is parse error"); + EXPECT_TRUE(!f.setup()); +} + +TEST_F("require that non-const functional constant is not allowed", ExecFixture("constant(foo)")) { + f.addTypeValue("foo", "tensor(x{})", "tensor(x{}):{a:a,b:5,c:7}"); + EXPECT_TRUE(!f.setup()); +} + +TEST_F("require that functional constant must have non-error type", ExecFixture("constant(foo)")) { + f.addTypeValue("foo", "error", "impossible to create value with error type"); + EXPECT_TRUE(!f.setup()); +} + +//----------------------------------------------------------------------------- TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/vespa/searchlib/features/constant_feature.cpp b/searchlib/src/vespa/searchlib/features/constant_feature.cpp index 5eedb5834bf..9de4d351584 100644 --- a/searchlib/src/vespa/searchlib/features/constant_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/constant_feature.cpp @@ -3,6 +3,8 @@ #include "constant_feature.h" #include "valuefeature.h" #include <vespa/searchlib/fef/featureexecutor.h> +#include <vespa/searchlib/fef/properties.h> +#include <vespa/eval/eval/function.h> #include <vespa/eval/eval/value_cache/constant_value.h> #include <vespa/vespalib/util/stash.h> @@ -11,6 +13,10 @@ LOG_SETUP(".features.constant_feature"); using namespace search::fef; +using vespalib::eval::ValueType; +using vespalib::eval::Function; +using vespalib::eval::SimpleConstantValue; + namespace search::features { /** @@ -62,13 +68,26 @@ ConstantBlueprint::setup(const IIndexEnvironment &env, _key = params[0].getValue(); _value = env.getConstantValue(_key); if (!_value) { - fail("Constant '%s' not found", _key.c_str()); + auto type_prop = env.getProperties().lookup(getName(), "type"); + auto value_prop = env.getProperties().lookup(getName(), "value"); + if ((type_prop.size() == 1) && (value_prop.size() == 1)) { + auto type = ValueType::from_spec(type_prop.get()); + auto value = Function::parse(value_prop.get())->root().get_const_value(); + if (!type.is_error() && value && (value->type() == type)) { + _value = std::make_unique<SimpleConstantValue>(std::move(value)); + } else { + fail("Constant '%s' has invalid spec: type='%s', value='%s'", + _key.c_str(), type_prop.get().c_str(), value_prop.get().c_str()); + } + } else { + fail("Constant '%s' not found", _key.c_str()); + } } else if (_value->type().is_error()) { fail("Constant '%s' has invalid type", _key.c_str()); } - FeatureType output_type = _value ? - FeatureType::object(_value->type()) : - FeatureType::number(); + FeatureType output_type = _value + ? FeatureType::object(_value->type()) + : FeatureType::number(); describeOutput("out", "The constant looked up in index environment using the given key.", output_type); return (_value && !_value->type().is_error()); |