summaryrefslogtreecommitdiffstats
path: root/vdslib
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-04-28 16:44:48 +0200
committerTor Egge <Tor.Egge@online.no>2023-04-28 16:44:48 +0200
commitcf6ec15a40afd10f2ab6d7ad7e908c95ce713b80 (patch)
treea8cb243b6e9dfb1412c8a6e9221190d20d60462d /vdslib
parentd8f8731f6e91337f241912f71cab12e5c3febf00 (diff)
Serialize match features in vdslib::SearchResult.
Diffstat (limited to 'vdslib')
-rw-r--r--vdslib/src/tests/container/searchresulttest.cpp90
-rw-r--r--vdslib/src/vespa/vdslib/container/searchresult.cpp136
-rw-r--r--vdslib/src/vespa/vdslib/container/searchresult.h6
3 files changed, 214 insertions, 18 deletions
diff --git a/vdslib/src/tests/container/searchresulttest.cpp b/vdslib/src/tests/container/searchresulttest.cpp
index 2b27aeeb95c..f757c19a58b 100644
--- a/vdslib/src/tests/container/searchresulttest.cpp
+++ b/vdslib/src/tests/container/searchresulttest.cpp
@@ -1,24 +1,79 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vdslib/container/searchresult.h>
+#include <vespa/document/util/bytebuffer.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/arrayref.h>
+#include <vespa/vespalib/util/growablebytebuffer.h>
+#include <variant>
using vespalib::FeatureValues;
using FeatureValue = vespalib::FeatureSet::Value;
+using ConvertedValue = std::variant<double, std::string>;
namespace vdslib {
namespace {
-std::vector<double> to_doubles(vespalib::ConstArrayRef<FeatureValue> v) {
- std::vector<double> result;
+std::vector<char> doc1_mf_data{'H', 'i'};
+std::vector<char> doc2_mf_data{'T', 'h', 'e', 'r', 'e'};
+
+
+std::vector<ConvertedValue> convert(vespalib::ConstArrayRef<FeatureValue> v) {
+ std::vector<ConvertedValue> result;
for (auto& iv : v) {
- EXPECT_TRUE(iv.is_double());
- result.emplace_back(iv.as_double());
+ if (iv.is_data()) {
+ result.emplace_back(iv.as_data().make_stringref());
+ } else {
+ result.emplace_back(iv.as_double());
+ }
}
return result;
}
+std::vector<char> serialize(const SearchResult& sr) {
+ auto serialized_size = sr.getSerializedSize();
+ vespalib::GrowableByteBuffer buf;
+ sr.serialize(buf);
+ EXPECT_EQ(serialized_size, buf.position());
+ return { buf.getBuffer(), buf.getBuffer() + buf.position() };
+}
+
+void deserialize(SearchResult& sr, vespalib::ConstArrayRef<char> buf)
+{
+ document::ByteBuffer dbuf(buf.data(), buf.size());
+ sr.deserialize(dbuf);
+ EXPECT_EQ(0, dbuf.getRemaining());
+}
+
+void populate(SearchResult& sr, FeatureValues& mf)
+{
+ sr.addHit(7, "doc1", 5);
+ sr.addHit(8, "doc2", 7);
+ mf.names.push_back("foo");
+ mf.names.push_back("bar");
+ mf.values.resize(4);
+ mf.values[0].set_double(1.0);
+ mf.values[1].set_data({doc1_mf_data.data(), doc1_mf_data.size()});
+ mf.values[2].set_double(12.0);
+ mf.values[3].set_data({doc2_mf_data.data(), doc2_mf_data.size()});
+ sr.set_match_features(FeatureValues(mf));
+}
+
+void check_match_features(SearchResult& sr, const vespalib::string& label, bool sort_remap)
+{
+ SCOPED_TRACE(label);
+ EXPECT_EQ((std::vector<ConvertedValue>{1.0, "Hi"}), convert(sr.get_match_feature_values(sort_remap ? 1 : 0)));
+ EXPECT_EQ((std::vector<ConvertedValue>{12.0, "There"}), convert(sr.get_match_feature_values(sort_remap ? 0 : 1)));
+}
+
+void check_match_features(std::vector<char> buf, const vespalib::string& label, bool sort_remap)
+{
+ SearchResult sr;
+ deserialize(sr, buf);
+ check_match_features(sr, label, sort_remap);
+}
+
}
TEST(SearchResultTest, test_simple)
@@ -86,28 +141,27 @@ TEST(SearchResultTest, test_simple_sort_data)
TEST(SearchResultTest, test_match_features)
{
SearchResult sr;
- sr.addHit(7, "doc1", 5);
- sr.addHit(8, "doc2", 7);
FeatureValues mf;
- mf.names.push_back("foo");
- mf.names.push_back("bar");
- mf.values.resize(4);
- mf.values[0].set_double(1.0);
- mf.values[1].set_double(7.0);
- mf.values[2].set_double(12.0);
- mf.values[3].set_double(13.0);
- sr.set_match_features(FeatureValues(mf));
+ populate(sr, mf);
EXPECT_EQ(mf.names, sr.get_match_features().names);
EXPECT_EQ(mf.values, sr.get_match_features().values);
- EXPECT_EQ((std::vector<double>{ 1.0, 7.0}), to_doubles(sr.get_match_feature_values(0)));
- EXPECT_EQ((std::vector<double>{ 12.0, 13.0}), to_doubles(sr.get_match_feature_values(1)));
+ check_match_features(sr, "unsorted", false);
sr.sort();
// Sorting does not change the stored match features
EXPECT_EQ(mf.names, sr.get_match_features().names);
EXPECT_EQ(mf.values, sr.get_match_features().values);
// Sorting affects retrieval of the stored matched features
- EXPECT_EQ((std::vector<double>{ 12.0, 13.0}), to_doubles(sr.get_match_feature_values(0)));
- EXPECT_EQ((std::vector<double>{ 1.0, 7.0}), to_doubles(sr.get_match_feature_values(1)));
+ check_match_features(sr, "sorted", true);
+}
+
+TEST(SearchResultTest, test_deserialized_match_features)
+{
+ SearchResult sr;
+ FeatureValues mf;
+ populate(sr, mf);
+ check_match_features(serialize(sr), "deserialized unsorted", false);
+ sr.sort();
+ check_match_features(serialize(sr), "deserialized sorted", true);
}
}
diff --git a/vdslib/src/vespa/vdslib/container/searchresult.cpp b/vdslib/src/vespa/vdslib/container/searchresult.cpp
index c110eaa792d..c8bc331d1a8 100644
--- a/vdslib/src/vespa/vdslib/container/searchresult.cpp
+++ b/vdslib/src/vespa/vdslib/container/searchresult.cpp
@@ -8,6 +8,28 @@
namespace vdslib {
+namespace {
+
+// Magic value for hit count to enable extension flags
+constexpr uint32_t enable_extension_flags_magic = 0xffffffffu;
+
+// Extension flag values
+constexpr uint32_t match_features_present = 1;
+
+// Selector values for feature value
+constexpr uint8_t feature_value_is_double = 0;
+constexpr uint8_t feature_value_is_data = 1;
+
+inline bool has_match_features(uint32_t extension_flags) {
+ return ((extension_flags & match_features_present) != 0);
+}
+
+inline bool must_serialize_extension_flags(uint32_t extension_flags, uint32_t hit_count) {
+ return ((extension_flags != 0) || (hit_count == enable_extension_flags_magic));
+}
+
+}
+
void AggregatorList::add(size_t id, const vespalib::MallocPtr & aggrBlob)
{
insert(value_type(id, aggrBlob));
@@ -132,6 +154,13 @@ SearchResult::deserialize(document::ByteBuffer & buf)
buf.getIntNetwork(tmp); _totalHits = tmp;
uint32_t numResults(0), bufSize(0);
buf.getIntNetwork(tmp); numResults = tmp;
+ uint32_t extension_flags = 0u;
+ if (numResults == enable_extension_flags_magic) {
+ buf.getIntNetwork(tmp);
+ extension_flags = tmp;
+ buf.getIntNetwork(tmp);
+ numResults = tmp;
+ }
if (numResults > 0) {
buf.getIntNetwork(tmp); bufSize = tmp;
_docIdBuffer.reset(new vespalib::MallocPtr(bufSize));
@@ -149,12 +178,20 @@ SearchResult::deserialize(document::ByteBuffer & buf)
_sortBlob.deserialize(buf);
_aggregatorList.deserialize(buf);
_groupingList.deserialize(buf);
+ if (has_match_features(extension_flags)) {
+ deserialize_match_features(buf);
+ }
}
void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const
{
buf.putInt(_totalHits);
uint32_t hitCount = std::min(_hits.size(), _wantedHits);
+ uint32_t extension_flags = calc_extension_flags(hitCount);
+ if (must_serialize_extension_flags(extension_flags, hitCount)) {
+ buf.putInt(enable_extension_flags_magic);
+ buf.putInt(extension_flags);
+ }
buf.putInt(hitCount);
if (hitCount > 0) {
uint32_t sz = getBufCount();
@@ -180,17 +217,116 @@ void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const
}
_aggregatorList.serialize(buf);
_groupingList.serialize(buf);
+ if (has_match_features(extension_flags)) {
+ serialize_match_features(buf, hitCount);
+ }
}
uint32_t SearchResult::getSerializedSize() const
{
uint32_t hitCount = std::min(_hits.size(), _wantedHits);
+ uint32_t extension_flags = calc_extension_flags(hitCount);
+ uint32_t extension_flags_overhead = must_serialize_extension_flags(extension_flags, hitCount) ? (2 * sizeof(uint32_t)) : 0;
+ uint32_t match_features_size = has_match_features(extension_flags) ? get_match_features_serialized_size(hitCount) : 0;
return _aggregatorList.getSerializedSize() +
_groupingList.getSerializedSize() +
_sortBlob.getSerializedSize() +
+ extension_flags_overhead +
+ match_features_size +
((hitCount > 0) ? ((4 * 3) + getBufCount() + sizeof(RankType)*hitCount) : 8);
}
+uint32_t
+SearchResult::calc_extension_flags(uint32_t hit_count) const noexcept
+{
+ uint32_t extension_flags = 0u;
+ if (!_match_features.names.empty() && hit_count != 0) {
+ extension_flags |= match_features_present;
+ }
+ return extension_flags;
+}
+
+uint32_t
+SearchResult::get_match_features_serialized_size(uint32_t hit_count) const noexcept
+{
+ uint32_t size = sizeof(uint32_t);
+ for (auto& name : _match_features.names) {
+ size += sizeof(uint32_t) + name.size();
+ }
+ for (uint32_t i = 0; i < hit_count; ++i) {
+ auto mfv = get_match_feature_values(i);
+ for (auto& value : mfv) {
+ if (value.is_data()) {
+ size += sizeof(uint8_t) + sizeof(uint32_t) + value.as_data().size;
+ } else {
+ size += sizeof(uint8_t) + sizeof(double);
+ }
+ }
+ }
+ return size;
+}
+
+void
+SearchResult::serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32_t hit_count) const
+{
+ buf.putInt(_match_features.names.size());
+ for (auto& name : _match_features.names) {
+ buf.putString(name);
+ }
+ for (uint32_t i = 0; i < hit_count; ++i) {
+ auto mfv = get_match_feature_values(i);
+ for (auto& value : mfv) {
+ if (value.is_data()) {
+ buf.putByte(feature_value_is_data);
+ auto mem = value.as_data();
+ buf.putInt(mem.size);
+ buf.putBytes(mem.data, mem.size);
+ } else {
+ buf.putByte(feature_value_is_double);
+ buf.putDouble(value.as_double());
+ }
+ }
+ }
+}
+
+void
+SearchResult::deserialize_match_features(document::ByteBuffer& buf)
+{
+ int32_t tmp(0);
+ double dtmp(0.0);
+ uint8_t selector(0);
+ std::vector<char> scratch;
+ buf.getIntNetwork(tmp);
+ uint32_t num_features = tmp;
+ _match_features.names.resize(num_features);
+ for (auto& name : _match_features.names) {
+ buf.getIntNetwork(tmp);
+ name.resize(tmp);
+ if (tmp != 0) {
+ buf.getBytes(&name[0], tmp);
+ }
+ }
+ uint32_t hit_count = _hits.size();
+ uint32_t num_values = num_features * hit_count;
+ _match_features.values.resize(num_values);
+ for (auto& value : _match_features.values) {
+ buf.getByte(selector);
+ if (selector == feature_value_is_data) {
+ buf.getIntNetwork(tmp);
+ scratch.resize(tmp);
+ if (!scratch.empty()) {
+ buf.getBytes(scratch.data(), scratch.size());
+ }
+ value.set_data({ scratch.data(), scratch.size() });
+ } else if (selector == feature_value_is_double) {
+ buf.getDoubleNetwork(dtmp);
+ value.set_double(dtmp);
+ } else {
+ abort();
+ }
+ }
+}
+
void SearchResult::addHit(uint32_t lid, const char * docId, RankType rank, const void * sortData, size_t sz)
{
addHit(lid, docId, rank);
diff --git a/vdslib/src/vespa/vdslib/container/searchresult.h b/vdslib/src/vespa/vdslib/container/searchresult.h
index bc25ad76631..8bb8df82627 100644
--- a/vdslib/src/vespa/vdslib/container/searchresult.h
+++ b/vdslib/src/vespa/vdslib/container/searchresult.h
@@ -132,6 +132,12 @@ private:
}
};
size_t getBufCount() const { return _numDocIdBytes; }
+
+ uint32_t calc_extension_flags(uint32_t hit_count) const noexcept;
+ uint32_t get_match_features_serialized_size(uint32_t hit_count) const noexcept;
+ void serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32_t hit_count) const;
+ void deserialize_match_features(document::ByteBuffer& buf);
+
using DocIdBuffer = std::shared_ptr<vespalib::MallocPtr>;
uint32_t _totalHits;
size_t _wantedHits;