aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp')
-rw-r--r--eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp20
1 files changed, 11 insertions, 9 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
index 4ba715ea192..ac7e0f6d126 100644
--- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
+++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp
@@ -199,16 +199,18 @@ TEST("require that functions with non-interpretable complex lambdas cannot be in
//-----------------------------------------------------------------------------
TEST("require that compilation meta-data can be collected") {
- Stash stash;
- const auto &x2 = tensor_function::inject(ValueType::from_spec("tensor(x[2])"), 0, stash);
- const auto &x3 = tensor_function::inject(ValueType::from_spec("tensor(x[3])"), 1, stash);
- const auto &concat_x5 = tensor_function::concat(x3, x2, "x", stash);
- const auto &x5 = tensor_function::inject(ValueType::from_spec("tensor(x[5])"), 2, stash);
- const auto &mapped_x5 = tensor_function::map(x5, operation::Relu::f, stash);
- const auto &flag = tensor_function::inject(ValueType::from_spec("double"), 0, stash);
- const auto &root = tensor_function::if_node(flag, concat_x5, mapped_x5, stash);
+ auto fun = Function::parse("if(flag,concat(x2,x3,x),map(x5,f(x)(relu(x))))");
+ fprintf(stderr, "%s\n", fun->dump_as_lambda().c_str());
+ ASSERT_TRUE(fun->dump_as_lambda().starts_with("f(flag,x2,x3,x5)"));
+ std::vector<ValueType> param_types({ValueType::from_spec("double"),
+ ValueType::from_spec("tensor(x[2])"),
+ ValueType::from_spec("tensor(x[3])"),
+ ValueType::from_spec("tensor(x[5])")});
+ NodeTypes types(*fun, param_types);
+ ASSERT_FALSE(types.get_type(fun->root()).is_error());
+ ASSERT_TRUE(types.errors().empty());
CTFMetaData meta;
- InterpretedFunction ifun(FastValueBuilderFactory::get(), root, &meta);
+ auto ifun = InterpretedFunction::opts(FastValueBuilderFactory::get()).meta(&meta).make(fun->root(), types);
fprintf(stderr, "compilation meta-data:\n");
for (const auto &step: meta.steps) {
fprintf(stderr, " %s -> %s\n", step.class_name.c_str(), step.symbol_name.c_str());