diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-05-02 15:21:40 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2023-05-02 15:21:40 +0000 |
commit | cde6875d5697545930e0e0c9bda6abf3b365cccc (patch) | |
tree | 11e8c6e8aa7422327838591b37f6ea140f8d3722 /vdslib | |
parent | b83cc57a23c24cba060e884ade5f056cd46c5a82 (diff) |
Deserialize match features in SearchResult used in streaming search.
Diffstat (limited to 'vdslib')
-rw-r--r-- | vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java | 63 | ||||
-rw-r--r-- | vdslib/src/vespa/vdslib/container/searchresult.cpp | 25 |
2 files changed, 73 insertions, 15 deletions
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java index e263c292bc8..b7c9b1b71b5 100644 --- a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java +++ b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java @@ -1,28 +1,40 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vdslib; +import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.vespa.objects.BufferSerializer; import com.yahoo.vespa.objects.Deserializer; import java.io.UnsupportedEncodingException; import java.nio.ByteOrder; +import java.util.ArrayList; import java.util.Map; +import java.util.Optional; import java.util.TreeMap; public class SearchResult { public static class Hit implements Comparable<Hit> { private String docId; - private double rank; + private double rank; + private MatchFeatureData.HitValue matchFeatures; public Hit(Hit h) { docId = h.docId; rank = h.rank; + matchFeatures = h.matchFeatures; } public Hit(String docId, double rank) { this.rank = rank; this.docId = docId; + this.matchFeatures = null; + } + final public String getDocId() { return docId; } + final public double getRank() { return rank; } + final public Optional<MatchFeatureData.HitValue> getMatchFeatures() { + return Optional.ofNullable(matchFeatures); } - final public String getDocId() { return docId; } - final public double getRank() { return rank; } final public void setRank(double rank) { this.rank = rank; } + final public void setMatchFeatures(MatchFeatureData.HitValue matchFeatures) { + this.matchFeatures = matchFeatures; + } public int compareTo(Hit h) { return (h.rank < rank) ? -1 : (h.rank > rank) ? 1 : 0; // Sort order: descending } @@ -49,12 +61,19 @@ public class SearchResult { private Hit[] hits; private TreeMap<Integer, byte []> aggregatorList; private TreeMap<Integer, byte []> groupingList; + private static int EXTENSION_FLAGS_PRESENT = -1; + private static int MATCH_FEATURES_PRESENT_MASK = 1; public SearchResult(Deserializer buf) { BufferSerializer bser = (BufferSerializer) buf; // TODO: dirty cast. must do this differently bser.order(ByteOrder.BIG_ENDIAN); this.totalHits = buf.getInt(null); int numHits = buf.getInt(null); + int extensionFlags = 0; + if (hasExtensionFlags(numHits)) { + extensionFlags = buf.getInt(null); + numHits = buf.getInt(null); + } hits = new Hit[numHits]; if (numHits != 0) { int docIdBufferLength = buf.getInt(null); @@ -101,7 +120,45 @@ public class SearchResult { groupingList.put(aggrId, buf.getBytes(null, aggrLength)); } + if (hasMatchFeatures(extensionFlags)) { + deserializeMatchFeatures(buf, numHits); + } } + + private void deserializeMatchFeatures(Deserializer buf, int numHits) { + var featureNames = new ArrayList<String>(); + int numFeatures = buf.getInt(null); + for (int i = 0; i < numFeatures; ++i) { + featureNames.add(buf.getString(null)); + } + var factory = new MatchFeatureData(featureNames); + for (int i = 0; i < numHits; ++i) { + var matchFeatures = factory.addHit(); + for (int j = 0; j < numFeatures; ++j) { + byte featureType = buf.getByte(null); + if (isDoubleFeature(featureType)) { + matchFeatures.set(j, buf.getDouble(null)); + } else { + int bufLength = buf.getInt(null); + matchFeatures.set(j, buf.getBytes(null, bufLength)); + } + } + hits[i].setMatchFeatures(matchFeatures); + } + } + + private static boolean hasExtensionFlags(int numHits) { + return numHits == EXTENSION_FLAGS_PRESENT; + } + + private static boolean hasMatchFeatures(int extensionFlags) { + return (extensionFlags & MATCH_FEATURES_PRESENT_MASK) != 0; + } + + private static boolean isDoubleFeature(byte featureType) { + return featureType == 0; + } + /** * Constructs a new message from a byte buffer. * diff --git a/vdslib/src/vespa/vdslib/container/searchresult.cpp b/vdslib/src/vespa/vdslib/container/searchresult.cpp index c8bc331d1a8..e348c9d9e13 100644 --- a/vdslib/src/vespa/vdslib/container/searchresult.cpp +++ b/vdslib/src/vespa/vdslib/container/searchresult.cpp @@ -11,21 +11,21 @@ namespace vdslib { namespace { // Magic value for hit count to enable extension flags -constexpr uint32_t enable_extension_flags_magic = 0xffffffffu; +constexpr uint32_t extension_flags_present = 0xffffffffu; // Extension flag values -constexpr uint32_t match_features_present = 1; +constexpr uint32_t match_features_present_mask = 1; // Selector values for feature value constexpr uint8_t feature_value_is_double = 0; constexpr uint8_t feature_value_is_data = 1; inline bool has_match_features(uint32_t extension_flags) { - return ((extension_flags & match_features_present) != 0); + return ((extension_flags & match_features_present_mask) != 0); } inline bool must_serialize_extension_flags(uint32_t extension_flags, uint32_t hit_count) { - return ((extension_flags != 0) || (hit_count == enable_extension_flags_magic)); + return ((extension_flags != 0) || (hit_count == extension_flags_present)); } } @@ -155,7 +155,7 @@ SearchResult::deserialize(document::ByteBuffer & buf) uint32_t numResults(0), bufSize(0); buf.getIntNetwork(tmp); numResults = tmp; uint32_t extension_flags = 0u; - if (numResults == enable_extension_flags_magic) { + if (numResults == extension_flags_present) { buf.getIntNetwork(tmp); extension_flags = tmp; buf.getIntNetwork(tmp); @@ -189,7 +189,7 @@ void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const uint32_t hitCount = std::min(_hits.size(), _wantedHits); uint32_t extension_flags = calc_extension_flags(hitCount); if (must_serialize_extension_flags(extension_flags, hitCount)) { - buf.putInt(enable_extension_flags_magic); + buf.putInt(extension_flags_present); buf.putInt(extension_flags); } buf.putInt(hitCount); @@ -241,7 +241,7 @@ SearchResult::calc_extension_flags(uint32_t hit_count) const noexcept { uint32_t extension_flags = 0u; if (!_match_features.names.empty() && hit_count != 0) { - extension_flags |= match_features_present; + extension_flags |= match_features_present_mask; } return extension_flags; } @@ -251,7 +251,7 @@ SearchResult::get_match_features_serialized_size(uint32_t hit_count) const noexc { uint32_t size = sizeof(uint32_t); for (auto& name : _match_features.names) { - size += sizeof(uint32_t) + name.size(); + size += sizeof(uint32_t) + name.size() + 1; } for (uint32_t i = 0; i < hit_count; ++i) { auto mfv = get_match_feature_values(i); @@ -271,7 +271,7 @@ SearchResult::serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32 { buf.putInt(_match_features.names.size()); for (auto& name : _match_features.names) { - buf.putString(name); + buf.put_c_string(name); } for (uint32_t i = 0; i < hit_count; ++i) { auto mfv = get_match_feature_values(i); @@ -301,10 +301,11 @@ SearchResult::deserialize_match_features(document::ByteBuffer& buf) _match_features.names.resize(num_features); for (auto& name : _match_features.names) { buf.getIntNetwork(tmp); - name.resize(tmp); - if (tmp != 0) { - buf.getBytes(&name[0], tmp); + if (tmp > 1) { + name.resize(tmp - 1); + buf.getBytes(&name[0], tmp - 1); } + buf.getByte(selector); // Read and ignore the nul-termination. } uint32_t hit_count = _hits.size(); uint32_t num_values = num_features * hit_count; |