diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-11-08 13:55:05 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-11-13 09:44:26 +0000 |
commit | 8753ece3f5f374a5cffd360f2009ef1de9ec9d35 (patch) | |
tree | 82340fe27f9df59a1b6650f9f348f7c6d4a5a106 /eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp | |
parent | 8435f3c414ccbc2dae11e69f7592f04f279144c5 (diff) |
enable nested ctf meta datahavardpe/enable-nested-ctf-meta-data
Diffstat (limited to 'eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp')
-rw-r--r-- | eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 4ba715ea192..ac7e0f6d126 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -199,16 +199,18 @@ TEST("require that functions with non-interpretable complex lambdas cannot be in //----------------------------------------------------------------------------- TEST("require that compilation meta-data can be collected") { - Stash stash; - const auto &x2 = tensor_function::inject(ValueType::from_spec("tensor(x[2])"), 0, stash); - const auto &x3 = tensor_function::inject(ValueType::from_spec("tensor(x[3])"), 1, stash); - const auto &concat_x5 = tensor_function::concat(x3, x2, "x", stash); - const auto &x5 = tensor_function::inject(ValueType::from_spec("tensor(x[5])"), 2, stash); - const auto &mapped_x5 = tensor_function::map(x5, operation::Relu::f, stash); - const auto &flag = tensor_function::inject(ValueType::from_spec("double"), 0, stash); - const auto &root = tensor_function::if_node(flag, concat_x5, mapped_x5, stash); + auto fun = Function::parse("if(flag,concat(x2,x3,x),map(x5,f(x)(relu(x))))"); + fprintf(stderr, "%s\n", fun->dump_as_lambda().c_str()); + ASSERT_TRUE(fun->dump_as_lambda().starts_with("f(flag,x2,x3,x5)")); + std::vector<ValueType> param_types({ValueType::from_spec("double"), + ValueType::from_spec("tensor(x[2])"), + ValueType::from_spec("tensor(x[3])"), + ValueType::from_spec("tensor(x[5])")}); + NodeTypes types(*fun, param_types); + ASSERT_FALSE(types.get_type(fun->root()).is_error()); + ASSERT_TRUE(types.errors().empty()); CTFMetaData meta; - InterpretedFunction ifun(FastValueBuilderFactory::get(), root, &meta); + auto ifun = InterpretedFunction::opts(FastValueBuilderFactory::get()).meta(&meta).make(fun->root(), types); fprintf(stderr, "compilation meta-data:\n"); for (const auto &step: meta.steps) { fprintf(stderr, " %s -> %s\n", step.class_name.c_str(), step.symbol_name.c_str()); |