diff options
author | Olli Virtanen <olli.virtanen@oath.com> | 2019-03-27 14:55:30 +0100 |
---|---|---|
committer | Olli Virtanen <olli.virtanen@oath.com> | 2019-03-27 14:55:30 +0100 |
commit | 2d68d291bacdcea237cb3f6c3e5f85aa61845b88 (patch) | |
tree | 8f45ae2022e0a812e18513398e242a0a121166dd /container-search/src | |
parent | 5e33bb54604989bb2cef605572f7750d45eb630a (diff) |
Retrieve document summaries over jrt/protobuf
Diffstat (limited to 'container-search/src')
12 files changed, 422 insertions, 172 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java index dccda0bf733..df72720a46c 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/VespaBackEndSearcher.java @@ -52,7 +52,6 @@ import java.util.logging.Logger; * * @author baldersheim */ -@SuppressWarnings("deprecation") public abstract class VespaBackEndSearcher extends PingableSearcher { static final CompoundName PACKET_COMPRESSION_LIMIT = new CompoundName("packetcompressionlimit"); diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index cf5bcedcf51..74d9c38b273 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -48,7 +48,7 @@ public class Dispatcher extends AbstractComponent { private static final CompoundName dispatchInternal = new CompoundName("dispatch.internal"); /** If enabled, search queries will use protobuf rpc */ - private static final CompoundName dispatchProtobuf = new CompoundName("dispatch.protobuf"); + public static final CompoundName dispatchProtobuf = new CompoundName("dispatch.protobuf"); /** A model of the search cluster this dispatches to */ private final SearchCluster searchCluster; diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java index 019e07221a6..4422538cff6 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/Client.java @@ -18,42 +18,46 @@ interface Client { int uncompressedLength, byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds); - void search(NodeConnection node, CompressionType compression, - int uncompressedLength, byte[] compressedPayload, RpcSearchInvoker responseReceiver, - double timeoutSeconds); + void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, + byte[] compressedPayload, ResponseReceiver responseReceiver, double timeoutSeconds); /** Creates a connection to a particular node in this */ NodeConnection createConnection(String hostname, int port); - class GetDocsumsResponseOrError { + interface ResponseReceiver { + void receive(ResponseOrError<ProtobufResponse> response); + } + + class ResponseOrError<T> { + final Optional<T> response; + final Optional<String> error; - // One of these will be non empty and the other not - private Optional<GetDocsumsResponse> response; - private Optional<String> error; + public static <T> ResponseOrError<T> fromResponse(T response) { + return new ResponseOrError<>(response); + } - public static GetDocsumsResponseOrError fromResponse(GetDocsumsResponse response) { - return new GetDocsumsResponseOrError(Optional.of(response), Optional.empty()); + public static <T> ResponseOrError<T> fromError(String error) { + return new ResponseOrError<T>(error); } - public static GetDocsumsResponseOrError fromError(String error) { - return new GetDocsumsResponseOrError(Optional.empty(), Optional.of(error)); + ResponseOrError(T response) { + this.response = Optional.of(response); + this.error = Optional.empty(); } - private GetDocsumsResponseOrError(Optional<GetDocsumsResponse> response, Optional<String> error) { - this.response = response; - this.error = error; + ResponseOrError(String error) { + this.response = Optional.empty(); + this.error = Optional.of(error); } /** Returns the response, or empty if there is an error */ - public Optional<GetDocsumsResponse> response() { return response; } + public Optional<T> response() { return response; } /** Returns the error or empty if there is a response */ public Optional<String> error() { return error; } - } class GetDocsumsResponse { - private final byte compression; private final int uncompressedSize; private final byte[] compressedSlimeBytes; @@ -91,38 +95,12 @@ interface Client { } - class SearchResponseOrError { - // One of these will be non empty and the other not - private Optional<SearchResponse> response; - private Optional<String> error; - - public static SearchResponseOrError fromResponse(SearchResponse response) { - return new SearchResponseOrError(Optional.of(response), Optional.empty()); - } - - public static SearchResponseOrError fromError(String error) { - return new SearchResponseOrError(Optional.empty(), Optional.of(error)); - } - - private SearchResponseOrError(Optional<SearchResponse> response, Optional<String> error) { - this.response = response; - this.error = error; - } - - /** Returns the response, or empty if there is an error */ - public Optional<SearchResponse> response() { return response; } - - /** Returns the error or empty if there is a response */ - public Optional<String> error() { return error; } - - } - - class SearchResponse { + class ProtobufResponse { private final byte compression; private final int uncompressedSize; private final byte[] compressedPayload; - public SearchResponse(byte compression, int uncompressedSize, byte[] compressedPayload) { + public ProtobufResponse(byte compression, int uncompressedSize, byte[] compressedPayload) { this.compression = compression; this.uncompressedSize = uncompressedSize; this.compressedPayload = compressedPayload; diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MapConverter.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MapConverter.java index 817ecfe0091..74828dd6740 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MapConverter.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MapConverter.java @@ -10,46 +10,42 @@ import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.function.Consumer; /** * @author ollivir */ public class MapConverter { - @FunctionalInterface - public interface PropertyInserter<T> { - void add(T prop); - } - - public static void convertMapTensors(Map<String, Object> map, PropertyInserter<TensorProperty.Builder> inserter) { + public static void convertMapTensors(Map<String, Object> map, Consumer<TensorProperty.Builder> inserter) { for (var entry : map.entrySet()) { var value = entry.getValue(); if (value instanceof Tensor) { byte[] tensor = TypedBinaryFormat.encode((Tensor) value); - inserter.add(TensorProperty.newBuilder().setName(entry.getKey()).setValue(ByteString.copyFrom(tensor))); + inserter.accept(TensorProperty.newBuilder().setName(entry.getKey()).setValue(ByteString.copyFrom(tensor))); } } } - public static void convertMapStrings(Map<String, Object> map, PropertyInserter<StringProperty.Builder> inserter) { + public static void convertMapStrings(Map<String, Object> map, Consumer<StringProperty.Builder> inserter) { for (var entry : map.entrySet()) { var value = entry.getValue(); if (!(value instanceof Tensor)) { - inserter.add(StringProperty.newBuilder().setName(entry.getKey()).addValues(value.toString())); + inserter.accept(StringProperty.newBuilder().setName(entry.getKey()).addValues(value.toString())); } } } - public static void convertStringMultiMap(Map<String, List<String>> map, PropertyInserter<StringProperty.Builder> inserter) { + public static void convertStringMultiMap(Map<String, List<String>> map, Consumer<StringProperty.Builder> inserter) { for (var entry : map.entrySet()) { var values = entry.getValue(); if (values != null) { - inserter.add(StringProperty.newBuilder().setName(entry.getKey()).addAllValues(values)); + inserter.accept(StringProperty.newBuilder().setName(entry.getKey()).addAllValues(values)); } } } - public static void convertMultiMap(Map<String, List<Object>> map, PropertyInserter<StringProperty.Builder> stringInserter, - PropertyInserter<TensorProperty.Builder> tensorInserter) { + public static void convertMultiMap(Map<String, List<Object>> map, Consumer<StringProperty.Builder> stringInserter, + Consumer<TensorProperty.Builder> tensorInserter) { for (var entry : map.entrySet()) { if (entry.getValue() != null) { var key = entry.getKey(); @@ -58,14 +54,14 @@ public class MapConverter { if (value != null) { if (value instanceof Tensor) { byte[] tensor = TypedBinaryFormat.encode((Tensor) value); - tensorInserter.add(TensorProperty.newBuilder().setName(key).setValue(ByteString.copyFrom(tensor))); + tensorInserter.accept(TensorProperty.newBuilder().setName(key).setValue(ByteString.copyFrom(tensor))); } else { stringValues.add(value.toString()); } } } if (!stringValues.isEmpty()) { - stringInserter.add(StringProperty.newBuilder().setName(key).addAllValues(stringValues)); + stringInserter.accept(StringProperty.newBuilder().setName(key).addAllValues(stringValues)); } } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java index 1c1c9ccd115..9903aacdda0 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java @@ -1,7 +1,8 @@ package com.yahoo.search.dispatch.rpc; import ai.vespa.searchlib.searchprotocol.protobuf.SearchProtocol; -import ai.vespa.searchlib.searchprotocol.protobuf.SearchProtocol.SearchRequest.Builder; +import ai.vespa.searchlib.searchprotocol.protobuf.SearchProtocol.StringProperty; +import ai.vespa.searchlib.searchprotocol.protobuf.SearchProtocol.TensorProperty; import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.yahoo.document.GlobalId; @@ -15,11 +16,10 @@ import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.grouping.vespa.GroupingExecutor; import com.yahoo.search.query.Model; +import com.yahoo.search.query.QueryTree; import com.yahoo.search.query.Ranking; import com.yahoo.search.query.Sorting; import com.yahoo.search.query.Sorting.Order; -import com.yahoo.search.query.ranking.RankFeatures; -import com.yahoo.search.query.ranking.RankProperties; import com.yahoo.search.result.Coverage; import com.yahoo.search.result.Relevance; import com.yahoo.searchlib.aggregation.Grouping; @@ -28,31 +28,24 @@ import com.yahoo.vespa.objects.BufferSerializer; import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; public class ProtobufSerialization { private static final int INITIAL_SERIALIZATION_BUFFER_SIZE = 10 * 1024; - public static byte[] serializeQuery(Query query, String serverId, boolean includeQueryData) { - return convertFromQuery(query, serverId, includeQueryData).toByteArray(); + public static byte[] serializeSearchRequest(Query query, String serverId) { + return convertFromQuery(query, serverId).toByteArray(); } - public static byte[] serializeResult(Result searchResult) { - return convertFromResult(searchResult).toByteArray(); - } - - public static Result deserializeToResult(byte[] payload, Query query, VespaBackEndSearcher searcher) - throws InvalidProtocolBufferException { - var protobuf = SearchProtocol.SearchReply.parseFrom(payload); - var result = convertToResult(query, protobuf, searcher.getDocumentDatabase(query)); - return result; - } - - private static SearchProtocol.SearchRequest convertFromQuery(Query query, String serverId, boolean includeQueryData) { + private static SearchProtocol.SearchRequest convertFromQuery(Query query, String serverId) { var builder = SearchProtocol.SearchRequest.newBuilder().setHits(query.getHits()).setOffset(query.getOffset()) .setTimeout((int) query.getTimeLeft()); - mergeToRequestFromRanking(query.getRanking(), builder, includeQueryData); - mergeToRequestFromModel(query.getModel(), builder); + var documentDb = query.getModel().getDocumentDb(); + if (documentDb != null) { + builder.setDocumentType(documentDb); + } + builder.setQueryTreeBlob(serializeQueryTree(query.getModel().getQueryTree())); if (query.getGroupingSessionCache() || query.getRanking().getQueryCache()) { // TODO verify that the session key is included whenever rank properties would have been @@ -71,71 +64,101 @@ public class ProtobufSerialization { gbuf.getBuf().flip(); builder.setGroupingBlob(ByteString.copyFrom(gbuf.getBuf().getByteBuffer())); } - if (query.getGroupingSessionCache()) { builder.setCacheGrouping(true); } + mergeToSearchRequestFromRanking(query.getRanking(), builder); + return builder.build(); } - private static void mergeToRequestFromModel(Model model, SearchProtocol.SearchRequest.Builder builder) { - if (model.getDocumentDb() != null) { - builder.setDocumentType(model.getDocumentDb()); + private static void mergeToSearchRequestFromRanking(Ranking ranking, SearchProtocol.SearchRequest.Builder builder) { + builder.setRankProfile(ranking.getProfile()); + + if (ranking.getQueryCache()) { + builder.setCacheQuery(true); } - int bufferSize = INITIAL_SERIALIZATION_BUFFER_SIZE; - boolean success = false; - while (!success) { - try { - ByteBuffer treeBuffer = ByteBuffer.allocate(bufferSize); - model.getQueryTree().encode(treeBuffer); - treeBuffer.flip(); - builder.setQueryTreeBlob(ByteString.copyFrom(treeBuffer)); - success = true; - } catch (java.nio.BufferOverflowException e) { - bufferSize *= 2; - } + if (ranking.getSorting() != null) { + mergeToSearchRequestFromSorting(ranking.getSorting(), builder); } + if (ranking.getLocation() != null) { + builder.setGeoLocation(ranking.getLocation().toString()); + } + + var featureMap = ranking.getFeatures().asMap(); + MapConverter.convertMapStrings(featureMap, builder::addFeatureOverrides); + MapConverter.convertMapTensors(featureMap, builder::addTensorFeatureOverrides); + mergeRankProperties(ranking, builder::addRankProperties, builder::addTensorRankProperties); } - private static void mergeToRequestFromSorting(Sorting sorting, SearchProtocol.SearchRequest.Builder builder, boolean includeQueryData) { + private static void mergeToSearchRequestFromSorting(Sorting sorting, SearchProtocol.SearchRequest.Builder builder) { for (var field : sorting.fieldOrders()) { - var sortField = SearchProtocol.SortField.newBuilder().setField(field.getSorter().getName()) + var sortField = SearchProtocol.SortField.newBuilder() + .setField(field.getSorter().getName()) .setAscending(field.getSortOrder() == Order.ASCENDING).build(); builder.addSorting(sortField); } } - private static void mergeToRequestFromRanking(Ranking ranking, SearchProtocol.SearchRequest.Builder builder, boolean includeQueryData) { - builder.setRankProfile(ranking.getProfile()); - if (ranking.getQueryCache()) { - builder.setCacheQuery(true); + public static SearchProtocol.DocsumRequest.Builder createDocsumRequestBuilder(Query query, String serverId, String summaryClass, + boolean includeQueryData) { + var builder = SearchProtocol.DocsumRequest.newBuilder() + .setTimeout((int) query.getTimeLeft()) + .setDumpFeatures(query.properties().getBoolean(Ranking.RANKFEATURES, false)); + + if (summaryClass != null) { + builder.setSummaryClass(summaryClass); } - if (ranking.getSorting() != null) { - mergeToRequestFromSorting(ranking.getSorting(), builder, includeQueryData); + + var documentDb = query.getModel().getDocumentDb(); + if (documentDb != null) { + builder.setDocumentType(documentDb); } - if (ranking.getLocation() != null) { - builder.setGeoLocation(ranking.getLocation().toString()); + + var ranking = query.getRanking(); + if (ranking.getQueryCache()) { + builder.setCacheQuery(true); + builder.setSessionKey(query.getSessionId(serverId).toString()); } - mergeToRequestFromRankFeatures(ranking.getFeatures(), builder, includeQueryData); - mergeToRequestFromRankProperties(ranking.getProperties(), builder, includeQueryData); - } + builder.setRankProfile(query.getRanking().getProfile()); - private static void mergeToRequestFromRankFeatures(RankFeatures features, SearchProtocol.SearchRequest.Builder builder, boolean includeQueryData) { if (includeQueryData) { - MapConverter.convertMapStrings(features.asMap(), builder::addFeatureOverrides); - MapConverter.convertMapTensors(features.asMap(), builder::addTensorFeatureOverrides); + mergeQueryDataToDocsumRequest(query, builder); } + + return builder; } - private static void mergeToRequestFromRankProperties(RankProperties properties, Builder builder, boolean includeQueryData) { - if (includeQueryData) { - MapConverter.convertMultiMap(properties.asMap(), propB -> { - if (!GetDocSumsPacket.sessionIdKey.equals(propB.getName())) { - builder.addRankProperties(propB); - } - }, builder::addTensorRankProperties); + public static byte[] serializeDocsumRequest(SearchProtocol.DocsumRequest.Builder builder, List<FastHit> documents) { + builder.clearGlobalIds(); + for (var hit : documents) { + builder.addGlobalIds(ByteString.copyFrom(hit.getGlobalId().getRawId())); } + return builder.build().toByteArray(); + } + + private static void mergeQueryDataToDocsumRequest(Query query, SearchProtocol.DocsumRequest.Builder builder) { + var ranking = query.getRanking(); + var featureMap = ranking.getFeatures().asMap(); + + builder.setQueryTreeBlob(serializeQueryTree(query.getModel().getQueryTree())); + builder.setGeoLocation(ranking.getLocation().toString()); + MapConverter.convertMapStrings(featureMap, builder::addFeatureOverrides); + MapConverter.convertMapTensors(featureMap, builder::addTensorFeatureOverrides); + MapConverter.convertStringMultiMap(query.getPresentation().getHighlight().getHighlightTerms(), builder::addHighlightTerms); + mergeRankProperties(ranking, builder::addRankProperties, builder::addTensorRankProperties); + } + + public static byte[] serializeResult(Result searchResult) { + return convertFromResult(searchResult).toByteArray(); + } + + public static Result deserializeToSearchResult(byte[] payload, Query query, VespaBackEndSearcher searcher) + throws InvalidProtocolBufferException { + var protobuf = SearchProtocol.SearchReply.parseFrom(payload); + var result = convertToResult(query, protobuf, searcher.getDocumentDatabase(query)); + return result; } private static Result convertToResult(Query query, SearchProtocol.SearchReply protobuf, DocumentDatabase documentDatabase) { @@ -211,4 +234,26 @@ public class ProtobufSerialization { return builder.build(); } + private static ByteString serializeQueryTree(QueryTree queryTree) { + int bufferSize = INITIAL_SERIALIZATION_BUFFER_SIZE; + while (true) { + try { + ByteBuffer treeBuffer = ByteBuffer.allocate(bufferSize); + queryTree.encode(treeBuffer); + treeBuffer.flip(); + return ByteString.copyFrom(treeBuffer); + } catch (java.nio.BufferOverflowException e) { + bufferSize *= 2; + } + } + } + + private static void mergeRankProperties(Ranking ranking, Consumer<StringProperty.Builder> stringProperties, + Consumer<TensorProperty.Builder> tensorProperties) { + MapConverter.convertMultiMap(ranking.getProperties().asMap(), propB -> { + if (!GetDocSumsPacket.sessionIdKey.equals(propB.getName())) { + stringProperties.accept(propB); + } + }, tensorProperties); + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java index 32a7917d43c..2aa01b05955 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcClient.java @@ -22,7 +22,6 @@ import java.util.List; * @author bratseth */ class RpcClient implements Client { - private final Supervisor supervisor = new Supervisor(new Transport()); @Override @@ -44,15 +43,15 @@ class RpcClient implements Client { } @Override - public void search(NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, - RpcSearchInvoker responseReceiver, double timeoutSeconds) { - Request request = new Request("vespa.searchprotocol.search"); + public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, + ResponseReceiver responseReceiver, double timeoutSeconds) { + Request request = new Request(rpcMethod); request.parameters().add(new Int8Value(compression.getCode())); request.parameters().add(new Int32Value(uncompressedLength)); request.parameters().add(new DataValue(compressedPayload)); RpcNodeConnection rpcNode = ((RpcNodeConnection) node); - rpcNode.invokeAsync(request, timeoutSeconds, new RpcSearchResponseWaiter(rpcNode, responseReceiver)); + rpcNode.invokeAsync(request, timeoutSeconds, new RpcProtobufResponseWaiter(rpcNode, responseReceiver)); } private static class RpcNodeConnection implements NodeConnection { @@ -111,40 +110,37 @@ class RpcClient implements Client { @Override public void handleRequestDone(Request requestWithResponse) { if (requestWithResponse.isError()) { - handler.receive(GetDocsumsResponseOrError.fromError("Error response from " + node + ": " + - requestWithResponse.errorMessage())); + handler.receive(ResponseOrError.fromError("Error response from " + node + ": " + requestWithResponse.errorMessage())); return; } Values returnValues = requestWithResponse.returnValues(); if (returnValues.size() < 3) { - handler.receive(GetDocsumsResponseOrError.fromError("Invalid getDocsums response from " + node + - ": Expected 3 return arguments, got " + - returnValues.size())); + handler.receive(ResponseOrError.fromError( + "Invalid getDocsums response from " + node + ": Expected 3 return arguments, got " + returnValues.size())); return; } byte compression = returnValues.get(0).asInt8(); int uncompressedSize = returnValues.get(1).asInt32(); byte[] compressedSlimeBytes = returnValues.get(2).asData(); + @SuppressWarnings("unchecked") // TODO: Non-protobuf rpc docsums to be removed soon List<FastHit> hits = (List<FastHit>) requestWithResponse.getContext(); - handler.receive(GetDocsumsResponseOrError.fromResponse(new GetDocsumsResponse(compression, - uncompressedSize, - compressedSlimeBytes, - hits))); + handler.receive( + ResponseOrError.fromResponse(new GetDocsumsResponse(compression, uncompressedSize, compressedSlimeBytes, hits))); } } - private static class RpcSearchResponseWaiter implements RequestWaiter { + private static class RpcProtobufResponseWaiter implements RequestWaiter { /** The node to which we made the request we are waiting for - for error messages only */ private final RpcNodeConnection node; /** The handler to which the response is forwarded */ - private final RpcSearchInvoker handler; + private final ResponseReceiver handler; - public RpcSearchResponseWaiter(RpcNodeConnection node, RpcSearchInvoker handler) { + public RpcProtobufResponseWaiter(RpcNodeConnection node, ResponseReceiver handler) { this.node = node; this.handler = handler; } @@ -152,13 +148,13 @@ class RpcClient implements Client { @Override public void handleRequestDone(Request requestWithResponse) { if (requestWithResponse.isError()) { - handler.receive(SearchResponseOrError.fromError("Error response from " + node + ": " + requestWithResponse.errorMessage())); + handler.receive(ResponseOrError.fromError("Error response from " + node + ": " + requestWithResponse.errorMessage())); return; } Values returnValues = requestWithResponse.returnValues(); if (returnValues.size() < 3) { - handler.receive(SearchResponseOrError.fromError( + handler.receive(ResponseOrError.fromError( "Invalid getDocsums response from " + node + ": Expected 3 return arguments, got " + returnValues.size())); return; } @@ -166,7 +162,7 @@ class RpcClient implements Client { byte compression = returnValues.get(0).asInt8(); int uncompressedSize = returnValues.get(1).asInt32(); byte[] compressedPayload = returnValues.get(2).asData(); - handler.receive(SearchResponseOrError.fromResponse(new SearchResponse(compression, uncompressedSize, compressedPayload))); + handler.receive(ResponseOrError.fromResponse(new ProtobufResponse(compression, uncompressedSize, compressedPayload))); } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java index b7286997514..760f7486923 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcFillInvoker.java @@ -13,6 +13,7 @@ import com.yahoo.prelude.fastsearch.TimeoutException; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.dispatch.FillInvoker; +import com.yahoo.search.dispatch.rpc.Client.GetDocsumsResponse; import com.yahoo.search.query.SessionId; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.Hit; @@ -102,7 +103,7 @@ public class RpcFillInvoker extends FillInvoker { Client.NodeConnection node = resourcePool.nodeConnections().get(nodeId); if (node == null) { String error = "Could not fill hits from unknown node " + nodeId; - responseReceiver.receive(Client.GetDocsumsResponseOrError.fromError(error)); + responseReceiver.receive(Client.ResponseOrError.fromError(error)); result.hits().addError(ErrorMessage.createEmptyDocsums(error)); log.warning("Got hits with partid " + nodeId + ", which is not included in the current dispatch config"); return; @@ -143,7 +144,7 @@ public class RpcFillInvoker extends FillInvoker { /** Receiver of the responses to a set of getDocsums requests */ public static class GetDocsumsResponseReceiver { - private final BlockingQueue<Client.GetDocsumsResponseOrError> responses; + private final BlockingQueue<Client.ResponseOrError<GetDocsumsResponse>> responses; private final Compressor compressor; private final Result result; @@ -161,7 +162,7 @@ public class RpcFillInvoker extends FillInvoker { } /** Called by a thread belonging to the client when a valid response becomes available */ - public void receive(Client.GetDocsumsResponseOrError response) { + public void receive(Client.ResponseOrError<GetDocsumsResponse> response) { responses.add(response); } @@ -181,7 +182,7 @@ public class RpcFillInvoker extends FillInvoker { if (timeLeftMs <= 0) { throwTimeout(); } - Client.GetDocsumsResponseOrError response = responses.poll(timeLeftMs, TimeUnit.MILLISECONDS); + Client.ResponseOrError<GetDocsumsResponse> response = responses.poll(timeLeftMs, TimeUnit.MILLISECONDS); if (response == null) throwTimeout(); skippedHits += processResponse(response, summaryClass, documentDb); @@ -197,7 +198,7 @@ public class RpcFillInvoker extends FillInvoker { } } - private int processResponse(Client.GetDocsumsResponseOrError responseOrError, + private int processResponse(Client.ResponseOrError<GetDocsumsResponse> responseOrError, String summaryClass, DocumentDatabase documentDb) { if (responseOrError.error().isPresent()) { diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcInvokerFactory.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcInvokerFactory.java index f17e7d63431..c1b164aaeaa 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcInvokerFactory.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcInvokerFactory.java @@ -6,6 +6,7 @@ import com.yahoo.prelude.fastsearch.VespaBackEndSearcher; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.Dispatcher; import com.yahoo.search.dispatch.FillInvoker; import com.yahoo.search.dispatch.InvokerFactory; import com.yahoo.search.dispatch.SearchInvoker; @@ -36,8 +37,15 @@ public class RpcInvokerFactory extends InvokerFactory { @Override public Optional<FillInvoker> createFillInvoker(VespaBackEndSearcher searcher, Result result) { Query query = result.getQuery(); + + boolean summaryNeedsQuery = searcher.summaryNeedsQuery(query); + + if(query.properties().getBoolean(Dispatcher.dispatchProtobuf, false)) { + return Optional.of(new RpcProtobufFillInvoker(rpcResourcePool, searcher.getDocumentDatabase(query), searcher.getServerId(), + summaryNeedsQuery)); + } if (query.properties().getBoolean(dispatchSummaries, true) - && ! searcher.summaryNeedsQuery(query) + && ! summaryNeedsQuery && query.getRanking().getLocation() == null) { return Optional.of(new RpcFillInvoker(rpcResourcePool, searcher.getDocumentDatabase(query))); diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java new file mode 100644 index 00000000000..317586d963d --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcProtobufFillInvoker.java @@ -0,0 +1,227 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch.rpc; + +import ai.vespa.searchlib.searchprotocol.protobuf.SearchProtocol; +import com.google.protobuf.InvalidProtocolBufferException; +import com.yahoo.collections.ListMap; +import com.yahoo.collections.Pair; +import com.yahoo.compress.CompressionType; +import com.yahoo.compress.Compressor; +import com.yahoo.container.protect.Error; +import com.yahoo.data.access.Inspector; +import com.yahoo.data.access.slime.SlimeAdapter; +import com.yahoo.prelude.fastsearch.DocumentDatabase; +import com.yahoo.prelude.fastsearch.FastHit; +import com.yahoo.prelude.fastsearch.TimeoutException; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.dispatch.FillInvoker; +import com.yahoo.search.dispatch.rpc.Client.ProtobufResponse; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.Hit; +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.BinaryFormat; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * {@link FillInvoker} implementation using Protobuf over JRT + * + * @author bratseth + * @author ollivir + */ +public class RpcProtobufFillInvoker extends FillInvoker { + private final String RPC_METHOD = "vespa.searchprotocol.getDocsums"; + + private static final Logger log = Logger.getLogger(RpcProtobufFillInvoker.class.getName()); + + private final DocumentDatabase documentDb; + private final RpcResourcePool resourcePool; + private final boolean summaryNeedsQuery; + private final String serverId; + + private BlockingQueue<Pair<Client.ResponseOrError<ProtobufResponse>, List<FastHit>>> responses; + + /** Whether we have already logged/notified about an error - to avoid spamming */ + private boolean hasReportedError = false; + + /** The number of responses we should receive (and process) before this is complete */ + private int outstandingResponses; + + RpcProtobufFillInvoker(RpcResourcePool resourcePool, DocumentDatabase documentDb, String serverId, boolean summaryNeedsQuery) { + this.documentDb = documentDb; + this.resourcePool = resourcePool; + this.serverId = serverId; + this.summaryNeedsQuery = summaryNeedsQuery; + } + + @Override + protected void sendFillRequest(Result result, String summaryClass) { + ListMap<Integer, FastHit> hitsByNode = hitsByNode(result); + + CompressionType compression = CompressionType + .valueOf(result.getQuery().properties().getString(RpcResourcePool.dispatchCompression, "LZ4").toUpperCase()); + + result.getQuery().trace(false, 5, "Sending ", hitsByNode.size(), " summary fetch requests with jrt/protobuf"); + + outstandingResponses = hitsByNode.size(); + responses = new LinkedBlockingQueue<>(outstandingResponses); + + var builder = ProtobufSerialization.createDocsumRequestBuilder(result.getQuery(), serverId, summaryClass, summaryNeedsQuery); + for (Map.Entry<Integer, List<FastHit>> nodeHits : hitsByNode.entrySet()) { + var payload = ProtobufSerialization.serializeDocsumRequest(builder, nodeHits.getValue()); + sendDocsumsRequest(nodeHits.getKey(), nodeHits.getValue(), payload, compression, result); + } + } + + @Override + protected void getFillResults(Result result, String summaryClass) { + try { + processResponses(result, summaryClass); + result.hits().setSorted(false); + result.analyzeHits(); + } catch (TimeoutException e) { + result.hits().addError(ErrorMessage.createTimeout("Summary data is incomplete: " + e.getMessage())); + } + } + + @Override + protected void release() { + // nothing to release + } + + /** Called by a thread belonging to the client when a valid response becomes available */ + public void receive(Client.ResponseOrError<ProtobufResponse> response, List<FastHit> hitsContext) { + responses.add(new Pair<>(response, hitsContext)); + } + + /** Return a map of hits by their search node (partition) id */ + private static ListMap<Integer, FastHit> hitsByNode(Result result) { + ListMap<Integer, FastHit> hitsByNode = new ListMap<>(); + for (Iterator<Hit> i = result.hits().unorderedDeepIterator(); i.hasNext();) { + Hit h = i.next(); + if (!(h instanceof FastHit)) + continue; + FastHit hit = (FastHit) h; + + hitsByNode.put(hit.getDistributionKey(), hit); + } + return hitsByNode; + } + + /** Send a docsums request to a node. Responses will be added to the given receiver. */ + private void sendDocsumsRequest(int nodeId, List<FastHit> hits, byte[] payload, CompressionType compression, Result result) { + Client.NodeConnection node = resourcePool.nodeConnections().get(nodeId); + if (node == null) { + String error = "Could not fill hits from unknown node " + nodeId; + receive(Client.ResponseOrError.fromError(error), hits); + result.hits().addError(ErrorMessage.createEmptyDocsums(error)); + log.warning("Got hits with partid " + nodeId + ", which is not included in the current dispatch config"); + return; + } + + Query query = result.getQuery(); + double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0; + Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, payload); + resourcePool.client().request(RPC_METHOD, node, compressionResult.type(), payload.length, compressionResult.data(), + roe -> receive(roe, hits), timeoutSeconds); + } + + private void processResponses(Result result, String summaryClass) throws TimeoutException { + try { + int skippedHits = 0; + while (outstandingResponses > 0) { + long timeLeftMs = result.getQuery().getTimeLeft(); + if (timeLeftMs <= 0) { + throwTimeout(); + } + var responseAndHits = responses.poll(timeLeftMs, TimeUnit.MILLISECONDS); + if (responseAndHits == null) { + throwTimeout(); + } + var response = responseAndHits.getFirst(); + var hitsContext = responseAndHits.getSecond(); + skippedHits += processResponse(result, response, hitsContext, summaryClass); + outstandingResponses--; + } + if (skippedHits != 0) { + result.hits().addError(ErrorMessage + .createEmptyDocsums("Missing hit summary data for summary " + summaryClass + " for " + skippedHits + " hits")); + } + } catch (InterruptedException e) { + // TODO: Add error + } + } + + private int processResponse(Result result, Client.ResponseOrError<ProtobufResponse> responseOrError, List<FastHit> hitsContext, + String summaryClass) { + if (responseOrError.error().isPresent()) { + if (hasReportedError) { + return 0; + } + String error = responseOrError.error().get(); + result.hits().addError(ErrorMessage.createBackendCommunicationError(error)); + log.log(Level.WARNING, "Error fetching summary data: " + error); + hasReportedError = true; + } else { + Client.ProtobufResponse response = responseOrError.response().get(); + CompressionType compression = CompressionType.valueOf(response.compression()); + byte[] responseBytes = resourcePool.compressor().decompress(response.compressedPayload(), compression, + response.uncompressedSize()); + return fill(result, hitsContext, summaryClass, responseBytes); + } + return 0; + } + + private void addErrors(Result result, com.yahoo.slime.Inspector errors) { + errors.traverse((ArrayTraverser) (index, value) -> { + int errorCode = ("timeout".equalsIgnoreCase(value.field("type").asString())) ? Error.TIMEOUT.code : Error.UNSPECIFIED.code; + result.hits().addError(new ErrorMessage(errorCode, value.field("message").asString(), value.field("details").asString())); + }); + } + + private int fill(Result result, List<FastHit> hits, String summaryClass, byte[] payload) { + try { + var protobuf = SearchProtocol.DocsumReply.parseFrom(payload); + var root = BinaryFormat.decode(protobuf.getSlimeSummaries().toByteArray()).get(); + var errors = root.field("errors"); + boolean hasErrors = errors.valid() && (errors.entries() > 0); + if (hasErrors) { + addErrors(result, errors); + } + + Inspector summaries = new SlimeAdapter(root.field("docsums")); + if (!summaries.valid()) { + return 0; // No summaries; Perhaps we requested a non-existing summary class + } + int skippedHits = 0; + for (int i = 0; i < hits.size(); i++) { + Inspector summary = summaries.entry(i).field("docsum"); + if (summary.fieldCount() != 0) { + hits.get(i).setField(Hit.SDDOCNAME_FIELD, documentDb.getName()); + hits.get(i).addSummary(documentDb.getDocsumDefinitionSet().getDocsum(summaryClass), summary); + hits.get(i).setFilled(summaryClass); + } else { + skippedHits++; + } + } + return skippedHits; + } catch (InvalidProtocolBufferException ex) { + log.log(Level.WARNING, "Invalid response to docsum request", ex); + result.hits().addError(ErrorMessage.createInternalServerError("Invalid response to docsum request from backend")); + return 0; + } + } + + private void throwTimeout() throws TimeoutException { + throw new TimeoutException("Timed out waiting for summary data. " + outstandingResponses + " responses outstanding."); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java index 88d77c760e3..f8a5ee2f8c1 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/RpcSearchInvoker.java @@ -9,7 +9,7 @@ import com.yahoo.prelude.fastsearch.VespaBackEndSearcher; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.dispatch.SearchInvoker; -import com.yahoo.search.dispatch.rpc.Client.SearchResponse; +import com.yahoo.search.dispatch.rpc.Client.ProtobufResponse; import com.yahoo.search.dispatch.searchcluster.Node; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.searchchain.Execution; @@ -25,11 +25,13 @@ import java.util.concurrent.TimeUnit; * * @author ollivir */ -public class RpcSearchInvoker extends SearchInvoker { +public class RpcSearchInvoker extends SearchInvoker implements Client.ResponseReceiver { + private final String RPC_METHOD = "vespa.searchprotocol.search"; + private final VespaBackEndSearcher searcher; private final Node node; private final RpcResourcePool resourcePool; - private final BlockingQueue<Client.SearchResponseOrError> responses; + private final BlockingQueue<Client.ResponseOrError<ProtobufResponse>> responses; private Query query; @@ -50,15 +52,16 @@ public class RpcSearchInvoker extends SearchInvoker { Client.NodeConnection nodeConnection = resourcePool.nodeConnections().get(node.key()); if (nodeConnection == null) { - responses.add(Client.SearchResponseOrError.fromError("Could send search to unknown node " + node.key())); + responses.add(Client.ResponseOrError.fromError("Could not send search to unknown node " + node.key())); responseAvailable(); return; } + query.trace(false, 5, "Sending search request with jrt/protobuf to node with dist key ", node.key()); - var payload = ProtobufSerialization.serializeQuery(query, searcher.getServerId(), true); + var payload = ProtobufSerialization.serializeSearchRequest(query, searcher.getServerId()); double timeoutSeconds = ((double) query.getTimeLeft() - 3.0) / 1000.0; Compressor.Compression compressionResult = resourcePool.compressor().compress(compression, payload); - resourcePool.client().search(nodeConnection, compressionResult.type(), payload.length, compressionResult.data(), this, + resourcePool.client().request(RPC_METHOD, nodeConnection, compressionResult.type(), payload.length, compressionResult.data(), this, timeoutSeconds); } @@ -68,7 +71,7 @@ public class RpcSearchInvoker extends SearchInvoker { if (timeLeftMs <= 0) { return errorResult(query, ErrorMessage.createTimeout("Timeout while waiting for " + getName())); } - Client.SearchResponseOrError response = null; + Client.ResponseOrError<ProtobufResponse> response = null; try { response = responses.poll(timeLeftMs, TimeUnit.MILLISECONDS); } catch (InterruptedException e) { @@ -84,11 +87,11 @@ public class RpcSearchInvoker extends SearchInvoker { return errorResult(query, ErrorMessage.createInternalServerError("Neither error nor result available")); } - SearchResponse searchResponse = response.response().get(); - CompressionType compression = CompressionType.valueOf(searchResponse.compression()); - byte[] payload = resourcePool.compressor().decompress(searchResponse.compressedPayload(), compression, - searchResponse.uncompressedSize()); - var result = ProtobufSerialization.deserializeToResult(payload, query, searcher); + ProtobufResponse protobufResponse = response.response().get(); + CompressionType compression = CompressionType.valueOf(protobufResponse.compression()); + byte[] payload = resourcePool.compressor().decompress(protobufResponse.compressedPayload(), compression, + protobufResponse.uncompressedSize()); + var result = ProtobufSerialization.deserializeToSearchResult(payload, query, searcher); result.hits().unorderedIterator().forEachRemaining(hit -> { if(hit instanceof FastHit) { FastHit fhit = (FastHit) hit; @@ -106,7 +109,7 @@ public class RpcSearchInvoker extends SearchInvoker { // nothing to release } - public void receive(Client.SearchResponseOrError response) { + public void receive(Client.ResponseOrError<ProtobufResponse> response) { responses.add(response); responseAvailable(); } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java index f9b628e594a..687d3e728c0 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MockClient.java @@ -7,9 +7,6 @@ import com.yahoo.document.GlobalId; import com.yahoo.document.idstring.IdIdString; import com.yahoo.prelude.fastsearch.FastHit; import com.yahoo.search.Result; -import com.yahoo.search.dispatch.rpc.Client; -import com.yahoo.search.dispatch.rpc.RpcFillInvoker; -import com.yahoo.search.dispatch.rpc.RpcSearchInvoker; import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.BinaryFormat; import com.yahoo.slime.Cursor; @@ -44,7 +41,7 @@ public class MockClient implements Client { int uncompressedSize, byte[] compressedSlime, RpcFillInvoker.GetDocsumsResponseReceiver responseReceiver, double timeoutSeconds) { if (malfunctioning) { - responseReceiver.receive(GetDocsumsResponseOrError.fromError("Malfunctioning")); + responseReceiver.receive(ResponseOrError.fromError("Malfunctioning")); return; } @@ -74,25 +71,25 @@ public class MockClient implements Client { Compressor.Compression compressionResult = compressor.compress(compression, slimeBytes); GetDocsumsResponse response = new GetDocsumsResponse(compressionResult.type().getCode(), slimeBytes.length, compressionResult.data(), hitsContext); - responseReceiver.receive(GetDocsumsResponseOrError.fromResponse(response)); + responseReceiver.receive(ResponseOrError.fromResponse(response)); } @Override - public void search(NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, - RpcSearchInvoker responseReceiver, double timeoutSeconds) { + public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, + ResponseReceiver responseReceiver, double timeoutSeconds) { if (malfunctioning) { - responseReceiver.receive(SearchResponseOrError.fromError("Malfunctioning")); + responseReceiver.receive(ResponseOrError.fromError("Malfunctioning")); return; } if(searchResult == null) { - responseReceiver.receive(SearchResponseOrError.fromError("No result defined")); + responseReceiver.receive(ResponseOrError.fromError("No result defined")); return; } var payload = ProtobufSerialization.serializeResult(searchResult); var compressionResult = compressor.compress(compression, payload); - var response = new SearchResponse(compressionResult.type().getCode(), payload.length, compressionResult.data()); - responseReceiver.receive(SearchResponseOrError.fromResponse(response)); + var response = new ProtobufResponse(compressionResult.type().getCode(), payload.length, compressionResult.data()); + responseReceiver.receive(ResponseOrError.fromResponse(response)); } public void setDocsumReponse(String nodeId, int docId, String docsumClass, Map<String, Object> docsumValues) { diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java index 689be53de23..b9d894f9eb5 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/RpcSearchInvokerTest.java @@ -54,8 +54,8 @@ public class RpcSearchInvokerTest { AtomicInteger lengthHolder) { return new Client() { @Override - public void search(NodeConnection node, CompressionType compression, int uncompressedLength, byte[] compressedPayload, - RpcSearchInvoker responseReceiver, double timeoutSeconds) { + public void request(String rpcMethod, NodeConnection node, CompressionType compression, int uncompressedLength, + byte[] compressedPayload, ResponseReceiver responseReceiver, double timeoutSeconds) { compressionTypeHolder.set(compression); payloadHolder.set(compressedPayload); lengthHolder.set(uncompressedLength); |