diff options
27 files changed, 683 insertions, 153 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 59aa5b3ba53..259ac5227ae 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -202,8 +202,9 @@ public class ConvertedModel { // Add expressions Map<String, ExpressionFunction> expressions = new HashMap<>(); - for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { - addExpression(output.getSecond(), output.getFirst(), + for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) { + ExpressionFunction expression = asExpressionFunction(outputFunction); + addExpression(expression, expression.getName(), constantsReplacedByFunctions, model, store, profile, queryProfiles, expressions); @@ -218,6 +219,23 @@ public class ConvertedModel { return expressions; } + private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) { + try { + Map<String, TensorType> argumentTypes = new HashMap<>(); + for (Map.Entry<String, String> entry : function.argumentTypes().entrySet()) + argumentTypes.put(entry.getKey(), TensorType.fromSpec(entry.getValue())); + + return new ExpressionFunction(function.name(), + function.arguments(), + new RankingExpression(function.expression()), + argumentTypes, + function.returnType().map(TensorType::fromSpec)); + } + catch (ParseException e) { + throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e); + } + } + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, @@ -248,7 +266,9 @@ public class ConvertedModel { return store.readExpressions(); } - private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); store.writeSmallConstant(constantName, constantValue); profile.addConstant(constantName, asValue(constantValue)); } @@ -258,7 +278,8 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByFunctions, String constantName, - Tensor constantValue) { + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName); if (rankingExpressionFunctionOverridingConstant != null) { TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles)); @@ -306,14 +327,14 @@ public class ConvertedModel { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.inputs().get(functionName); - if (requiredType == null) continue; // Not a required function + Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); + if ( ! requiredType.isPresent()) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); if (rankingExpressionFunction == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + " but this function is not present in " + - profile); + "' of type " + requiredType.get() + + " but this function is not present in " + profile); // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second // phase and summary features), as it may only resolve correctly given those bindings // Or, probably better, annotate the functions with type constraints here and verify during general @@ -321,12 +342,12 @@ public class ConvertedModel { TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + + "' of type " + requiredType.get() + " which must be produced by a function in the rank profile, but " + "this function references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) + if ( ! actualType.isAssignableTo(requiredType.get())) throw new IllegalArgumentException("Model refers input '" + functionName + "'. " + - typeMismatchExplanation(requiredType, actualType)); + typeMismatchExplanation(requiredType.get(), actualType)); } } @@ -339,7 +360,7 @@ public class ConvertedModel { /** Add the generated functions to the rank profile */ private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) { - model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy())); + model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v))); } /** @@ -383,7 +404,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -391,7 +412,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -487,7 +508,7 @@ public class ConvertedModel { if (referenceNode.getOutput() == null) { // function references cannot specify outputs names.add(referenceNode.getName()); if (model.functions().containsKey(referenceNode.getName())) { - addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model); + addFunctionNamesIn(RankingExpression.from(model.functions().get(referenceNode.getName())).getRoot(), names, model); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java index a1b3a0cd9c0..b3b530448fc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java @@ -422,6 +422,15 @@ public class IndexedSearchCluster extends SearchCluster } builder.maxNodesDownPerGroup(rootDispatch.getMaxNodesDownPerFixedRow()); builder.useMultilevelDispatch(useMultilevelDispatchSetup()); + builder.searchableCopies(rootDispatch.getSearchableCopies()); + if (searchCoverage != null) { + if (searchCoverage.getMinimum() != null) + builder.minSearchCoverage(searchCoverage.getMinimum()); + if (searchCoverage.getMinWaitAfterCoverageFactor() != null) + builder.minWaitAfterCoverageFactor(searchCoverage.getMinWaitAfterCoverageFactor()); + if (searchCoverage.getMaxWaitAfterCoverageFactor() != null) + builder.maxWaitAfterCoverageFactor(searchCoverage.getMaxWaitAfterCoverageFactor()); + } } @Override diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def index dce6098ee9b..50989c3ef74 100644 --- a/configdefinitions/src/vespa/dispatch.def +++ b/configdefinitions/src/vespa/dispatch.def @@ -19,6 +19,18 @@ distributionPolicy enum { ROUNDROBIN, ADAPTIVE } default=ROUNDROBIN # Is multi-level dispatch configured for this cluster useMultilevelDispatch bool default=false +# Number of document copies +searchableCopies long default=1 + +# Minimum search coverage required before returning the results of a query +minSearchCoverage double default=100 + +# Minimum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage +minWaitAfterCoverageFactor double default=0 + +# Maximum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage +maxWaitAfterCoverageFactor double default=1 + # The unique key of a search node node[].key int diff --git a/container-core/src/main/java/com/yahoo/container/handler/Coverage.java b/container-core/src/main/java/com/yahoo/container/handler/Coverage.java index 4a937068d81..84cc0734e7c 100644 --- a/container-core/src/main/java/com/yahoo/container/handler/Coverage.java +++ b/container-core/src/main/java/com/yahoo/container/handler/Coverage.java @@ -28,9 +28,9 @@ public class Coverage { EXPLICITLY_FULL, EXPLICITLY_INCOMPLETE, DOCUMENT_COUNT; } - private final static int DEGRADED_BY_MATCH_PHASE = 1; - private final static int DEGRADED_BY_TIMEOUT = 2; - private final static int DEGRADED_BY_ADAPTIVE_TIMEOUT = 4; + public final static int DEGRADED_BY_MATCH_PHASE = 1; + public final static int DEGRADED_BY_TIMEOUT = 2; + public final static int DEGRADED_BY_ADAPTIVE_TIMEOUT = 4; /** * Build an invalid instance to initiate manually. diff --git a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java index de4d9c9fe8b..f40550f1f70 100644 --- a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java +++ b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java @@ -1,6 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.fs4.mplex; +import com.yahoo.concurrent.SystemTimer; +import com.yahoo.fs4.BasicPacket; +import com.yahoo.fs4.ChannelTimeoutException; +import com.yahoo.fs4.Packet; +import com.yahoo.search.Query; +import com.yahoo.search.dispatch.ResponseMonitor; + import java.io.IOException; import java.util.ArrayList; import java.util.List; @@ -9,12 +16,6 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; -import com.yahoo.concurrent.SystemTimer; -import com.yahoo.fs4.BasicPacket; -import com.yahoo.fs4.ChannelTimeoutException; -import com.yahoo.fs4.Packet; -import com.yahoo.search.Query; - /** * This class is used to represent a "channel" in the FS4 protocol. * A channel represents a session between a client and the fdispatch. @@ -34,6 +35,7 @@ public class FS4Channel { volatile private BlockingQueue<BasicPacket> responseQueue; private Query query; private boolean isPingChannel = false; + private ResponseMonitor<FS4Channel> monitor; /** for unit testing. do not use */ protected FS4Channel () { @@ -197,6 +199,9 @@ public class FS4Channel { throws InterruptedException, InvalidChannelException { ensureValidQ().put(packet); + if(monitor != null) { + monitor.responseAvailable(this); + } } /** @@ -241,4 +246,7 @@ public class FS4Channel { return "fs4 channel " + channelId + (isValid() ? " [valid]" : " [invalid]"); } + public void setResponseMonitor(ResponseMonitor<FS4Channel> monitor) { + this.monitor = monitor; + } } diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java index f5d082635ab..8fa8bdb66bf 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java @@ -50,7 +50,7 @@ public class FS4InvokerFactory { public SearchInvoker getSearchInvoker(Query query, Node node) { Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port()); - return new FS4SearchInvoker(searcher, query, backend.openChannel(), node); + return new FS4SearchInvoker(searcher, query, backend.openChannel(), Optional.of(node)); } /** @@ -70,14 +70,14 @@ public class FS4InvokerFactory { * list is invalid and the remaining coverage is not sufficient */ public Optional<SearchInvoker> getSearchInvoker(Query query, int groupId, List<Node> nodes, boolean acceptIncompleteCoverage) { - Map<Integer, SearchInvoker> invokers = new HashMap<>(); + List<SearchInvoker> invokers = new ArrayList<>(nodes.size()); Set<Integer> failed = null; for (Node node : nodes) { boolean nodeAdded = false; if (node.isWorking()) { Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port()); if (backend.probeConnection()) { - invokers.put(node.key(), new FS4SearchInvoker(searcher, query, backend.openChannel(), node)); + invokers.add(node.key(), new FS4SearchInvoker(searcher, query, backend.openChannel(), Optional.of(node))); nodeAdded = true; } } @@ -99,7 +99,7 @@ public class FS4InvokerFactory { } if (!searchCluster.isPartialGroupCoverageSufficient(groupId, success)) { if (acceptIncompleteCoverage) { - createCoverageErrorInvoker(invokers, nodes, failed); + invokers.add(createCoverageErrorInvoker(nodes, failed)); } else { return Optional.empty(); } @@ -107,13 +107,13 @@ public class FS4InvokerFactory { } if (invokers.size() == 1) { - return Optional.of(invokers.values().iterator().next()); + return Optional.of(invokers.get(0)); } else { - return Optional.of(new InterleavedSearchInvoker(invokers)); + return Optional.of(new InterleavedSearchInvoker(invokers, searchCluster)); } } - private void createCoverageErrorInvoker(Map<Integer, SearchInvoker> invokers, List<Node> nodes, Set<Integer> failed) { + private SearchInvoker createCoverageErrorInvoker(List<Node> nodes, Set<Integer> failed) { long activeDocuments = 0; StringBuilder down = new StringBuilder("Connection failure on nodes with distribution-keys: "); Integer key = null; @@ -129,7 +129,8 @@ public class FS4InvokerFactory { } } Coverage coverage = new Coverage(0, activeDocuments, 0); - invokers.put(key, new SearchErrorInvoker(ErrorMessage.createBackendCommunicationError(down.toString()), coverage)); + coverage.setNodesTried(1); + return new SearchErrorInvoker(ErrorMessage.createBackendCommunicationError(down.toString()), coverage); } public FillInvoker getFillInvoker(Query query, Node node) { diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java index 98676890bdf..da32cfc4fda 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java @@ -6,14 +6,13 @@ import com.yahoo.fs4.ChannelTimeoutException; import com.yahoo.fs4.Packet; import com.yahoo.fs4.QueryPacket; import com.yahoo.fs4.QueryResultPacket; -import com.yahoo.fs4.mplex.Backend; import com.yahoo.fs4.mplex.FS4Channel; import com.yahoo.fs4.mplex.InvalidChannelException; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.ResponseMonitor; import com.yahoo.search.dispatch.SearchInvoker; import com.yahoo.search.dispatch.searchcluster.Node; -import com.yahoo.search.result.Coverage; import com.yahoo.search.result.ErrorMessage; import java.io.IOException; @@ -30,29 +29,21 @@ import static java.util.Arrays.asList; * * @author ollivir */ -public class FS4SearchInvoker extends SearchInvoker { +public class FS4SearchInvoker extends SearchInvoker implements ResponseMonitor<FS4Channel> { private final VespaBackEndSearcher searcher; private FS4Channel channel; - private final Optional<Node> node; private ErrorMessage pendingSearchError = null; private Query query = null; private QueryPacket queryPacket = null; - public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, FS4Channel channel, Node node) { + public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, FS4Channel channel, Optional<Node> node) { + super(node); this.searcher = searcher; - this.node = Optional.of(node); this.channel = channel; channel.setQuery(query); - } - - // fdispatch code path - public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, Backend backend) { - this.searcher = searcher; - this.node = Optional.empty(); - this.channel = backend.openChannel(); - channel.setQuery(query); + channel.setResponseMonitor(this); } @Override @@ -68,6 +59,8 @@ public class FS4SearchInvoker extends SearchInvoker { this.query = query; this.queryPacket = queryPacket; + channel.setResponseMonitor(this); + try { boolean couldSend = channel.sendPacket(queryPacket); if (!couldSend) { @@ -115,7 +108,7 @@ public class FS4SearchInvoker extends SearchInvoker { searcher.addMetaInfo(query, queryPacket.getQueryPacketData(), resultPacket, result); - searcher.addUnfilledHits(result, resultPacket.getDocuments(), false, queryPacket.getQueryPacketData(), cacheKey, node.map(Node::key)); + searcher.addUnfilledHits(result, resultPacket.getDocuments(), false, queryPacket.getQueryPacketData(), cacheKey, distributionKey()); Packet[] packets; CacheControl cacheControl = searcher.getCacheControl(); PacketWrapper packetWrapper = cacheControl.lookup(cacheKey, query); @@ -130,7 +123,7 @@ public class FS4SearchInvoker extends SearchInvoker { } else { packets = new Packet[1]; packets[0] = resultPacket; - cacheControl.cache(cacheKey, query, new DocsumPacketKey[0], packets, node.map(Node::key)); + cacheControl.cache(cacheKey, query, new DocsumPacketKey[0], packets, distributionKey()); } } return asList(result); @@ -138,10 +131,7 @@ public class FS4SearchInvoker extends SearchInvoker { private List<Result> errorResult(ErrorMessage errorMessage) { Result error = new Result(query, errorMessage); - node.ifPresent(n -> { - Coverage coverage = new Coverage(0, n.getActiveDocuments(), 0); - error.setCoverage(coverage); - }); + getErrorCoverage().ifPresent(error::setCoverage); return Arrays.asList(error); } @@ -164,4 +154,9 @@ public class FS4SearchInvoker extends SearchInvoker { private boolean isLoggingFine() { return getLogger().isLoggable(Level.FINE); } + + @Override + public void responseAvailable(FS4Channel from) { + responseAvailable(); + } } diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java index a98c34295ee..209f6faefa0 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java @@ -222,7 +222,7 @@ public class FastSearcher extends VespaBackEndSearcher { if(direct.isPresent()) { return fs4InvokerFactory.getSearchInvoker(query, direct.get()); } - return new FS4SearchInvoker(this, query, dispatchBackend); + return new FS4SearchInvoker(this, query, dispatchBackend.openChannel(), Optional.empty()); } /** @@ -284,6 +284,7 @@ public class FastSearcher extends VespaBackEndSearcher { result.hits().addAll(partialResult.hits().asUnorderedHits()); } if (finalCoverage != null) { + adjustCoverageDegradedReason(finalCoverage); result.setCoverage(finalCoverage); } @@ -301,6 +302,18 @@ public class FastSearcher extends VespaBackEndSearcher { return result; } + private void adjustCoverageDegradedReason(Coverage coverage) { + int asked = coverage.getNodesTried(); + int answered = coverage.getNodes(); + if (asked > answered) { + int searchableCopies = (int) dispatcher.searchCluster().dispatchConfig().searchableCopies(); + int missingNodes = (asked - answered) - (searchableCopies - 1); + if (missingNodes > 0) { + coverage.setDegradedReason(com.yahoo.container.handler.Coverage.DEGRADED_BY_TIMEOUT); + } + } + } + private static @NonNull Optional<String> quotedSummaryClass(String summaryClass) { return Optional.of(summaryClass == null ? "[null]" : quote(summaryClass)); } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index 0382f47457e..1ca64be7924 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -20,7 +20,6 @@ import com.yahoo.vespa.config.search.DispatchConfig; import java.util.Arrays; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; @@ -134,7 +133,7 @@ public class Dispatcher extends AbstractComponent { int max = Integer.min(searchCluster.orderedGroups().size(), MAX_GROUP_SELECTION_ATTEMPTS); Set<Integer> rejected = null; for (int i = 0; i < max; i++) { - Optional<Group> groupInCluster = loadBalancer.takeGroupForQuery(rejected); + Optional<Group> groupInCluster = loadBalancer.takeGroup(rejected); if (!groupInCluster.isPresent()) { // No groups available break; diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index 9ff43df87cf..83647b332e6 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -5,12 +5,25 @@ import com.yahoo.fs4.QueryPacket; import com.yahoo.prelude.fastsearch.CacheKey; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.searchcluster.SearchCluster; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.vespa.config.search.DispatchConfig; import java.io.IOException; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.IdentityHashMap; import java.util.List; -import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static com.yahoo.container.handler.Coverage.DEGRADED_BY_ADAPTIVE_TIMEOUT; +import static com.yahoo.container.handler.Coverage.DEGRADED_BY_TIMEOUT; /** * InterleavedSearchInvoker uses multiple {@link SearchInvoker} objects to interface with content @@ -19,11 +32,25 @@ import java.util.Map; * * @author ollivir */ -public class InterleavedSearchInvoker extends SearchInvoker { - private final Collection<SearchInvoker> invokers; +public class InterleavedSearchInvoker extends SearchInvoker implements ResponseMonitor<SearchInvoker> { + private static final Logger log = Logger.getLogger(InterleavedSearchInvoker.class.getName()); + + private final Set<SearchInvoker> invokers; + private final SearchCluster searchCluster; + private final LinkedBlockingQueue<SearchInvoker> availableForProcessing; + private Query query; + + private boolean adaptiveTimeoutCalculated = false; + private long adaptiveTimeoutMin = 0; + private long adaptiveTimeoutMax = 0; + private long deadline = 0; - public InterleavedSearchInvoker(Map<Integer, SearchInvoker> invokers) { - this.invokers = new ArrayList<>(invokers.values()); + public InterleavedSearchInvoker(Collection<SearchInvoker> invokers, SearchCluster searchCluster) { + super(Optional.empty()); + this.invokers = Collections.newSetFromMap(new IdentityHashMap<>()); + this.invokers.addAll(invokers); + this.searchCluster = searchCluster; + this.availableForProcessing = newQueue(); } /** @@ -33,27 +60,109 @@ public class InterleavedSearchInvoker extends SearchInvoker { */ @Override protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException { + this.query = query; + invokers.forEach(invoker -> invoker.setMonitor(this)); + deadline = currentTime() + query.getTimeLeft(); + int originalHits = query.getHits(); int originalOffset = query.getOffset(); query.setHits(query.getHits() + query.getOffset()); query.setOffset(0); + for (SearchInvoker invoker : invokers) { invoker.sendSearchRequest(query, null); } + query.setHits(originalHits); query.setOffset(originalOffset); } @Override protected List<Result> getSearchResults(CacheKey cacheKey) throws IOException { + int requests = invokers.size(); + int responses = 0; List<Result> results = new ArrayList<>(); - for (SearchInvoker invoker : invokers) { - results.addAll(invoker.getSearchResults(cacheKey)); + long nextTimeout = query.getTimeLeft(); + try { + while (!invokers.isEmpty() && nextTimeout >= 0) { + SearchInvoker invoker = availableForProcessing.poll(nextTimeout, TimeUnit.MILLISECONDS); + if (invoker == null) { + if (log.isLoggable(Level.FINE)) { + log.fine("Search timed out with " + requests + " requests made, " + responses + " responses received"); + } + break; + } else { + invokers.remove(invoker); + results.addAll(invoker.getSearchResults(cacheKey)); + responses++; + } + nextTimeout = nextTimeout(requests, responses); + } + } catch (InterruptedException e) { + throw new RuntimeException("Interrupted while waiting for search results", e); } + + insertTimeoutErrors(results); return results; } + private void insertTimeoutErrors(List<Result> results) { + int degradedReason = adaptiveTimeoutCalculated ? DEGRADED_BY_ADAPTIVE_TIMEOUT : DEGRADED_BY_TIMEOUT; + + for (SearchInvoker invoker : invokers) { + Optional<Integer> dk = invoker.distributionKey(); + String message; + if (dk.isPresent()) { + message = "Backend communication timeout on node with distribution-key " + dk.get(); + } else { + message = "Backend communication timeout"; + } + Result error = new Result(query, ErrorMessage.createBackendCommunicationError(message)); + invoker.getErrorCoverage().ifPresent(coverage -> { + coverage.setDegradedReason(degradedReason); + error.setCoverage(coverage); + }); + results.add(error); + } + } + + private long nextTimeout(int requests, int responses) { + DispatchConfig config = searchCluster.dispatchConfig(); + double minimumCoverage = config.minSearchCoverage(); + + if (requests == responses || minimumCoverage >= 100.0) { + return query.getTimeLeft(); + } + int minimumResponses = (int) (requests * minimumCoverage / 100.0); + + if (responses < minimumResponses) { + return query.getTimeLeft(); + } + + long timeLeft = query.getTimeLeft(); + if (!adaptiveTimeoutCalculated) { + adaptiveTimeoutMin = (long) (timeLeft * config.minWaitAfterCoverageFactor()); + adaptiveTimeoutMax = (long) (timeLeft * config.maxWaitAfterCoverageFactor()); + adaptiveTimeoutCalculated = true; + } + + long now = currentTime(); + int pendingQueries = requests - responses; + double missWidth = ((100.0 - config.minSearchCoverage()) * requests) / 100.0 - 1.0; + double slopedWait = adaptiveTimeoutMin; + if (pendingQueries > 1 && missWidth > 0.0) { + slopedWait += ((adaptiveTimeoutMax - adaptiveTimeoutMin) * (pendingQueries - 1)) / missWidth; + } + long nextAdaptive = (long) slopedWait; + if (now + nextAdaptive >= deadline) { + return deadline - now; + } + deadline = now + nextAdaptive; + + return nextAdaptive; + } + @Override protected void release() { if (!invokers.isEmpty()) { @@ -61,4 +170,26 @@ public class InterleavedSearchInvoker extends SearchInvoker { invokers.clear(); } } + + @Override + public void responseAvailable(SearchInvoker from) { + if (availableForProcessing != null) { + availableForProcessing.add(from); + } + } + + @Override + protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) { + // never to be called + } + + // For overriding in tests + protected long currentTime() { + return System.currentTimeMillis(); + } + + // For overriding in tests + protected LinkedBlockingQueue<SearchInvoker> newQueue() { + return new LinkedBlockingQueue<>(); + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java index 222ae6a4ea0..df6384cf81c 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java @@ -1,7 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; -import com.yahoo.search.Query; import com.yahoo.search.dispatch.searchcluster.Group; import com.yahoo.search.dispatch.searchcluster.SearchCluster; @@ -45,13 +44,13 @@ public class LoadBalancer { } /** - * Select and allocate the search cluster group which is to be used for the provided query. Callers <b>must</b> call + * Select and allocate the search cluster group which is to be used for the next search query. Callers <b>must</b> call * {@link #releaseGroup} symmetrically for each taken allocation. * * @param rejectedGroups if not null, the load balancer will only return groups with IDs not in the set * @return The node group to target, or <i>empty</i> if the internal dispatch logic cannot be used */ - public Optional<Group> takeGroupForQuery(Set<Integer> rejectedGroups) { + public Optional<Group> takeGroup(Set<Integer> rejectedGroups) { if (scoreboard == null) { return Optional.empty(); } @@ -60,7 +59,7 @@ public class LoadBalancer { } /** - * Release an allocation given by {@link #takeGroupForQuery}. The release must be done exactly once for each allocation. + * Release an allocation given by {@link #takeGroup}. The release must be done exactly once for each allocation. * * @param group * previously allocated group diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java b/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java new file mode 100644 index 00000000000..c2e81d43677 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java @@ -0,0 +1,13 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +/** + * Classes implementing ResponseMonitor can be informed by monitored objects + * that a response is available for processing. The responseAvailable method + * must be thread-safe. + * + * @author ollivir + */ +public interface ResponseMonitor<T> { + void responseAvailable(T from); +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java index d5c505aa31b..01da3c20745 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java @@ -11,6 +11,7 @@ import com.yahoo.search.result.ErrorMessage; import java.io.IOException; import java.util.Arrays; import java.util.List; +import java.util.Optional; /** * A search invoker that will immediately produce an error that occurred during @@ -23,8 +24,10 @@ public class SearchErrorInvoker extends SearchInvoker { private final ErrorMessage message; private Query query; private final Coverage coverage; + private ResponseMonitor<SearchInvoker> monitor; public SearchErrorInvoker(ErrorMessage message, Coverage coverage) { + super(Optional.empty()); this.message = message; this.coverage = coverage; } @@ -36,6 +39,9 @@ public class SearchErrorInvoker extends SearchInvoker { @Override protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException { this.query = query; + if(monitor != null) { + monitor.responseAvailable(this); + } } @Override @@ -52,4 +58,8 @@ public class SearchErrorInvoker extends SearchInvoker { // nothing to do } + @Override + protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) { + this.monitor = monitor; + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java index 53e09823f32..2691b32d631 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java @@ -5,9 +5,12 @@ import com.yahoo.fs4.QueryPacket; import com.yahoo.prelude.fastsearch.CacheKey; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.dispatch.searchcluster.Node; +import com.yahoo.search.result.Coverage; import java.io.IOException; import java.util.List; +import java.util.Optional; /** * SearchInvoker encapsulates an allocated connection for running a single search query. @@ -16,6 +19,13 @@ import java.util.List; * @author ollivir */ public abstract class SearchInvoker extends CloseableInvoker { + private final Optional<Node> node; + private ResponseMonitor<SearchInvoker> monitor; + + protected SearchInvoker(Optional<Node> node) { + this.node = node; + } + /** * Retrieve the hits for the given {@link Query}. The invoker may return more than one result, in which case the caller is responsible * for merging the results. If multiple results are returned and the search query had a hit offset other than zero, that offset is @@ -29,4 +39,26 @@ public abstract class SearchInvoker extends CloseableInvoker { protected abstract void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException; protected abstract List<Result> getSearchResults(CacheKey cacheKey) throws IOException; + + protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) { + this.monitor = monitor; + } + + protected void responseAvailable() { + if(monitor != null) { + monitor.responseAvailable(this); + } + } + + protected Optional<Integer> distributionKey() { + return node.map(Node::key); + } + + protected Optional<Coverage> getErrorCoverage() { + if(node.isPresent()) { + return Optional.of(new Coverage(0, node.get().getActiveDocuments(), 0)); + } else { + return Optional.empty(); + } + } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java index b8d76906f70..8e278f78d7a 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java @@ -36,10 +36,7 @@ public class SearchCluster implements NodeManager<Node> { private static final Logger log = Logger.getLogger(SearchCluster.class.getName()); - /** The min active docs a group must have to be considered up, as a % of the average active docs of the other groups */ - private final double minActivedocsCoveragePercentage; - private final double minGroupCoverage; - private final int maxNodesDownPerGroup; + private final DispatchConfig dispatchConfig; private final int size; private final String clusterId; private final ImmutableMap<Integer, Group> groups; @@ -62,20 +59,14 @@ public class SearchCluster implements NodeManager<Node> { private final FS4ResourcePool fs4ResourcePool; public SearchCluster(String clusterId, DispatchConfig dispatchConfig, FS4ResourcePool fs4ResourcePool, int containerClusterSize, VipStatus vipStatus) { - this(clusterId, dispatchConfig.minActivedocsPercentage(), dispatchConfig.minGroupCoverage(), dispatchConfig.maxNodesDownPerGroup(), - toNodes(dispatchConfig), fs4ResourcePool, containerClusterSize, vipStatus); - } - - public SearchCluster(String clusterId, double minActivedocsCoverage, double minGroupCoverage, int maxNodesDownPerGroup, List<Node> nodes, FS4ResourcePool fs4ResourcePool, - int containerClusterSize, VipStatus vipStatus) { this.clusterId = clusterId; - this.minActivedocsCoveragePercentage = minActivedocsCoverage; - this.minGroupCoverage = minGroupCoverage; - this.maxNodesDownPerGroup = maxNodesDownPerGroup; - this.size = nodes.size(); + this.dispatchConfig = dispatchConfig; + this.size = dispatchConfig.node().size(); this.fs4ResourcePool = fs4ResourcePool; this.vipStatus = vipStatus; + List<Node> nodes = toNodes(dispatchConfig); + // Create groups ImmutableMap.Builder<Integer, Group> groupsBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<Integer, List<Node>> group : nodes.stream().collect(Collectors.groupingBy(Node::group)).entrySet()) { @@ -143,6 +134,10 @@ public class SearchCluster implements NodeManager<Node> { return nodesBuilder.build(); } + public DispatchConfig dispatchConfig() { + return dispatchConfig; + } + /** Returns the number of nodes in this cluster (across all groups) */ public int size() { return size; } @@ -286,7 +281,7 @@ public class SearchCluster implements NodeManager<Node> { if (averageDocumentsInOtherGroups > 0) { double coverage = 100.0 * (double) activeDocuments / averageDocumentsInOtherGroups; - sufficientCoverage = coverage >= minActivedocsCoveragePercentage; + sufficientCoverage = coverage >= dispatchConfig.minActivedocsPercentage(); } if (sufficientCoverage) { sufficientCoverage = isGroupNodeCoverageSufficient(nodes); @@ -302,7 +297,8 @@ public class SearchCluster implements NodeManager<Node> { } } int numNodes = nodes.size(); - int nodesAllowedDown = maxNodesDownPerGroup + (int) (((double) numNodes * (100.0 - minGroupCoverage)) / 100.0); + int nodesAllowedDown = dispatchConfig.maxNodesDownPerGroup() + + (int) (((double) numNodes * (100.0 - dispatchConfig.minGroupCoverage())) / 100.0); return nodesUp + nodesAllowedDown >= numNodes; } @@ -325,7 +321,7 @@ public class SearchCluster implements NodeManager<Node> { */ public boolean isPartialGroupCoverageSufficient(int groupId, List<Node> nodes) { if (orderedGroups.size() == 1) { - return nodes.size() >= groupSize() - maxNodesDownPerGroup; + return nodes.size() >= groupSize() - dispatchConfig.maxNodesDownPerGroup(); } long sumOfActiveDocuments = 0; int otherGroups = 0; diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java new file mode 100644 index 00000000000..69458f25f93 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -0,0 +1,180 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import com.yahoo.fs4.QueryPacket; +import com.yahoo.prelude.fastsearch.CacheKey; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.dispatch.searchcluster.Node; +import com.yahoo.search.dispatch.searchcluster.SearchCluster; +import com.yahoo.test.ManualClock; +import org.junit.Test; + +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static com.yahoo.search.dispatch.MockSearchCluster.createDispatchConfig; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * @author ollivir + */ +public class InterleavedSearchInvokerTest { + private ManualClock clock = new ManualClock(Instant.now()); + private Query query = new TestQuery(); + private LinkedList<Event> expectedEvents = new LinkedList<>(); + private List<SearchInvoker> invokers = new ArrayList<>(); + + @Test + public void requireThatAdaptiveTimeoutsAreNotUsedWithFullCoverageRequirement() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(100.0), 1, 3); + SearchInvoker invoker = createInterleavedInvoker(cluster, 3); + + expectedEvents.add(new Event(5000, 100, 0)); + expectedEvents.add(new Event(4900, 100, 1)); + expectedEvents.add(new Event(4800, 100, 2)); + + invoker.search(query, null, null); + + assertTrue("All test scenario events processed", expectedEvents.isEmpty()); + } + + @Test + public void requireThatTimeoutsAreNotMarkedAsAdaptive() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(100.0), 1, 3); + SearchInvoker invoker = createInterleavedInvoker(cluster, 3); + + expectedEvents.add(new Event(5000, 300, 0)); + expectedEvents.add(new Event(4700, 300, 1)); + expectedEvents.add(null); + + List<Result> results = invoker.search(query, null, null); + + assertTrue("All test scenario events processed", expectedEvents.isEmpty()); + assertNotNull("Last invoker is marked as an error", results.get(2).hits().getErrorHit()); + assertTrue("Timed out invoker is a normal timeout", results.get(2).getCoverage(false).isDegradedByTimeout()); + } + + @Test + public void requireThatAdaptiveTimeoutDecreasesTimeoutWhenCoverageIsReached() throws IOException { + SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(50.0), 1, 4); + SearchInvoker invoker = createInterleavedInvoker(cluster, 4); + + expectedEvents.add(new Event(5000, 100, 0)); + expectedEvents.add(new Event(4900, 100, 1)); + expectedEvents.add(new Event(2400, 100, 2)); + expectedEvents.add(new Event(0, 0, null)); + + List<Result> results = invoker.search(query, null, null); + + assertTrue("All test scenario events processed", expectedEvents.isEmpty()); + assertNotNull("Last invoker is marked as an error", results.get(3).hits().getErrorHit()); + assertTrue("Timed out invoker is an adaptive timeout", results.get(3).getCoverage(false).isDegradedByAdapativeTimeout()); + } + + private InterleavedSearchInvoker createInterleavedInvoker(SearchCluster searchCluster, int numInvokers) { + for (int i = 0; i < numInvokers; i++) { + invokers.add(new TestInvoker()); + } + + return new InterleavedSearchInvoker(invokers, searchCluster) { + @Override + protected long currentTime() { + return clock.millis(); + } + + @Override + protected LinkedBlockingQueue<SearchInvoker> newQueue() { + return new LinkedBlockingQueue<SearchInvoker>() { + @Override + public SearchInvoker poll(long timeout, TimeUnit timeUnit) throws InterruptedException { + assertFalse(expectedEvents.isEmpty()); + Event ev = expectedEvents.removeFirst(); + if (ev == null) { + return null; + } else { + return ev.process(query, timeout); + } + } + }; + } + }; + } + + private class Event { + Long expectedTimeout; + long delay; + Integer invokerIndex; + + public Event(Integer expectedTimeout, int delay, Integer invokerIndex) { + this.expectedTimeout = (long) expectedTimeout; + this.delay = delay; + this.invokerIndex = invokerIndex; + } + + public SearchInvoker process(Query query, long currentTimeout) { + if (expectedTimeout != null) { + assertEquals("Expecting timeout to be " + expectedTimeout, (long) expectedTimeout, currentTimeout); + } + clock.advance(Duration.ofMillis(delay)); + if (query.getTimeLeft() < 0) { + fail("Test sequence ran out of time window"); + } + if (invokerIndex == null) { + return null; + } else { + return invokers.get(invokerIndex); + } + } + } + + private class TestInvoker extends SearchInvoker { + protected TestInvoker() { + super(Optional.of(new Node(42, "?", 0, 0))); + } + + @Override + protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException { + } + + @Override + protected List<Result> getSearchResults(CacheKey cacheKey) throws IOException { + return Collections.singletonList(new Result(query)); + } + + @Override + protected void release() { + } + } + + public class TestQuery extends Query { + private long start = clock.millis(); + + public TestQuery() { + super(); + setTimeout(5000); + } + + @Override + public long getStartTime() { + return start; + } + + @Override + public long getDurationTime() { + return clock.millis() - start; + } + } +} diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java index 38a753360d8..c056423a9c4 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java @@ -7,9 +7,9 @@ import com.yahoo.search.dispatch.searchcluster.SearchCluster; import junit.framework.AssertionFailedError; import org.junit.Test; -import java.util.Arrays; import java.util.Optional; +import static com.yahoo.search.dispatch.MockSearchCluster.createDispatchConfig; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -22,10 +22,10 @@ public class LoadBalancerTest { @Test public void requreThatLoadBalancerServesSingleNodeSetups() { Node n1 = new Node(0, "test-node1", 0, 0); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.orElseGet(() -> { throw new AssertionFailedError("Expected a SearchCluster.Group"); }); @@ -36,10 +36,10 @@ public class LoadBalancerTest { public void requreThatLoadBalancerServesMultiGroupSetups() { Node n1 = new Node(0, "test-node1", 0, 0); Node n2 = new Node(1, "test-node2", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.orElseGet(() -> { throw new AssertionFailedError("Expected a SearchCluster.Group"); }); @@ -52,10 +52,10 @@ public class LoadBalancerTest { Node n2 = new Node(1, "test-node2", 1, 0); Node n3 = new Node(0, "test-node3", 0, 1); Node n4 = new Node(1, "test-node4", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2, n3, n4), null, 2, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2, n3, n4), null, 2, null); LoadBalancer lb = new LoadBalancer(cluster, true); - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); assertThat(grp.isPresent(), is(true)); } @@ -63,18 +63,18 @@ public class LoadBalancerTest { public void requreThatLoadBalancerReturnsDifferentGroups() { Node n1 = new Node(0, "test-node1", 0, 0); Node n2 = new Node(1, "test-node2", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); // get first group - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.get(); int id1 = group.id(); // release allocation lb.releaseGroup(group); // get second group - grp = lb.takeGroupForQuery(null); + grp = lb.takeGroup(null); group = grp.get(); assertThat(group.id(), not(equalTo(id1))); } @@ -83,16 +83,16 @@ public class LoadBalancerTest { public void requreThatLoadBalancerReturnsGroupWithShortestQueue() { Node n1 = new Node(0, "test-node1", 0, 0); Node n2 = new Node(1, "test-node2", 1, 1); - SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null); + SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null); LoadBalancer lb = new LoadBalancer(cluster, true); // get first group - Optional<Group> grp = lb.takeGroupForQuery(null); + Optional<Group> grp = lb.takeGroup(null); Group group = grp.get(); int id1 = group.id(); // get second group - grp = lb.takeGroupForQuery(null); + grp = lb.takeGroup(null); group = grp.get(); int id2 = group.id(); assertThat(id2, not(equalTo(id1))); @@ -100,7 +100,7 @@ public class LoadBalancerTest { lb.releaseGroup(group); // get third group - grp = lb.takeGroupForQuery(null); + grp = lb.takeGroup(null); group = grp.get(); assertThat(group.id(), equalTo(id2)); } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java index fc505097472..f7b92419b52 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java @@ -6,9 +6,9 @@ import com.google.common.collect.ImmutableMultimap; import com.yahoo.search.dispatch.searchcluster.Group; import com.yahoo.search.dispatch.searchcluster.Node; import com.yahoo.search.dispatch.searchcluster.SearchCluster; +import com.yahoo.vespa.config.search.DispatchConfig; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Optional; @@ -22,7 +22,11 @@ public class MockSearchCluster extends SearchCluster { private final ImmutableMultimap<String, Node> nodesByHost; public MockSearchCluster(String clusterId, int groups, int nodesPerGroup) { - super(clusterId, 100, 100, 0, Collections.emptyList(), null, 1, null); + this(clusterId, createDispatchConfig(), groups, nodesPerGroup); + } + + public MockSearchCluster(String clusterId, DispatchConfig dispatchConfig, int groups, int nodesPerGroup) { + super(clusterId, dispatchConfig, null, 1, null); ImmutableMap.Builder<Integer, Group> groupBuilder = ImmutableMap.builder(); ImmutableMultimap.Builder<String, Node> hostBuilder = ImmutableMultimap.builder(); @@ -58,7 +62,7 @@ public class MockSearchCluster extends SearchCluster { } public Optional<Group> group(int n) { - if(n < numGroups) { + if (n < numGroups) { return Optional.of(groups.get(n)); } else { return Optional.empty(); @@ -80,4 +84,24 @@ public class MockSearchCluster extends SearchCluster { public void failed(Node node) { node.setWorking(false); } + + public static DispatchConfig createDispatchConfig(Node... nodes) { + return createDispatchConfig(100.0, nodes); + } + + public static DispatchConfig createDispatchConfig(double minSearchCoverage, Node... nodes) { + DispatchConfig.Builder builder = new DispatchConfig.Builder(); + builder.minActivedocsPercentage(88.0); + builder.minGroupCoverage(99.0); + builder.maxNodesDownPerGroup(0); + builder.minSearchCoverage(minSearchCoverage); + if(minSearchCoverage < 100.0) { + builder.minWaitAfterCoverageFactor(0); + builder.maxWaitAfterCoverageFactor(0.5); + } + for (Node n : nodes) { + builder.node(new DispatchConfig.Node.Builder().key(n.key()).host(n.hostname()).port(n.fs4port()).group(n.group())); + } + return new DispatchConfig(builder); + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 2866a2c76b2..c2235b9abe9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; @@ -59,18 +60,29 @@ public class ImportedModel { /** Returns an immutable map of the inputs of this */ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } + // CFG + public Optional<String> inputTypeSpec(String input) { + return Optional.ofNullable(inputs.get(input)).map(TensorType::toString); + } + /** - * Returns an immutable map of the small constants of this. + * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form. * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + // CFG + public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); } + + boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); } /** * Returns an immutable map of the large constants of this. * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. * For TensorFlow this corresponds to Variable files stored separately. */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + // CFG + public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); } + + boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); } /** * Returns an immutable map of the expressions of this - corresponding to graph nodes @@ -79,11 +91,14 @@ public class ImportedModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + // TODO: Most of the usage of the above can be replaced by a faster expressionNames method + /** * Returns an immutable map of the functions that are part of this model. * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } + // CFG + public Map<String, String> functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -108,43 +123,60 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionFunction>> outputExpressions() { - List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); + // CFG + public List<ImportedFunction> outputExpressions() { + List<ImportedFunction> functions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), - signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(), + signatureEntry.getKey() + "." + outputEntry.getKey())); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty()))); + functions.add(new ImportedFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty())); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty())); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty()))); + functions.add(new ImportedFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty())); } } } - return expressions; + return functions; + } + + private Map<String, String> asTensorStrings(Map<String, Tensor> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, Tensor> entry : map.entrySet()) { + Tensor tensor = entry.getValue(); + // TODO: See Tensor.toStandardString + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) + values.put(entry.getKey(), tensor.toString()); + else + values.put(entry.getKey(), tensor.type() + ":" + tensor); + } + return values; + } + + private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) { + HashMap<String, String> values = new HashMap<>(); + for (Map.Entry<String, RankingExpression> entry : map.entrySet()) + values.put(entry.getKey(), entry.getValue().getRoot().toString()); + return values; } /** @@ -213,6 +245,17 @@ public class ImportedModel { Optional.empty()); } + /** Returns the expression this output references as an imported function */ + public ImportedFunction outputFunction(String outputName, String functionName) { + return new ImportedFunction(functionName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); + } + + // CFG + @Override public String toString() { return "signature '" + name + "'"; } @@ -223,4 +266,37 @@ public class ImportedModel { } + // CFG + public static class ImportedFunction { + + private final String name; + private final List<String> arguments; + private final Map<String, String> argumentTypes; + private final String expression; + private final Optional<String> returnType; + + public ImportedFunction(String name, List<String> arguments, RankingExpression expression, + Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { + this.name = name; + this.arguments = arguments; + this.expression = expression.getRoot().toString(); + this.argumentTypes = asStrings(argumentTypes); + this.returnType = returnType.map(TensorType::toString); + } + + private static Map<String, String> asStrings(Map<String, TensorType> map) { + Map<String, String> stringMap = new HashMap<>(); + for (Map.Entry<String, TensorType> entry : map.entrySet()) + stringMap.put(entry.getKey(), entry.getValue().toString()); + return stringMap; + } + + public String name() { return name; } + public List<String> arguments() { return Collections.unmodifiableList(arguments); } + public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); } + public String expression() { return expression; } + public Optional<String> returnType() { return returnType; } + + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java index 1b7532631e1..bfdaaca1dd7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java @@ -69,6 +69,7 @@ public class ImportedModels { * models directory works * @return the model at this path or null if none */ + // CFG public ImportedModel get(File modelPath) { return importedModels.get(toName(modelPath)); } @@ -78,6 +79,7 @@ public class ImportedModels { } /** Returns an immutable collection of all the imported models */ + // CFG public Collection<ImportedModel> all() { return importedModels.values(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index cb095e81147..8a885938bf9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -121,7 +121,7 @@ public abstract class ModelImporter { private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) { return operation.function(); } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index d3996da9b58..315456c2613 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -28,13 +28,13 @@ public class OnnxMnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.largeConstants().size()); - Tensor constant0 = model.largeConstants().get("test_Variable"); + Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.largeConstants().get("test_Variable_1"); + Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -84,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase { private Context contextFrom(ImportedModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index 6215997d8f9..be676186017 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -24,13 +24,13 @@ public class TensorFlowMnistSoftmaxImportTestCase { // Check constants Assert.assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("test_Variable_read"); + Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read")); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read"); + Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read")); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index c3b82cccb46..4ff0c96d369 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -93,8 +93,8 @@ public class TestableTensorFlowModel { static Context contextFrom(ImportedModel result) { TestableModelContext context = new TestableModelContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); return context; } @@ -108,7 +108,7 @@ public class TestableTensorFlowModel { private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { - RankingExpression e = model.functions().get(functionName); + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); evaluateFunctionDependencies(context, model, e.getRoot()); context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 483ccd330e0..1ee22c69c23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -230,7 +230,7 @@ public interface Tensor { * @return the tensor on the standard string format */ static String toStandardString(Tensor tensor) { - if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that? + if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Always do that return tensor.type() + ":" + contentToString(tensor); else return contentToString(tensor); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 000f33696f2..fa32d385004 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -76,32 +76,41 @@ class TensorParser { } private static Tensor fromCellString(Tensor.Builder builder, String s) { - s = s.trim().substring(1).trim(); - while (s.length() > 1) { - int keyOrTensorEnd = s.indexOf('}'); + int index = 1; + index = skipSpace(index, s); + while (index + 1 < s.length()) { + int keyOrTensorEnd = s.indexOf('}', index); TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); - if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAdress is empty - addLabels(s.substring(0, keyOrTensorEnd + 1), addressBuilder); - s = s.substring(keyOrTensorEnd + 1).trim(); - if ( ! s.startsWith(":")) - throw new IllegalArgumentException("Expecting a ':' after " + s + ", got '" + s + "'"); - s = s.substring(1); + if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty + addLabels(s.substring(index, keyOrTensorEnd + 1), addressBuilder); + index = keyOrTensorEnd + 1; + index = skipSpace(index, s); + if ( s.charAt(index) != ':') + throw new IllegalArgumentException("Expecting a ':' after " + s.substring(index) + ", got '" + s + "'"); + index++; } - int valueEnd = s.indexOf(','); + int valueEnd = s.indexOf(',', index); if (valueEnd < 0) { // last value - valueEnd = s.indexOf("}"); + valueEnd = s.indexOf('}', index); if (valueEnd < 0) throw new IllegalArgumentException("A tensor string must end by '}'"); } TensorAddress address = addressBuilder.build(); - Double value = asDouble(address, s.substring(0, valueEnd).trim()); + Double value = asDouble(address, s.substring(index, valueEnd).trim()); builder.cell(address, value); - s = s.substring(valueEnd+1).trim(); + index = valueEnd+1; + index = skipSpace(index, s); } return builder.build(); } + private static int skipSpace(int index, String s) { + while (index < s.length() && s.charAt(index) == ' ') + index++; + return index; + } + /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */ private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { mapAddressString = mapAddressString.trim(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 38a8329bff1..122b6019884 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -101,7 +101,7 @@ public class TensorTestCase { " {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"), Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:0,y:0,z:0}:1, {x:0,y:1,z:0}:0, {x:1,y:0,z:0}:0, {x:1,y:1,z:0}:0, {x:2,y:0,z:0}:0, {x:2,y:1,z:0}:0, "+ - " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 }"), + " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 } "), Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x")); } |