diff options
Diffstat (limited to 'eval/src/tests/instruction/generic_peek/generic_peek_test.cpp')
-rw-r--r-- | eval/src/tests/instruction/generic_peek/generic_peek_test.cpp | 34 |
1 files changed, 34 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp index 7dc2d756e6b..3874b254ad8 100644 --- a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp +++ b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp @@ -2,6 +2,7 @@ #include <vespa/eval/eval/simple_value.h> #include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/eval/value_codec.h> #include <vespa/eval/instruction/generic_peek.h> #include <vespa/eval/eval/interpreted_function.h> @@ -111,6 +112,37 @@ TensorSpec perform_generic_peek(const TensorSpec &a, const ValueType &result_typ return spec_from_value(single.eval(my_stack)); } +TensorSpec tensor_function_peek(const TensorSpec &a, const ValueType &result_type, + PeekSpec spec, const ValueBuilderFactory &factory) +{ + Stash stash; + auto param = value_from_spec(a, factory); + EXPECT_FALSE(param->type().is_error()); + EXPECT_FALSE(result_type.is_error()); + std::vector<Value::CREF> my_stack; + my_stack.push_back(*param); + const auto &func_double = tensor_function::inject(ValueType::double_type(), 1, stash); + std::map<vespalib::string, std::variant<TensorSpec::Label, TensorFunction::CREF>> func_spec; + for (auto & [dim_name, label_or_child] : spec) { + if (std::holds_alternative<size_t>(label_or_child)) { + // here, label_or_child is a size_t specifying the value + // this child should produce (but cast to signed first, + // to allow negative values) + ssize_t child_value = std::get<size_t>(label_or_child); + my_stack.push_back(stash.create<DoubleValue>(double(child_value))); + func_spec.emplace(dim_name, func_double); + } else { + auto label = std::get<TensorSpec::Label>(label_or_child); + func_spec.emplace(dim_name, label); + } + } + const auto &func_param = tensor_function::inject(param->type(), 0, stash); + const auto &peek_node = tensor_function::peek(func_param, func_spec, stash); + auto my_op = peek_node.compile_self(factory, stash); + InterpretedFunction::EvalSingle single(factory, my_op); + return spec_from_value(single.eval(my_stack)); +} + vespalib::string to_str(const PeekSpec &spec) { vespalib::asciistream os; os << "{ "; @@ -149,6 +181,8 @@ void verify_peek_equal(const TensorSpec &input, expect.to_string().c_str())); auto actual = perform_generic_peek(input, result_type, spec, factory); EXPECT_EQ(actual, expect); + auto from_func = tensor_function_peek(input, result_type, spec, factory); + EXPECT_EQ(from_func, expect); } void fill_dims_and_check(const TensorSpec &input, |