diff options
Diffstat (limited to 'streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp')
-rw-r--r-- | streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp index 059d3c9f597..362aaeda938 100644 --- a/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp +++ b/streamingvisitors/src/tests/hitcollector/hitcollector_test.cpp @@ -11,6 +11,7 @@ #include <vespa/eval/eval/value_codec.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/util/stringfmt.h> using namespace document; using namespace search::fef; @@ -22,7 +23,39 @@ using vespalib::eval::DoubleValue; using vespalib::eval::SimpleValue; using vespalib::eval::TensorSpec; using vespalib::eval::Value; +using vespalib::make_string_short::fmt; +using FeatureValue = FeatureSet::Value; + +namespace { + +double as_double(const FeatureValue& v) { + EXPECT_TRUE(v.is_double()); + return v.as_double(); +} + +TensorSpec as_spec(const FeatureValue& v) { + EXPECT_TRUE(v.is_data()); + auto mem = v.as_data(); + nbostream buf(mem.data, mem.size); + return spec_from_value(*SimpleValue::from_stream(buf)); +} + +ConstArrayRef<FeatureValue> as_value_slice(FeatureValues& mf, uint32_t index, uint32_t num_features) +{ + return { mf.values.data() + index * num_features, num_features }; +} + +void check_match_features(ConstArrayRef<FeatureValue> v, uint32_t docid) +{ + SCOPED_TRACE(fmt("Checking docid %u for expected match features", docid)); + // The following values should have been set by MyRankProgram::run() + EXPECT_EQ(10 + docid, as_double(v[0])); + EXPECT_EQ(30 + docid, as_double(v[1])); + EXPECT_EQ(TensorSpec("tensor(x{})").add({{"x", "a"}}, 20 + docid), as_spec(v[2])); +} + +} namespace streaming { @@ -326,6 +359,33 @@ TEST_F(HitCollectorTest, feature_set) assertHit(30, 4, 2, sr); } +TEST_F(HitCollectorTest, match_features) +{ + HitCollector hc(3); + + addHit(hc, 0, 10); + addHit(hc, 1, 50); // on heap + addHit(hc, 2, 20); + addHit(hc, 3, 40); // on heap + addHit(hc, 4, 30); // on heap + + MyRankProgram rankProgram; + FeatureResolver resolver(rankProgram.get_resolver()); + search::StringStringMap renames; + renames["bar"] = "qux"; + auto mf = hc.get_match_features(rankProgram, resolver, renames); + auto num_features = resolver.num_features(); + + EXPECT_EQ(num_features, mf.names.size()); + EXPECT_EQ("foo", mf.names[0]); + EXPECT_EQ("qux", mf.names[1]); + EXPECT_EQ("baz", mf.names[2]); + EXPECT_EQ(num_features * 3, mf.values.size()); + check_match_features(as_value_slice(mf, 0, num_features), 1); + check_match_features(as_value_slice(mf, 1, num_features), 3); + check_match_features(as_value_slice(mf, 2, num_features), 4); +} + } // namespace streaming GTEST_MAIN_RUN_ALL_TESTS() |