aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp')
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp21
1 files changed, 21 insertions, 0 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index ba73a578f6f..bcb2e29472c 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -3,8 +3,10 @@
#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/function.h>
#include <vespa/eval/eval/tensor_spec.h>
+#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/eval/compile_tensor_function.h>
#include <vespa/eval/eval/test/eval_spec.h>
#include <vespa/eval/eval/basic_nodes.h>
#include <vespa/eval/eval/simple_value.h>
@@ -168,4 +170,23 @@ TEST("require that functions with non-compilable lambdas cannot be interpreted")
//-----------------------------------------------------------------------------
+TEST("require that compilation meta-data can be collected") {
+ Stash stash;
+ const auto &x2 = tensor_function::inject(ValueType::from_spec("tensor(x[2])"), 0, stash);
+ const auto &x3 = tensor_function::inject(ValueType::from_spec("tensor(x[3])"), 1, stash);
+ const auto &concat_x5 = tensor_function::concat(x3, x2, "x", stash);
+ const auto &x5 = tensor_function::inject(ValueType::from_spec("tensor(x[5])"), 2, stash);
+ const auto &mapped_x5 = tensor_function::map(x5, operation::Relu::f, stash);
+ const auto &flag = tensor_function::inject(ValueType::from_spec("double"), 0, stash);
+ const auto &root = tensor_function::if_node(flag, concat_x5, mapped_x5, stash);
+ CTFMetaData meta;
+ InterpretedFunction ifun(FastValueBuilderFactory::get(), root, meta);
+ fprintf(stderr, "compilation meta-data:\n");
+ for (const auto &step: meta.steps) {
+ fprintf(stderr, " %s -> %s\n", step.class_name.c_str(), step.symbol_name.c_str());
+ }
+}
+
+//-----------------------------------------------------------------------------
+
TEST_MAIN() { TEST_RUN_ALL(); }