summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-10-26 14:54:00 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-10-26 14:54:00 +0000
commit51a6686b3b0dd58ccf352aac41feacd797efddc1 (patch)
treec9cbb4ef82f4245291d8ff60a8a597ec2189b765 /eval
parent56947142b32ed745072782da08bdadf7c4be0c52 (diff)
normalize to fewer operation primitives
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp4
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp228
2 files changed, 124 insertions, 108 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index fb31601bf0f..0e548a3b82e 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -129,6 +129,10 @@ TEST("require that interpreted function instructions have expected size") {
EXPECT_EQUAL(sizeof(InterpretedFunction::Instruction), 16u);
}
+TEST("require that function pointers can be passed as instruction parameters") {
+ EXPECT_EQUAL(sizeof(&operation::Add::f), sizeof(uint64_t));
+}
+
TEST("require that basic addition works") {
Function function = Function::parse("a+10");
InterpretedFunction interpreted(SimpleTensorEngine::ref(), function, NodeTypes());
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index 41339791f47..09cc2d57e80 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -18,6 +18,8 @@ namespace {
using namespace nodes;
using State = InterpretedFunction::State;
using Instruction = InterpretedFunction::Instruction;
+using map_fun_t = double (*)(double);
+using join_fun_t = double (*)(double, double);
//-----------------------------------------------------------------------------
@@ -32,6 +34,13 @@ const T &unwrap_param(uint64_t param) { return *((const T *)param); }
//-----------------------------------------------------------------------------
+uint64_t to_param(map_fun_t value) { return (uint64_t)value; }
+uint64_t to_param(join_fun_t value) { return (uint64_t)value; }
+map_fun_t to_map_fun(uint64_t param) { return (map_fun_t)param; }
+join_fun_t to_join_fun(uint64_t param) { return (join_fun_t)param; }
+
+//-----------------------------------------------------------------------------
+
void op_load_const(State &state, uint64_t param) {
state.stack.push_back(unwrap_param<Value>(param));
}
@@ -46,18 +55,6 @@ void op_load_let(State &state, uint64_t param) {
//-----------------------------------------------------------------------------
-template <typename OP1>
-void op_unary(State &state, uint64_t) {
- state.replace(1, state.engine.map(state.peek(0), OP1::f, state.stash));
-}
-
-template <typename OP2>
-void op_binary(State &state, uint64_t) {
- state.replace(2, state.engine.join(state.peek(1), state.peek(0), OP2::f, state.stash));
-}
-
-//-----------------------------------------------------------------------------
-
void op_skip(State &state, uint64_t param) {
state.program_offset += param;
}
@@ -102,13 +99,11 @@ void op_not_member(State &state, uint64_t) {
//-----------------------------------------------------------------------------
void op_tensor_map(State &state, uint64_t param) {
- const CompiledFunction &cfun = unwrap_param<CompiledFunction>(param);
- state.replace(1, state.engine.map(state.peek(0), cfun.get_function<1>(), state.stash));
+ state.replace(1, state.engine.map(state.peek(0), to_map_fun(param), state.stash));
}
void op_tensor_join(State &state, uint64_t param) {
- const CompiledFunction &cfun = unwrap_param<CompiledFunction>(param);
- state.replace(2, state.engine.join(state.peek(1), state.peek(0), cfun.get_function<2>(), state.stash));
+ state.replace(2, state.engine.join(state.peek(1), state.peek(0), to_join_fun(param), state.stash));
}
using ReduceParams = std::pair<Aggr,std::vector<vespalib::string>>;
@@ -216,8 +211,25 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
//-------------------------------------------------------------------------
+ void make_const_op(const Node &node, const Value &value) {
+ (void) node;
+ program.emplace_back(op_load_const, wrap_param<Value>(value));
+ }
+
+ void make_map_op(const Node &node, map_fun_t function) {
+ (void) node;
+ program.emplace_back(op_tensor_map, to_param(function));
+ }
+
+ void make_join_op(const Node &node, join_fun_t function) {
+ (void) node;
+ program.emplace_back(op_tensor_join, to_param(function));
+ }
+
+ //-------------------------------------------------------------------------
+
void visit(const Number &node) override {
- program.emplace_back(op_load_const, wrap_param<Value>(stash.create<DoubleValue>(node.value())));
+ make_const_op(node, stash.create<DoubleValue>(node.value()));
}
void visit(const Symbol &node) override {
if (node.id() >= 0) { // param value
@@ -228,16 +240,16 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
}
}
void visit(const String &node) override {
- program.emplace_back(op_load_const, wrap_param<Value>(stash.create<DoubleValue>(node.hash())));
+ make_const_op(node, stash.create<DoubleValue>(node.hash()));
}
void visit(const Array &node) override {
- program.emplace_back(op_load_const, wrap_param<Value>(stash.create<DoubleValue>(node.size())));
+ make_const_op(node, stash.create<DoubleValue>(node.size()));
}
- void visit(const Neg &) override {
- program.emplace_back(op_unary<operation::Neg>);
+ void visit(const Neg &node) override {
+ make_map_op(node, operation::Neg::f);
}
- void visit(const Not &) override {
- program.emplace_back(op_unary<operation::Not>);
+ void visit(const Not &node) override {
+ make_map_op(node, operation::Not::f);
}
void visit(const If &node) override {
node.cond().traverse(*this);
@@ -256,16 +268,16 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
node.expr().traverse(*this);
program.emplace_back(op_evict_let);
}
- void visit(const Error &) override {
- program.emplace_back(op_load_const, wrap_param<Value>(stash.create<ErrorValue>()));
+ void visit(const Error &node) override {
+ make_const_op(node, ErrorValue::instance);
}
void visit(const TensorMap &node) override {
const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
- program.emplace_back(op_tensor_map, wrap_param<CompiledFunction>(token.get()->get()));
+ make_map_op(node, token.get()->get().get_function<1>());
}
void visit(const TensorJoin &node) override {
const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE));
- program.emplace_back(op_tensor_join, wrap_param<CompiledFunction>(token.get()->get()));
+ make_join_op(node, token.get()->get().get_function<2>());
}
void visit(const TensorReduce &node) override {
if ((node.aggr() == Aggr::SUM) && is_typed(node) && is_typed_tensor_product_of_params(node.get_child(0))) {
@@ -306,50 +318,50 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
spec.add(addr, fun(&params[0]));
} while (step_labels(params, type));
auto tensor = tensor_engine.create(spec);
- program.emplace_back(op_load_const, wrap_param<Value>(stash.create<TensorValue>(std::move(tensor))));
+ make_const_op(node, stash.create<TensorValue>(std::move(tensor)));
}
void visit(const TensorConcat &node) override {
vespalib::string &dimension = stash.create<vespalib::string>(node.dimension());
program.emplace_back(op_tensor_concat, wrap_param<vespalib::string>(dimension));
}
- void visit(const Add &) override {
- program.emplace_back(op_binary<operation::Add>);
+ void visit(const Add &node) override {
+ make_join_op(node, operation::Add::f);
}
- void visit(const Sub &) override {
- program.emplace_back(op_binary<operation::Sub>);
+ void visit(const Sub &node) override {
+ make_join_op(node, operation::Sub::f);
}
- void visit(const Mul &) override {
- program.emplace_back(op_binary<operation::Mul>);
+ void visit(const Mul &node) override {
+ make_join_op(node, operation::Mul::f);
}
- void visit(const Div &) override {
- program.emplace_back(op_binary<operation::Div>);
+ void visit(const Div &node) override {
+ make_join_op(node, operation::Div::f);
}
- void visit(const Mod &) override {
- program.emplace_back(op_binary<operation::Mod>);
+ void visit(const Mod &node) override {
+ make_join_op(node, operation::Mod::f);
}
- void visit(const Pow &) override {
- program.emplace_back(op_binary<operation::Pow>);
+ void visit(const Pow &node) override {
+ make_join_op(node, operation::Pow::f);
}
- void visit(const Equal &) override {
- program.emplace_back(op_binary<operation::Equal>);
+ void visit(const Equal &node) override {
+ make_join_op(node, operation::Equal::f);
}
- void visit(const NotEqual &) override {
- program.emplace_back(op_binary<operation::NotEqual>);
+ void visit(const NotEqual &node) override {
+ make_join_op(node, operation::NotEqual::f);
}
- void visit(const Approx &) override {
- program.emplace_back(op_binary<operation::Approx>);
+ void visit(const Approx &node) override {
+ make_join_op(node, operation::Approx::f);
}
- void visit(const Less &) override {
- program.emplace_back(op_binary<operation::Less>);
+ void visit(const Less &node) override {
+ make_join_op(node, operation::Less::f);
}
- void visit(const LessEqual &) override {
- program.emplace_back(op_binary<operation::LessEqual>);
+ void visit(const LessEqual &node) override {
+ make_join_op(node, operation::LessEqual::f);
}
- void visit(const Greater &) override {
- program.emplace_back(op_binary<operation::Greater>);
+ void visit(const Greater &node) override {
+ make_join_op(node, operation::Greater::f);
}
- void visit(const GreaterEqual &) override {
- program.emplace_back(op_binary<operation::GreaterEqual>);
+ void visit(const GreaterEqual &node) override {
+ make_join_op(node, operation::GreaterEqual::f);
}
void visit(const In &node) override {
std::vector<size_t> checks;
@@ -371,86 +383,86 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
}
program.emplace_back(op_not_member);
}
- void visit(const And &) override {
- program.emplace_back(op_binary<operation::And>);
+ void visit(const And &node) override {
+ make_join_op(node, operation::And::f);
}
- void visit(const Or &) override {
- program.emplace_back(op_binary<operation::Or>);
+ void visit(const Or &node) override {
+ make_join_op(node, operation::Or::f);
}
- void visit(const Cos &) override {
- program.emplace_back(op_unary<operation::Cos>);
+ void visit(const Cos &node) override {
+ make_map_op(node, operation::Cos::f);
}
- void visit(const Sin &) override {
- program.emplace_back(op_unary<operation::Sin>);
+ void visit(const Sin &node) override {
+ make_map_op(node, operation::Sin::f);
}
- void visit(const Tan &) override {
- program.emplace_back(op_unary<operation::Tan>);
+ void visit(const Tan &node) override {
+ make_map_op(node, operation::Tan::f);
}
- void visit(const Cosh &) override {
- program.emplace_back(op_unary<operation::Cosh>);
+ void visit(const Cosh &node) override {
+ make_map_op(node, operation::Cosh::f);
}
- void visit(const Sinh &) override {
- program.emplace_back(op_unary<operation::Sinh>);
+ void visit(const Sinh &node) override {
+ make_map_op(node, operation::Sinh::f);
}
- void visit(const Tanh &) override {
- program.emplace_back(op_unary<operation::Tanh>);
+ void visit(const Tanh &node) override {
+ make_map_op(node, operation::Tanh::f);
}
- void visit(const Acos &) override {
- program.emplace_back(op_unary<operation::Acos>);
+ void visit(const Acos &node) override {
+ make_map_op(node, operation::Acos::f);
}
- void visit(const Asin &) override {
- program.emplace_back(op_unary<operation::Asin>);
+ void visit(const Asin &node) override {
+ make_map_op(node, operation::Asin::f);
}
- void visit(const Atan &) override {
- program.emplace_back(op_unary<operation::Atan>);
+ void visit(const Atan &node) override {
+ make_map_op(node, operation::Atan::f);
}
- void visit(const Exp &) override {
- program.emplace_back(op_unary<operation::Exp>);
+ void visit(const Exp &node) override {
+ make_map_op(node, operation::Exp::f);
}
- void visit(const Log10 &) override {
- program.emplace_back(op_unary<operation::Log10>);
+ void visit(const Log10 &node) override {
+ make_map_op(node, operation::Log10::f);
}
- void visit(const Log &) override {
- program.emplace_back(op_unary<operation::Log>);
+ void visit(const Log &node) override {
+ make_map_op(node, operation::Log::f);
}
- void visit(const Sqrt &) override {
- program.emplace_back(op_unary<operation::Sqrt>);
+ void visit(const Sqrt &node) override {
+ make_map_op(node, operation::Sqrt::f);
}
- void visit(const Ceil &) override {
- program.emplace_back(op_unary<operation::Ceil>);
+ void visit(const Ceil &node) override {
+ make_map_op(node, operation::Ceil::f);
}
- void visit(const Fabs &) override {
- program.emplace_back(op_unary<operation::Fabs>);
+ void visit(const Fabs &node) override {
+ make_map_op(node, operation::Fabs::f);
}
- void visit(const Floor &) override {
- program.emplace_back(op_unary<operation::Floor>);
+ void visit(const Floor &node) override {
+ make_map_op(node, operation::Floor::f);
}
- void visit(const Atan2 &) override {
- program.emplace_back(op_binary<operation::Atan2>);
+ void visit(const Atan2 &node) override {
+ make_join_op(node, operation::Atan2::f);
}
- void visit(const Ldexp &) override {
- program.emplace_back(op_binary<operation::Ldexp>);
+ void visit(const Ldexp &node) override {
+ make_join_op(node, operation::Ldexp::f);
}
- void visit(const Pow2 &) override {
- program.emplace_back(op_binary<operation::Pow>);
+ void visit(const Pow2 &node) override {
+ make_join_op(node, operation::Pow::f);
}
- void visit(const Fmod &) override {
- program.emplace_back(op_binary<operation::Mod>);
+ void visit(const Fmod &node) override {
+ make_join_op(node, operation::Mod::f);
}
- void visit(const Min &) override {
- program.emplace_back(op_binary<operation::Min>);
+ void visit(const Min &node) override {
+ make_join_op(node, operation::Min::f);
}
- void visit(const Max &) override {
- program.emplace_back(op_binary<operation::Max>);
+ void visit(const Max &node) override {
+ make_join_op(node, operation::Max::f);
}
- void visit(const IsNan &) override {
- program.emplace_back(op_unary<operation::IsNan>);
+ void visit(const IsNan &node) override {
+ make_map_op(node, operation::IsNan::f);
}
- void visit(const Relu &) override {
- program.emplace_back(op_unary<operation::Relu>);
+ void visit(const Relu &node) override {
+ make_map_op(node, operation::Relu::f);
}
- void visit(const Sigmoid &) override {
- program.emplace_back(op_unary<operation::Sigmoid>);
+ void visit(const Sigmoid &node) override {
+ make_map_op(node, operation::Sigmoid::f);
}
//-------------------------------------------------------------------------