aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java57
-rw-r--r--documentapi/src/tests/messages/messages60test.cpp6
-rw-r--r--documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.datbin128 -> 130 bytes
-rw-r--r--vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java63
-rw-r--r--vdslib/src/vespa/vdslib/container/searchresult.cpp25
-rw-r--r--vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp8
-rw-r--r--vespalib/src/vespa/vespalib/util/growablebytebuffer.h8
7 files changed, 136 insertions, 31 deletions
diff --git a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java
index 22650fcdbf8..940217aa2b4 100644
--- a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java
+++ b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java
@@ -633,6 +633,8 @@ public class Messages60TestCase extends MessagesTestBase {
@Override
public void run() throws Exception {
+ test_result_with_match_features();
+
Routable routable = deserialize("QueryResultMessage-1", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP);
assertTrue(routable instanceof QueryResultMessage);
@@ -647,9 +649,11 @@ public class Messages60TestCase extends MessagesTestBase {
com.yahoo.vdslib.SearchResult.Hit h = msg.getResult().getHit(0);
assertEquals(89.0, h.getRank(), 1E-6);
assertEquals("doc1", h.getDocId());
+ assertFalse(h.getMatchFeatures().isPresent());
h = msg.getResult().getHit(1);
assertEquals(109.0, h.getRank(), 1E-6);
assertEquals("doc17", h.getDocId());
+ assertFalse(h.getMatchFeatures().isPresent());
routable = deserialize("QueryResultMessage-3", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP);
assertTrue(routable instanceof QueryResultMessage);
@@ -659,9 +663,11 @@ public class Messages60TestCase extends MessagesTestBase {
h = msg.getResult().getHit(0);
assertEquals(109.0, h.getRank(), 1E-6);
assertEquals("doc17", h.getDocId());
+ assertFalse(h.getMatchFeatures().isPresent());
h = msg.getResult().getHit(1);
assertEquals(89.0, h.getRank(), 1E-6);
assertEquals("doc1", h.getDocId());
+ assertFalse(h.getMatchFeatures().isPresent());
routable = deserialize("QueryResultMessage-4", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP);
assertTrue(routable instanceof QueryResultMessage);
@@ -673,32 +679,55 @@ public class Messages60TestCase extends MessagesTestBase {
assertEquals(89.0, h.getRank(), 1E-6);
assertEquals("doc1", h.getDocId());
byte[] b = ((SearchResult.HitWithSortBlob)h).getSortBlob();
- assertEquals(9, b.length);
- byte[] e = { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '2' };
- for (int i = 0; i < b.length; i++) {
- assertEquals(e[i], b[i]);
- }
+ assertEqualsData(new byte[] { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '2' }, b);
+
h = msg.getResult().getHit(1);
assertTrue(h instanceof SearchResult.HitWithSortBlob);
assertEquals(109.0, h.getRank(), 1E-6);
assertEquals("doc17", h.getDocId());
b = ((SearchResult.HitWithSortBlob)h).getSortBlob();
- assertEquals(9, b.length);
- byte[] d = { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '1' };
- for (int i = 0; i < b.length; i++) {
- assertEquals(d[i], b[i]);
- }
+ assertEqualsData(new byte[] { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '1' }, b);
+
h = msg.getResult().getHit(2);
assertTrue(h instanceof SearchResult.HitWithSortBlob);
assertEquals(90.0, h.getRank(), 1E-6);
assertEquals("doc18", h.getDocId());
b = ((SearchResult.HitWithSortBlob)h).getSortBlob();
- assertEquals(9, b.length);
- byte[] c = { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '3' };
- for (int i = 0; i < b.length; i++) {
- assertEquals(c[i], b[i]);
+ assertEqualsData(new byte[] { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '3' }, b);
+ }
+
+ void assertEqualsData(byte[] exp, byte[] act) {
+ assertEquals(exp.length, act.length);
+ for (int i = 0; i < exp.length; ++i) {
+ assertEquals(exp[i], act[i]);
}
}
+
+ void test_result_with_match_features() {
+ Routable routable = deserialize("QueryResultMessage-6", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP);
+ assertTrue(routable instanceof QueryResultMessage);
+
+ var msg = (QueryResultMessage)routable;
+ assertEquals(2, msg.getResult().getHitCount());
+
+ var h = msg.getResult().getHit(0);
+ assertTrue(h instanceof SearchResult.Hit);
+ assertEquals(7.0, h.getRank(), 1E-6);
+ assertEquals("doc2", h.getDocId());
+ assertTrue(h.getMatchFeatures().isPresent());
+ var mf = h.getMatchFeatures().get();
+ assertEquals(12.0, mf.field("foo").asDouble(), 1E-6);
+ assertEqualsData(new byte[] { 'T', 'h', 'e', 'r', 'e' }, mf.field("bar").asData());
+
+ h = msg.getResult().getHit(1);
+ assertTrue(h instanceof SearchResult.Hit);
+ assertEquals(5.0, h.getRank(), 1E-6);
+ assertEquals("doc1", h.getDocId());
+ assertTrue(h.getMatchFeatures().isPresent());
+ mf = h.getMatchFeatures().get();
+ assertEquals(1.0, mf.field("foo").asDouble(), 1E-6);
+ assertEqualsData(new byte[] { 'H', 'i' }, mf.field("bar").asData());
+ }
}
public class testGetBucketListReply implements RunnableTest {
diff --git a/documentapi/src/tests/messages/messages60test.cpp b/documentapi/src/tests/messages/messages60test.cpp
index 89ab20373e7..58295ae2395 100644
--- a/documentapi/src/tests/messages/messages60test.cpp
+++ b/documentapi/src/tests/messages/messages60test.cpp
@@ -687,7 +687,7 @@ Messages60Test::testQueryResultMessage()
sr3.set_match_features(FeatureValues(mf));
sr3.sort();
- EXPECT_EQUAL(MESSAGE_BASE_LENGTH + 123u, serialize("QueryResultMessage-6", qrm3));
+ EXPECT_EQUAL(MESSAGE_BASE_LENGTH + 125u, serialize("QueryResultMessage-6", qrm3));
routable = deserialize("QueryResultMessage-6", DocumentProtocol::MESSAGE_QUERYRESULT, LANG_CPP);
if (!EXPECT_TRUE(routable)) {
return false;
@@ -709,6 +709,10 @@ Messages60Test::testQueryResultMessage()
EXPECT_EQUAL(2u, mfv.size());
EXPECT_EQUAL(1.0, mfv[0].as_double());
EXPECT_EQUAL("Hi", mfv[1].as_data().make_string());
+ const auto& mf_names = dr->get_match_features().names;
+ EXPECT_EQUAL(2u, mf_names.size());
+ EXPECT_EQUAL("foo", mf_names[0]);
+ EXPECT_EQUAL("bar", mf_names[1]);
return true;
}
diff --git a/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat b/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat
index 229441aa9ba..efe7f8546a9 100644
--- a/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat
+++ b/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat
Binary files differ
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
index e263c292bc8..b7c9b1b71b5 100644
--- a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
+++ b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java
@@ -1,28 +1,40 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vdslib;
+import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.vespa.objects.BufferSerializer;
import com.yahoo.vespa.objects.Deserializer;
import java.io.UnsupportedEncodingException;
import java.nio.ByteOrder;
+import java.util.ArrayList;
import java.util.Map;
+import java.util.Optional;
import java.util.TreeMap;
public class SearchResult {
public static class Hit implements Comparable<Hit> {
private String docId;
- private double rank;
+ private double rank;
+ private MatchFeatureData.HitValue matchFeatures;
public Hit(Hit h) {
docId = h.docId;
rank = h.rank;
+ matchFeatures = h.matchFeatures;
}
public Hit(String docId, double rank) {
this.rank = rank;
this.docId = docId;
+ this.matchFeatures = null;
+ }
+ final public String getDocId() { return docId; }
+ final public double getRank() { return rank; }
+ final public Optional<MatchFeatureData.HitValue> getMatchFeatures() {
+ return Optional.ofNullable(matchFeatures);
}
- final public String getDocId() { return docId; }
- final public double getRank() { return rank; }
final public void setRank(double rank) { this.rank = rank; }
+ final public void setMatchFeatures(MatchFeatureData.HitValue matchFeatures) {
+ this.matchFeatures = matchFeatures;
+ }
public int compareTo(Hit h) {
return (h.rank < rank) ? -1 : (h.rank > rank) ? 1 : 0; // Sort order: descending
}
@@ -49,12 +61,19 @@ public class SearchResult {
private Hit[] hits;
private TreeMap<Integer, byte []> aggregatorList;
private TreeMap<Integer, byte []> groupingList;
+ private static int EXTENSION_FLAGS_PRESENT = -1;
+ private static int MATCH_FEATURES_PRESENT_MASK = 1;
public SearchResult(Deserializer buf) {
BufferSerializer bser = (BufferSerializer) buf; // TODO: dirty cast. must do this differently
bser.order(ByteOrder.BIG_ENDIAN);
this.totalHits = buf.getInt(null);
int numHits = buf.getInt(null);
+ int extensionFlags = 0;
+ if (hasExtensionFlags(numHits)) {
+ extensionFlags = buf.getInt(null);
+ numHits = buf.getInt(null);
+ }
hits = new Hit[numHits];
if (numHits != 0) {
int docIdBufferLength = buf.getInt(null);
@@ -101,7 +120,45 @@ public class SearchResult {
groupingList.put(aggrId, buf.getBytes(null, aggrLength));
}
+ if (hasMatchFeatures(extensionFlags)) {
+ deserializeMatchFeatures(buf, numHits);
+ }
}
+
+ private void deserializeMatchFeatures(Deserializer buf, int numHits) {
+ var featureNames = new ArrayList<String>();
+ int numFeatures = buf.getInt(null);
+ for (int i = 0; i < numFeatures; ++i) {
+ featureNames.add(buf.getString(null));
+ }
+ var factory = new MatchFeatureData(featureNames);
+ for (int i = 0; i < numHits; ++i) {
+ var matchFeatures = factory.addHit();
+ for (int j = 0; j < numFeatures; ++j) {
+ byte featureType = buf.getByte(null);
+ if (isDoubleFeature(featureType)) {
+ matchFeatures.set(j, buf.getDouble(null));
+ } else {
+ int bufLength = buf.getInt(null);
+ matchFeatures.set(j, buf.getBytes(null, bufLength));
+ }
+ }
+ hits[i].setMatchFeatures(matchFeatures);
+ }
+ }
+
+ private static boolean hasExtensionFlags(int numHits) {
+ return numHits == EXTENSION_FLAGS_PRESENT;
+ }
+
+ private static boolean hasMatchFeatures(int extensionFlags) {
+ return (extensionFlags & MATCH_FEATURES_PRESENT_MASK) != 0;
+ }
+
+ private static boolean isDoubleFeature(byte featureType) {
+ return featureType == 0;
+ }
+
/**
* Constructs a new message from a byte buffer.
*
diff --git a/vdslib/src/vespa/vdslib/container/searchresult.cpp b/vdslib/src/vespa/vdslib/container/searchresult.cpp
index c8bc331d1a8..e348c9d9e13 100644
--- a/vdslib/src/vespa/vdslib/container/searchresult.cpp
+++ b/vdslib/src/vespa/vdslib/container/searchresult.cpp
@@ -11,21 +11,21 @@ namespace vdslib {
namespace {
// Magic value for hit count to enable extension flags
-constexpr uint32_t enable_extension_flags_magic = 0xffffffffu;
+constexpr uint32_t extension_flags_present = 0xffffffffu;
// Extension flag values
-constexpr uint32_t match_features_present = 1;
+constexpr uint32_t match_features_present_mask = 1;
// Selector values for feature value
constexpr uint8_t feature_value_is_double = 0;
constexpr uint8_t feature_value_is_data = 1;
inline bool has_match_features(uint32_t extension_flags) {
- return ((extension_flags & match_features_present) != 0);
+ return ((extension_flags & match_features_present_mask) != 0);
}
inline bool must_serialize_extension_flags(uint32_t extension_flags, uint32_t hit_count) {
- return ((extension_flags != 0) || (hit_count == enable_extension_flags_magic));
+ return ((extension_flags != 0) || (hit_count == extension_flags_present));
}
}
@@ -155,7 +155,7 @@ SearchResult::deserialize(document::ByteBuffer & buf)
uint32_t numResults(0), bufSize(0);
buf.getIntNetwork(tmp); numResults = tmp;
uint32_t extension_flags = 0u;
- if (numResults == enable_extension_flags_magic) {
+ if (numResults == extension_flags_present) {
buf.getIntNetwork(tmp);
extension_flags = tmp;
buf.getIntNetwork(tmp);
@@ -189,7 +189,7 @@ void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const
uint32_t hitCount = std::min(_hits.size(), _wantedHits);
uint32_t extension_flags = calc_extension_flags(hitCount);
if (must_serialize_extension_flags(extension_flags, hitCount)) {
- buf.putInt(enable_extension_flags_magic);
+ buf.putInt(extension_flags_present);
buf.putInt(extension_flags);
}
buf.putInt(hitCount);
@@ -241,7 +241,7 @@ SearchResult::calc_extension_flags(uint32_t hit_count) const noexcept
{
uint32_t extension_flags = 0u;
if (!_match_features.names.empty() && hit_count != 0) {
- extension_flags |= match_features_present;
+ extension_flags |= match_features_present_mask;
}
return extension_flags;
}
@@ -251,7 +251,7 @@ SearchResult::get_match_features_serialized_size(uint32_t hit_count) const noexc
{
uint32_t size = sizeof(uint32_t);
for (auto& name : _match_features.names) {
- size += sizeof(uint32_t) + name.size();
+ size += sizeof(uint32_t) + name.size() + 1;
}
for (uint32_t i = 0; i < hit_count; ++i) {
auto mfv = get_match_feature_values(i);
@@ -271,7 +271,7 @@ SearchResult::serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32
{
buf.putInt(_match_features.names.size());
for (auto& name : _match_features.names) {
- buf.putString(name);
+ buf.put_c_string(name);
}
for (uint32_t i = 0; i < hit_count; ++i) {
auto mfv = get_match_feature_values(i);
@@ -301,10 +301,11 @@ SearchResult::deserialize_match_features(document::ByteBuffer& buf)
_match_features.names.resize(num_features);
for (auto& name : _match_features.names) {
buf.getIntNetwork(tmp);
- name.resize(tmp);
- if (tmp != 0) {
- buf.getBytes(&name[0], tmp);
+ if (tmp > 1) {
+ name.resize(tmp - 1);
+ buf.getBytes(&name[0], tmp - 1);
}
+ buf.getByte(selector); // Read and ignore the nul-termination.
}
uint32_t hit_count = _hits.size();
uint32_t num_values = num_features * hit_count;
diff --git a/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp b/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp
index 424ad7ba470..fa0dc9bf99e 100644
--- a/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp
+++ b/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp
@@ -76,6 +76,14 @@ GrowableByteBuffer::putString(vespalib::stringref v)
}
void
+GrowableByteBuffer::put_c_string(vespalib::stringref v)
+{
+ putInt(v.size() + 1);
+ putBytes(v.data(), v.size());
+ putByte(0);
+}
+
+void
GrowableByteBuffer::putByte(uint8_t v)
{
putBytes(reinterpret_cast<const char*>(&v), sizeof(v));
diff --git a/vespalib/src/vespa/vespalib/util/growablebytebuffer.h b/vespalib/src/vespa/vespalib/util/growablebytebuffer.h
index b0fb30606d4..61698868dba 100644
--- a/vespalib/src/vespa/vespalib/util/growablebytebuffer.h
+++ b/vespalib/src/vespa/vespalib/util/growablebytebuffer.h
@@ -68,11 +68,17 @@ public:
void putDouble(double v);
/**
- Adds a string to the buffer.
+ Adds a string to the buffer (without nul-termination).
*/
void putString(vespalib::stringref v);
/**
+ * Adds a string to the buffer (including nul-termination).
+ * This matches com.yahoo.vespa.objects.Deserializer.getString.
+ */
+ void put_c_string(vespalib::stringref v);
+
+ /**
Adds a single byte to the buffer.
*/
void putByte(uint8_t v);