diff options
Diffstat (limited to 'eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp')
-rw-r--r-- | eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index ba73a578f6f..bcb2e29472c 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -3,8 +3,10 @@ #include <vespa/eval/eval/fast_value.h> #include <vespa/eval/eval/function.h> #include <vespa/eval/eval/tensor_spec.h> +#include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/eval/operation.h> #include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/eval/compile_tensor_function.h> #include <vespa/eval/eval/test/eval_spec.h> #include <vespa/eval/eval/basic_nodes.h> #include <vespa/eval/eval/simple_value.h> @@ -168,4 +170,23 @@ TEST("require that functions with non-compilable lambdas cannot be interpreted") //----------------------------------------------------------------------------- +TEST("require that compilation meta-data can be collected") { + Stash stash; + const auto &x2 = tensor_function::inject(ValueType::from_spec("tensor(x[2])"), 0, stash); + const auto &x3 = tensor_function::inject(ValueType::from_spec("tensor(x[3])"), 1, stash); + const auto &concat_x5 = tensor_function::concat(x3, x2, "x", stash); + const auto &x5 = tensor_function::inject(ValueType::from_spec("tensor(x[5])"), 2, stash); + const auto &mapped_x5 = tensor_function::map(x5, operation::Relu::f, stash); + const auto &flag = tensor_function::inject(ValueType::from_spec("double"), 0, stash); + const auto &root = tensor_function::if_node(flag, concat_x5, mapped_x5, stash); + CTFMetaData meta; + InterpretedFunction ifun(FastValueBuilderFactory::get(), root, meta); + fprintf(stderr, "compilation meta-data:\n"); + for (const auto &step: meta.steps) { + fprintf(stderr, " %s -> %s\n", step.class_name.c_str(), step.symbol_name.c_str()); + } +} + +//----------------------------------------------------------------------------- + TEST_MAIN() { TEST_RUN_ALL(); } |