diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-10-26 14:54:00 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-10-26 14:54:00 +0000 |
commit | 51a6686b3b0dd58ccf352aac41feacd797efddc1 (patch) | |
tree | c9cbb4ef82f4245291d8ff60a8a597ec2189b765 /eval | |
parent | 56947142b32ed745072782da08bdadf7c4be0c52 (diff) |
normalize to fewer operation primitives
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp | 4 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/interpreted_function.cpp | 228 |
2 files changed, 124 insertions, 108 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index fb31601bf0f..0e548a3b82e 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -129,6 +129,10 @@ TEST("require that interpreted function instructions have expected size") { EXPECT_EQUAL(sizeof(InterpretedFunction::Instruction), 16u); } +TEST("require that function pointers can be passed as instruction parameters") { + EXPECT_EQUAL(sizeof(&operation::Add::f), sizeof(uint64_t)); +} + TEST("require that basic addition works") { Function function = Function::parse("a+10"); InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes()); diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp index 41339791f47..09cc2d57e80 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.cpp +++ b/eval/src/vespa/eval/eval/interpreted_function.cpp @@ -18,6 +18,8 @@ namespace { using namespace nodes; using State = InterpretedFunction::State; using Instruction = InterpretedFunction::Instruction; +using map_fun_t = double (*)(double); +using join_fun_t = double (*)(double, double); //----------------------------------------------------------------------------- @@ -32,6 +34,13 @@ const T &unwrap_param(uint64_t param) { return *((const T *)param); } //----------------------------------------------------------------------------- +uint64_t to_param(map_fun_t value) { return (uint64_t)value; } +uint64_t to_param(join_fun_t value) { return (uint64_t)value; } +map_fun_t to_map_fun(uint64_t param) { return (map_fun_t)param; } +join_fun_t to_join_fun(uint64_t param) { return (join_fun_t)param; } + +//----------------------------------------------------------------------------- + void op_load_const(State &state, uint64_t param) { state.stack.push_back(unwrap_param<Value>(param)); } @@ -46,18 +55,6 @@ void op_load_let(State &state, uint64_t param) { //----------------------------------------------------------------------------- -template <typename OP1> -void op_unary(State &state, uint64_t) { - state.replace(1, state.engine.map(state.peek(0), OP1::f, state.stash)); -} - -template <typename OP2> -void op_binary(State &state, uint64_t) { - state.replace(2, state.engine.join(state.peek(1), state.peek(0), OP2::f, state.stash)); -} - -//----------------------------------------------------------------------------- - void op_skip(State &state, uint64_t param) { state.program_offset += param; } @@ -102,13 +99,11 @@ void op_not_member(State &state, uint64_t) { //----------------------------------------------------------------------------- void op_tensor_map(State &state, uint64_t param) { - const CompiledFunction &cfun = unwrap_param<CompiledFunction>(param); - state.replace(1, state.engine.map(state.peek(0), cfun.get_function<1>(), state.stash)); + state.replace(1, state.engine.map(state.peek(0), to_map_fun(param), state.stash)); } void op_tensor_join(State &state, uint64_t param) { - const CompiledFunction &cfun = unwrap_param<CompiledFunction>(param); - state.replace(2, state.engine.join(state.peek(1), state.peek(0), cfun.get_function<2>(), state.stash)); + state.replace(2, state.engine.join(state.peek(1), state.peek(0), to_join_fun(param), state.stash)); } using ReduceParams = std::pair<Aggr,std::vector<vespalib::string>>; @@ -216,8 +211,25 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { //------------------------------------------------------------------------- + void make_const_op(const Node &node, const Value &value) { + (void) node; + program.emplace_back(op_load_const, wrap_param<Value>(value)); + } + + void make_map_op(const Node &node, map_fun_t function) { + (void) node; + program.emplace_back(op_tensor_map, to_param(function)); + } + + void make_join_op(const Node &node, join_fun_t function) { + (void) node; + program.emplace_back(op_tensor_join, to_param(function)); + } + + //------------------------------------------------------------------------- + void visit(const Number &node) override { - program.emplace_back(op_load_const, wrap_param<Value>(stash.create<DoubleValue>(node.value()))); + make_const_op(node, stash.create<DoubleValue>(node.value())); } void visit(const Symbol &node) override { if (node.id() >= 0) { // param value @@ -228,16 +240,16 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { } } void visit(const String &node) override { - program.emplace_back(op_load_const, wrap_param<Value>(stash.create<DoubleValue>(node.hash()))); + make_const_op(node, stash.create<DoubleValue>(node.hash())); } void visit(const Array &node) override { - program.emplace_back(op_load_const, wrap_param<Value>(stash.create<DoubleValue>(node.size()))); + make_const_op(node, stash.create<DoubleValue>(node.size())); } - void visit(const Neg &) override { - program.emplace_back(op_unary<operation::Neg>); + void visit(const Neg &node) override { + make_map_op(node, operation::Neg::f); } - void visit(const Not &) override { - program.emplace_back(op_unary<operation::Not>); + void visit(const Not &node) override { + make_map_op(node, operation::Not::f); } void visit(const If &node) override { node.cond().traverse(*this); @@ -256,16 +268,16 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { node.expr().traverse(*this); program.emplace_back(op_evict_let); } - void visit(const Error &) override { - program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>())); + void visit(const Error &node) override { + make_const_op(node, ErrorValue::instance); } void visit(const TensorMap &node) override { const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - program.emplace_back(op_tensor_map, wrap_param<CompiledFunction>(token.get()->get())); + make_map_op(node, token.get()->get().get_function<1>()); } void visit(const TensorJoin &node) override { const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - program.emplace_back(op_tensor_join, wrap_param<CompiledFunction>(token.get()->get())); + make_join_op(node, token.get()->get().get_function<2>()); } void visit(const TensorReduce &node) override { if ((node.aggr() == Aggr::SUM) && is_typed(node) && is_typed_tensor_product_of_params(node.get_child(0))) { @@ -306,50 +318,50 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { spec.add(addr, fun(¶ms[0])); } while (step_labels(params, type)); auto tensor = tensor_engine.create(spec); - program.emplace_back(op_load_const, wrap_param<Value>(stash.create<TensorValue>(std::move(tensor)))); + make_const_op(node, stash.create<TensorValue>(std::move(tensor))); } void visit(const TensorConcat &node) override { vespalib::string &dimension = stash.create<vespalib::string>(node.dimension()); program.emplace_back(op_tensor_concat, wrap_param<vespalib::string>(dimension)); } - void visit(const Add &) override { - program.emplace_back(op_binary<operation::Add>); + void visit(const Add &node) override { + make_join_op(node, operation::Add::f); } - void visit(const Sub &) override { - program.emplace_back(op_binary<operation::Sub>); + void visit(const Sub &node) override { + make_join_op(node, operation::Sub::f); } - void visit(const Mul &) override { - program.emplace_back(op_binary<operation::Mul>); + void visit(const Mul &node) override { + make_join_op(node, operation::Mul::f); } - void visit(const Div &) override { - program.emplace_back(op_binary<operation::Div>); + void visit(const Div &node) override { + make_join_op(node, operation::Div::f); } - void visit(const Mod &) override { - program.emplace_back(op_binary<operation::Mod>); + void visit(const Mod &node) override { + make_join_op(node, operation::Mod::f); } - void visit(const Pow &) override { - program.emplace_back(op_binary<operation::Pow>); + void visit(const Pow &node) override { + make_join_op(node, operation::Pow::f); } - void visit(const Equal &) override { - program.emplace_back(op_binary<operation::Equal>); + void visit(const Equal &node) override { + make_join_op(node, operation::Equal::f); } - void visit(const NotEqual &) override { - program.emplace_back(op_binary<operation::NotEqual>); + void visit(const NotEqual &node) override { + make_join_op(node, operation::NotEqual::f); } - void visit(const Approx &) override { - program.emplace_back(op_binary<operation::Approx>); + void visit(const Approx &node) override { + make_join_op(node, operation::Approx::f); } - void visit(const Less &) override { - program.emplace_back(op_binary<operation::Less>); + void visit(const Less &node) override { + make_join_op(node, operation::Less::f); } - void visit(const LessEqual &) override { - program.emplace_back(op_binary<operation::LessEqual>); + void visit(const LessEqual &node) override { + make_join_op(node, operation::LessEqual::f); } - void visit(const Greater &) override { - program.emplace_back(op_binary<operation::Greater>); + void visit(const Greater &node) override { + make_join_op(node, operation::Greater::f); } - void visit(const GreaterEqual &) override { - program.emplace_back(op_binary<operation::GreaterEqual>); + void visit(const GreaterEqual &node) override { + make_join_op(node, operation::GreaterEqual::f); } void visit(const In &node) override { std::vector<size_t> checks; @@ -371,86 +383,86 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser { } program.emplace_back(op_not_member); } - void visit(const And &) override { - program.emplace_back(op_binary<operation::And>); + void visit(const And &node) override { + make_join_op(node, operation::And::f); } - void visit(const Or &) override { - program.emplace_back(op_binary<operation::Or>); + void visit(const Or &node) override { + make_join_op(node, operation::Or::f); } - void visit(const Cos &) override { - program.emplace_back(op_unary<operation::Cos>); + void visit(const Cos &node) override { + make_map_op(node, operation::Cos::f); } - void visit(const Sin &) override { - program.emplace_back(op_unary<operation::Sin>); + void visit(const Sin &node) override { + make_map_op(node, operation::Sin::f); } - void visit(const Tan &) override { - program.emplace_back(op_unary<operation::Tan>); + void visit(const Tan &node) override { + make_map_op(node, operation::Tan::f); } - void visit(const Cosh &) override { - program.emplace_back(op_unary<operation::Cosh>); + void visit(const Cosh &node) override { + make_map_op(node, operation::Cosh::f); } - void visit(const Sinh &) override { - program.emplace_back(op_unary<operation::Sinh>); + void visit(const Sinh &node) override { + make_map_op(node, operation::Sinh::f); } - void visit(const Tanh &) override { - program.emplace_back(op_unary<operation::Tanh>); + void visit(const Tanh &node) override { + make_map_op(node, operation::Tanh::f); } - void visit(const Acos &) override { - program.emplace_back(op_unary<operation::Acos>); + void visit(const Acos &node) override { + make_map_op(node, operation::Acos::f); } - void visit(const Asin &) override { - program.emplace_back(op_unary<operation::Asin>); + void visit(const Asin &node) override { + make_map_op(node, operation::Asin::f); } - void visit(const Atan &) override { - program.emplace_back(op_unary<operation::Atan>); + void visit(const Atan &node) override { + make_map_op(node, operation::Atan::f); } - void visit(const Exp &) override { - program.emplace_back(op_unary<operation::Exp>); + void visit(const Exp &node) override { + make_map_op(node, operation::Exp::f); } - void visit(const Log10 &) override { - program.emplace_back(op_unary<operation::Log10>); + void visit(const Log10 &node) override { + make_map_op(node, operation::Log10::f); } - void visit(const Log &) override { - program.emplace_back(op_unary<operation::Log>); + void visit(const Log &node) override { + make_map_op(node, operation::Log::f); } - void visit(const Sqrt &) override { - program.emplace_back(op_unary<operation::Sqrt>); + void visit(const Sqrt &node) override { + make_map_op(node, operation::Sqrt::f); } - void visit(const Ceil &) override { - program.emplace_back(op_unary<operation::Ceil>); + void visit(const Ceil &node) override { + make_map_op(node, operation::Ceil::f); } - void visit(const Fabs &) override { - program.emplace_back(op_unary<operation::Fabs>); + void visit(const Fabs &node) override { + make_map_op(node, operation::Fabs::f); } - void visit(const Floor &) override { - program.emplace_back(op_unary<operation::Floor>); + void visit(const Floor &node) override { + make_map_op(node, operation::Floor::f); } - void visit(const Atan2 &) override { - program.emplace_back(op_binary<operation::Atan2>); + void visit(const Atan2 &node) override { + make_join_op(node, operation::Atan2::f); } - void visit(const Ldexp &) override { - program.emplace_back(op_binary<operation::Ldexp>); + void visit(const Ldexp &node) override { + make_join_op(node, operation::Ldexp::f); } - void visit(const Pow2 &) override { - program.emplace_back(op_binary<operation::Pow>); + void visit(const Pow2 &node) override { + make_join_op(node, operation::Pow::f); } - void visit(const Fmod &) override { - program.emplace_back(op_binary<operation::Mod>); + void visit(const Fmod &node) override { + make_join_op(node, operation::Mod::f); } - void visit(const Min &) override { - program.emplace_back(op_binary<operation::Min>); + void visit(const Min &node) override { + make_join_op(node, operation::Min::f); } - void visit(const Max &) override { - program.emplace_back(op_binary<operation::Max>); + void visit(const Max &node) override { + make_join_op(node, operation::Max::f); } - void visit(const IsNan &) override { - program.emplace_back(op_unary<operation::IsNan>); + void visit(const IsNan &node) override { + make_map_op(node, operation::IsNan::f); } - void visit(const Relu &) override { - program.emplace_back(op_unary<operation::Relu>); + void visit(const Relu &node) override { + make_map_op(node, operation::Relu::f); } - void visit(const Sigmoid &) override { - program.emplace_back(op_unary<operation::Sigmoid>); + void visit(const Sigmoid &node) override { + make_map_op(node, operation::Sigmoid::f); } //------------------------------------------------------------------------- |