aboutsummaryrefslogtreecommitdiffstats
path: root/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp')
-rw-r--r--streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp60
1 files changed, 60 insertions, 0 deletions
diff --git a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
index 059d3c9f597..362aaeda938 100644
--- a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
+++ b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp
@@ -11,6 +11,7 @@
#include <vespa/eval/eval/value_codec.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/vespalib/util/stringfmt.h>
using namespace document;
using namespace search::fef;
@@ -22,7 +23,39 @@ using vespalib::eval::DoubleValue;
using vespalib::eval::SimpleValue;
using vespalib::eval::TensorSpec;
using vespalib::eval::Value;
+using vespalib::make_string_short::fmt;
+using FeatureValue = FeatureSet::Value;
+
+namespace {
+
+double as_double(const FeatureValue& v) {
+ EXPECT_TRUE(v.is_double());
+ return v.as_double();
+}
+
+TensorSpec as_spec(const FeatureValue& v) {
+ EXPECT_TRUE(v.is_data());
+ auto mem = v.as_data();
+ nbostream buf(mem.data, mem.size);
+ return spec_from_value(*SimpleValue::from_stream(buf));
+}
+
+ConstArrayRef<FeatureValue> as_value_slice(FeatureValues& mf, uint32_t index, uint32_t num_features)
+{
+ return { mf.values.data() + index * num_features, num_features };
+}
+
+void check_match_features(ConstArrayRef<FeatureValue> v, uint32_t docid)
+{
+ SCOPED_TRACE(fmt("Checking docid %u for expected match features", docid));
+ // The following values should have been set by MyRankProgram::run()
+ EXPECT_EQ(10 + docid, as_double(v[0]));
+ EXPECT_EQ(30 + docid, as_double(v[1]));
+ EXPECT_EQ(TensorSpec("tensor(x{})").add({{"x", "a"}}, 20 + docid), as_spec(v[2]));
+}
+
+}
namespace streaming {
@@ -326,6 +359,33 @@ TEST_F(HitCollectorTest, feature_set)
assertHit(30, 4, 2, sr);
}
+TEST_F(HitCollectorTest, match_features)
+{
+ HitCollector hc(3);
+
+ addHit(hc, 0, 10);
+ addHit(hc, 1, 50); // on heap
+ addHit(hc, 2, 20);
+ addHit(hc, 3, 40); // on heap
+ addHit(hc, 4, 30); // on heap
+
+ MyRankProgram rankProgram;
+ FeatureResolver resolver(rankProgram.get_resolver());
+ search::StringStringMap renames;
+ renames["bar"] = "qux";
+ auto mf = hc.get_match_features(rankProgram, resolver, renames);
+ auto num_features = resolver.num_features();
+
+ EXPECT_EQ(num_features, mf.names.size());
+ EXPECT_EQ("foo", mf.names[0]);
+ EXPECT_EQ("qux", mf.names[1]);
+ EXPECT_EQ("baz", mf.names[2]);
+ EXPECT_EQ(num_features * 3, mf.values.size());
+ check_match_features(as_value_slice(mf, 0, num_features), 1);
+ check_match_features(as_value_slice(mf, 1, num_features), 3);
+ check_match_features(as_value_slice(mf, 2, num_features), 4);
+}
+
} // namespace streaming
GTEST_MAIN_RUN_ALL_TESTS()