From cf6ec15a40afd10f2ab6d7ad7e908c95ce713b80 Mon Sep 17 00:00:00 2001 From: Tor Egge Date: Fri, 28 Apr 2023 16:44:48 +0200 Subject: Serialize match features in vdslib::SearchResult. --- documentapi/src/tests/messages/messages60test.cpp | 48 ++++++++ .../6.221-cpp-QueryResultMessage-6.dat | Bin 0 -> 128 bytes vdslib/src/tests/container/searchresulttest.cpp | 90 +++++++++++--- vdslib/src/vespa/vdslib/container/searchresult.cpp | 136 +++++++++++++++++++++ vdslib/src/vespa/vdslib/container/searchresult.h | 6 + 5 files changed, 262 insertions(+), 18 deletions(-) create mode 100644 documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat diff --git a/documentapi/src/tests/messages/messages60test.cpp b/documentapi/src/tests/messages/messages60test.cpp index 12cecefb072..89ab20373e7 100644 --- a/documentapi/src/tests/messages/messages60test.cpp +++ b/documentapi/src/tests/messages/messages60test.cpp @@ -11,9 +11,19 @@ #include #include #include +#include +#include using document::DataType; using document::DocumentTypeRepo; +using vespalib::FeatureValues; + +namespace { + +std::vector doc1_mf_data{'H', 'i'}; +std::vector doc2_mf_data{'T', 'h', 'e', 'r', 'e'}; + +} template struct Unwrap { @@ -661,6 +671,44 @@ Messages60Test::testQueryResultMessage() EXPECT_EQUAL(memcmp("sortdata3", buf, sz), 0); EXPECT_EQUAL(rank, vdslib::SearchResult::RankType(90)); EXPECT_EQUAL(strcmp("doc18", docId), 0); + + QueryResultMessage qrm3; + auto& sr3(qrm3.getSearchResult()); + sr3.addHit(0, "doc1", 5); + sr3.addHit(1, "doc2", 7); + FeatureValues mf; + mf.names.push_back("foo"); + mf.names.push_back("bar"); + mf.values.resize(4); + mf.values[0].set_double(1.0); + mf.values[1].set_data({ doc1_mf_data.data(), doc1_mf_data.size()}); + mf.values[2].set_double(12.0); + mf.values[3].set_data({ doc2_mf_data.data(), doc2_mf_data.size()}); + sr3.set_match_features(FeatureValues(mf)); + sr3.sort(); + + EXPECT_EQUAL(MESSAGE_BASE_LENGTH + 123u, serialize("QueryResultMessage-6", qrm3)); + routable = deserialize("QueryResultMessage-6", DocumentProtocol::MESSAGE_QUERYRESULT, LANG_CPP); + if (!EXPECT_TRUE(routable)) { + return false; + } + dm = static_cast(routable.get()); + dr = &dm->getSearchResult(); + EXPECT_EQUAL(size_t(2), dr->getHitCount()); + dr->getHit(0, docId, rank); + EXPECT_EQUAL(vdslib::SearchResult::RankType(7), rank); + EXPECT_EQUAL(strcmp("doc2", docId), 0); + dr->getHit(1, docId, rank); + EXPECT_EQUAL(vdslib::SearchResult::RankType(5), rank); + EXPECT_EQUAL(strcmp("doc1", docId), 0); + auto mfv = dr->get_match_feature_values(0); + EXPECT_EQUAL(2u, mfv.size()); + EXPECT_EQUAL(12.0, mfv[0].as_double()); + EXPECT_EQUAL("There", mfv[1].as_data().make_string()); + mfv = dr->get_match_feature_values(1); + EXPECT_EQUAL(2u, mfv.size()); + EXPECT_EQUAL(1.0, mfv[0].as_double()); + EXPECT_EQUAL("Hi", mfv[1].as_data().make_string()); return true; } diff --git a/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat b/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat new file mode 100644 index 00000000000..229441aa9ba Binary files /dev/null and b/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat differ diff --git a/vdslib/src/tests/container/searchresulttest.cpp b/vdslib/src/tests/container/searchresulttest.cpp index 2b27aeeb95c..f757c19a58b 100644 --- a/vdslib/src/tests/container/searchresulttest.cpp +++ b/vdslib/src/tests/container/searchresulttest.cpp @@ -1,24 +1,79 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include +#include #include +#include +#include +#include using vespalib::FeatureValues; using FeatureValue = vespalib::FeatureSet::Value; +using ConvertedValue = std::variant; namespace vdslib { namespace { -std::vector to_doubles(vespalib::ConstArrayRef v) { - std::vector result; +std::vector doc1_mf_data{'H', 'i'}; +std::vector doc2_mf_data{'T', 'h', 'e', 'r', 'e'}; + + +std::vector convert(vespalib::ConstArrayRef v) { + std::vector result; for (auto& iv : v) { - EXPECT_TRUE(iv.is_double()); - result.emplace_back(iv.as_double()); + if (iv.is_data()) { + result.emplace_back(iv.as_data().make_stringref()); + } else { + result.emplace_back(iv.as_double()); + } } return result; } +std::vector serialize(const SearchResult& sr) { + auto serialized_size = sr.getSerializedSize(); + vespalib::GrowableByteBuffer buf; + sr.serialize(buf); + EXPECT_EQ(serialized_size, buf.position()); + return { buf.getBuffer(), buf.getBuffer() + buf.position() }; +} + +void deserialize(SearchResult& sr, vespalib::ConstArrayRef buf) +{ + document::ByteBuffer dbuf(buf.data(), buf.size()); + sr.deserialize(dbuf); + EXPECT_EQ(0, dbuf.getRemaining()); +} + +void populate(SearchResult& sr, FeatureValues& mf) +{ + sr.addHit(7, "doc1", 5); + sr.addHit(8, "doc2", 7); + mf.names.push_back("foo"); + mf.names.push_back("bar"); + mf.values.resize(4); + mf.values[0].set_double(1.0); + mf.values[1].set_data({doc1_mf_data.data(), doc1_mf_data.size()}); + mf.values[2].set_double(12.0); + mf.values[3].set_data({doc2_mf_data.data(), doc2_mf_data.size()}); + sr.set_match_features(FeatureValues(mf)); +} + +void check_match_features(SearchResult& sr, const vespalib::string& label, bool sort_remap) +{ + SCOPED_TRACE(label); + EXPECT_EQ((std::vector{1.0, "Hi"}), convert(sr.get_match_feature_values(sort_remap ? 1 : 0))); + EXPECT_EQ((std::vector{12.0, "There"}), convert(sr.get_match_feature_values(sort_remap ? 0 : 1))); +} + +void check_match_features(std::vector buf, const vespalib::string& label, bool sort_remap) +{ + SearchResult sr; + deserialize(sr, buf); + check_match_features(sr, label, sort_remap); +} + } TEST(SearchResultTest, test_simple) @@ -86,28 +141,27 @@ TEST(SearchResultTest, test_simple_sort_data) TEST(SearchResultTest, test_match_features) { SearchResult sr; - sr.addHit(7, "doc1", 5); - sr.addHit(8, "doc2", 7); FeatureValues mf; - mf.names.push_back("foo"); - mf.names.push_back("bar"); - mf.values.resize(4); - mf.values[0].set_double(1.0); - mf.values[1].set_double(7.0); - mf.values[2].set_double(12.0); - mf.values[3].set_double(13.0); - sr.set_match_features(FeatureValues(mf)); + populate(sr, mf); EXPECT_EQ(mf.names, sr.get_match_features().names); EXPECT_EQ(mf.values, sr.get_match_features().values); - EXPECT_EQ((std::vector{ 1.0, 7.0}), to_doubles(sr.get_match_feature_values(0))); - EXPECT_EQ((std::vector{ 12.0, 13.0}), to_doubles(sr.get_match_feature_values(1))); + check_match_features(sr, "unsorted", false); sr.sort(); // Sorting does not change the stored match features EXPECT_EQ(mf.names, sr.get_match_features().names); EXPECT_EQ(mf.values, sr.get_match_features().values); // Sorting affects retrieval of the stored matched features - EXPECT_EQ((std::vector{ 12.0, 13.0}), to_doubles(sr.get_match_feature_values(0))); - EXPECT_EQ((std::vector{ 1.0, 7.0}), to_doubles(sr.get_match_feature_values(1))); + check_match_features(sr, "sorted", true); +} + +TEST(SearchResultTest, test_deserialized_match_features) +{ + SearchResult sr; + FeatureValues mf; + populate(sr, mf); + check_match_features(serialize(sr), "deserialized unsorted", false); + sr.sort(); + check_match_features(serialize(sr), "deserialized sorted", true); } } diff --git a/vdslib/src/vespa/vdslib/container/searchresult.cpp b/vdslib/src/vespa/vdslib/container/searchresult.cpp index c110eaa792d..c8bc331d1a8 100644 --- a/vdslib/src/vespa/vdslib/container/searchresult.cpp +++ b/vdslib/src/vespa/vdslib/container/searchresult.cpp @@ -8,6 +8,28 @@ namespace vdslib { +namespace { + +// Magic value for hit count to enable extension flags +constexpr uint32_t enable_extension_flags_magic = 0xffffffffu; + +// Extension flag values +constexpr uint32_t match_features_present = 1; + +// Selector values for feature value +constexpr uint8_t feature_value_is_double = 0; +constexpr uint8_t feature_value_is_data = 1; + +inline bool has_match_features(uint32_t extension_flags) { + return ((extension_flags & match_features_present) != 0); +} + +inline bool must_serialize_extension_flags(uint32_t extension_flags, uint32_t hit_count) { + return ((extension_flags != 0) || (hit_count == enable_extension_flags_magic)); +} + +} + void AggregatorList::add(size_t id, const vespalib::MallocPtr & aggrBlob) { insert(value_type(id, aggrBlob)); @@ -132,6 +154,13 @@ SearchResult::deserialize(document::ByteBuffer & buf) buf.getIntNetwork(tmp); _totalHits = tmp; uint32_t numResults(0), bufSize(0); buf.getIntNetwork(tmp); numResults = tmp; + uint32_t extension_flags = 0u; + if (numResults == enable_extension_flags_magic) { + buf.getIntNetwork(tmp); + extension_flags = tmp; + buf.getIntNetwork(tmp); + numResults = tmp; + } if (numResults > 0) { buf.getIntNetwork(tmp); bufSize = tmp; _docIdBuffer.reset(new vespalib::MallocPtr(bufSize)); @@ -149,12 +178,20 @@ SearchResult::deserialize(document::ByteBuffer & buf) _sortBlob.deserialize(buf); _aggregatorList.deserialize(buf); _groupingList.deserialize(buf); + if (has_match_features(extension_flags)) { + deserialize_match_features(buf); + } } void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const { buf.putInt(_totalHits); uint32_t hitCount = std::min(_hits.size(), _wantedHits); + uint32_t extension_flags = calc_extension_flags(hitCount); + if (must_serialize_extension_flags(extension_flags, hitCount)) { + buf.putInt(enable_extension_flags_magic); + buf.putInt(extension_flags); + } buf.putInt(hitCount); if (hitCount > 0) { uint32_t sz = getBufCount(); @@ -180,17 +217,116 @@ void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const } _aggregatorList.serialize(buf); _groupingList.serialize(buf); + if (has_match_features(extension_flags)) { + serialize_match_features(buf, hitCount); + } } uint32_t SearchResult::getSerializedSize() const { uint32_t hitCount = std::min(_hits.size(), _wantedHits); + uint32_t extension_flags = calc_extension_flags(hitCount); + uint32_t extension_flags_overhead = must_serialize_extension_flags(extension_flags, hitCount) ? (2 * sizeof(uint32_t)) : 0; + uint32_t match_features_size = has_match_features(extension_flags) ? get_match_features_serialized_size(hitCount) : 0; return _aggregatorList.getSerializedSize() + _groupingList.getSerializedSize() + _sortBlob.getSerializedSize() + + extension_flags_overhead + + match_features_size + ((hitCount > 0) ? ((4 * 3) + getBufCount() + sizeof(RankType)*hitCount) : 8); } +uint32_t +SearchResult::calc_extension_flags(uint32_t hit_count) const noexcept +{ + uint32_t extension_flags = 0u; + if (!_match_features.names.empty() && hit_count != 0) { + extension_flags |= match_features_present; + } + return extension_flags; +} + +uint32_t +SearchResult::get_match_features_serialized_size(uint32_t hit_count) const noexcept +{ + uint32_t size = sizeof(uint32_t); + for (auto& name : _match_features.names) { + size += sizeof(uint32_t) + name.size(); + } + for (uint32_t i = 0; i < hit_count; ++i) { + auto mfv = get_match_feature_values(i); + for (auto& value : mfv) { + if (value.is_data()) { + size += sizeof(uint8_t) + sizeof(uint32_t) + value.as_data().size; + } else { + size += sizeof(uint8_t) + sizeof(double); + } + } + } + return size; +} + +void +SearchResult::serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32_t hit_count) const +{ + buf.putInt(_match_features.names.size()); + for (auto& name : _match_features.names) { + buf.putString(name); + } + for (uint32_t i = 0; i < hit_count; ++i) { + auto mfv = get_match_feature_values(i); + for (auto& value : mfv) { + if (value.is_data()) { + buf.putByte(feature_value_is_data); + auto mem = value.as_data(); + buf.putInt(mem.size); + buf.putBytes(mem.data, mem.size); + } else { + buf.putByte(feature_value_is_double); + buf.putDouble(value.as_double()); + } + } + } +} + +void +SearchResult::deserialize_match_features(document::ByteBuffer& buf) +{ + int32_t tmp(0); + double dtmp(0.0); + uint8_t selector(0); + std::vector scratch; + buf.getIntNetwork(tmp); + uint32_t num_features = tmp; + _match_features.names.resize(num_features); + for (auto& name : _match_features.names) { + buf.getIntNetwork(tmp); + name.resize(tmp); + if (tmp != 0) { + buf.getBytes(&name[0], tmp); + } + } + uint32_t hit_count = _hits.size(); + uint32_t num_values = num_features * hit_count; + _match_features.values.resize(num_values); + for (auto& value : _match_features.values) { + buf.getByte(selector); + if (selector == feature_value_is_data) { + buf.getIntNetwork(tmp); + scratch.resize(tmp); + if (!scratch.empty()) { + buf.getBytes(scratch.data(), scratch.size()); + } + value.set_data({ scratch.data(), scratch.size() }); + } else if (selector == feature_value_is_double) { + buf.getDoubleNetwork(dtmp); + value.set_double(dtmp); + } else { + abort(); + } + } +} + void SearchResult::addHit(uint32_t lid, const char * docId, RankType rank, const void * sortData, size_t sz) { addHit(lid, docId, rank); diff --git a/vdslib/src/vespa/vdslib/container/searchresult.h b/vdslib/src/vespa/vdslib/container/searchresult.h index bc25ad76631..8bb8df82627 100644 --- a/vdslib/src/vespa/vdslib/container/searchresult.h +++ b/vdslib/src/vespa/vdslib/container/searchresult.h @@ -132,6 +132,12 @@ private: } }; size_t getBufCount() const { return _numDocIdBytes; } + + uint32_t calc_extension_flags(uint32_t hit_count) const noexcept; + uint32_t get_match_features_serialized_size(uint32_t hit_count) const noexcept; + void serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32_t hit_count) const; + void deserialize_match_features(document::ByteBuffer& buf); + using DocIdBuffer = std::shared_ptr; uint32_t _totalHits; size_t _wantedHits; -- cgit v1.2.3