diff options
7 files changed, 136 insertions, 31 deletions
diff --git a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java index 22650fcdbf8..940217aa2b4 100644 --- a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java +++ b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/Messages60TestCase.java @@ -633,6 +633,8 @@ public class Messages60TestCase extends MessagesTestBase { @Override public void run() throws Exception { + test_result_with_match_features(); + Routable routable = deserialize("QueryResultMessage-1", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP); assertTrue(routable instanceof QueryResultMessage); @@ -647,9 +649,11 @@ public class Messages60TestCase extends MessagesTestBase { com.yahoo.vdslib.SearchResult.Hit h = msg.getResult().getHit(0); assertEquals(89.0, h.getRank(), 1E-6); assertEquals("doc1", h.getDocId()); + assertFalse(h.getMatchFeatures().isPresent()); h = msg.getResult().getHit(1); assertEquals(109.0, h.getRank(), 1E-6); assertEquals("doc17", h.getDocId()); + assertFalse(h.getMatchFeatures().isPresent()); routable = deserialize("QueryResultMessage-3", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP); assertTrue(routable instanceof QueryResultMessage); @@ -659,9 +663,11 @@ public class Messages60TestCase extends MessagesTestBase { h = msg.getResult().getHit(0); assertEquals(109.0, h.getRank(), 1E-6); assertEquals("doc17", h.getDocId()); + assertFalse(h.getMatchFeatures().isPresent()); h = msg.getResult().getHit(1); assertEquals(89.0, h.getRank(), 1E-6); assertEquals("doc1", h.getDocId()); + assertFalse(h.getMatchFeatures().isPresent()); routable = deserialize("QueryResultMessage-4", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP); assertTrue(routable instanceof QueryResultMessage); @@ -673,32 +679,55 @@ public class Messages60TestCase extends MessagesTestBase { assertEquals(89.0, h.getRank(), 1E-6); assertEquals("doc1", h.getDocId()); byte[] b = ((SearchResult.HitWithSortBlob)h).getSortBlob(); - assertEquals(9, b.length); - byte[] e = { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '2' }; - for (int i = 0; i < b.length; i++) { - assertEquals(e[i], b[i]); - } + assertEqualsData(new byte[] { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '2' }, b); + h = msg.getResult().getHit(1); assertTrue(h instanceof SearchResult.HitWithSortBlob); assertEquals(109.0, h.getRank(), 1E-6); assertEquals("doc17", h.getDocId()); b = ((SearchResult.HitWithSortBlob)h).getSortBlob(); - assertEquals(9, b.length); - byte[] d = { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '1' }; - for (int i = 0; i < b.length; i++) { - assertEquals(d[i], b[i]); - } + assertEqualsData(new byte[] { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '1' }, b); + h = msg.getResult().getHit(2); assertTrue(h instanceof SearchResult.HitWithSortBlob); assertEquals(90.0, h.getRank(), 1E-6); assertEquals("doc18", h.getDocId()); b = ((SearchResult.HitWithSortBlob)h).getSortBlob(); - assertEquals(9, b.length); - byte[] c = { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '3' }; - for (int i = 0; i < b.length; i++) { - assertEquals(c[i], b[i]); + assertEqualsData(new byte[] { 's', 'o', 'r', 't', 'd', 'a', 't', 'a', '3' }, b); + } + + void assertEqualsData(byte[] exp, byte[] act) { + assertEquals(exp.length, act.length); + for (int i = 0; i < exp.length; ++i) { + assertEquals(exp[i], act[i]); } } + + void test_result_with_match_features() { + Routable routable = deserialize("QueryResultMessage-6", DocumentProtocol.MESSAGE_QUERYRESULT, Language.CPP); + assertTrue(routable instanceof QueryResultMessage); + + var msg = (QueryResultMessage)routable; + assertEquals(2, msg.getResult().getHitCount()); + + var h = msg.getResult().getHit(0); + assertTrue(h instanceof SearchResult.Hit); + assertEquals(7.0, h.getRank(), 1E-6); + assertEquals("doc2", h.getDocId()); + assertTrue(h.getMatchFeatures().isPresent()); + var mf = h.getMatchFeatures().get(); + assertEquals(12.0, mf.field("foo").asDouble(), 1E-6); + assertEqualsData(new byte[] { 'T', 'h', 'e', 'r', 'e' }, mf.field("bar").asData()); + + h = msg.getResult().getHit(1); + assertTrue(h instanceof SearchResult.Hit); + assertEquals(5.0, h.getRank(), 1E-6); + assertEquals("doc1", h.getDocId()); + assertTrue(h.getMatchFeatures().isPresent()); + mf = h.getMatchFeatures().get(); + assertEquals(1.0, mf.field("foo").asDouble(), 1E-6); + assertEqualsData(new byte[] { 'H', 'i' }, mf.field("bar").asData()); + } } public class testGetBucketListReply implements RunnableTest { diff --git a/documentapi/src/tests/messages/messages60test.cpp b/documentapi/src/tests/messages/messages60test.cpp index 89ab20373e7..58295ae2395 100644 --- a/documentapi/src/tests/messages/messages60test.cpp +++ b/documentapi/src/tests/messages/messages60test.cpp @@ -687,7 +687,7 @@ Messages60Test::testQueryResultMessage() sr3.set_match_features(FeatureValues(mf)); sr3.sort(); - EXPECT_EQUAL(MESSAGE_BASE_LENGTH + 123u, serialize("QueryResultMessage-6", qrm3)); + EXPECT_EQUAL(MESSAGE_BASE_LENGTH + 125u, serialize("QueryResultMessage-6", qrm3)); routable = deserialize("QueryResultMessage-6", DocumentProtocol::MESSAGE_QUERYRESULT, LANG_CPP); if (!EXPECT_TRUE(routable)) { return false; @@ -709,6 +709,10 @@ Messages60Test::testQueryResultMessage() EXPECT_EQUAL(2u, mfv.size()); EXPECT_EQUAL(1.0, mfv[0].as_double()); EXPECT_EQUAL("Hi", mfv[1].as_data().make_string()); + const auto& mf_names = dr->get_match_features().names; + EXPECT_EQUAL(2u, mf_names.size()); + EXPECT_EQUAL("foo", mf_names[0]); + EXPECT_EQUAL("bar", mf_names[1]); return true; } diff --git a/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat b/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat Binary files differindex 229441aa9ba..efe7f8546a9 100644 --- a/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat +++ b/documentapi/test/crosslanguagefiles/6.221-cpp-QueryResultMessage-6.dat diff --git a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java index e263c292bc8..b7c9b1b71b5 100644 --- a/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java +++ b/vdslib/src/main/java/com/yahoo/vdslib/SearchResult.java @@ -1,28 +1,40 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vdslib; +import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.vespa.objects.BufferSerializer; import com.yahoo.vespa.objects.Deserializer; import java.io.UnsupportedEncodingException; import java.nio.ByteOrder; +import java.util.ArrayList; import java.util.Map; +import java.util.Optional; import java.util.TreeMap; public class SearchResult { public static class Hit implements Comparable<Hit> { private String docId; - private double rank; + private double rank; + private MatchFeatureData.HitValue matchFeatures; public Hit(Hit h) { docId = h.docId; rank = h.rank; + matchFeatures = h.matchFeatures; } public Hit(String docId, double rank) { this.rank = rank; this.docId = docId; + this.matchFeatures = null; + } + final public String getDocId() { return docId; } + final public double getRank() { return rank; } + final public Optional<MatchFeatureData.HitValue> getMatchFeatures() { + return Optional.ofNullable(matchFeatures); } - final public String getDocId() { return docId; } - final public double getRank() { return rank; } final public void setRank(double rank) { this.rank = rank; } + final public void setMatchFeatures(MatchFeatureData.HitValue matchFeatures) { + this.matchFeatures = matchFeatures; + } public int compareTo(Hit h) { return (h.rank < rank) ? -1 : (h.rank > rank) ? 1 : 0; // Sort order: descending } @@ -49,12 +61,19 @@ public class SearchResult { private Hit[] hits; private TreeMap<Integer, byte []> aggregatorList; private TreeMap<Integer, byte []> groupingList; + private static int EXTENSION_FLAGS_PRESENT = -1; + private static int MATCH_FEATURES_PRESENT_MASK = 1; public SearchResult(Deserializer buf) { BufferSerializer bser = (BufferSerializer) buf; // TODO: dirty cast. must do this differently bser.order(ByteOrder.BIG_ENDIAN); this.totalHits = buf.getInt(null); int numHits = buf.getInt(null); + int extensionFlags = 0; + if (hasExtensionFlags(numHits)) { + extensionFlags = buf.getInt(null); + numHits = buf.getInt(null); + } hits = new Hit[numHits]; if (numHits != 0) { int docIdBufferLength = buf.getInt(null); @@ -101,7 +120,45 @@ public class SearchResult { groupingList.put(aggrId, buf.getBytes(null, aggrLength)); } + if (hasMatchFeatures(extensionFlags)) { + deserializeMatchFeatures(buf, numHits); + } } + + private void deserializeMatchFeatures(Deserializer buf, int numHits) { + var featureNames = new ArrayList<String>(); + int numFeatures = buf.getInt(null); + for (int i = 0; i < numFeatures; ++i) { + featureNames.add(buf.getString(null)); + } + var factory = new MatchFeatureData(featureNames); + for (int i = 0; i < numHits; ++i) { + var matchFeatures = factory.addHit(); + for (int j = 0; j < numFeatures; ++j) { + byte featureType = buf.getByte(null); + if (isDoubleFeature(featureType)) { + matchFeatures.set(j, buf.getDouble(null)); + } else { + int bufLength = buf.getInt(null); + matchFeatures.set(j, buf.getBytes(null, bufLength)); + } + } + hits[i].setMatchFeatures(matchFeatures); + } + } + + private static boolean hasExtensionFlags(int numHits) { + return numHits == EXTENSION_FLAGS_PRESENT; + } + + private static boolean hasMatchFeatures(int extensionFlags) { + return (extensionFlags & MATCH_FEATURES_PRESENT_MASK) != 0; + } + + private static boolean isDoubleFeature(byte featureType) { + return featureType == 0; + } + /** * Constructs a new message from a byte buffer. * diff --git a/vdslib/src/vespa/vdslib/container/searchresult.cpp b/vdslib/src/vespa/vdslib/container/searchresult.cpp index c8bc331d1a8..e348c9d9e13 100644 --- a/vdslib/src/vespa/vdslib/container/searchresult.cpp +++ b/vdslib/src/vespa/vdslib/container/searchresult.cpp @@ -11,21 +11,21 @@ namespace vdslib { namespace { // Magic value for hit count to enable extension flags -constexpr uint32_t enable_extension_flags_magic = 0xffffffffu; +constexpr uint32_t extension_flags_present = 0xffffffffu; // Extension flag values -constexpr uint32_t match_features_present = 1; +constexpr uint32_t match_features_present_mask = 1; // Selector values for feature value constexpr uint8_t feature_value_is_double = 0; constexpr uint8_t feature_value_is_data = 1; inline bool has_match_features(uint32_t extension_flags) { - return ((extension_flags & match_features_present) != 0); + return ((extension_flags & match_features_present_mask) != 0); } inline bool must_serialize_extension_flags(uint32_t extension_flags, uint32_t hit_count) { - return ((extension_flags != 0) || (hit_count == enable_extension_flags_magic)); + return ((extension_flags != 0) || (hit_count == extension_flags_present)); } } @@ -155,7 +155,7 @@ SearchResult::deserialize(document::ByteBuffer & buf) uint32_t numResults(0), bufSize(0); buf.getIntNetwork(tmp); numResults = tmp; uint32_t extension_flags = 0u; - if (numResults == enable_extension_flags_magic) { + if (numResults == extension_flags_present) { buf.getIntNetwork(tmp); extension_flags = tmp; buf.getIntNetwork(tmp); @@ -189,7 +189,7 @@ void SearchResult::serialize(vespalib::GrowableByteBuffer & buf) const uint32_t hitCount = std::min(_hits.size(), _wantedHits); uint32_t extension_flags = calc_extension_flags(hitCount); if (must_serialize_extension_flags(extension_flags, hitCount)) { - buf.putInt(enable_extension_flags_magic); + buf.putInt(extension_flags_present); buf.putInt(extension_flags); } buf.putInt(hitCount); @@ -241,7 +241,7 @@ SearchResult::calc_extension_flags(uint32_t hit_count) const noexcept { uint32_t extension_flags = 0u; if (!_match_features.names.empty() && hit_count != 0) { - extension_flags |= match_features_present; + extension_flags |= match_features_present_mask; } return extension_flags; } @@ -251,7 +251,7 @@ SearchResult::get_match_features_serialized_size(uint32_t hit_count) const noexc { uint32_t size = sizeof(uint32_t); for (auto& name : _match_features.names) { - size += sizeof(uint32_t) + name.size(); + size += sizeof(uint32_t) + name.size() + 1; } for (uint32_t i = 0; i < hit_count; ++i) { auto mfv = get_match_feature_values(i); @@ -271,7 +271,7 @@ SearchResult::serialize_match_features(vespalib::GrowableByteBuffer& buf, uint32 { buf.putInt(_match_features.names.size()); for (auto& name : _match_features.names) { - buf.putString(name); + buf.put_c_string(name); } for (uint32_t i = 0; i < hit_count; ++i) { auto mfv = get_match_feature_values(i); @@ -301,10 +301,11 @@ SearchResult::deserialize_match_features(document::ByteBuffer& buf) _match_features.names.resize(num_features); for (auto& name : _match_features.names) { buf.getIntNetwork(tmp); - name.resize(tmp); - if (tmp != 0) { - buf.getBytes(&name[0], tmp); + if (tmp > 1) { + name.resize(tmp - 1); + buf.getBytes(&name[0], tmp - 1); } + buf.getByte(selector); // Read and ignore the nul-termination. } uint32_t hit_count = _hits.size(); uint32_t num_values = num_features * hit_count; diff --git a/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp b/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp index 424ad7ba470..fa0dc9bf99e 100644 --- a/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp +++ b/vespalib/src/vespa/vespalib/util/growablebytebuffer.cpp @@ -76,6 +76,14 @@ GrowableByteBuffer::putString(vespalib::stringref v) } void +GrowableByteBuffer::put_c_string(vespalib::stringref v) +{ + putInt(v.size() + 1); + putBytes(v.data(), v.size()); + putByte(0); +} + +void GrowableByteBuffer::putByte(uint8_t v) { putBytes(reinterpret_cast<const char*>(&v), sizeof(v)); diff --git a/vespalib/src/vespa/vespalib/util/growablebytebuffer.h b/vespalib/src/vespa/vespalib/util/growablebytebuffer.h index b0fb30606d4..61698868dba 100644 --- a/vespalib/src/vespa/vespalib/util/growablebytebuffer.h +++ b/vespalib/src/vespa/vespalib/util/growablebytebuffer.h @@ -68,11 +68,17 @@ public: void putDouble(double v); /** - Adds a string to the buffer. + Adds a string to the buffer (without nul-termination). */ void putString(vespalib::stringref v); /** + * Adds a string to the buffer (including nul-termination). + * This matches com.yahoo.vespa.objects.Deserializer.getString. + */ + void put_c_string(vespalib::stringref v); + + /** Adds a single byte to the buffer. */ void putByte(uint8_t v); |