summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/instruction/generic_peek/generic_peek_test.cpp')
-rw-r--r--eval/src/tests/instruction/generic_peek/generic_peek_test.cpp34
1 files changed, 34 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp
index 7dc2d756e6b..3874b254ad8 100644
--- a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp
+++ b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp
@@ -2,6 +2,7 @@
#include <vespa/eval/eval/simple_value.h>
#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/instruction/generic_peek.h>
#include <vespa/eval/eval/interpreted_function.h>
@@ -111,6 +112,37 @@ TensorSpec perform_generic_peek(const TensorSpec &a, const ValueType &result_typ
return spec_from_value(single.eval(my_stack));
}
+TensorSpec tensor_function_peek(const TensorSpec &a, const ValueType &result_type,
+ PeekSpec spec, const ValueBuilderFactory &factory)
+{
+ Stash stash;
+ auto param = value_from_spec(a, factory);
+ EXPECT_FALSE(param->type().is_error());
+ EXPECT_FALSE(result_type.is_error());
+ std::vector<Value::CREF> my_stack;
+ my_stack.push_back(*param);
+ const auto &func_double = tensor_function::inject(ValueType::double_type(), 1, stash);
+ std::map<vespalib::string, std::variant<TensorSpec::Label, TensorFunction::CREF>> func_spec;
+ for (auto & [dim_name, label_or_child] : spec) {
+ if (std::holds_alternative<size_t>(label_or_child)) {
+ // here, label_or_child is a size_t specifying the value
+ // this child should produce (but cast to signed first,
+ // to allow negative values)
+ ssize_t child_value = std::get<size_t>(label_or_child);
+ my_stack.push_back(stash.create<DoubleValue>(double(child_value)));
+ func_spec.emplace(dim_name, func_double);
+ } else {
+ auto label = std::get<TensorSpec::Label>(label_or_child);
+ func_spec.emplace(dim_name, label);
+ }
+ }
+ const auto &func_param = tensor_function::inject(param->type(), 0, stash);
+ const auto &peek_node = tensor_function::peek(func_param, func_spec, stash);
+ auto my_op = peek_node.compile_self(factory, stash);
+ InterpretedFunction::EvalSingle single(factory, my_op);
+ return spec_from_value(single.eval(my_stack));
+}
+
vespalib::string to_str(const PeekSpec &spec) {
vespalib::asciistream os;
os << "{ ";
@@ -149,6 +181,8 @@ void verify_peek_equal(const TensorSpec &input,
expect.to_string().c_str()));
auto actual = perform_generic_peek(input, result_type, spec, factory);
EXPECT_EQ(actual, expect);
+ auto from_func = tensor_function_peek(input, result_type, spec, factory);
+ EXPECT_EQ(from_func, expect);
}
void fill_dims_and_check(const TensorSpec &input,