diff options
Diffstat (limited to 'eval/src/vespa/eval/eval/compile_tensor_function.cpp')
-rw-r--r-- | eval/src/vespa/eval/eval/compile_tensor_function.cpp | 40 |
1 files changed, 26 insertions, 14 deletions
diff --git a/eval/src/vespa/eval/eval/compile_tensor_function.cpp b/eval/src/vespa/eval/eval/compile_tensor_function.cpp index 425910d6249..02a38f7ad42 100644 --- a/eval/src/vespa/eval/eval/compile_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/compile_tensor_function.cpp @@ -35,18 +35,23 @@ struct Frame { }; struct ProgramCompiler { - const ValueBuilderFactory &factory; - Stash &stash; + CTFContext ctx; std::vector<Frame> stack; std::vector<Instruction> prog; - CTFMetaData *meta; - ProgramCompiler(const ValueBuilderFactory &factory_in, Stash &stash_in, CTFMetaData *meta_in) - : factory(factory_in), stash(stash_in), stack(), prog(), meta(meta_in) {} + ProgramCompiler(const CTFContext &ctx_in) + : ctx(ctx_in), stack(), prog() {} ~ProgramCompiler(); + void add_meta(const TensorFunction &node, const Instruction &instr, CTFMetaData *nested) { + ctx.meta->steps.emplace_back(getClassName(node), instr.resolve_symbol()); + if (nested) { + ctx.meta->steps.back().nested = nested->extract(); + } + } + void maybe_add_meta(const TensorFunction &node, const Instruction &instr) { - if (meta != nullptr) { - meta->steps.emplace_back(getClassName(node), instr.resolve_symbol()); + if (ctx.meta != nullptr) { + add_meta(node, instr, nullptr); } } @@ -56,11 +61,11 @@ struct ProgramCompiler { void open(const TensorFunction &node) { if (auto if_node = as<tensor_function::If>(node)) { - append(compile_tensor_function(factory, if_node->cond(), stash, meta)); + append(compile_tensor_function(if_node->cond(), ctx)); maybe_add_meta(node, Instruction(op_skip_if_false)); - auto true_prog = compile_tensor_function(factory, if_node->true_child(), stash, meta); + auto true_prog = compile_tensor_function(if_node->true_child(), ctx); maybe_add_meta(node, Instruction(op_skip)); - auto false_prog = compile_tensor_function(factory, if_node->false_child(), stash, meta); + auto false_prog = compile_tensor_function(if_node->false_child(), ctx); true_prog.emplace_back(op_skip, false_prog.size()); prog.emplace_back(op_skip_if_false, true_prog.size()); append(true_prog); @@ -71,8 +76,15 @@ struct ProgramCompiler { } void close(const TensorFunction &node) { - prog.push_back(node.compile_self(factory, stash)); - maybe_add_meta(node, prog.back()); + if (ctx.meta == nullptr) { + prog.push_back(node.compile_self(ctx)); + } else { + CTFMetaData sub_meta; + CTFContext sub_context = ctx; + sub_context.meta = &sub_meta; + prog.push_back(node.compile_self(sub_context)); + add_meta(node, prog.back(), &sub_meta); + } } std::vector<Instruction> compile(const TensorFunction &function) { @@ -94,8 +106,8 @@ ProgramCompiler::~ProgramCompiler() = default; CTFMetaData::~CTFMetaData() = default; -std::vector<Instruction> compile_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash, CTFMetaData *meta) { - ProgramCompiler compiler(factory, stash, meta); +std::vector<Instruction> compile_tensor_function(const TensorFunction &function, const CTFContext &ctx) { + ProgramCompiler compiler(ctx); return compiler.compile(function); } |