summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java51
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java9
-rw-r--r--configdefinitions/src/vespa/dispatch.def12
-rw-r--r--container-core/src/main/java/com/yahoo/container/handler/Coverage.java6
-rw-r--r--container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java20
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java17
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java35
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java15
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java145
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java7
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java13
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java10
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java32
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java30
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java180
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java28
-rw-r--r--container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java30
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java132
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java8
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java35
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java2
27 files changed, 683 insertions, 153 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 59aa5b3ba53..259ac5227ae 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -202,8 +202,9 @@ public class ConvertedModel {
// Add expressions
Map<String, ExpressionFunction> expressions = new HashMap<>();
- for (Pair<String, ExpressionFunction> output : model.outputExpressions()) {
- addExpression(output.getSecond(), output.getFirst(),
+ for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) {
+ ExpressionFunction expression = asExpressionFunction(outputFunction);
+ addExpression(expression, expression.getName(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
expressions);
@@ -218,6 +219,23 @@ public class ConvertedModel {
return expressions;
}
+ private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) {
+ try {
+ Map<String, TensorType> argumentTypes = new HashMap<>();
+ for (Map.Entry<String, String> entry : function.argumentTypes().entrySet())
+ argumentTypes.put(entry.getKey(), TensorType.fromSpec(entry.getValue()));
+
+ return new ExpressionFunction(function.name(),
+ function.arguments(),
+ new RankingExpression(function.expression()),
+ argumentTypes,
+ function.returnType().map(TensorType::fromSpec));
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e);
+ }
+ }
+
private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
@@ -248,7 +266,9 @@ public class ConvertedModel {
return store.readExpressions();
}
- private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName,
+ String constantValueString) {
+ Tensor constantValue = Tensor.from(constantValueString);
store.writeSmallConstant(constantName, constantValue);
profile.addConstant(constantName, asValue(constantValue));
}
@@ -258,7 +278,8 @@ public class ConvertedModel {
QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByFunctions,
String constantName,
- Tensor constantValue) {
+ String constantValueString) {
+ Tensor constantValue = Tensor.from(constantValueString);
RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName);
if (rankingExpressionFunctionOverridingConstant != null) {
TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles));
@@ -306,14 +327,14 @@ public class ConvertedModel {
Set<String> functionNames = new HashSet<>();
addFunctionNamesIn(expression.getRoot(), functionNames, model);
for (String functionName : functionNames) {
- TensorType requiredType = model.inputs().get(functionName);
- if (requiredType == null) continue; // Not a required function
+ Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
+ if ( ! requiredType.isPresent()) continue; // Not a required function
RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
if (rankingExpressionFunction == null)
throw new IllegalArgumentException("Model refers input '" + functionName +
- "' of type " + requiredType + " but this function is not present in " +
- profile);
+ "' of type " + requiredType.get() +
+ " but this function is not present in " + profile);
// TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
// phase and summary features), as it may only resolve correctly given those bindings
// Or, probably better, annotate the functions with type constraints here and verify during general
@@ -321,12 +342,12 @@ public class ConvertedModel {
TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles));
if ( actualType == null)
throw new IllegalArgumentException("Model refers input '" + functionName +
- "' of type " + requiredType +
+ "' of type " + requiredType.get() +
" which must be produced by a function in the rank profile, but " +
"this function references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
+ if ( ! actualType.isAssignableTo(requiredType.get()))
throw new IllegalArgumentException("Model refers input '" + functionName + "'. " +
- typeMismatchExplanation(requiredType, actualType));
+ typeMismatchExplanation(requiredType.get(), actualType));
}
}
@@ -339,7 +360,7 @@ public class ConvertedModel {
/** Add the generated functions to the rank profile */
private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) {
- model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy()));
+ model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v)));
}
/**
@@ -383,7 +404,7 @@ public class ConvertedModel {
List<ExpressionNode> children = ((TensorFunctionNode)node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.inputs().containsKey(referenceNode.getName())) {
+ if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
@@ -391,7 +412,7 @@ public class ConvertedModel {
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.inputs().containsKey(referenceNode.getName())) {
+ if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
@@ -487,7 +508,7 @@ public class ConvertedModel {
if (referenceNode.getOutput() == null) { // function references cannot specify outputs
names.add(referenceNode.getName());
if (model.functions().containsKey(referenceNode.getName())) {
- addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model);
+ addFunctionNamesIn(RankingExpression.from(model.functions().get(referenceNode.getName())).getRoot(), names, model);
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
index a1b3a0cd9c0..b3b530448fc 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/IndexedSearchCluster.java
@@ -422,6 +422,15 @@ public class IndexedSearchCluster extends SearchCluster
}
builder.maxNodesDownPerGroup(rootDispatch.getMaxNodesDownPerFixedRow());
builder.useMultilevelDispatch(useMultilevelDispatchSetup());
+ builder.searchableCopies(rootDispatch.getSearchableCopies());
+ if (searchCoverage != null) {
+ if (searchCoverage.getMinimum() != null)
+ builder.minSearchCoverage(searchCoverage.getMinimum());
+ if (searchCoverage.getMinWaitAfterCoverageFactor() != null)
+ builder.minWaitAfterCoverageFactor(searchCoverage.getMinWaitAfterCoverageFactor());
+ if (searchCoverage.getMaxWaitAfterCoverageFactor() != null)
+ builder.maxWaitAfterCoverageFactor(searchCoverage.getMaxWaitAfterCoverageFactor());
+ }
}
@Override
diff --git a/configdefinitions/src/vespa/dispatch.def b/configdefinitions/src/vespa/dispatch.def
index dce6098ee9b..50989c3ef74 100644
--- a/configdefinitions/src/vespa/dispatch.def
+++ b/configdefinitions/src/vespa/dispatch.def
@@ -19,6 +19,18 @@ distributionPolicy enum { ROUNDROBIN, ADAPTIVE } default=ROUNDROBIN
# Is multi-level dispatch configured for this cluster
useMultilevelDispatch bool default=false
+# Number of document copies
+searchableCopies long default=1
+
+# Minimum search coverage required before returning the results of a query
+minSearchCoverage double default=100
+
+# Minimum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage
+minWaitAfterCoverageFactor double default=0
+
+# Maximum wait time for full coverage after minimum coverage is achieved, factored based on time left at minimum coverage
+maxWaitAfterCoverageFactor double default=1
+
# The unique key of a search node
node[].key int
diff --git a/container-core/src/main/java/com/yahoo/container/handler/Coverage.java b/container-core/src/main/java/com/yahoo/container/handler/Coverage.java
index 4a937068d81..84cc0734e7c 100644
--- a/container-core/src/main/java/com/yahoo/container/handler/Coverage.java
+++ b/container-core/src/main/java/com/yahoo/container/handler/Coverage.java
@@ -28,9 +28,9 @@ public class Coverage {
EXPLICITLY_FULL, EXPLICITLY_INCOMPLETE, DOCUMENT_COUNT;
}
- private final static int DEGRADED_BY_MATCH_PHASE = 1;
- private final static int DEGRADED_BY_TIMEOUT = 2;
- private final static int DEGRADED_BY_ADAPTIVE_TIMEOUT = 4;
+ public final static int DEGRADED_BY_MATCH_PHASE = 1;
+ public final static int DEGRADED_BY_TIMEOUT = 2;
+ public final static int DEGRADED_BY_ADAPTIVE_TIMEOUT = 4;
/**
* Build an invalid instance to initiate manually.
diff --git a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java
index de4d9c9fe8b..f40550f1f70 100644
--- a/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java
+++ b/container-search/src/main/java/com/yahoo/fs4/mplex/FS4Channel.java
@@ -1,6 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.fs4.mplex;
+import com.yahoo.concurrent.SystemTimer;
+import com.yahoo.fs4.BasicPacket;
+import com.yahoo.fs4.ChannelTimeoutException;
+import com.yahoo.fs4.Packet;
+import com.yahoo.search.Query;
+import com.yahoo.search.dispatch.ResponseMonitor;
+
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@@ -9,12 +16,6 @@ import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
-import com.yahoo.concurrent.SystemTimer;
-import com.yahoo.fs4.BasicPacket;
-import com.yahoo.fs4.ChannelTimeoutException;
-import com.yahoo.fs4.Packet;
-import com.yahoo.search.Query;
-
/**
* This class is used to represent a "channel" in the FS4 protocol.
* A channel represents a session between a client and the fdispatch.
@@ -34,6 +35,7 @@ public class FS4Channel {
volatile private BlockingQueue<BasicPacket> responseQueue;
private Query query;
private boolean isPingChannel = false;
+ private ResponseMonitor<FS4Channel> monitor;
/** for unit testing. do not use */
protected FS4Channel () {
@@ -197,6 +199,9 @@ public class FS4Channel {
throws InterruptedException, InvalidChannelException
{
ensureValidQ().put(packet);
+ if(monitor != null) {
+ monitor.responseAvailable(this);
+ }
}
/**
@@ -241,4 +246,7 @@ public class FS4Channel {
return "fs4 channel " + channelId + (isValid() ? " [valid]" : " [invalid]");
}
+ public void setResponseMonitor(ResponseMonitor<FS4Channel> monitor) {
+ this.monitor = monitor;
+ }
}
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java
index f5d082635ab..8fa8bdb66bf 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4InvokerFactory.java
@@ -50,7 +50,7 @@ public class FS4InvokerFactory {
public SearchInvoker getSearchInvoker(Query query, Node node) {
Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port());
- return new FS4SearchInvoker(searcher, query, backend.openChannel(), node);
+ return new FS4SearchInvoker(searcher, query, backend.openChannel(), Optional.of(node));
}
/**
@@ -70,14 +70,14 @@ public class FS4InvokerFactory {
* list is invalid and the remaining coverage is not sufficient
*/
public Optional<SearchInvoker> getSearchInvoker(Query query, int groupId, List<Node> nodes, boolean acceptIncompleteCoverage) {
- Map<Integer, SearchInvoker> invokers = new HashMap<>();
+ List<SearchInvoker> invokers = new ArrayList<>(nodes.size());
Set<Integer> failed = null;
for (Node node : nodes) {
boolean nodeAdded = false;
if (node.isWorking()) {
Backend backend = fs4ResourcePool.getBackend(node.hostname(), node.fs4port());
if (backend.probeConnection()) {
- invokers.put(node.key(), new FS4SearchInvoker(searcher, query, backend.openChannel(), node));
+ invokers.add(node.key(), new FS4SearchInvoker(searcher, query, backend.openChannel(), Optional.of(node)));
nodeAdded = true;
}
}
@@ -99,7 +99,7 @@ public class FS4InvokerFactory {
}
if (!searchCluster.isPartialGroupCoverageSufficient(groupId, success)) {
if (acceptIncompleteCoverage) {
- createCoverageErrorInvoker(invokers, nodes, failed);
+ invokers.add(createCoverageErrorInvoker(nodes, failed));
} else {
return Optional.empty();
}
@@ -107,13 +107,13 @@ public class FS4InvokerFactory {
}
if (invokers.size() == 1) {
- return Optional.of(invokers.values().iterator().next());
+ return Optional.of(invokers.get(0));
} else {
- return Optional.of(new InterleavedSearchInvoker(invokers));
+ return Optional.of(new InterleavedSearchInvoker(invokers, searchCluster));
}
}
- private void createCoverageErrorInvoker(Map<Integer, SearchInvoker> invokers, List<Node> nodes, Set<Integer> failed) {
+ private SearchInvoker createCoverageErrorInvoker(List<Node> nodes, Set<Integer> failed) {
long activeDocuments = 0;
StringBuilder down = new StringBuilder("Connection failure on nodes with distribution-keys: ");
Integer key = null;
@@ -129,7 +129,8 @@ public class FS4InvokerFactory {
}
}
Coverage coverage = new Coverage(0, activeDocuments, 0);
- invokers.put(key, new SearchErrorInvoker(ErrorMessage.createBackendCommunicationError(down.toString()), coverage));
+ coverage.setNodesTried(1);
+ return new SearchErrorInvoker(ErrorMessage.createBackendCommunicationError(down.toString()), coverage);
}
public FillInvoker getFillInvoker(Query query, Node node) {
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java
index 98676890bdf..da32cfc4fda 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FS4SearchInvoker.java
@@ -6,14 +6,13 @@ import com.yahoo.fs4.ChannelTimeoutException;
import com.yahoo.fs4.Packet;
import com.yahoo.fs4.QueryPacket;
import com.yahoo.fs4.QueryResultPacket;
-import com.yahoo.fs4.mplex.Backend;
import com.yahoo.fs4.mplex.FS4Channel;
import com.yahoo.fs4.mplex.InvalidChannelException;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
+import com.yahoo.search.dispatch.ResponseMonitor;
import com.yahoo.search.dispatch.SearchInvoker;
import com.yahoo.search.dispatch.searchcluster.Node;
-import com.yahoo.search.result.Coverage;
import com.yahoo.search.result.ErrorMessage;
import java.io.IOException;
@@ -30,29 +29,21 @@ import static java.util.Arrays.asList;
*
* @author ollivir
*/
-public class FS4SearchInvoker extends SearchInvoker {
+public class FS4SearchInvoker extends SearchInvoker implements ResponseMonitor<FS4Channel> {
private final VespaBackEndSearcher searcher;
private FS4Channel channel;
- private final Optional<Node> node;
private ErrorMessage pendingSearchError = null;
private Query query = null;
private QueryPacket queryPacket = null;
- public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, FS4Channel channel, Node node) {
+ public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, FS4Channel channel, Optional<Node> node) {
+ super(node);
this.searcher = searcher;
- this.node = Optional.of(node);
this.channel = channel;
channel.setQuery(query);
- }
-
- // fdispatch code path
- public FS4SearchInvoker(VespaBackEndSearcher searcher, Query query, Backend backend) {
- this.searcher = searcher;
- this.node = Optional.empty();
- this.channel = backend.openChannel();
- channel.setQuery(query);
+ channel.setResponseMonitor(this);
}
@Override
@@ -68,6 +59,8 @@ public class FS4SearchInvoker extends SearchInvoker {
this.query = query;
this.queryPacket = queryPacket;
+ channel.setResponseMonitor(this);
+
try {
boolean couldSend = channel.sendPacket(queryPacket);
if (!couldSend) {
@@ -115,7 +108,7 @@ public class FS4SearchInvoker extends SearchInvoker {
searcher.addMetaInfo(query, queryPacket.getQueryPacketData(), resultPacket, result);
- searcher.addUnfilledHits(result, resultPacket.getDocuments(), false, queryPacket.getQueryPacketData(), cacheKey, node.map(Node::key));
+ searcher.addUnfilledHits(result, resultPacket.getDocuments(), false, queryPacket.getQueryPacketData(), cacheKey, distributionKey());
Packet[] packets;
CacheControl cacheControl = searcher.getCacheControl();
PacketWrapper packetWrapper = cacheControl.lookup(cacheKey, query);
@@ -130,7 +123,7 @@ public class FS4SearchInvoker extends SearchInvoker {
} else {
packets = new Packet[1];
packets[0] = resultPacket;
- cacheControl.cache(cacheKey, query, new DocsumPacketKey[0], packets, node.map(Node::key));
+ cacheControl.cache(cacheKey, query, new DocsumPacketKey[0], packets, distributionKey());
}
}
return asList(result);
@@ -138,10 +131,7 @@ public class FS4SearchInvoker extends SearchInvoker {
private List<Result> errorResult(ErrorMessage errorMessage) {
Result error = new Result(query, errorMessage);
- node.ifPresent(n -> {
- Coverage coverage = new Coverage(0, n.getActiveDocuments(), 0);
- error.setCoverage(coverage);
- });
+ getErrorCoverage().ifPresent(error::setCoverage);
return Arrays.asList(error);
}
@@ -164,4 +154,9 @@ public class FS4SearchInvoker extends SearchInvoker {
private boolean isLoggingFine() {
return getLogger().isLoggable(Level.FINE);
}
+
+ @Override
+ public void responseAvailable(FS4Channel from) {
+ responseAvailable();
+ }
}
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java
index a98c34295ee..209f6faefa0 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/FastSearcher.java
@@ -222,7 +222,7 @@ public class FastSearcher extends VespaBackEndSearcher {
if(direct.isPresent()) {
return fs4InvokerFactory.getSearchInvoker(query, direct.get());
}
- return new FS4SearchInvoker(this, query, dispatchBackend);
+ return new FS4SearchInvoker(this, query, dispatchBackend.openChannel(), Optional.empty());
}
/**
@@ -284,6 +284,7 @@ public class FastSearcher extends VespaBackEndSearcher {
result.hits().addAll(partialResult.hits().asUnorderedHits());
}
if (finalCoverage != null) {
+ adjustCoverageDegradedReason(finalCoverage);
result.setCoverage(finalCoverage);
}
@@ -301,6 +302,18 @@ public class FastSearcher extends VespaBackEndSearcher {
return result;
}
+ private void adjustCoverageDegradedReason(Coverage coverage) {
+ int asked = coverage.getNodesTried();
+ int answered = coverage.getNodes();
+ if (asked > answered) {
+ int searchableCopies = (int) dispatcher.searchCluster().dispatchConfig().searchableCopies();
+ int missingNodes = (asked - answered) - (searchableCopies - 1);
+ if (missingNodes > 0) {
+ coverage.setDegradedReason(com.yahoo.container.handler.Coverage.DEGRADED_BY_TIMEOUT);
+ }
+ }
+ }
+
private static @NonNull Optional<String> quotedSummaryClass(String summaryClass) {
return Optional.of(summaryClass == null ? "[null]" : quote(summaryClass));
}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java
index 0382f47457e..1ca64be7924 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java
@@ -20,7 +20,6 @@ import com.yahoo.vespa.config.search.DispatchConfig;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
-import java.util.Map;
import java.util.Optional;
import java.util.Set;
@@ -134,7 +133,7 @@ public class Dispatcher extends AbstractComponent {
int max = Integer.min(searchCluster.orderedGroups().size(), MAX_GROUP_SELECTION_ATTEMPTS);
Set<Integer> rejected = null;
for (int i = 0; i < max; i++) {
- Optional<Group> groupInCluster = loadBalancer.takeGroupForQuery(rejected);
+ Optional<Group> groupInCluster = loadBalancer.takeGroup(rejected);
if (!groupInCluster.isPresent()) {
// No groups available
break;
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java
index 9ff43df87cf..83647b332e6 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java
@@ -5,12 +5,25 @@ import com.yahoo.fs4.QueryPacket;
import com.yahoo.prelude.fastsearch.CacheKey;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
+import com.yahoo.search.dispatch.searchcluster.SearchCluster;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.vespa.config.search.DispatchConfig;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.IdentityHashMap;
import java.util.List;
-import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import static com.yahoo.container.handler.Coverage.DEGRADED_BY_ADAPTIVE_TIMEOUT;
+import static com.yahoo.container.handler.Coverage.DEGRADED_BY_TIMEOUT;
/**
* InterleavedSearchInvoker uses multiple {@link SearchInvoker} objects to interface with content
@@ -19,11 +32,25 @@ import java.util.Map;
*
* @author ollivir
*/
-public class InterleavedSearchInvoker extends SearchInvoker {
- private final Collection<SearchInvoker> invokers;
+public class InterleavedSearchInvoker extends SearchInvoker implements ResponseMonitor<SearchInvoker> {
+ private static final Logger log = Logger.getLogger(InterleavedSearchInvoker.class.getName());
+
+ private final Set<SearchInvoker> invokers;
+ private final SearchCluster searchCluster;
+ private final LinkedBlockingQueue<SearchInvoker> availableForProcessing;
+ private Query query;
+
+ private boolean adaptiveTimeoutCalculated = false;
+ private long adaptiveTimeoutMin = 0;
+ private long adaptiveTimeoutMax = 0;
+ private long deadline = 0;
- public InterleavedSearchInvoker(Map<Integer, SearchInvoker> invokers) {
- this.invokers = new ArrayList<>(invokers.values());
+ public InterleavedSearchInvoker(Collection<SearchInvoker> invokers, SearchCluster searchCluster) {
+ super(Optional.empty());
+ this.invokers = Collections.newSetFromMap(new IdentityHashMap<>());
+ this.invokers.addAll(invokers);
+ this.searchCluster = searchCluster;
+ this.availableForProcessing = newQueue();
}
/**
@@ -33,27 +60,109 @@ public class InterleavedSearchInvoker extends SearchInvoker {
*/
@Override
protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException {
+ this.query = query;
+ invokers.forEach(invoker -> invoker.setMonitor(this));
+ deadline = currentTime() + query.getTimeLeft();
+
int originalHits = query.getHits();
int originalOffset = query.getOffset();
query.setHits(query.getHits() + query.getOffset());
query.setOffset(0);
+
for (SearchInvoker invoker : invokers) {
invoker.sendSearchRequest(query, null);
}
+
query.setHits(originalHits);
query.setOffset(originalOffset);
}
@Override
protected List<Result> getSearchResults(CacheKey cacheKey) throws IOException {
+ int requests = invokers.size();
+ int responses = 0;
List<Result> results = new ArrayList<>();
- for (SearchInvoker invoker : invokers) {
- results.addAll(invoker.getSearchResults(cacheKey));
+ long nextTimeout = query.getTimeLeft();
+ try {
+ while (!invokers.isEmpty() && nextTimeout >= 0) {
+ SearchInvoker invoker = availableForProcessing.poll(nextTimeout, TimeUnit.MILLISECONDS);
+ if (invoker == null) {
+ if (log.isLoggable(Level.FINE)) {
+ log.fine("Search timed out with " + requests + " requests made, " + responses + " responses received");
+ }
+ break;
+ } else {
+ invokers.remove(invoker);
+ results.addAll(invoker.getSearchResults(cacheKey));
+ responses++;
+ }
+ nextTimeout = nextTimeout(requests, responses);
+ }
+ } catch (InterruptedException e) {
+ throw new RuntimeException("Interrupted while waiting for search results", e);
}
+
+ insertTimeoutErrors(results);
return results;
}
+ private void insertTimeoutErrors(List<Result> results) {
+ int degradedReason = adaptiveTimeoutCalculated ? DEGRADED_BY_ADAPTIVE_TIMEOUT : DEGRADED_BY_TIMEOUT;
+
+ for (SearchInvoker invoker : invokers) {
+ Optional<Integer> dk = invoker.distributionKey();
+ String message;
+ if (dk.isPresent()) {
+ message = "Backend communication timeout on node with distribution-key " + dk.get();
+ } else {
+ message = "Backend communication timeout";
+ }
+ Result error = new Result(query, ErrorMessage.createBackendCommunicationError(message));
+ invoker.getErrorCoverage().ifPresent(coverage -> {
+ coverage.setDegradedReason(degradedReason);
+ error.setCoverage(coverage);
+ });
+ results.add(error);
+ }
+ }
+
+ private long nextTimeout(int requests, int responses) {
+ DispatchConfig config = searchCluster.dispatchConfig();
+ double minimumCoverage = config.minSearchCoverage();
+
+ if (requests == responses || minimumCoverage >= 100.0) {
+ return query.getTimeLeft();
+ }
+ int minimumResponses = (int) (requests * minimumCoverage / 100.0);
+
+ if (responses < minimumResponses) {
+ return query.getTimeLeft();
+ }
+
+ long timeLeft = query.getTimeLeft();
+ if (!adaptiveTimeoutCalculated) {
+ adaptiveTimeoutMin = (long) (timeLeft * config.minWaitAfterCoverageFactor());
+ adaptiveTimeoutMax = (long) (timeLeft * config.maxWaitAfterCoverageFactor());
+ adaptiveTimeoutCalculated = true;
+ }
+
+ long now = currentTime();
+ int pendingQueries = requests - responses;
+ double missWidth = ((100.0 - config.minSearchCoverage()) * requests) / 100.0 - 1.0;
+ double slopedWait = adaptiveTimeoutMin;
+ if (pendingQueries > 1 && missWidth > 0.0) {
+ slopedWait += ((adaptiveTimeoutMax - adaptiveTimeoutMin) * (pendingQueries - 1)) / missWidth;
+ }
+ long nextAdaptive = (long) slopedWait;
+ if (now + nextAdaptive >= deadline) {
+ return deadline - now;
+ }
+ deadline = now + nextAdaptive;
+
+ return nextAdaptive;
+ }
+
@Override
protected void release() {
if (!invokers.isEmpty()) {
@@ -61,4 +170,26 @@ public class InterleavedSearchInvoker extends SearchInvoker {
invokers.clear();
}
}
+
+ @Override
+ public void responseAvailable(SearchInvoker from) {
+ if (availableForProcessing != null) {
+ availableForProcessing.add(from);
+ }
+ }
+
+ @Override
+ protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) {
+ // never to be called
+ }
+
+ // For overriding in tests
+ protected long currentTime() {
+ return System.currentTimeMillis();
+ }
+
+ // For overriding in tests
+ protected LinkedBlockingQueue<SearchInvoker> newQueue() {
+ return new LinkedBlockingQueue<>();
+ }
}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java
index 222ae6a4ea0..df6384cf81c 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/LoadBalancer.java
@@ -1,7 +1,6 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.dispatch;
-import com.yahoo.search.Query;
import com.yahoo.search.dispatch.searchcluster.Group;
import com.yahoo.search.dispatch.searchcluster.SearchCluster;
@@ -45,13 +44,13 @@ public class LoadBalancer {
}
/**
- * Select and allocate the search cluster group which is to be used for the provided query. Callers <b>must</b> call
+ * Select and allocate the search cluster group which is to be used for the next search query. Callers <b>must</b> call
* {@link #releaseGroup} symmetrically for each taken allocation.
*
* @param rejectedGroups if not null, the load balancer will only return groups with IDs not in the set
* @return The node group to target, or <i>empty</i> if the internal dispatch logic cannot be used
*/
- public Optional<Group> takeGroupForQuery(Set<Integer> rejectedGroups) {
+ public Optional<Group> takeGroup(Set<Integer> rejectedGroups) {
if (scoreboard == null) {
return Optional.empty();
}
@@ -60,7 +59,7 @@ public class LoadBalancer {
}
/**
- * Release an allocation given by {@link #takeGroupForQuery}. The release must be done exactly once for each allocation.
+ * Release an allocation given by {@link #takeGroup}. The release must be done exactly once for each allocation.
*
* @param group
* previously allocated group
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java b/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java
new file mode 100644
index 00000000000..c2e81d43677
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/ResponseMonitor.java
@@ -0,0 +1,13 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.dispatch;
+
+/**
+ * Classes implementing ResponseMonitor can be informed by monitored objects
+ * that a response is available for processing. The responseAvailable method
+ * must be thread-safe.
+ *
+ * @author ollivir
+ */
+public interface ResponseMonitor<T> {
+ void responseAvailable(T from);
+}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java
index d5c505aa31b..01da3c20745 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchErrorInvoker.java
@@ -11,6 +11,7 @@ import com.yahoo.search.result.ErrorMessage;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
+import java.util.Optional;
/**
* A search invoker that will immediately produce an error that occurred during
@@ -23,8 +24,10 @@ public class SearchErrorInvoker extends SearchInvoker {
private final ErrorMessage message;
private Query query;
private final Coverage coverage;
+ private ResponseMonitor<SearchInvoker> monitor;
public SearchErrorInvoker(ErrorMessage message, Coverage coverage) {
+ super(Optional.empty());
this.message = message;
this.coverage = coverage;
}
@@ -36,6 +39,9 @@ public class SearchErrorInvoker extends SearchInvoker {
@Override
protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException {
this.query = query;
+ if(monitor != null) {
+ monitor.responseAvailable(this);
+ }
}
@Override
@@ -52,4 +58,8 @@ public class SearchErrorInvoker extends SearchInvoker {
// nothing to do
}
+ @Override
+ protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) {
+ this.monitor = monitor;
+ }
}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java
index 53e09823f32..2691b32d631 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/SearchInvoker.java
@@ -5,9 +5,12 @@ import com.yahoo.fs4.QueryPacket;
import com.yahoo.prelude.fastsearch.CacheKey;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
+import com.yahoo.search.dispatch.searchcluster.Node;
+import com.yahoo.search.result.Coverage;
import java.io.IOException;
import java.util.List;
+import java.util.Optional;
/**
* SearchInvoker encapsulates an allocated connection for running a single search query.
@@ -16,6 +19,13 @@ import java.util.List;
* @author ollivir
*/
public abstract class SearchInvoker extends CloseableInvoker {
+ private final Optional<Node> node;
+ private ResponseMonitor<SearchInvoker> monitor;
+
+ protected SearchInvoker(Optional<Node> node) {
+ this.node = node;
+ }
+
/**
* Retrieve the hits for the given {@link Query}. The invoker may return more than one result, in which case the caller is responsible
* for merging the results. If multiple results are returned and the search query had a hit offset other than zero, that offset is
@@ -29,4 +39,26 @@ public abstract class SearchInvoker extends CloseableInvoker {
protected abstract void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException;
protected abstract List<Result> getSearchResults(CacheKey cacheKey) throws IOException;
+
+ protected void setMonitor(ResponseMonitor<SearchInvoker> monitor) {
+ this.monitor = monitor;
+ }
+
+ protected void responseAvailable() {
+ if(monitor != null) {
+ monitor.responseAvailable(this);
+ }
+ }
+
+ protected Optional<Integer> distributionKey() {
+ return node.map(Node::key);
+ }
+
+ protected Optional<Coverage> getErrorCoverage() {
+ if(node.isPresent()) {
+ return Optional.of(new Coverage(0, node.get().getActiveDocuments(), 0));
+ } else {
+ return Optional.empty();
+ }
+ }
}
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java
index b8d76906f70..8e278f78d7a 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java
@@ -36,10 +36,7 @@ public class SearchCluster implements NodeManager<Node> {
private static final Logger log = Logger.getLogger(SearchCluster.class.getName());
- /** The min active docs a group must have to be considered up, as a % of the average active docs of the other groups */
- private final double minActivedocsCoveragePercentage;
- private final double minGroupCoverage;
- private final int maxNodesDownPerGroup;
+ private final DispatchConfig dispatchConfig;
private final int size;
private final String clusterId;
private final ImmutableMap<Integer, Group> groups;
@@ -62,20 +59,14 @@ public class SearchCluster implements NodeManager<Node> {
private final FS4ResourcePool fs4ResourcePool;
public SearchCluster(String clusterId, DispatchConfig dispatchConfig, FS4ResourcePool fs4ResourcePool, int containerClusterSize, VipStatus vipStatus) {
- this(clusterId, dispatchConfig.minActivedocsPercentage(), dispatchConfig.minGroupCoverage(), dispatchConfig.maxNodesDownPerGroup(),
- toNodes(dispatchConfig), fs4ResourcePool, containerClusterSize, vipStatus);
- }
-
- public SearchCluster(String clusterId, double minActivedocsCoverage, double minGroupCoverage, int maxNodesDownPerGroup, List<Node> nodes, FS4ResourcePool fs4ResourcePool,
- int containerClusterSize, VipStatus vipStatus) {
this.clusterId = clusterId;
- this.minActivedocsCoveragePercentage = minActivedocsCoverage;
- this.minGroupCoverage = minGroupCoverage;
- this.maxNodesDownPerGroup = maxNodesDownPerGroup;
- this.size = nodes.size();
+ this.dispatchConfig = dispatchConfig;
+ this.size = dispatchConfig.node().size();
this.fs4ResourcePool = fs4ResourcePool;
this.vipStatus = vipStatus;
+ List<Node> nodes = toNodes(dispatchConfig);
+
// Create groups
ImmutableMap.Builder<Integer, Group> groupsBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<Integer, List<Node>> group : nodes.stream().collect(Collectors.groupingBy(Node::group)).entrySet()) {
@@ -143,6 +134,10 @@ public class SearchCluster implements NodeManager<Node> {
return nodesBuilder.build();
}
+ public DispatchConfig dispatchConfig() {
+ return dispatchConfig;
+ }
+
/** Returns the number of nodes in this cluster (across all groups) */
public int size() { return size; }
@@ -286,7 +281,7 @@ public class SearchCluster implements NodeManager<Node> {
if (averageDocumentsInOtherGroups > 0) {
double coverage = 100.0 * (double) activeDocuments / averageDocumentsInOtherGroups;
- sufficientCoverage = coverage >= minActivedocsCoveragePercentage;
+ sufficientCoverage = coverage >= dispatchConfig.minActivedocsPercentage();
}
if (sufficientCoverage) {
sufficientCoverage = isGroupNodeCoverageSufficient(nodes);
@@ -302,7 +297,8 @@ public class SearchCluster implements NodeManager<Node> {
}
}
int numNodes = nodes.size();
- int nodesAllowedDown = maxNodesDownPerGroup + (int) (((double) numNodes * (100.0 - minGroupCoverage)) / 100.0);
+ int nodesAllowedDown = dispatchConfig.maxNodesDownPerGroup()
+ + (int) (((double) numNodes * (100.0 - dispatchConfig.minGroupCoverage())) / 100.0);
return nodesUp + nodesAllowedDown >= numNodes;
}
@@ -325,7 +321,7 @@ public class SearchCluster implements NodeManager<Node> {
*/
public boolean isPartialGroupCoverageSufficient(int groupId, List<Node> nodes) {
if (orderedGroups.size() == 1) {
- return nodes.size() >= groupSize() - maxNodesDownPerGroup;
+ return nodes.size() >= groupSize() - dispatchConfig.maxNodesDownPerGroup();
}
long sumOfActiveDocuments = 0;
int otherGroups = 0;
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java
new file mode 100644
index 00000000000..69458f25f93
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java
@@ -0,0 +1,180 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.dispatch;
+
+import com.yahoo.fs4.QueryPacket;
+import com.yahoo.prelude.fastsearch.CacheKey;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.dispatch.searchcluster.Node;
+import com.yahoo.search.dispatch.searchcluster.SearchCluster;
+import com.yahoo.test.ManualClock;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import static com.yahoo.search.dispatch.MockSearchCluster.createDispatchConfig;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * @author ollivir
+ */
+public class InterleavedSearchInvokerTest {
+ private ManualClock clock = new ManualClock(Instant.now());
+ private Query query = new TestQuery();
+ private LinkedList<Event> expectedEvents = new LinkedList<>();
+ private List<SearchInvoker> invokers = new ArrayList<>();
+
+ @Test
+ public void requireThatAdaptiveTimeoutsAreNotUsedWithFullCoverageRequirement() throws IOException {
+ SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(100.0), 1, 3);
+ SearchInvoker invoker = createInterleavedInvoker(cluster, 3);
+
+ expectedEvents.add(new Event(5000, 100, 0));
+ expectedEvents.add(new Event(4900, 100, 1));
+ expectedEvents.add(new Event(4800, 100, 2));
+
+ invoker.search(query, null, null);
+
+ assertTrue("All test scenario events processed", expectedEvents.isEmpty());
+ }
+
+ @Test
+ public void requireThatTimeoutsAreNotMarkedAsAdaptive() throws IOException {
+ SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(100.0), 1, 3);
+ SearchInvoker invoker = createInterleavedInvoker(cluster, 3);
+
+ expectedEvents.add(new Event(5000, 300, 0));
+ expectedEvents.add(new Event(4700, 300, 1));
+ expectedEvents.add(null);
+
+ List<Result> results = invoker.search(query, null, null);
+
+ assertTrue("All test scenario events processed", expectedEvents.isEmpty());
+ assertNotNull("Last invoker is marked as an error", results.get(2).hits().getErrorHit());
+ assertTrue("Timed out invoker is a normal timeout", results.get(2).getCoverage(false).isDegradedByTimeout());
+ }
+
+ @Test
+ public void requireThatAdaptiveTimeoutDecreasesTimeoutWhenCoverageIsReached() throws IOException {
+ SearchCluster cluster = new MockSearchCluster("!", createDispatchConfig(50.0), 1, 4);
+ SearchInvoker invoker = createInterleavedInvoker(cluster, 4);
+
+ expectedEvents.add(new Event(5000, 100, 0));
+ expectedEvents.add(new Event(4900, 100, 1));
+ expectedEvents.add(new Event(2400, 100, 2));
+ expectedEvents.add(new Event(0, 0, null));
+
+ List<Result> results = invoker.search(query, null, null);
+
+ assertTrue("All test scenario events processed", expectedEvents.isEmpty());
+ assertNotNull("Last invoker is marked as an error", results.get(3).hits().getErrorHit());
+ assertTrue("Timed out invoker is an adaptive timeout", results.get(3).getCoverage(false).isDegradedByAdapativeTimeout());
+ }
+
+ private InterleavedSearchInvoker createInterleavedInvoker(SearchCluster searchCluster, int numInvokers) {
+ for (int i = 0; i < numInvokers; i++) {
+ invokers.add(new TestInvoker());
+ }
+
+ return new InterleavedSearchInvoker(invokers, searchCluster) {
+ @Override
+ protected long currentTime() {
+ return clock.millis();
+ }
+
+ @Override
+ protected LinkedBlockingQueue<SearchInvoker> newQueue() {
+ return new LinkedBlockingQueue<SearchInvoker>() {
+ @Override
+ public SearchInvoker poll(long timeout, TimeUnit timeUnit) throws InterruptedException {
+ assertFalse(expectedEvents.isEmpty());
+ Event ev = expectedEvents.removeFirst();
+ if (ev == null) {
+ return null;
+ } else {
+ return ev.process(query, timeout);
+ }
+ }
+ };
+ }
+ };
+ }
+
+ private class Event {
+ Long expectedTimeout;
+ long delay;
+ Integer invokerIndex;
+
+ public Event(Integer expectedTimeout, int delay, Integer invokerIndex) {
+ this.expectedTimeout = (long) expectedTimeout;
+ this.delay = delay;
+ this.invokerIndex = invokerIndex;
+ }
+
+ public SearchInvoker process(Query query, long currentTimeout) {
+ if (expectedTimeout != null) {
+ assertEquals("Expecting timeout to be " + expectedTimeout, (long) expectedTimeout, currentTimeout);
+ }
+ clock.advance(Duration.ofMillis(delay));
+ if (query.getTimeLeft() < 0) {
+ fail("Test sequence ran out of time window");
+ }
+ if (invokerIndex == null) {
+ return null;
+ } else {
+ return invokers.get(invokerIndex);
+ }
+ }
+ }
+
+ private class TestInvoker extends SearchInvoker {
+ protected TestInvoker() {
+ super(Optional.of(new Node(42, "?", 0, 0)));
+ }
+
+ @Override
+ protected void sendSearchRequest(Query query, QueryPacket queryPacket) throws IOException {
+ }
+
+ @Override
+ protected List<Result> getSearchResults(CacheKey cacheKey) throws IOException {
+ return Collections.singletonList(new Result(query));
+ }
+
+ @Override
+ protected void release() {
+ }
+ }
+
+ public class TestQuery extends Query {
+ private long start = clock.millis();
+
+ public TestQuery() {
+ super();
+ setTimeout(5000);
+ }
+
+ @Override
+ public long getStartTime() {
+ return start;
+ }
+
+ @Override
+ public long getDurationTime() {
+ return clock.millis() - start;
+ }
+ }
+}
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java
index 38a753360d8..c056423a9c4 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/LoadBalancerTest.java
@@ -7,9 +7,9 @@ import com.yahoo.search.dispatch.searchcluster.SearchCluster;
import junit.framework.AssertionFailedError;
import org.junit.Test;
-import java.util.Arrays;
import java.util.Optional;
+import static com.yahoo.search.dispatch.MockSearchCluster.createDispatchConfig;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.not;
@@ -22,10 +22,10 @@ public class LoadBalancerTest {
@Test
public void requreThatLoadBalancerServesSingleNodeSetups() {
Node n1 = new Node(0, "test-node1", 0, 0);
- SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1), null, 1, null);
+ SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1), null, 1, null);
LoadBalancer lb = new LoadBalancer(cluster, true);
- Optional<Group> grp = lb.takeGroupForQuery(null);
+ Optional<Group> grp = lb.takeGroup(null);
Group group = grp.orElseGet(() -> {
throw new AssertionFailedError("Expected a SearchCluster.Group");
});
@@ -36,10 +36,10 @@ public class LoadBalancerTest {
public void requreThatLoadBalancerServesMultiGroupSetups() {
Node n1 = new Node(0, "test-node1", 0, 0);
Node n2 = new Node(1, "test-node2", 1, 1);
- SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null);
+ SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null);
LoadBalancer lb = new LoadBalancer(cluster, true);
- Optional<Group> grp = lb.takeGroupForQuery(null);
+ Optional<Group> grp = lb.takeGroup(null);
Group group = grp.orElseGet(() -> {
throw new AssertionFailedError("Expected a SearchCluster.Group");
});
@@ -52,10 +52,10 @@ public class LoadBalancerTest {
Node n2 = new Node(1, "test-node2", 1, 0);
Node n3 = new Node(0, "test-node3", 0, 1);
Node n4 = new Node(1, "test-node4", 1, 1);
- SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2, n3, n4), null, 2, null);
+ SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2, n3, n4), null, 2, null);
LoadBalancer lb = new LoadBalancer(cluster, true);
- Optional<Group> grp = lb.takeGroupForQuery(null);
+ Optional<Group> grp = lb.takeGroup(null);
assertThat(grp.isPresent(), is(true));
}
@@ -63,18 +63,18 @@ public class LoadBalancerTest {
public void requreThatLoadBalancerReturnsDifferentGroups() {
Node n1 = new Node(0, "test-node1", 0, 0);
Node n2 = new Node(1, "test-node2", 1, 1);
- SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null);
+ SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null);
LoadBalancer lb = new LoadBalancer(cluster, true);
// get first group
- Optional<Group> grp = lb.takeGroupForQuery(null);
+ Optional<Group> grp = lb.takeGroup(null);
Group group = grp.get();
int id1 = group.id();
// release allocation
lb.releaseGroup(group);
// get second group
- grp = lb.takeGroupForQuery(null);
+ grp = lb.takeGroup(null);
group = grp.get();
assertThat(group.id(), not(equalTo(id1)));
}
@@ -83,16 +83,16 @@ public class LoadBalancerTest {
public void requreThatLoadBalancerReturnsGroupWithShortestQueue() {
Node n1 = new Node(0, "test-node1", 0, 0);
Node n2 = new Node(1, "test-node2", 1, 1);
- SearchCluster cluster = new SearchCluster("a", 88.0, 99.0, 0, Arrays.asList(n1, n2), null, 1, null);
+ SearchCluster cluster = new SearchCluster("a", createDispatchConfig(n1, n2), null, 1, null);
LoadBalancer lb = new LoadBalancer(cluster, true);
// get first group
- Optional<Group> grp = lb.takeGroupForQuery(null);
+ Optional<Group> grp = lb.takeGroup(null);
Group group = grp.get();
int id1 = group.id();
// get second group
- grp = lb.takeGroupForQuery(null);
+ grp = lb.takeGroup(null);
group = grp.get();
int id2 = group.id();
assertThat(id2, not(equalTo(id1)));
@@ -100,7 +100,7 @@ public class LoadBalancerTest {
lb.releaseGroup(group);
// get third group
- grp = lb.takeGroupForQuery(null);
+ grp = lb.takeGroup(null);
group = grp.get();
assertThat(group.id(), equalTo(id2));
}
diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java
index fc505097472..f7b92419b52 100644
--- a/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java
+++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockSearchCluster.java
@@ -6,9 +6,9 @@ import com.google.common.collect.ImmutableMultimap;
import com.yahoo.search.dispatch.searchcluster.Group;
import com.yahoo.search.dispatch.searchcluster.Node;
import com.yahoo.search.dispatch.searchcluster.SearchCluster;
+import com.yahoo.vespa.config.search.DispatchConfig;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.List;
import java.util.Optional;
@@ -22,7 +22,11 @@ public class MockSearchCluster extends SearchCluster {
private final ImmutableMultimap<String, Node> nodesByHost;
public MockSearchCluster(String clusterId, int groups, int nodesPerGroup) {
- super(clusterId, 100, 100, 0, Collections.emptyList(), null, 1, null);
+ this(clusterId, createDispatchConfig(), groups, nodesPerGroup);
+ }
+
+ public MockSearchCluster(String clusterId, DispatchConfig dispatchConfig, int groups, int nodesPerGroup) {
+ super(clusterId, dispatchConfig, null, 1, null);
ImmutableMap.Builder<Integer, Group> groupBuilder = ImmutableMap.builder();
ImmutableMultimap.Builder<String, Node> hostBuilder = ImmutableMultimap.builder();
@@ -58,7 +62,7 @@ public class MockSearchCluster extends SearchCluster {
}
public Optional<Group> group(int n) {
- if(n < numGroups) {
+ if (n < numGroups) {
return Optional.of(groups.get(n));
} else {
return Optional.empty();
@@ -80,4 +84,24 @@ public class MockSearchCluster extends SearchCluster {
public void failed(Node node) {
node.setWorking(false);
}
+
+ public static DispatchConfig createDispatchConfig(Node... nodes) {
+ return createDispatchConfig(100.0, nodes);
+ }
+
+ public static DispatchConfig createDispatchConfig(double minSearchCoverage, Node... nodes) {
+ DispatchConfig.Builder builder = new DispatchConfig.Builder();
+ builder.minActivedocsPercentage(88.0);
+ builder.minGroupCoverage(99.0);
+ builder.maxNodesDownPerGroup(0);
+ builder.minSearchCoverage(minSearchCoverage);
+ if(minSearchCoverage < 100.0) {
+ builder.minWaitAfterCoverageFactor(0);
+ builder.maxWaitAfterCoverageFactor(0.5);
+ }
+ for (Node n : nodes) {
+ builder.node(new DispatchConfig.Node.Builder().key(n.key()).host(n.hostname()).port(n.fs4port()).group(n.group()));
+ }
+ return new DispatchConfig(builder);
+ }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 2866a2c76b2..c2235b9abe9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
@@ -59,18 +60,29 @@ public class ImportedModel {
/** Returns an immutable map of the inputs of this */
public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
+ // CFG
+ public Optional<String> inputTypeSpec(String input) {
+ return Optional.ofNullable(inputs.get(input)).map(TensorType::toString);
+ }
+
/**
- * Returns an immutable map of the small constants of this.
+ * Returns an immutable map of the small constants of this, represented as strings on the standard tensor form.
* These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
*/
- public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+ // CFG
+ public Map<String, String> smallConstants() { return asTensorStrings(smallConstants); }
+
+ boolean hasSmallConstant(String name) { return smallConstants.containsKey(name); }
/**
* Returns an immutable map of the large constants of this.
* These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
* For TensorFlow this corresponds to Variable files stored separately.
*/
- public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
+ // CFG
+ public Map<String, String> largeConstants() { return asTensorStrings(largeConstants); }
+
+ boolean hasLargeConstant(String name) { return largeConstants.containsKey(name); }
/**
* Returns an immutable map of the expressions of this - corresponding to graph nodes
@@ -79,11 +91,14 @@ public class ImportedModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+ // TODO: Most of the usage of the above can be replaced by a faster expressionNames method
+
/**
* Returns an immutable map of the functions that are part of this model.
* Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
*/
- public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); }
+ // CFG
+ public Map<String, String> functions() { return asExpressionStrings(functions); }
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -108,43 +123,60 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public List<Pair<String, ExpressionFunction>> outputExpressions() {
- List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
+ // CFG
+ public List<ImportedFunction> outputExpressions() {
+ List<ImportedFunction> functions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
- signatureEntry.getValue().outputExpression(outputEntry.getKey())
- .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
+ functions.add(signatureEntry.getValue().outputFunction(outputEntry.getKey(),
+ signatureEntry.getKey() + "." + outputEntry.getKey()));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- expressions.add(new Pair<>(signatureEntry.getKey(),
- new ExpressionFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().values()),
- expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap(),
- Optional.empty())));
+ functions.add(new ImportedFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty()));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- expressions.add(new Pair<>(singleEntry.getKey(),
- new ExpressionFunction(singleEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- singleEntry.getValue(),
- inputs,
- Optional.empty())));
+ functions.add(new ImportedFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty()));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- expressions.add(new Pair<>(expressionEntry.getKey(),
- new ExpressionFunction(expressionEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- expressionEntry.getValue(),
- inputs,
- Optional.empty())));
+ functions.add(new ImportedFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty()));
}
}
}
- return expressions;
+ return functions;
+ }
+
+ private Map<String, String> asTensorStrings(Map<String, Tensor> map) {
+ HashMap<String, String> values = new HashMap<>();
+ for (Map.Entry<String, Tensor> entry : map.entrySet()) {
+ Tensor tensor = entry.getValue();
+ // TODO: See Tensor.toStandardString
+ if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty())
+ values.put(entry.getKey(), tensor.toString());
+ else
+ values.put(entry.getKey(), tensor.type() + ":" + tensor);
+ }
+ return values;
+ }
+
+ private Map<String, String> asExpressionStrings(Map<String, RankingExpression> map) {
+ HashMap<String, String> values = new HashMap<>();
+ for (Map.Entry<String, RankingExpression> entry : map.entrySet())
+ values.put(entry.getKey(), entry.getValue().getRoot().toString());
+ return values;
}
/**
@@ -213,6 +245,17 @@ public class ImportedModel {
Optional.empty());
}
+ /** Returns the expression this output references as an imported function */
+ public ImportedFunction outputFunction(String outputName, String functionName) {
+ return new ImportedFunction(functionName,
+ new ArrayList<>(inputs.values()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
+ }
+
+ // CFG
+
@Override
public String toString() { return "signature '" + name + "'"; }
@@ -223,4 +266,37 @@ public class ImportedModel {
}
+ // CFG
+ public static class ImportedFunction {
+
+ private final String name;
+ private final List<String> arguments;
+ private final Map<String, String> argumentTypes;
+ private final String expression;
+ private final Optional<String> returnType;
+
+ public ImportedFunction(String name, List<String> arguments, RankingExpression expression,
+ Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
+ this.name = name;
+ this.arguments = arguments;
+ this.expression = expression.getRoot().toString();
+ this.argumentTypes = asStrings(argumentTypes);
+ this.returnType = returnType.map(TensorType::toString);
+ }
+
+ private static Map<String, String> asStrings(Map<String, TensorType> map) {
+ Map<String, String> stringMap = new HashMap<>();
+ for (Map.Entry<String, TensorType> entry : map.entrySet())
+ stringMap.put(entry.getKey(), entry.getValue().toString());
+ return stringMap;
+ }
+
+ public String name() { return name; }
+ public List<String> arguments() { return Collections.unmodifiableList(arguments); }
+ public Map<String, String> argumentTypes() { return Collections.unmodifiableMap(argumentTypes); }
+ public String expression() { return expression; }
+ public Optional<String> returnType() { return returnType; }
+
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
index 1b7532631e1..bfdaaca1dd7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
@@ -69,6 +69,7 @@ public class ImportedModels {
* models directory works
* @return the model at this path or null if none
*/
+ // CFG
public ImportedModel get(File modelPath) {
return importedModels.get(toName(modelPath));
}
@@ -78,6 +79,7 @@ public class ImportedModels {
}
/** Returns an immutable collection of all the imported models */
+ // CFG
public Collection<ImportedModel> all() {
return importedModels.values();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index cb095e81147..8a885938bf9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -121,7 +121,7 @@ public abstract class ModelImporter {
private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ if (model.hasLargeConstant(name) || model.hasSmallConstant(name)) {
return operation.function();
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index d3996da9b58..315456c2613 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -28,13 +28,13 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check constants
assertEquals(2, model.largeConstants().size());
- Tensor constant0 = model.largeConstants().get("test_Variable");
+ Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable"));
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.largeConstants().get("test_Variable_1");
+ Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1"));
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
@@ -84,8 +84,8 @@ public class OnnxMnistSoftmaxImportTestCase {
private Context contextFrom(ImportedModel result) {
MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
return context;
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index 6215997d8f9..be676186017 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -24,13 +24,13 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// Check constants
Assert.assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("test_Variable_read");
+ Tensor constant0 = Tensor.from(model.get().largeConstants().get("test_Variable_read"));
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("test_Variable_1_read");
+ Tensor constant1 = Tensor.from(model.get().largeConstants().get("test_Variable_1_read"));
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index c3b82cccb46..4ff0c96d369 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -93,8 +93,8 @@ public class TestableTensorFlowModel {
static Context contextFrom(ImportedModel result) {
TestableModelContext context = new TestableModelContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
return context;
}
@@ -108,7 +108,7 @@ public class TestableTensorFlowModel {
private void evaluateFunction(Context context, ImportedModel model, String functionName) {
if (!context.names().contains(functionName)) {
- RankingExpression e = model.functions().get(functionName);
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
evaluateFunctionDependencies(context, model, e.getRoot());
context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 483ccd330e0..1ee22c69c23 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -230,7 +230,7 @@ public interface Tensor {
* @return the tensor on the standard string format
*/
static String toStandardString(Tensor tensor) {
- if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Never do that?
+ if (tensor.isEmpty() && ! tensor.type().dimensions().isEmpty()) // explicitly output type TODO: Always do that
return tensor.type() + ":" + contentToString(tensor);
else
return contentToString(tensor);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 000f33696f2..fa32d385004 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -76,32 +76,41 @@ class TensorParser {
}
private static Tensor fromCellString(Tensor.Builder builder, String s) {
- s = s.trim().substring(1).trim();
- while (s.length() > 1) {
- int keyOrTensorEnd = s.indexOf('}');
+ int index = 1;
+ index = skipSpace(index, s);
+ while (index + 1 < s.length()) {
+ int keyOrTensorEnd = s.indexOf('}', index);
TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type());
- if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAdress is empty
- addLabels(s.substring(0, keyOrTensorEnd + 1), addressBuilder);
- s = s.substring(keyOrTensorEnd + 1).trim();
- if ( ! s.startsWith(":"))
- throw new IllegalArgumentException("Expecting a ':' after " + s + ", got '" + s + "'");
- s = s.substring(1);
+ if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty
+ addLabels(s.substring(index, keyOrTensorEnd + 1), addressBuilder);
+ index = keyOrTensorEnd + 1;
+ index = skipSpace(index, s);
+ if ( s.charAt(index) != ':')
+ throw new IllegalArgumentException("Expecting a ':' after " + s.substring(index) + ", got '" + s + "'");
+ index++;
}
- int valueEnd = s.indexOf(',');
+ int valueEnd = s.indexOf(',', index);
if (valueEnd < 0) { // last value
- valueEnd = s.indexOf("}");
+ valueEnd = s.indexOf('}', index);
if (valueEnd < 0)
throw new IllegalArgumentException("A tensor string must end by '}'");
}
TensorAddress address = addressBuilder.build();
- Double value = asDouble(address, s.substring(0, valueEnd).trim());
+ Double value = asDouble(address, s.substring(index, valueEnd).trim());
builder.cell(address, value);
- s = s.substring(valueEnd+1).trim();
+ index = valueEnd+1;
+ index = skipSpace(index, s);
}
return builder.build();
}
+ private static int skipSpace(int index, String s) {
+ while (index < s.length() && s.charAt(index) == ' ')
+ index++;
+ return index;
+ }
+
/** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */
private static void addLabels(String mapAddressString, TensorAddress.Builder builder) {
mapAddressString = mapAddressString.trim();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 38a8329bff1..122b6019884 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -101,7 +101,7 @@ public class TensorTestCase {
" {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"),
Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
assertEquals(Tensor.from("{ {x:0,y:0,z:0}:1, {x:0,y:1,z:0}:0, {x:1,y:0,z:0}:0, {x:1,y:1,z:0}:0, {x:2,y:0,z:0}:0, {x:2,y:1,z:0}:0, "+
- " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 }"),
+ " {x:0,y:0,z:1}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:1}:1, {x:2,y:0,z:1}:0, {x:2,y:1,z:1}:00 } "),
Tensor.diag(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build()));
assertEquals(Tensor.from("{ {x:1}:0, {x:3}:1, {x:9}:0 }"), Tensor.from("{ {x:1}:1, {x:3}:5, {x:9}:3 }").argmax("x"));
}