summaryrefslogtreecommitdiffstats
path: root/vdslib
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-05-02 15:21:40 +0000
committerGeir Storli <geirst@yahooinc.com>2023-05-02 15:21:40 +0000
commitcde6875d5697545930e0e0c9bda6abf3b365cccc (patch)
tree11e8c6e8aa7422327838591b37f6ea140f8d3722 /vdslib
parentb83cc57a23c24cba060e884ade5f056cd46c5a82 (diff)
Deserialize match features in SearchResult used in streaming search.
Diffstat (limited to 'vdslib')
-rw-r--r--vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java63
-rw-r--r--vdslib/src/vespa/vdslib/container/searchresult.cpp25
2 files changed, 73 insertions, 15 deletions
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
index e263c292bc8..b7c9b1b71b5 100644
--- a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
+++ b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
@@ -1,28 +1,40 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vdslib;
+import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.vespa.objects.BufferSerializer;
import com.yahoo.vespa.objects.Deserializer;
import java.io.UnsupportedEncodingException;
import java.nio.ByteOrder;
+import java.util.ArrayList;
import java.util.Map;
+import java.util.Optional;
import java.util.TreeMap;
public class SearchResult {
public static class Hit implements Comparable<Hit> {
private String docId;
- private double rank;
+ private double rank;
+ private MatchFeatureData.HitValue matchFeatures;
public Hit(Hit h) {
docId = h.docId;
rank = h.rank;
+ matchFeatures = h.matchFeatures;
}
public Hit(String docId, double rank) {
this.rank = rank;
this.docId = docId;
+ this.matchFeatures = null;
+ }
+ final public String getDocId() { return docId; }
+ final public double getRank() { return rank; }
+ final public Optional<MatchFeatureData.HitValue> getMatchFeatures() {
+ return Optional.ofNullable(matchFeatures);
}
- final public String getDocId() { return docId; }
- final public double getRank() { return rank; }
final public void setRank(double rank) { this.rank = rank; }
+ final public void setMatchFeatures(MatchFeatureData.HitValue matchFeatures) {
+ this.matchFeatures = matchFeatures;
+ }
public int compareTo(Hit h) {
return (h.rank < rank) ? -1 : (h.rank > rank) ? 1 : 0; // Sort order: descending
}
@@ -49,12 +61,19 @@ public class SearchResult {
private Hit[] hits;
private TreeMap<Integer, byte []> aggregatorList;
private TreeMap<Integer, byte []> groupingList;
+ private static int EXTENSION_FLAGS_PRESENT = -1;
+ private static int MATCH_FEATURES_PRESENT_MASK = 1;
public SearchResult(Deserializer buf) {
BufferSerializer bser = (BufferSerializer) buf; // TODO: dirty cast. must do this differently
bser.order(ByteOrder.BIG_ENDIAN);
this.totalHits = buf.getInt(null);
int numHits = buf.getInt(null);
+ int extensionFlags = 0;
+ if (hasExtensionFlags(numHits)) {
+ extensionFlags = buf.getInt(null);
+ numHits = buf.getInt(null);
+ }
hits = new Hit[numHits];
if (numHits != 0) {
int docIdBufferLength = buf.getInt(null);
@@ -101,7 +120,45 @@ public class SearchResult {
groupingList.put(aggrId, buf.getBytes(null, aggrLength));
}
+ if (hasMatchFeatures(extensionFlags)) {
+ deserializeMatchFeatures(buf, numHits);
+ }
}
+
+ private void deserializeMatchFeatures(Deserializer buf, int numHits) {
+ var featureNames = new ArrayList<String>();
+ int numFeatures = buf.getInt(null);
+ for (int i = 0; i < numFeatures; ++i) {
+ featureNames.add(buf.getString(null));
+ }
+ var factory = new MatchFeatureData(featureNames);
+ for (int i = 0; i < numHits; ++i) {
+ var matchFeatures = factory.addHit();
+ for (int j = 0; j < numFeatures; ++j) {
+ byte featureType = buf.getByte(null);
+ if (isDoubleFeature(featureType)) {
+ matchFeatures.set(j, buf.getDouble(null));
+ } else {
+ int bufLength = buf.getInt(null);
+ matchFeatures.set(j, buf.getBytes(null, bufLength));
+ }
+ }
+ hits[i].setMatchFeatures(matchFeatures);
+ }
+ }
+
+ private static boolean hasExtensionFlags(int numHits) {
+ return numHits == EXTENSION_FLAGS_PRESENT;
+ }
+
+ private static boolean hasMatchFeatures(int extensionFlags) {
+ return (extensionFlags & MATCH_FEATURES_PRESENT_MASK) != 0;
+ }
+
+ private static boolean isDoubleFeature(byte featureType) {
+ return featureType == 0;
+ }
+
/**
* Constructs a new message from a byte buffer.
*
diff --git a/vdslib/src/vespa/vdslib/container/searchresult.cpp b/vdslib/src/vespa/vdslib/container/searchresult.cpp
index c8bc331d1a8..e348c9d9e13 100644
--- a/vdslib/src/vespa/vdslib/container/searchresult.cpp
+++ b/vdslib/src/vespa/vdslib/container/searchresult.cpp
@@ -11,21 +11,21 @@ namespace vdslib {
namespace {
// Magic value for hit count to enable extension flags
-constexpr uint32_t enable_extension_flags_magic = 0xffffffffu;
+constexpr uint32_t extension_flags_present = 0xffffffffu;
// Extension flag values
-constexpr uint32_t match_features_present = 1;
+constexpr uint32_t match_features_present_mask = 1;
// Selector values for feature value
constexpr uint8_t feature_value_is_double = 0;
constexpr uint8_t feature_value_is_data = 1;
inline bool has_match_features(uint32_t extension_flags) {
- return ((extension_flags & match_features_present) != 0);
+ return ((extension_flags & match_features_present_mask) != 0);
}
inline bool must_serialize_extension_flags(uint32_t extension_flags, uint32_t hit_count) {
- return ((extension_flags != 0) || (hit_count == enable_extension_flags_magic));
+ return ((extension_flags != 0) || (hit_count == extension_flags_present));
}
}
@@ -155,7 +155,7 @@ SearchResult::deserialize(document::ByteBuffer & buf)
uint32_t numResults(0), bufSize(0);
buf.getIntNetwork(tmp); numResults = tmp;
uint32_t extension_flags = 0u;
- if (numResults == enable_extension_flags_magic) {
+ if (numResults == extension_flags_present) {
buf.getIntNetwork(tmp);
extension_flags = tmp;
buf.getIntNetwork(tmp);
@@ -189,7 +189,7 @@ void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const
uint32_t hitCount = std::min(_hits.size(), _wantedHits);
uint32_t extension_flags = calc_extension_flags(hitCount);
if (must_serialize_extension_flags(extension_flags, hitCount)) {
- buf.putInt(enable_extension_flags_magic);
+ buf.putInt(extension_flags_present);
buf.putInt(extension_flags);
}
buf.putInt(hitCount);
@@ -241,7 +241,7 @@ SearchResult::calc_extension_flags(uint32_t hit_count) const noexcept
{
uint32_t extension_flags = 0u;
if (!_match_features.names.empty() && hit_count != 0) {
- extension_flags |= match_features_present;
+ extension_flags |= match_features_present_mask;
}
return extension_flags;
}
@@ -251,7 +251,7 @@ SearchResult::get_match_features_serialized_size(uint32_t hit_count) const noexc
{
uint32_t size = sizeof(uint32_t);
for (auto& name : _match_features.names) {
- size += sizeof(uint32_t) + name.size();
+ size += sizeof(uint32_t) + name.size() + 1;
}
for (uint32_t i = 0; i < hit_count; ++i) {
auto mfv = get_match_feature_values(i);
@@ -271,7 +271,7 @@ SearchResult::serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32
{
buf.putInt(_match_features.names.size());
for (auto& name : _match_features.names) {
- buf.putString(name);
+ buf.put_c_string(name);
}
for (uint32_t i = 0; i < hit_count; ++i) {
auto mfv = get_match_feature_values(i);
@@ -301,10 +301,11 @@ SearchResult::deserialize_match_features(document::ByteBuffer& buf)
_match_features.names.resize(num_features);
for (auto& name : _match_features.names) {
buf.getIntNetwork(tmp);
- name.resize(tmp);
- if (tmp != 0) {
- buf.getBytes(&name[0], tmp);
+ if (tmp > 1) {
+ name.resize(tmp - 1);
+ buf.getBytes(&name[0], tmp - 1);
}
+ buf.getByte(selector); // Read and ignore the nul-termination.
}
uint32_t hit_count = _hits.size();
uint32_t num_values = num_features * hit_count;