aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/compile_tensor_function.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/vespa/eval/eval/compile_tensor_function.cpp')
-rw-r--r--eval/src/vespa/eval/eval/compile_tensor_function.cpp40
1 files changed, 26 insertions, 14 deletions
diff --git a/eval/src/vespa/eval/eval/compile_tensor_function.cpp b/eval/src/vespa/eval/eval/compile_tensor_function.cpp
index 425910d6249..02a38f7ad42 100644
--- a/eval/src/vespa/eval/eval/compile_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/compile_tensor_function.cpp
@@ -35,18 +35,23 @@ struct Frame {
};
struct ProgramCompiler {
- const ValueBuilderFactory &factory;
- Stash &stash;
+ CTFContext ctx;
std::vector<Frame> stack;
std::vector<Instruction> prog;
- CTFMetaData *meta;
- ProgramCompiler(const ValueBuilderFactory &factory_in, Stash &stash_in, CTFMetaData *meta_in)
- : factory(factory_in), stash(stash_in), stack(), prog(), meta(meta_in) {}
+ ProgramCompiler(const CTFContext &ctx_in)
+ : ctx(ctx_in), stack(), prog() {}
~ProgramCompiler();
+ void add_meta(const TensorFunction &node, const Instruction &instr, CTFMetaData *nested) {
+ ctx.meta->steps.emplace_back(getClassName(node), instr.resolve_symbol());
+ if (nested) {
+ ctx.meta->steps.back().nested = nested->extract();
+ }
+ }
+
void maybe_add_meta(const TensorFunction &node, const Instruction &instr) {
- if (meta != nullptr) {
- meta->steps.emplace_back(getClassName(node), instr.resolve_symbol());
+ if (ctx.meta != nullptr) {
+ add_meta(node, instr, nullptr);
}
}
@@ -56,11 +61,11 @@ struct ProgramCompiler {
void open(const TensorFunction &node) {
if (auto if_node = as<tensor_function::If>(node)) {
- append(compile_tensor_function(factory, if_node->cond(), stash, meta));
+ append(compile_tensor_function(if_node->cond(), ctx));
maybe_add_meta(node, Instruction(op_skip_if_false));
- auto true_prog = compile_tensor_function(factory, if_node->true_child(), stash, meta);
+ auto true_prog = compile_tensor_function(if_node->true_child(), ctx);
maybe_add_meta(node, Instruction(op_skip));
- auto false_prog = compile_tensor_function(factory, if_node->false_child(), stash, meta);
+ auto false_prog = compile_tensor_function(if_node->false_child(), ctx);
true_prog.emplace_back(op_skip, false_prog.size());
prog.emplace_back(op_skip_if_false, true_prog.size());
append(true_prog);
@@ -71,8 +76,15 @@ struct ProgramCompiler {
}
void close(const TensorFunction &node) {
- prog.push_back(node.compile_self(factory, stash));
- maybe_add_meta(node, prog.back());
+ if (ctx.meta == nullptr) {
+ prog.push_back(node.compile_self(ctx));
+ } else {
+ CTFMetaData sub_meta;
+ CTFContext sub_context = ctx;
+ sub_context.meta = &sub_meta;
+ prog.push_back(node.compile_self(sub_context));
+ add_meta(node, prog.back(), &sub_meta);
+ }
}
std::vector<Instruction> compile(const TensorFunction &function) {
@@ -94,8 +106,8 @@ ProgramCompiler::~ProgramCompiler() = default;
CTFMetaData::~CTFMetaData() = default;
-std::vector<Instruction> compile_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash, CTFMetaData *meta) {
- ProgramCompiler compiler(factory, stash, meta);
+std::vector<Instruction> compile_tensor_function(const TensorFunction &function, const CTFContext &ctx) {
+ ProgramCompiler compiler(ctx);
return compiler.compile(function);
}