diff options
182 files changed, 1878 insertions, 677 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a54e21aae68..2be3022ce6e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -108,7 +108,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement if (FeatureNames.isSimpleFeature(reference)) { // The argument may be a local identifier bound to the actual value String argument = reference.simpleArgument().get(); - reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); + String argumentBinding = getBinding(argument); + reference = Reference.simple(reference.name(), argumentBinding != null ? argumentBinding : argument); return featureTypes.get(reference); } @@ -152,7 +153,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private Optional<String> boundIdentifier(Reference reference) { if ( ! reference.arguments().isEmpty()) return Optional.empty(); if ( reference.output() != null) return Optional.empty(); - return Optional.ofNullable(bindings.get(reference.name())); + return Optional.ofNullable(getBinding(reference.name())); } private Optional<ExpressionFunction> functionInvocation(Reference reference) { @@ -203,8 +204,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement Map<String, String> bindings = new HashMap<>(formalArguments.size()); for (int i = 0; i < formalArguments.size(); i++) { String identifier = invocationArguments.expressions().get(i).toString(); - identifier = super.bindings.getOrDefault(identifier, identifier); - bindings.put(formalArguments.get(i), identifier); + String identifierBinding = super.getBinding(identifier); + bindings.put(formalArguments.get(i), identifierBinding != null ? identifierBinding : identifier); } return bindings; } @@ -215,7 +216,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { - if (bindings.isEmpty() && this.bindings.isEmpty()) return this; return new MapEvaluationTypeContext(functions(), bindings, featureTypes, currentResolutionCallStack); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java index bf585df9005..271442768a8 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java @@ -53,8 +53,8 @@ public class RankProfileRegistry { if (existingRangProfileWithSameName == null) return; if ( ! overridableRankProfileNames.contains(rankProfileName)) { - throw new IllegalArgumentException("Cannot add rank profile '" + rankProfileName + "' in search definition '" - + rankProfile.getSearch().getName() + "', since it already exists"); + throw new IllegalArgumentException("Duplicate rank profile '" + rankProfileName + "' in " + + rankProfile.getSearch()); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 6192db2654e..1a22b98fd9f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -324,7 +324,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { try { firstPhaseRanking = new RankingExpression(property.getValue()); } catch (ParseException e) { - throw new IllegalArgumentException("Could not parse second phase expression", e); + throw new IllegalArgumentException("Could not parse first phase expression", e); } } else if ("rankingExpression(secondphase).rankingScript".equals(property.getName())) { @@ -406,7 +406,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { return properties; } - private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String phase) { + private List<Pair<String, String>> deriveRankingPhaseRankProperties(RankingExpression expression, String phase) { List<Pair<String, String>> properties = new ArrayList<>(); if (expression == null) return properties; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/FunctionShadower.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/FunctionShadower.java index c3ee7d5fc3d..bb2e20a4f05 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/FunctionShadower.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/FunctionShadower.java @@ -44,9 +44,8 @@ public class FunctionShadower extends ExpressionTransformer<RankProfileTransform private ExpressionNode transformFunctionNode(FunctionNode function, RankProfileTransformContext context) { String name = function.getFunction().toString(); RankProfile.RankingExpressionFunction rankingExpressionFunction = context.rankProfile().findFunction(name); - if (rankingExpressionFunction == null) { + if (rankingExpressionFunction == null) return transformChildren(function, context); - } int functionArity = function.getFunction().arity(); if (functionArity != rankingExpressionFunction.function().arguments().size()) diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index cebfa244159..554a36aef86 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -112,7 +112,7 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" rankprofile[].name "profile7" rankprofile[].fef.property[].name "rankingExpression(reshaped).rankingScript" -rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])({x:1 - x, y:d0})" +rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])(attribute(f2){x:1 - x, y:d0})" rankprofile[].fef.property[].name "rankingExpression(reshaped).type" rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])" rankprofile[].fef.property[].name "vespa.rank.firstphase" @@ -127,3 +127,34 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f4" rankprofile[].fef.property[].value "tensor(x[10],y[20])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].name "profile8" +rankprofile[].fef.property[].name "rankingExpression(functionNotLabel).rankingScript" +rankprofile[].fef.property[].value "3" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(tensor(d0[1])(attribute{x:(rankingExpression(functionNotLabel))}), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].name "profile9" +rankprofile[].fef.property[].name "rankingExpression(shadow).rankingScript" +rankprofile[].fef.property[].value "3" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(tensor(shadow[1])(attribute{x:shadow + rankingExpression(shadow)}), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor<float>(x[10])" + diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index 15d56517a43..c3380bed19c 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -90,4 +90,29 @@ search tensor { } + rank-profile profile8 { + + first-phase { + expression: sum(tensor(d0[1])(attribute{x:(functionNotLabel)})) + } + + function functionNotLabel() { + expression: 3 + } + + } + + rank-profile profile9 { + + # shadow refers to the generate index and shadow() to the function + first-phase { + expression: sum(tensor(shadow[1])(attribute{x: shadow + shadow() })) + } + + function shadow() { + expression: 3 + } + + } + } diff --git a/config/src/tests/api/api.cpp b/config/src/tests/api/api.cpp index 4db66761444..0af2b848ea5 100644 --- a/config/src/tests/api/api.cpp +++ b/config/src/tests/api/api.cpp @@ -32,7 +32,7 @@ TEST_MT_FFF("require that source may be unable to serve config temporarily", 2, ASSERT_TRUE(cfg.get() != NULL); ASSERT_EQUAL("myfoo", cfg->myField); } else { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); f3.myField = "myfoo"; f2.addBuilder("myid", &f3); f1->reload(); diff --git a/config/src/tests/configfetcher/configfetcher.cpp b/config/src/tests/configfetcher/configfetcher.cpp index 607ab0a29a5..be25e913980 100644 --- a/config/src/tests/configfetcher/configfetcher.cpp +++ b/config/src/tests/configfetcher/configfetcher.cpp @@ -69,7 +69,7 @@ TEST("requireThatConfigUpdatesArePerformed") { while (!cb._configured && timer.elapsed().ms() < 20000.0) { if (cb._configured) break; - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } ASSERT_TRUE(cb._configured); ASSERT_TRUE(cb._config); diff --git a/config/src/tests/configretriever/configretriever.cpp b/config/src/tests/configretriever/configretriever.cpp index fc921a324af..87f189ad7d3 100644 --- a/config/src/tests/configretriever/configretriever.cpp +++ b/config/src/tests/configretriever/configretriever.cpp @@ -251,7 +251,7 @@ public: if (configured) { return true; } - FastOS_Thread::Sleep(200); + std::this_thread::sleep_for(200ms); } return configured; } diff --git a/config/src/tests/frt/frt.cpp b/config/src/tests/frt/frt.cpp index ba8279a1999..28dea82bfe7 100644 --- a/config/src/tests/frt/frt.cpp +++ b/config/src/tests/frt/frt.cpp @@ -49,7 +49,7 @@ namespace { while (timer.elapsed().ms() < timeoutInMillis) { if (notified) break; - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return notified; } @@ -260,7 +260,7 @@ TEST_FF("require that request is config task is scheduled", SourceFixture(), FRT f1.conn.scheduler.CheckTasks(); if (f2.result.notified) break; - FastOS_Thread::Sleep(500); + std::this_thread::sleep_for(500ms); } ASSERT_TRUE(f2.result.notified); f2.src.close(); diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index 3dabf9bc649..395d8853603 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -226,9 +226,8 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } public static QueryProfileType getArgumentType() { return argumentType; } - /** The aliases of query properties */ - private static Map<String,CompoundName> propertyAliases; + private static Map<String, CompoundName> propertyAliases; static { Map<String,CompoundName> propertyAliasesBuilder = new HashMap<>(); addAliases(Query.getArgumentType(), propertyAliasesBuilder); diff --git a/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java b/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java index 499cb634295..6e36881ae63 100644 --- a/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java @@ -115,7 +115,8 @@ public class FederationSearcher extends ForkingSearcher { this(searchChainResolver, false, PropagateSourceProperties.ALL, null); } - private FederationSearcher(SearchChainResolver searchChainResolver, boolean strictSearchchain, + private FederationSearcher(SearchChainResolver searchChainResolver, + boolean strictSearchchain, PropagateSourceProperties.Enum propagateSourceProperties, TargetSelector targetSelector) { this.searchChainResolver = searchChainResolver; @@ -295,9 +296,11 @@ public class FederationSearcher extends ForkingSearcher { } } - private Object getSourceOrProviderProperty(Query query, CompoundName propertyName, - String sourceName, String providerName, - Object defaultValue) { + private Object getSourceOrProviderProperty(Query query, + CompoundName propertyName, + String sourceName, + String providerName, + Object defaultValue) { Object result = getProperty(query, new SourceKey(sourceName, propertyName.toString())); if (result == null) result = getProperty(query, new ProviderKey(providerName, propertyName.toString())); diff --git a/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java b/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java index 7ade9a0eaf9..8ccbe39cc5a 100644 --- a/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java +++ b/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java @@ -1,12 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.federation.selection; -import java.util.Optional; import com.yahoo.component.chain.Chain; import com.yahoo.search.Searcher; import com.yahoo.search.searchchain.model.federation.FederationOptions; -import static com.google.common.base.Preconditions.checkNotNull; +import java.util.Objects; /** * Represents a search chain that the federation searcher should send a query to, @@ -22,11 +21,8 @@ public final class FederationTarget<T> { private final T customData; public FederationTarget(Chain<Searcher> chain, FederationOptions federationOptions, T customData) { - checkNotNull(chain); - checkNotNull(federationOptions); - - this.chain = chain; - this.federationOptions = federationOptions; + this.chain = Objects.requireNonNull(chain, "chain cannot be null"); + this.federationOptions = Objects.requireNonNull(federationOptions, "federationOptions cannot be null"); this.customData = customData; } @@ -62,9 +58,7 @@ public final class FederationTarget<T> { @Override public int hashCode() { - int result = chain.hashCode(); - result = 31 * result + federationOptions.hashCode(); - return result; + return Objects.hash(chain, federationOptions); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java b/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java index ac39e986ff0..a4a82d27f8e 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java @@ -20,14 +20,14 @@ import java.util.Map; public class PropertyAliases extends Properties { /** A map from aliases to standard names */ - private final Map<String,CompoundName> aliases; + private final Map<String, CompoundName> aliases; /** * Creates an instance with a set of aliases. The given aliases will be used directly by this class. * To make this class immutable and thread safe, relinquish ownership of the parameter map. */ - public PropertyAliases(Map<String,CompoundName> aliases) { - this.aliases=aliases; + public PropertyAliases(Map<String, CompoundName> aliases) { + this.aliases = aliases; } /** @@ -42,20 +42,21 @@ public class PropertyAliases extends Properties { } @Override - public Map<String, Object> listProperties(CompoundName property,Map<String,String> context, - com.yahoo.processing.request.Properties substitution) { - return super.listProperties(unalias(property),context,substitution); + public Map<String, Object> listProperties(CompoundName property, + Map<String,String> context, + com.yahoo.processing.request.Properties substitution) { + return super.listProperties(unalias(property), context, substitution); } @Override - public Object get(CompoundName name,Map<String,String> context, + public Object get(CompoundName name, Map<String,String> context, com.yahoo.processing.request.Properties substitution) { return super.get(unalias(name),context,substitution); } @Override - public void set(CompoundName name,Object value,Map<String,String> context) { - super.set(unalias(name),value,context); + public void set(CompoundName name, Object value, Map<String,String> context) { + super.set(unalias(name), value, context); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 4cdd4488f7b..c06c84fcc36 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -38,13 +38,13 @@ public class QueryProperties extends Properties { } public void setParentQuery(Query query) { - this.query=query; + this.query = query; super.setParentQuery(query); } - @SuppressWarnings("deprecation") @Override - public Object get(CompoundName key, Map<String,String> context, + public Object get(CompoundName key, + Map<String,String> context, com.yahoo.processing.request.Properties substitution) { if (key.size() == 2 && key.first().equals(Model.MODEL)) { Model model = query.getModel(); diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java index 95669f7f05d..26bf189dd3d 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java @@ -33,7 +33,8 @@ enum PathGroup { "/zone/v2/{*}"), /** Paths used for creating and reading user resources. */ - user("/application/v4/user", + user(Optional.of("/api"), + "/application/v4/user", "/athenz/v1/{*}"), /** Paths used for creating tenants with proper access control. */ @@ -176,7 +177,6 @@ enum PathGroup { "/deployment/v1/{*}", "/", "/d/{*}", - "/static/{*}", "/statuspage/v1/{*}"), /** Same as classifiedInfo, but with optional /api prefix */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java index 26c4bf6292a..d10a4879bf5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java @@ -37,6 +37,7 @@ import java.util.logging.Logger; public class AthenzApiHandler extends LoggingRequestHandler { private final static Logger log = Logger.getLogger(AthenzApiHandler.class.getName()); + private static final String OPTIONAL_PREFIX = "/api"; private final AthenzFacade athenz; private final AthenzDomain sandboxDomain; @@ -69,7 +70,7 @@ public class AthenzApiHandler extends LoggingRequestHandler { } private HttpResponse get(HttpRequest request) { - Path path = new Path(request.getUri()); + Path path = new Path(request.getUri(), OPTIONAL_PREFIX); if (path.matches("/athenz/v1")) return root(request); if (path.matches("/athenz/v1/domains")) return domainList(request); if (path.matches("/athenz/v1/properties")) return properties(); @@ -79,7 +80,7 @@ public class AthenzApiHandler extends LoggingRequestHandler { } private HttpResponse post(HttpRequest request) { - Path path = new Path(request.getUri()); + Path path = new Path(request.getUri(), OPTIONAL_PREFIX); if (path.matches("/athenz/v1/user")) return signup(request); return ErrorResponse.notFoundError(String.format("No '%s' handler at '%s'", request.getMethod(), request.getUri().getPath())); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java index d37df2cc313..c263054c808 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java @@ -127,6 +127,7 @@ public class ControllerContainerTest { " </handler>\n" + " <handler id='com.yahoo.vespa.hosted.controller.restapi.athenz.AthenzApiHandler'>\n" + " <binding>http://*/athenz/v1/*</binding>\n" + + " <binding>http://*/api/athenz/v1/*</binding>\n" + " </handler>\n" + " <handler id='com.yahoo.vespa.hosted.controller.restapi.zone.v1.ZoneApiHandler'>\n" + " <binding>http://*/zone/v1</binding>\n" + diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java index c90dcbf7e2b..34ee160c449 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java @@ -48,7 +48,7 @@ public class AthenzApiTest extends ControllerContainerTest { new File("property-list.json")); // POST user signup - tester.assertResponse(authenticatedRequest("http://localhost:8080/athenz/v1/user", "", Request.Method.POST), + tester.assertResponse(authenticatedRequest("http://localhost:8080/api/athenz/v1/user", "", Request.Method.POST), "{\"message\":\"User 'bob' added to admin role of 'vespa.vespa.tenants.sandbox'\"}"); } diff --git a/documentapi/src/tests/policies/policies_test.cpp b/documentapi/src/tests/policies/policies_test.cpp index 93c5d51fef5..0c3a39186a7 100644 --- a/documentapi/src/tests/policies/policies_test.cpp +++ b/documentapi/src/tests/policies/policies_test.cpp @@ -284,7 +284,7 @@ Test::assertMirrorReady(const slobrok::api::IMirrorAPI &mirror) if (mirror.ready()) { return; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } ASSERT_TRUE(false); } @@ -297,7 +297,7 @@ Test::assertMirrorContains(const slobrok::api::IMirrorAPI &mirror, const string if (mirror.lookup(pattern).size() == numEntries) { return; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } ASSERT_TRUE(false); } diff --git a/documentapi/src/tests/policies/testframe.cpp b/documentapi/src/tests/policies/testframe.cpp index 4cdc5d4ba14..1ca449816d9 100644 --- a/documentapi/src/tests/policies/testframe.cpp +++ b/documentapi/src/tests/policies/testframe.cpp @@ -8,6 +8,8 @@ #include <vespa/messagebus/testlib/simpleprotocol.h> #include <vespa/messagebus/testlib/simplereply.h> #include <vespa/messagebus/network/rpcnetworkparams.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".testframe"); @@ -297,7 +299,7 @@ TestFrame::waitSlobrok(const string &pattern, uint32_t cnt) if (res.size() == cnt) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } LOG(error, "Slobrok failed to resolve '%s' to %d recipients in time.", pattern.c_str(), cnt); return false; diff --git a/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp b/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp index 18dd525b066..e82a184d8b2 100644 --- a/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp +++ b/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp @@ -1,12 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "externslobrokpolicy.h" -#include <vespa/vespalib/text/stringtokenizer.h> #include <vespa/messagebus/routing/routingcontext.h> +#include <vespa/vespalib/text/stringtokenizer.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fnet/frt/frt.h> #include <vespa/slobrok/sbmirror.h> #include <vespa/fnet/transport.h> -#include <vespa/fastos/thread.h> +#include <thread> using slobrok::api::IMirrorAPI; using slobrok::api::MirrorAPI; @@ -82,7 +83,7 @@ ExternSlobrokPolicy::lookup(mbus::RoutingContext& context, const string& pattern if (_firstTry) { int count = 0; while (entries.empty() && count < 100) { - FastOS_Thread::Sleep(50); + std::this_thread::sleep_for(50ms); entries = mirror.lookup(pattern); count++; } diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index 6932ba46eab..6b9b52900c7 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -951,11 +951,22 @@ TEST("require that tensor peek can contain expressions") { TEST_DO(verify_parse("t{x:1,y:(foo)}", "f(t,foo)(t{x:1,y:(foo)})")); } +TEST("require that trivial tensor peek string/number expressions are converted to verbatim labels") { + TEST_DO(verify_parse("t{x:\"foo\"}", "f(t)(t{x:foo})")); + TEST_DO(verify_parse("t{x:(\"foo\")}", "f(t)(t{x:foo})")); + TEST_DO(verify_parse("t{x:5.5}", "f(t)(t{x:5})")); + TEST_DO(verify_parse("t{x:(5.5)}", "f(t)(t{x:5})")); +} + TEST("require that tensor peek can contain extra whitespace") { TEST_DO(verify_parse(" t { x : 1 + bar , y : foo + 2 } ", "f(t,bar,foo)(t{x:(1+bar),y:(foo+2)})")); } +TEST("require that converted tensor peek string expression must be valid identifier") { + TEST_DO(verify_error("x{a:\"5.5\"}", "[x{a:\"5.5\"]...[invalid identifier: '5.5']...[}]")); +} + TEST("require that empty tensor peek is not allowed") { TEST_DO(verify_error("x{}", "[x{}]...[empty peek spec]...[]")); } @@ -963,8 +974,8 @@ TEST("require that empty tensor peek is not allowed") { //----------------------------------------------------------------------------- TEST("require that nested tensor lambda using tensor peek can be parsed") { - vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:(0)}," - "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:(1)}}"); + vespalib::string expect("tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a),{y:1}:((0+1)+a)}{y:0}," + "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a),{y:1}:((1+1)+a)}{y:1}}"); EXPECT_EQUAL(Function::parse(expect).dump(), expect); auto fun = Function::parse("tensor(x[2])(tensor(y[2])(x+y+a){y:(x)})"); EXPECT_EQUAL(fun.dump(), expect); diff --git a/eval/src/vespa/eval/eval/basic_nodes.h b/eval/src/vespa/eval/eval/basic_nodes.h index 52f39510a9b..62af30fb71f 100644 --- a/eval/src/vespa/eval/eval/basic_nodes.h +++ b/eval/src/vespa/eval/eval/basic_nodes.h @@ -122,7 +122,7 @@ public: String(const vespalib::string &value_in) : _value(value_in) {} bool is_const() const override { return true; } double get_const_value() const override { return hash(); } - const vespalib::string value() const { return _value; } + const vespalib::string &value() const { return _value; } uint32_t hash() const { return hash_code(_value.data(), _value.size()); } vespalib::string dump(DumpContext &ctx) const override; void accept(NodeVisitor &visitor) const override; diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index a7eaf14e9cd..593451d6838 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -445,6 +445,14 @@ bool is_ident(char c, bool first) { (c == '$' && !first)); } +bool is_ident(const vespalib::string &str) { + bool result = str.empty() ? false : is_ident(str[0], true); + for (size_t i = 1; result && (i < str.size()); ++i) { + result &= is_ident(str[i], false); + } + return result; +} + vespalib::string get_ident(ParseContext &ctx, bool allow_empty) { ctx.skip_spaces(); vespalib::string ident; @@ -797,7 +805,19 @@ void parse_tensor_peek(ParseContext &ctx) { peek_spec.emplace(dim_name, label); } else { ctx.restore_input_mark(before_label); - peek_spec.emplace(dim_name, get_expression(ctx)); + auto expr = get_expression(ctx); + if (auto num = nodes::as<nodes::Number>(*expr)) { + size_t index(num->value()); + peek_spec.emplace(dim_name, make_string("%zu", index)); + } else if (auto str = nodes::as<nodes::String>(*expr)) { + if (is_ident(str->value())) { + peek_spec.emplace(dim_name, str->value()); + } else { + ctx.fail(make_string("invalid identifier: '%s'", str->value().c_str())); + } + } else { + peek_spec.emplace(dim_name, std::move(expr)); + } } } } diff --git a/fastos/src/tests/processtest.cpp b/fastos/src/tests/processtest.cpp index a6729dbb783..5a78eff1d36 100644 --- a/fastos/src/tests/processtest.cpp +++ b/fastos/src/tests/processtest.cpp @@ -3,6 +3,8 @@ #include <vespa/fastos/process.h> #include <vespa/fastos/timestamp.h> +using namespace std::chrono_literals; + class MyListener : public FastOS_ProcessRedirectListener { private: @@ -119,7 +121,7 @@ public: xproc->WriteStdin(nullptr, 0); } - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } if(i == 10) diff --git a/fastos/src/tests/thread_bounce_test.cpp b/fastos/src/tests/thread_bounce_test.cpp index f7bb7ee1260..84506938455 100644 --- a/fastos/src/tests/thread_bounce_test.cpp +++ b/fastos/src/tests/thread_bounce_test.cpp @@ -43,7 +43,7 @@ class Thread_Bounce_Test : public ThreadTestBase int left = static_cast<int>(checkTime.elapsed().ms()); while (left < 1000) { - FastOS_Thread::Sleep(1000 - left); + std::this_thread::sleep_for(std::chrono::milliseconds(1000 - left)); left = static_cast<int>(checkTime.elapsed().ms()); } diff --git a/fastos/src/tests/thread_mutex_test.cpp b/fastos/src/tests/thread_mutex_test.cpp index d49cf37163d..6d3f8c3c5f0 100644 --- a/fastos/src/tests/thread_mutex_test.cpp +++ b/fastos/src/tests/thread_mutex_test.cpp @@ -132,7 +132,7 @@ class Thread_Mutex_Test : public ThreadTestBase { bool lockrc; - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); for(int i=0; i<5; i++) { @@ -145,7 +145,7 @@ class Thread_Mutex_Test : public ThreadTestBase } } - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); lockrc = mtx.try_lock(); Progress(lockrc, "We should get the mutex lock now (%s)", diff --git a/fastos/src/tests/thread_sleep_test.cpp b/fastos/src/tests/thread_sleep_test.cpp index 7fd3412b7c3..209b7d3f880 100644 --- a/fastos/src/tests/thread_sleep_test.cpp +++ b/fastos/src/tests/thread_sleep_test.cpp @@ -20,7 +20,7 @@ class Thread_Sleep_Test : public ThreadTestBase Progress(rc, "Creating Thread"); Progress(true, "Sleeping 3 seconds"); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); } Progress(true, "Closing threadpool..."); diff --git a/fastos/src/tests/thread_stats_test.cpp b/fastos/src/tests/thread_stats_test.cpp index 3633c12bcaa..a9d304d411f 100644 --- a/fastos/src/tests/thread_stats_test.cpp +++ b/fastos/src/tests/thread_stats_test.cpp @@ -31,7 +31,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -44,7 +44,7 @@ class Thread_Stats_Test : public ThreadTestBase job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -57,7 +57,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread->SetBreakFlag(); job[1].ownThread->SetBreakFlag(); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads); @@ -72,7 +72,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].code = WAIT_FOR_BREAK_FLAG; job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 1, "Inactive threads = %d", inactiveThreads); @@ -84,7 +84,7 @@ class Thread_Stats_Test : public ThreadTestBase job[1].code = WAIT_FOR_BREAK_FLAG; job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -97,7 +97,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread->SetBreakFlag(); job[1].ownThread->SetBreakFlag(); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads); diff --git a/fastos/src/tests/thread_test_base.hpp b/fastos/src/tests/thread_test_base.hpp index 7966e95b369..c4f7ed76ea7 100644 --- a/fastos/src/tests/thread_test_base.hpp +++ b/fastos/src/tests/thread_test_base.hpp @@ -3,6 +3,7 @@ #pragma once #include <chrono> +#include <thread> static volatile int64_t number; #define INCREASE_NUMBER_AMOUNT 10000 @@ -47,7 +48,7 @@ public: } } - FastOS_Thread::Sleep(500); + std::this_thread::sleep_for(500ms); if(threadsFinished) break; @@ -88,7 +89,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) Progress(true, "Thread printing message: [%s]", job->message); job->result = strlen(job->message); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); break; } @@ -109,7 +110,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) number = number + 2; if(i == sleepOn) - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } guard = std::unique_lock<std::mutex>(); @@ -123,7 +124,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) { for(;;) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); if(thread->GetBreakFlag()) { @@ -192,7 +193,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) case WAIT2SEC_AND_SIGNALCOND: { - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); job->condition->notify_one(); job->result = 1; break; @@ -202,7 +203,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) { { std::lock_guard<std::mutex> guard(*job->mutex); - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); } job->result = 1; break; @@ -210,7 +211,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) case WAIT_2_SEC: { - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); job->result = 1; break; } diff --git a/fastos/src/tests/threadtest.cpp b/fastos/src/tests/threadtest.cpp index 9507bb1e5d7..0a8a0d2bf02 100644 --- a/fastos/src/tests/threadtest.cpp +++ b/fastos/src/tests/threadtest.cpp @@ -43,7 +43,7 @@ class ThreadTest : public ThreadTestBase if(waitingThreads == numWait) break; - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } } @@ -336,7 +336,7 @@ class ThreadTest : public ThreadTestBase // Threads are not guaranteed to have entered sleep yet, // as this test only tests for result code // Wait another second to be sure. - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } void SignalTest () diff --git a/fastos/src/vespa/fastos/thread.cpp b/fastos/src/vespa/fastos/thread.cpp index 3df8fa584a7..3e2f2674d97 100644 --- a/fastos/src/vespa/fastos/thread.cpp +++ b/fastos/src/vespa/fastos/thread.cpp @@ -352,17 +352,12 @@ void FastOS_ThreadInterface::Join () // FastOS_Runnable // ---------------------------------------------------------------------- -FastOS_Runnable::FastOS_Runnable(void) +FastOS_Runnable::FastOS_Runnable() : _thread(nullptr) { } -FastOS_Runnable::~FastOS_Runnable(void) +FastOS_Runnable::~FastOS_Runnable() { // assert(_thread == nullptr); } - -void FastOS_Runnable::Detach(void) -{ - _thread = nullptr; -} diff --git a/fastos/src/vespa/fastos/thread.h b/fastos/src/vespa/fastos/thread.h index c025a48d563..257acbc92d3 100644 --- a/fastos/src/vespa/fastos/thread.h +++ b/fastos/src/vespa/fastos/thread.h @@ -148,7 +148,7 @@ public: /** * Destructor. Closes pool if necessary. */ - virtual ~FastOS_ThreadPool(void); + virtual ~FastOS_ThreadPool(); /** @@ -168,9 +168,9 @@ public: * Get the stack size used for threads in this pool. * @return Stack size in bytes. */ - int GetStackSize(void) const { return _stackSize; } + int GetStackSize() const { return _stackSize; } - int GetStackGuardSize(void) const { return 0; } + int GetStackGuardSize() const { return 0; } /** * Close the threadpool. This involves setting the break flag on @@ -347,15 +347,7 @@ public: /** * Destructor. */ - virtual ~FastOS_ThreadInterface (){} - - /** - * Sleep for x milliseconds. Attempting to sleep for <1 milliseconds - * will result in failure. - * @param ms Number of milliseconds to sleep. - * @return Boolean success/failure - */ - static bool Sleep(int ms); + virtual ~FastOS_ThreadInterface () {} /** * Instruct a thread to exit. This could be used in conjunction with @@ -469,7 +461,7 @@ public: */ class FastOS_Runnable { -protected: +private: friend class FastOS_ThreadInterface; FastOS_ThreadInterface *_thread; @@ -498,10 +490,9 @@ public: */ virtual void Run(FastOS_ThreadInterface *thisThread, void *arguments)=0; - FastOS_ThreadInterface *GetThread(void) { return _thread; } - const FastOS_ThreadInterface *GetThread(void) const { return _thread; } - bool HasThread(void) const { return _thread != nullptr; } - void Detach(void); + FastOS_ThreadInterface *GetThread() { return _thread; } + const FastOS_ThreadInterface *GetThread() const { return _thread; } + bool HasThread() const { return _thread != nullptr; } }; #include <vespa/fastos/unix_thread.h> diff --git a/fastos/src/vespa/fastos/unix_process.cpp b/fastos/src/vespa/fastos/unix_process.cpp index 86d285059b8..4d4197f5354 100644 --- a/fastos/src/vespa/fastos/unix_process.cpp +++ b/fastos/src/vespa/fastos/unix_process.cpp @@ -40,6 +40,8 @@ extern char **environ; #endif +using namespace std::chrono_literals; + static pid_t safe_fork () { pid_t pid; @@ -1629,7 +1631,7 @@ FastOS_UNIX_ProcessStarter::Wait(FastOS_UNIX_Process *process, } } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return rc; diff --git a/fastos/src/vespa/fastos/unix_thread.cpp b/fastos/src/vespa/fastos/unix_thread.cpp index 5218bde2630..9e48727deb3 100644 --- a/fastos/src/vespa/fastos/unix_thread.cpp +++ b/fastos/src/vespa/fastos/unix_thread.cpp @@ -83,18 +83,6 @@ FastOS_UNIX_Thread::~FastOS_UNIX_Thread() } } -bool FastOS_UNIX_Thread::Sleep (int ms) -{ - bool rc=false; - - if (ms > 0) { - usleep(ms*1000); - rc = true; - } - - return rc; -} - FastOS_ThreadId FastOS_UNIX_Thread::GetThreadId () { return _handle; diff --git a/fastos/src/vespa/fastos/unix_thread.h b/fastos/src/vespa/fastos/unix_thread.h index c6e0b040fc7..35df3f5745f 100644 --- a/fastos/src/vespa/fastos/unix_thread.h +++ b/fastos/src/vespa/fastos/unix_thread.h @@ -36,7 +36,6 @@ public: ~FastOS_UNIX_Thread(); - static bool Sleep (int ms); FastOS_ThreadId GetThreadId () override; static bool CompareThreadIds (FastOS_ThreadId a, FastOS_ThreadId b); static FastOS_ThreadId GetCurrentThreadId (); diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 93df3722da5..b651c5c1a22 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -166,6 +166,12 @@ public class Flags { "Takes effect on restart of config server", NODE_TYPE, HOSTNAME); + public static final UnboundBooleanFlag USE_TLS_FOR_ZOOKEEPER_CLIENT = defineFeatureFlag( + "use-tls-for-zookeeper-client", false, + "Whether to use TLS for ZooKeeper clients", + "Takes effect on restart of process", + NODE_TYPE, HOSTNAME); + public static final UnboundBooleanFlag USE_OLD_METRICS_CHECKS = defineFeatureFlag( "use-old-metrics-checks", true, "Whether to use old metrics checks", diff --git a/fnet/src/examples/timeout/timeout.cpp b/fnet/src/examples/timeout/timeout.cpp index 1d6ecc11909..23dfbeb9070 100644 --- a/fnet/src/examples/timeout/timeout.cpp +++ b/fnet/src/examples/timeout/timeout.cpp @@ -2,7 +2,8 @@ #include <vespa/fnet/fnet.h> #include <vespa/fastos/app.h> -#include <chrono> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("timeout"); @@ -55,7 +56,7 @@ MyApp::Main() transport.Start(&pool); // stable-state operation - FastOS_Thread::Sleep(500); + std::this_thread::sleep_for(500ms); FNET_Packet *packet; FNET_Context context; @@ -64,7 +65,7 @@ MyApp::Main() t = clock::now(); timeout.Schedule(2.0); // timeout in 2 seconds - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); timeout.Unschedule(); // cancel timeout ms = (clock::now() - t); diff --git a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp index 43a61cd9bcd..95dbe672909 100644 --- a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp +++ b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp @@ -54,7 +54,7 @@ TEST("detach return invoke") { if (receptor.req != 0) { break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } req->SubRef(); target->SubRef(); diff --git a/fnet/src/vespa/fnet/frt/invoker.h b/fnet/src/vespa/fnet/frt/invoker.h index 64adf66688e..0838ef84dd3 100644 --- a/fnet/src/vespa/fnet/frt/invoker.h +++ b/fnet/src/vespa/fnet/frt/invoker.h @@ -5,7 +5,6 @@ #include "rpcrequest.h" #include <vespa/fnet/task.h> #include <vespa/fnet/ipackethandler.h> -#include <vespa/fastos/thread.h> #include <mutex> #include <condition_variable> diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateAdapter.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateAdapter.java index 7296e4ce61e..bb89ce736f7 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateAdapter.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateAdapter.java @@ -207,7 +207,9 @@ public class FieldUpdateAdapter implements UpdateAdapter { for (Iterator<FieldValue> it = arr.fieldValueIterator(); it.hasNext();) { FieldValue childVal = it.next(); for (ValueUpdate childUpd : createValueUpdates(childVal, upd.getUpdate())) { - ret.add(new MapValueUpdate(childVal, childUpd)); + // The array update is always directed towards a particular array index, which is + // kept as the _value_ in the original update. + ret.add(new MapValueUpdate(upd.getValue(), childUpd)); } } return ret; diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java index e25af74333d..e51f7984d65 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java @@ -88,7 +88,15 @@ public abstract class FieldUpdateHelper { return val; } else if (upd instanceof MapValueUpdate) { if (val instanceof Array) { - return createFieldValue(val, ((MapValueUpdate)upd).getUpdate()); + var nestedUpdate = ((MapValueUpdate)upd).getUpdate(); + if (nestedUpdate instanceof AssignValueUpdate) { + // Can't assign an array's value type directly to the array, so we have to add it as a + // singular element to the partial document. + ((Array)val).add(nestedUpdate.getValue()); + return val; + } else { + return createFieldValue(val, nestedUpdate); + } } else if (val instanceof MapFieldValue) { throw new UnsupportedOperationException("Can not map into a " + val.getClass().getName() + "."); } else if (val instanceof StructuredFieldValue) { diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/DocumentUpdateTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/DocumentUpdateTestCase.java index beed3053692..a6362e71594 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/DocumentUpdateTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/DocumentUpdateTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.document.update.AddValueUpdate; import com.yahoo.document.update.AssignValueUpdate; import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.ValueUpdate; +import com.yahoo.document.update.MapValueUpdate; import com.yahoo.vespa.indexinglanguage.expressions.Expression; import com.yahoo.vespa.indexinglanguage.parser.ParseException; import org.junit.Test; @@ -66,14 +67,19 @@ public class DocumentUpdateTestCase { assertTrue(upd.getCreateIfNonExistent()); } - @Test - public void assign_updates_to_structs_are_preserved() throws ParseException { - var docType = new DocumentType("my_input"); + private static StructDataType makeStructType() { var structType = new StructDataType("foobarstruct"); var fooField = new Field("foo", DataType.STRING); var barField = new Field("bar", DataType.STRING); structType.addField(fooField); structType.addField(barField); + return structType; + } + + @Test + public void assign_updates_to_structs_are_preserved() throws ParseException { + var docType = new DocumentType("my_input"); + var structType = makeStructType(); docType.addField(new Field("mystruct", structType)); var upd = new DocumentUpdate(docType, "id:scheme:my_input::"); @@ -91,4 +97,59 @@ public class DocumentUpdateTestCase { var av = (AssignValueUpdate)valueUpdate; assertEquals(av.getValue(), updatedStruct); } + + @Test + public void assign_matched_array_of_structs_element_update_is_preserved() throws ParseException { + var docType = new DocumentType("my_input"); + var structType = makeStructType(); + var arrayType = ArrayDataType.getArray(structType); + docType.addField(new Field("my_array", arrayType)); + + var updatedStruct = new Struct(structType); + updatedStruct.setFieldValue("foo", new StringFieldValue("new groovy value")); + updatedStruct.setFieldValue("bar", new StringFieldValue("totally tubular!")); + + var upd = new DocumentUpdate(docType, "id:scheme:my_input::"); + var assignUpdate = ValueUpdate.createAssign(updatedStruct); + upd.addFieldUpdate(FieldUpdate.createMap(docType.getField("my_array"), + new IntegerFieldValue(2), assignUpdate)); + + upd = Expression.execute(Expression.fromString("input my_array | passthrough my_array"), upd); + + assertEquals(upd.fieldUpdates().size(), 1); + var fieldUpdate = upd.getFieldUpdate("my_array"); + assertNotNull(fieldUpdate); + var valueUpdate = fieldUpdate.getValueUpdate(0); + assertTrue(valueUpdate instanceof MapValueUpdate); + var mvu = (MapValueUpdate)valueUpdate; + assertEquals(mvu.getValue(), new IntegerFieldValue(2)); + assertEquals(mvu.getUpdate(), assignUpdate); + } + + @Test + public void assign_matched_array_of_primitives_element_update_is_preserved() throws ParseException { + var docType = new DocumentType("my_input"); + var arrayType = ArrayDataType.getArray(DataType.INT); + docType.addField(new Field("my_array", arrayType)); + + var upd = new DocumentUpdate(docType, "id:scheme:my_input::"); + // Use an unreasonably large array index to ensure nothing creates an implicit array under the + // hood when processing the update itself. "Ensure" here means "the test will most likely OOM + // and we'll notice it pretty quickly". + var arrayIndex = new IntegerFieldValue(2_000_000_000); + var assignUpdate = ValueUpdate.createAssign(new IntegerFieldValue(12345)); + upd.addFieldUpdate(FieldUpdate.createMap(docType.getField("my_array"), arrayIndex, assignUpdate)); + + upd = Expression.execute(Expression.fromString("input my_array | passthrough my_array"), upd); + + assertEquals(upd.fieldUpdates().size(), 1); + var fieldUpdate = upd.getFieldUpdate("my_array"); + assertNotNull(fieldUpdate); + var valueUpdate = fieldUpdate.getValueUpdate(0); + assertTrue(valueUpdate instanceof MapValueUpdate); + var mvu = (MapValueUpdate)valueUpdate; + assertEquals(mvu.getValue(), arrayIndex); + assertEquals(mvu.getUpdate(), assignUpdate); + } + } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ValueUpdateToDocumentTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ValueUpdateToDocumentTestCase.java index b9be7ddbe50..2468dbe5003 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ValueUpdateToDocumentTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ValueUpdateToDocumentTestCase.java @@ -138,6 +138,32 @@ public class ValueUpdateToDocumentTestCase { } @Test + public void array_of_struct_assign_is_converted() { + DocumentType docType = new DocumentType("my_type"); + StructDataType structType = new StructDataType("my_struct"); + structType.addField(new Field("a", DataType.INT)); + ArrayDataType arrType = DataType.getArray(structType); + Field field = new Field("my_arr", arrType); + docType.addField(field); + + var updatedStruct = new Struct(structType); + updatedStruct.setFieldValue("a", new IntegerFieldValue(42)); + ValueUpdate update = ValueUpdate.createMap(new IntegerFieldValue(2), ValueUpdate.createAssign(updatedStruct)); + + Document doc = FieldUpdateHelper.newPartialDocument(docType, new DocumentId("id:foo:my_type::1"), field, update); + assertNotNull(doc); + // Due to how the roller coaster ride of the partial documents appear to work, we end up creating + // a document with an array that only contains the to-be-updated element, but always at the first + // index rather than the arbitrary updated index (which is good because otherwise it'd have to + // be pre-allocated). + FieldValue obj = doc.getFieldValue("my_arr"); + assertTrue(obj instanceof Array); + Array arr = (Array)obj; + assertEquals(1, arr.size()); + assertEquals(updatedStruct, arr.get(0)); + } + + @Test public void requireThatRemoveIsConverted() { DocumentType docType = new DocumentType("my_type"); ArrayDataType arrType = DataType.getArray(DataType.INT); diff --git a/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp b/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp index 40c54980c11..cd1ad7e6eed 100644 --- a/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp +++ b/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp @@ -2,6 +2,8 @@ #include <vespa/fastos/app.h> #include <vespa/fnet/frt/frt.h> +#include <vespa/vespalib/util/time.h> +#include <thread> class RPCInfo : public FastOS_Application { @@ -85,7 +87,7 @@ public: if (info->GetErrorCode() != FRTE_RPC_CONNECTION) { break; } - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); target->SubRef(); target = supervisor.GetTarget(_argv[1]); } diff --git a/messagebus/src/tests/context/context.cpp b/messagebus/src/tests/context/context.cpp index de9dd1b83a6..c71357d09ad 100644 --- a/messagebus/src/tests/context/context.cpp +++ b/messagebus/src/tests/context/context.cpp @@ -77,7 +77,7 @@ Test::Main() if (queue.size() == 3) { break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } EXPECT_EQUAL(queue.size(), 3u); { diff --git a/messagebus/src/tests/loadbalance/loadbalance.cpp b/messagebus/src/tests/loadbalance/loadbalance.cpp index 2f510d98ff1..05ea6d78871 100644 --- a/messagebus/src/tests/loadbalance/loadbalance.cpp +++ b/messagebus/src/tests/loadbalance/loadbalance.cpp @@ -78,7 +78,7 @@ Test::Main() if (queue.size() == msgCnt) { break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } EXPECT_TRUE(queue.size() == msgCnt); EXPECT_TRUE(h1.cnt == msgCnt / 3); diff --git a/messagebus/src/tests/messagebus/messagebus.cpp b/messagebus/src/tests/messagebus/messagebus.cpp index 7434941a900..d9c6e438523 100644 --- a/messagebus/src/tests/messagebus/messagebus.cpp +++ b/messagebus/src/tests/messagebus/messagebus.cpp @@ -43,7 +43,7 @@ struct Base { if (queue.size() == size) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } @@ -270,7 +270,7 @@ Test::testSendToCol() } } client->waitQueueSize(300); - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); client->waitQueueSize(300); while (client->queue.size() > 0) { Routable::UP reply = client->queue.dequeue(0); @@ -347,7 +347,7 @@ Test::testSendToAnyThenCol() } } client->waitQueueSize(300); - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); client->waitQueueSize(300); while (client->queue.size() > 0) { Routable::UP reply = client->queue.dequeue(0); diff --git a/messagebus/src/tests/messageordering/messageordering.cpp b/messagebus/src/tests/messageordering/messageordering.cpp index 520c3d3dea3..481b8bbd270 100644 --- a/messagebus/src/tests/messageordering/messageordering.cpp +++ b/messagebus/src/tests/messageordering/messageordering.cpp @@ -167,7 +167,7 @@ Test::Main() const int messageCount = 5000; for (int i = 0; i < messageCount; ++i) { vespalib::string str(vespalib::make_string("%d", i)); - //FastOS_Thread::Sleep(1); + //std::this_thread::sleep_for(1ms); auto msg = std::make_unique<SimpleMessage>(str, true, commonMessageId); msg->getTrace().setLevel(9); //LOG(debug, "Sending message %p for %d", msg.get(), i); diff --git a/messagebus/src/tests/serviceaddress/serviceaddress.cpp b/messagebus/src/tests/serviceaddress/serviceaddress.cpp index ac43cec3c02..441da5a80ac 100644 --- a/messagebus/src/tests/serviceaddress/serviceaddress.cpp +++ b/messagebus/src/tests/serviceaddress/serviceaddress.cpp @@ -82,7 +82,7 @@ Test::waitSlobrok(RPCNetwork &network, const string &pattern, size_t num) if (res.size() == num) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } diff --git a/messagebus/src/tests/slobrok/slobrok.cpp b/messagebus/src/tests/slobrok/slobrok.cpp index 7e0718283a6..439ee0b23b5 100644 --- a/messagebus/src/tests/slobrok/slobrok.cpp +++ b/messagebus/src/tests/slobrok/slobrok.cpp @@ -51,7 +51,7 @@ compare(const IMirrorAPI &api, const string &pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return false; } diff --git a/messagebus/src/tests/sourcesession/sourcesession.cpp b/messagebus/src/tests/sourcesession/sourcesession.cpp index 5177cf0e799..de04715a060 100644 --- a/messagebus/src/tests/sourcesession/sourcesession.cpp +++ b/messagebus/src/tests/sourcesession/sourcesession.cpp @@ -35,7 +35,7 @@ struct DelayedHandler : public IMessageHandler // this will block the transport thread in the server messagebus, // but that should be ok, as we only want to test the timing in the // client messagebus... - FastOS_Thread::Sleep(delay); + std::this_thread::sleep_for(std::chrono::milliseconds(delay)); session->acknowledge(std::move(msg)); } }; @@ -59,7 +59,7 @@ bool waitQueueSize(RoutableQueue &queue, uint32_t size) { if (queue.size() == size) { return true; } - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } return false; } @@ -99,7 +99,7 @@ Test::testSequencing() EXPECT_TRUE(ss->send(Message::UP(new SimpleMessage("foo", true, 2)), "dst").isAccepted()); EXPECT_TRUE(ss->send(Message::UP(new SimpleMessage("foo", true, 1)), "dst").isAccepted()); EXPECT_TRUE(waitQueueSize(dstQ, 2)); - FastOS_Thread::Sleep(250); + std::this_thread::sleep_for(250ms); EXPECT_TRUE(waitQueueSize(dstQ, 2)); EXPECT_TRUE(waitQueueSize(srcQ, 0)); ds->acknowledge(Message::UP((Message*)dstQ.dequeue(0).release())); diff --git a/messagebus/src/tests/throttling/throttling.cpp b/messagebus/src/tests/throttling/throttling.cpp index 5d3525e8ba6..76bba89b72d 100644 --- a/messagebus/src/tests/throttling/throttling.cpp +++ b/messagebus/src/tests/throttling/throttling.cpp @@ -51,7 +51,7 @@ bool waitQueueSize(RoutableQueue &queue, uint32_t size) if (queue.size() == size) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } @@ -62,7 +62,7 @@ bool waitPending(SourceSession& session, uint32_t size) if (session.getPendingCount() == size) { return true; } - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } return false; } diff --git a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp index 5ae6b07c3fa..280250a5119 100644 --- a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp @@ -17,6 +17,7 @@ #include <vespa/fnet/scheduler.h> #include <vespa/fnet/transport.h> #include <vespa/fnet/frt/supervisor.h> +#include <vespa/fastos/thread.h> #include <thread> #include <vespa/log/log.h> diff --git a/messagebus/src/vespa/messagebus/testlib/testserver.cpp b/messagebus/src/vespa/messagebus/testlib/testserver.cpp index bbd23d52c0b..e0c6d6a756d 100644 --- a/messagebus/src/vespa/messagebus/testlib/testserver.cpp +++ b/messagebus/src/vespa/messagebus/testlib/testserver.cpp @@ -4,6 +4,8 @@ #include "slobrok.h" #include "slobrokstate.h" #include <vespa/vespalib/component/vtag.h> +#include <vespa/vespalib/util/time.h> +#include <thread> namespace mbus { @@ -59,7 +61,7 @@ TestServer::waitState(const SlobrokState &slobrokState) if (done) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } diff --git a/messagebus_test/src/tests/error/cpp-client.cpp b/messagebus_test/src/tests/error/cpp-client.cpp index f186be68d01..147052e0701 100644 --- a/messagebus_test/src/tests/error/cpp-client.cpp +++ b/messagebus_test/src/tests/error/cpp-client.cpp @@ -7,6 +7,8 @@ #include <vespa/messagebus/rpcmessagebus.h> #include <vespa/messagebus/network/rpcnetworkparams.h> #include <vespa/messagebus/testlib/receptor.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -45,7 +47,7 @@ App::Main() break; } } - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } if (reply.get() == 0) { fprintf(stderr, "CPP-CLIENT: no reply\n"); diff --git a/messagebus_test/src/tests/error/cpp-server.cpp b/messagebus_test/src/tests/error/cpp-server.cpp index d5200ed20c1..383f703317e 100644 --- a/messagebus_test/src/tests/error/cpp-server.cpp +++ b/messagebus_test/src/tests/error/cpp-server.cpp @@ -6,6 +6,8 @@ #include <vespa/messagebus/network/rpcnetworkparams.h> #include <vespa/messagebus/emptyreply.h> #include <vespa/messagebus/errorcode.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -55,7 +57,7 @@ App::Main() "file:routing.cfg"); Server server(mb.getMessageBus()); while (true) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/messagebus_test/src/tests/speed/cpp-client.cpp b/messagebus_test/src/tests/speed/cpp-client.cpp index 43d030b519b..ff00128037a 100644 --- a/messagebus_test/src/tests/speed/cpp-client.cpp +++ b/messagebus_test/src/tests/speed/cpp-client.cpp @@ -7,6 +7,8 @@ #include <vespa/messagebus/testlib/simplemessage.h> #include <vespa/messagebus/testlib/simpleprotocol.h> #include <vespa/messagebus/testlib/simplereply.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/timestamp.h> #include <vespa/fastos/app.h> @@ -100,7 +102,7 @@ App::Main() Client client(mb.getMessageBus(), SourceSessionParams().setTimeout(30s)); // let the system 'warm up' - FastOS_Thread::Sleep(5000); + std::this_thread::sleep_for(5s); // inject messages into the feedback loop for (uint32_t i = 0; i < 1024; ++i) { @@ -108,7 +110,7 @@ App::Main() } // let the system 'warm up' - FastOS_Thread::Sleep(5000); + std::this_thread::sleep_for(5s); fastos::StopWatch stopWatch; uint32_t okBefore = 0; @@ -117,7 +119,7 @@ App::Main() uint32_t failAfter = 0; client.sample(okBefore, failBefore); - FastOS_Thread::Sleep(10000); // Benchmark time + std::this_thread::sleep_for(10s); // Benchmark time fastos::TimeStamp elapsed = stopWatch.elapsed(); client.sample(okAfter, failAfter); double time = elapsed.ms(); diff --git a/messagebus_test/src/tests/speed/cpp-server.cpp b/messagebus_test/src/tests/speed/cpp-server.cpp index 82b884c46f2..a1aa5a5029c 100644 --- a/messagebus_test/src/tests/speed/cpp-server.cpp +++ b/messagebus_test/src/tests/speed/cpp-server.cpp @@ -6,6 +6,8 @@ #include <vespa/messagebus/testlib/simpleprotocol.h> #include <vespa/messagebus/rpcmessagebus.h> #include <vespa/messagebus/network/rpcnetworkparams.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -62,7 +64,7 @@ App::Main() "file:routing.cfg"); Server server(mb.getMessageBus()); while (true) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/messagebus_test/src/tests/trace/cpp-server.cpp b/messagebus_test/src/tests/trace/cpp-server.cpp index d6db86070b1..75f4ee3a002 100644 --- a/messagebus_test/src/tests/trace/cpp-server.cpp +++ b/messagebus_test/src/tests/trace/cpp-server.cpp @@ -5,6 +5,8 @@ #include <vespa/messagebus/rpcmessagebus.h> #include <vespa/messagebus/network/rpcnetworkparams.h> #include <vespa/messagebus/emptyreply.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -73,7 +75,7 @@ App::Main() "file:routing.cfg"); Server server(mb.getMessageBus(), _argv[1]); while (true) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/messagebus_test/src/tests/trace/trace.cpp b/messagebus_test/src/tests/trace/trace.cpp index a804bef6785..48c8d4afab0 100644 --- a/messagebus_test/src/tests/trace/trace.cpp +++ b/messagebus_test/src/tests/trace/trace.cpp @@ -36,7 +36,7 @@ waitSlobrok(RPCMessageBus &mbus, const std::string &pattern) if (res.size() > 0) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } @@ -112,7 +112,7 @@ Test::Main() } } std::cout << "Attempt " << i << " got errors, retrying in 1 second.." << std::endl; - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } EXPECT_TRUE(!reply->hasErrors()); diff --git a/metrics/src/tests/metricmanagertest.cpp b/metrics/src/tests/metricmanagertest.cpp index 1d954a641b6..6407bb73ecb 100644 --- a/metrics/src/tests/metricmanagertest.cpp +++ b/metrics/src/tests/metricmanagertest.cpp @@ -10,8 +10,10 @@ #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/xmlstream.h> -#include <vespa/log/log.h> +#include <vespa/vespalib/util/time.h> +#include <thread> +#include <vespa/log/log.h> LOG_SETUP(".test.metricmanager"); namespace metrics { @@ -386,7 +388,7 @@ bool waitForTimeProcessed(const MetricManager& mm, while (time(0) < lastchance) { if (mm.getLastProcessedTime() >= processtime) return true; mm.timeChangedNotification(); - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } diff --git a/metrics/src/tests/stresstest.cpp b/metrics/src/tests/stresstest.cpp index 4a6d2f4a2ea..f3e709b4b04 100644 --- a/metrics/src/tests/stresstest.cpp +++ b/metrics/src/tests/stresstest.cpp @@ -4,6 +4,8 @@ #include <vespa/metrics/metricmanager.h> #include <vespa/metrics/metrics.h> #include <vespa/metrics/summetric.hpp> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/log/log.h> @@ -39,7 +41,7 @@ InnerMetricSet::InnerMetricSet(const char* name, const LoadTypeSet& lt, MetricSe _valueSum.addMetricToSum(_value1); _valueSum.addMetricToSum(_value2); } -InnerMetricSet::~InnerMetricSet() { } +InnerMetricSet::~InnerMetricSet() = default; MetricSet* InnerMetricSet::clone(std::vector<Metric::UP> &ownerList, CopyType copyType, @@ -133,11 +135,10 @@ TEST(StressTest, test_stress) FastOS_ThreadPool threadPool(256 * 1024); std::vector<Hammer::UP> hammers; for (uint32_t i=0; i<10; ++i) { - hammers.push_back(Hammer::UP( - new Hammer(metrics, loadTypes, threadPool))); + hammers.push_back(std::make_unique<Hammer>(metrics, loadTypes, threadPool)); } LOG(info, "Waiting to let loadgivers hammer a while"); - FastOS_Thread::Sleep(5 * 1000); + std::this_thread::sleep_for(5s); LOG(info, "Removing loadgivers"); hammers.clear(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 6c583d960bd..14aa3ebf84e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -70,7 +70,7 @@ public class IntermediateGraph { return operations; } - void optimize() { + public void optimize() { renameDimensions(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 280fe354149..55f5d979ea8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -3,11 +3,13 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.operations.Gemm; +import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; @@ -21,6 +23,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; @@ -36,24 +39,37 @@ import java.util.stream.Collectors; */ class GraphImporter { + private static final Value eluAlpha = DoubleValue.frozen(1.0); + private static final Value seluAlpha = DoubleValue.frozen(1.6732632423543772848170429916717); + private static final Value seluGamma = DoubleValue.frozen(1.0507009873554804934193349852946); + private static final Value leakyReluAlpha = DoubleValue.frozen(0.01); + private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, IntermediateGraph graph) { + String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); + return mapOperation(type, inputs, modelName, nodeName, attributes); + } - switch (node.getOpType().toLowerCase()) { + static IntermediateOperation mapOperation(String opType, + List<IntermediateOperation> inputs, + String modelName, + String nodeName, + AttributeConverter attributes) { + switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); - case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); @@ -63,23 +79,31 @@ class GraphImporter { case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); - case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); - case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); - case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); + case "max": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.max); + case "min": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.min); + case "mean": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); + case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null); + case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt()); + case "reducelogsum":return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, null, ScalarFunctions.log()); + case "reducelogsumexp": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.exp(), ScalarFunctions.log()); + case "reducemax": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.max); case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg); + case "reducemin": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.min); + case "reduceprod": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.prod); + case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reducesumsquare": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), null); case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu(attributes.get("gamma").orElse(seluGamma).asDouble(), attributes.get("alpha").orElse(seluAlpha).asDouble())); + case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu(attributes.get("alpha").orElse(leakyReluAlpha).asDouble())); case "shape": return new Shape(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); @@ -90,7 +114,7 @@ class GraphImporter { } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); - op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + op.warning("Operation '" + opType + "' is currently not implemented"); return op; } @@ -260,5 +284,4 @@ class GraphImporter { "Either no explicit name given or no single output name."); } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java new file mode 100644 index 00000000000..ea6bb2eaf99 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java @@ -0,0 +1,76 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class ConcatReduce extends IntermediateOperation { + + private final static String tmpDimensionName = "__concat_reduce_tmp_dimension_name__"; + private final Reduce.Aggregator aggregator; + + public ConcatReduce(String modelName, String nodeName, List<IntermediateOperation> inputs, Reduce.Aggregator aggregator) { + super(modelName, nodeName, inputs); + this.aggregator = aggregator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(inputs.size())) return null; + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(inputs.size())) return null; + + TensorFunction result = inputs.get(0).function().get(); + for (int i = 1; i < inputs.size(); ++i) { + TensorFunction b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName); + } + return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! allInputTypesPresent(inputs.size())) return; + + OrderedTensorType a = inputs.get(0).type().get(); + for (int i = 1; i < inputs.size(); ++i) { + OrderedTensorType b = inputs.get(i).type().get(); + + OrderedTensorType largest = largestInput(a, b); + OrderedTensorType smallest = smallestInput(a, b); + + int sizeDifference = largest.rank() - smallest.rank(); + for (int j = 0; j < smallest.rank(); ++j) { + String bDim = smallest.dimensions().get(j).name(); + String aDim = largest.dimensions().get(j + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); + } + a = b; + } + } + + private OrderedTensorType largestInput(OrderedTensorType a, OrderedTensorType b) { + return a.rank() >= b.rank() ? a : b; + } + + private OrderedTensorType smallestInput(OrderedTensorType a, OrderedTensorType b) { + return a.rank() < b.rank() ? a : b; + } + + @Override + public ConcatReduce withInputs(List<IntermediateOperation> inputs) { + return new ConcatReduce(modelName(), name(), inputs, aggregator); + } + + @Override + public String operationName() { return "ConcatReduce"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index f091ae165d1..3fba8680332 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -92,7 +92,7 @@ public class Gemm extends IntermediateOperation { return null; } - String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose! + String joinDimension = aType.dimensions().get(1 - transposeA).name(); TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index bd302afa5c7..efd6f9d3339 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -199,7 +199,9 @@ public abstract class IntermediateOperation { String constantName = "constant(" + vespaName() + ")"; Value result = context.get(constantName); if (result == DoubleValue.NaN) { - if (inputs.size() == 0) { + if (constantValue != null) { + result = constantValue; + } else if (inputs.size() == 0) { if (getConstantValue().isEmpty()) { throw new IllegalArgumentException("Error in evaluating constant for " + name); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java index ded76db60fe..5785621eed3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java @@ -28,6 +28,9 @@ public class OnnxConcat extends IntermediateOperation { if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; OrderedTensorType aType = inputs.get(0).type().get(); + if (concatDimensionIndex < 0) { + concatDimensionIndex = aType.dimensions().size() + concatDimensionIndex; + } long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L); for (int i = 1; i < inputs.size(); ++i) { @@ -92,7 +95,7 @@ public class OnnxConcat extends IntermediateOperation { public void renameDimensions(DimensionRenamer renamer) { super.renameDimensions(renamer); concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); - } + } @Override public OnnxConcat withInputs(List<IntermediateOperation> inputs) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java index 1b2d9ac090e..7af051484f5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.function.DoubleUnaryOperator; /** * ONNX Reduce[Sum/Mean/etc] operation @@ -24,6 +25,8 @@ public class Reduce extends IntermediateOperation { private final AttributeMap attributeMap; private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator; + private final DoubleUnaryOperator preOperator; + private final DoubleUnaryOperator postOperator; private List<String> reduceDimensions; @@ -31,11 +34,23 @@ public class Reduce extends IntermediateOperation { List<IntermediateOperation> inputs, AttributeMap attributeMap, com.yahoo.tensor.functions.Reduce.Aggregator aggregator) { + this(modelName, nodeName, inputs, attributeMap, aggregator, null, null); + } + + public Reduce(String modelName, String nodeName, + List<IntermediateOperation> inputs, + AttributeMap attributeMap, + com.yahoo.tensor.functions.Reduce.Aggregator aggregator, + DoubleUnaryOperator preOperator, + DoubleUnaryOperator postOperator) { super(modelName, nodeName, inputs); this.attributeMap = attributeMap; this.aggregator = aggregator; + this.preOperator = preOperator; + this.postOperator = postOperator; } + @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; @@ -48,7 +63,7 @@ public class Reduce extends IntermediateOperation { for (Value i : attributeMap.getList("axes").get()) { int dimensionIndex = (int) i.asDouble(); if (dimensionIndex < 0) { - dimensionIndex = inputType.dimensions().size() - dimensionIndex; + dimensionIndex = inputType.dimensions().size() + dimensionIndex; } reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); } @@ -61,6 +76,9 @@ public class Reduce extends IntermediateOperation { if ( ! allInputTypesPresent(1)) return null; TensorFunction inputFunction = inputs.get(0).function().get(); + if (preOperator != null) { + inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator); + } TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions @@ -74,6 +92,9 @@ public class Reduce extends IntermediateOperation { new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); } + if (postOperator != null) { + output = new com.yahoo.tensor.functions.Map(output, postOperator); + } return output; } @@ -93,7 +114,7 @@ public class Reduce extends IntermediateOperation { @Override public Reduce withInputs(List<IntermediateOperation> inputs) { - return new Reduce(modelName(), name(), inputs, attributeMap, aggregator); + return new Reduce(modelName(), name(), inputs, attributeMap, aggregator, preOperator, postOperator); } @Override @@ -101,7 +122,7 @@ public class Reduce extends IntermediateOperation { private boolean shouldKeepDimensions() { Optional<Value> keepDims = attributeMap.get("keepdims"); - return keepDims.isPresent() && keepDims.get().asBoolean(); + return keepDims.isEmpty() || keepDims.get().asBoolean(); // default is 1 } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c7accd00619..c88fc18e6c6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; public class Reshape extends IntermediateOperation { - public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override protected OrderedTensorType lazyGetType() { - if ( ! allInputTypesPresent(2)) return null; + if (inputs.size() == 2) { + return typeWithShapeAsInput(); + } else if (inputs.size() == 1) { + return typeWithShapeAsAttribute(); + } + throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size()); + } + private OrderedTensorType typeWithShapeAsInput() { IntermediateOperation newShape = inputs.get(1); if (newShape.getConstantValue().isEmpty()) - throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant."); + OrderedTensorType inputType = inputs.get(0).type().get(); Tensor shape = newShape.getConstantValue().get().asTensor(); + List<Integer> dimSizes = new ArrayList<>(shape.type().rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + // first pass - set 0 values, meaning that size is retained from input + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) == 0) { + if (i >= inputType.dimensions().size()) { + throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input"); + } + dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue()); + } + } + + // second pass - set any -1 value, meaning that the dimension size should be expanded to fill the tensor + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) < 0) { + int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize)); + } + } + + return buildOutputType(dimSizes); + } + + private OrderedTensorType typeWithShapeAsAttribute() { + if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0) + throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty."); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); + List<Value> shape = attributeMap.getList("shape").get(); + List<Integer> dimSizes = new ArrayList<>(shape.size()); + + for (Value v : shape) { + int size = (int) v.asDouble(); if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - OrderedTensorType.tensorSize(inputType.type()).intValue(); + int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + size = -1 * shapeSize / tensorSize; } - outputTypeBuilder.add(TensorType.Dimension.indexed( - String.format("%s_%d", vespaName(), dimensionIndex), size)); - dimensionIndex++; + dimSizes.add(size); + } + return buildOutputType(dimSizes); + } + + private OrderedTensorType buildOutputType(List<Integer> dimSizes) { + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i))); } return outputTypeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { - if ( ! allInputTypesPresent(2)) return null; - if ( ! allInputFunctionsPresent(2)) return null; + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null; + if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null; OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); - return reshape(inputFunction, inputType.type(), type.type()); + return reshape(inputFunction, inputType, type); } @Override @@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation { @Override public Reshape withInputs(List<IntermediateOperation> inputs) { - return new Reshape(modelName(), name(), inputs); + return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) + public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, @@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation { // the new shape. We have to introduce temporary dimension names and rename back if dimension names // in the new and old tensor type overlap. + // Todo: change this to use tensor generate when available + List<String> from = new ArrayList<>(); List<String> to = new ArrayList<>(); boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType); if (dimensionNamesOverlap) { - TensorType.Builder builder = new TensorType.Builder(outputType.valueType()); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType()); for (int i = 0; i < outputType.rank(); ++i) { TensorType.Dimension dim = outputType.dimensions().get(i); from.add(dim.name()); to.add("temp_" + dim.name()); - builder.dimension(dim.withName("temp_" + dim.name())); + builder.add(dim.withName("temp_" + dim.name())); } outputType = builder.build(); } ExpressionNode unrollFrom = unrollTensorExpression(inputType); ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo)); + ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo)); - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build(); Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); @@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation { return result; } - private static boolean dimensionNamesOverlap(TensorType a, TensorType b) { - return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent()); + private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { + return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } - private static ExpressionNode unrollTensorExpression(TensorType type) { + private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 032ffb88a46..83086926316 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -2,8 +2,13 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Map; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import java.util.ArrayList; import java.util.List; /** @@ -13,8 +18,11 @@ import java.util.List; */ public class Softmax extends IntermediateOperation { - public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -28,18 +36,30 @@ public class Softmax extends IntermediateOperation { if ( ! allInputFunctionsPresent(1)) return null; OrderedTensorType inputType = inputs.get(0).type().get(); - String dimension = inputType.dimensions().get(0).name(); - if (inputType.rank() == 2) { - dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension + + int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension + if (attributeMap.get("axis").isPresent()) { + axis = (int)attributeMap.get("axis").get().asDouble(); + } + if (axis < 0) { + axis = inputType.rank() + axis; } + List<String> reduceDimensions = new ArrayList<>(); + for (int i = axis; i < inputType.rank(); ++i) { + reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension + } + + TensorFunction input = inputs.get(0).function().get(); + TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); - TensorFunction inputFunction = inputs.get(0).function().get(); - return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); + return div; } @Override public Softmax withInputs(List<IntermediateOperation> inputs) { - return new Softmax(modelName(), name(), inputs); + return new Softmax(modelName(), name(), inputs, attributeMap); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 4f656d86929..0d2ba0cc714 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -64,7 +64,7 @@ class GraphImporter { case "identity": return new Identity(modelName, nodeName, inputs); case "placeholder": return new Argument(modelName, nodeName, nodeType); case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "shape": return new Shape(modelName, nodeName, inputs); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); @@ -113,7 +113,7 @@ class GraphImporter { case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); // state ops case "variable": return new Constant(modelName, nodeName, nodeType); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java new file mode 100644 index 00000000000..6954abe5157 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -0,0 +1,460 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*; +import static onnx.Onnx.AttributeProto.AttributeType.FLOAT; +import static onnx.Onnx.AttributeProto.AttributeType.INT; +import static onnx.Onnx.AttributeProto.AttributeType.INTS; +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for ONNX operators. The number on the test reflects the minimum + * opset number for the operations tested. + * + * @author lesters + */ +public class OnnxOperationsTestCase { + + private static final String modelName = "test_model"; + + @Test + public void testElementwiseOperators7() throws ParseException { + Tensor x = evaluate("tensor(d0[7]):[-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]"); + assertEval("acos", x, evaluate("acos(x)", x)); + assertEval("asin", x, evaluate("asin(x)", x)); + assertEval("atan", x, evaluate("atan(x)", x)); + assertEval("cos", x, evaluate("cos(x)", x)); + assertEval("sin", x, evaluate("sin(x)", x)); + assertEval("tan", x, evaluate("tan(x)", x)); + assertEval("tanh", x, evaluate("tanh(x)", x)); + assertEval("neg", x, evaluate("-x", x)); + assertEval("sigmoid", x, evaluate("sigmoid(x)", x)); + assertEval("exp", x, evaluate("exp(x)", x)); + assertEval("floor", x, evaluate("floor(x)", x)); + assertEval("ceil", x, evaluate("ceil(x)", x)); + assertEval("abs", x, evaluate("abs(x)", x)); + + assertEval("relu", x, evaluate("max(0, x)", x)); + assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 1.0 * (exp(a)-1), a)))", x)); + assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 0.5 * (exp(a)-1), a)))", x), createAttribute("alpha", 0.5f)); + assertEval("selu", x, evaluate("map(x, f(a)(1.050700987 * if(a >= 0, a, 1.673263242 * (exp(a) - 1))))", x)); + assertEval("selu", x, evaluate("map(x, f(a)(1.0 * if(a >= 0, a, 1.5 * (exp(a) - 1))))", x), createAttributes().attr("gamma", 1.0f).attr("alpha", 1.5f).build()); + assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x)); + assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f)); + + x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]"); + assertEval("log", x, evaluate("log(x)", x)); + assertEval("sqrt", x, evaluate("sqrt(x)", x)); + assertEval("reciprocal", x, evaluate("map(x, f(a)(1.0 / a))", x)); + } + + @Test + public void testJoinOperators7() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[3, 4]"); + Tensor y = evaluate("tensor(d0[2]):[1, 2]"); + assertEval("add", x, y, evaluate("tensor(d0[2]):[4, 6]")); + assertEval("sub", x, y, evaluate("tensor(d0[2]):[2, 2]")); + assertEval("mul", x, y, evaluate("tensor(d0[2]):[3, 8]")); + assertEval("div", x, y, evaluate("tensor(d0[2]):[3, 2]")); + assertEval("greater", x, y, evaluate("tensor(d0[2]):[1, 1]")); + assertEval("less", x, y, evaluate("tensor(d0[2]):[0, 0]")); + assertEval("equal", x, y, evaluate("tensor(d0[2]):[0, 0]")); + assertEval("pow", x, y, evaluate("tensor(d0[2]):[3, 16]")); + + x = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + y = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + assertEval("add", x, y, evaluate("x + y", x, y)); + assertEval("sub", x, y, evaluate("x - y", x, y)); + assertEval("mul", x, y, evaluate("x * y", x, y)); + assertEval("div", x, y, evaluate("x / y", x, y)); + assertEval("greater", x, y, evaluate("join(x, y, f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(x, y, f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(x, y, f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(x, y, f(a,b)(pow(a,b)))", x, y)); + + // broadcasting + x = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + y = evaluate("random(d0[4]) + 1"); + assertEval("add", x, y, evaluate("x + rename(y, d0, d2)", x, y)); + assertEval("sub", x, y, evaluate("x - rename(y, d0, d2)", x, y)); + assertEval("mul", x, y, evaluate("x * rename(y, d0, d2)", x, y)); + assertEval("div", x, y, evaluate("x / rename(y, d0, d2)", x, y)); + assertEval("greater", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y)); + } + + @Test + public void testConcatReduce8() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[3, 4]"); + Tensor y = evaluate("tensor(d0[2]):[1, 2]"); + Tensor z = evaluate("tensor(d0[2]):[5, 6]"); + assertEval("max", x, y, z, evaluate("tensor(d0[2]):[5, 6]")); + assertEval("min", x, y, z, evaluate("tensor(d0[2]):[1, 2]")); + assertEval("mean", x, y, z, evaluate("tensor(d0[2]):[3, 4]")); + + x = evaluate("random(d0[2],d1[3],d2[4])"); + y = evaluate("random(d0[2],d1[3],d2[4])"); + z = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), max, tmp)", x, y, z)); + assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), min, tmp)", x, y, z)); + assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), avg, tmp)", x, y, z)); + + // broadcasting + x = evaluate("random(d0[2],d1[3],d2[4])"); + y = evaluate("random(d0[3],d1[4])"); + z = evaluate("random(d0[4])"); + assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), max, tmp)", x, y, z)); + assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), min, tmp)", x, y, z)); + assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), avg, tmp)", x, y, z)); + } + + @Test + public void testConcat4() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[1, 2]"); + Tensor y = evaluate("tensor(d0[2]):[3, 4]"); + Tensor expected = evaluate("tensor(d0[4]):[1,2,3,4]"); + assertEval("concat", x, y, expected, createAttribute("axis", 0)); + assertEval("concat", x, y, expected, createAttribute("axis", -1)); + + x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", 0)); + assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", 1)); + assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", -1)); + assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", -2)); + + x = evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 5, 6, 7, 8]"); + y = evaluate("tensor(d0[2],d1[2],d2[2]):[9,10,11,12,13,14,15,16]"); + assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", 0)); + assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", 1)); + assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", 2)); + assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", -1)); + assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", -2)); + assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", -3)); + } + + @Test + public void testGemm7() throws ParseException { + Tensor a = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + Tensor b = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + Tensor c = evaluate("tensor(d0[2],d1[2]):[0.1, 0.2, 0.3, 0.4]"); + + assertEval("gemm", a, b, evaluate("tensor(d0[2],d1[2]):[19, 22, 43, 50]")); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.3, 50.4]")); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[38.1, 44.2, 86.3, 100.4]"), createAttribute("alpha", 2.0f)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.2, 22.4, 43.6, 50.8]"), createAttribute("beta", 2.0f)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[26.1, 30.2, 38.3, 44.4]"), createAttribute("transA", 1)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[17.1, 23.2, 39.3, 53.4]"), createAttribute("transB", 1)); + + // unidictional broadcasting for c + c = evaluate("tensor(d0[2]):[0.1, 0.2]"); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.1, 50.2]")); + } + + @Test + public void testIdentity1() throws ParseException { + Tensor x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("identity", x, x); + } + + @Test + public void testMatMul1() throws ParseException { + Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]"); + Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]")); + } + + @Test + public void testReshape5() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + Tensor y = evaluate("tensor(d0[1]):[4]"); + assertEval("reshape", x, y, evaluate("tensor(d0[4]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[3]):[2,1,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,-1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,0]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[0,-1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + x = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[2]):[3,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]")); + + y = evaluate("tensor(d0[4]):[3,2,-1,1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]")); + } + + @Test + public void testReduceOperators1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]")); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("axes", new int[] {0,1})); + assertEval("reducesum", x, evaluate("tensor():[10]"), createAttribute("keepdims", 0)); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("keepdims", 1)); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[]{0})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[4, 6]"), createAttributes().attr("axes", new int[]{0}).attr("keepdims", 0).build()); + assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {1})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[]{1}).attr("keepdims", 0).build()); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[] {-2})); + assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {-1})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[] {-1}).attr("keepdims", 0).build()); + + assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[1]):[24]")); + assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[2]):[3, 8]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemin", x, evaluate("tensor(d0[1],d1[1]):[1]")); + assertEval("reducemin", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemax", x, evaluate("tensor(d0[1],d1[1]):[4]")); + assertEval("reducemax", x, evaluate("tensor(d0[1],d1[2]):[3, 4]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemean", x, evaluate("tensor():[2.5]"), createAttribute("keepdims", 0)); + assertEval("reducemean", x, evaluate("tensor(d0[2]):[2, 3]"), createAttributes().attr("axes", new int[] {0}).attr("keepdims", 0).build()); + + assertEval("reducelogsum", x, evaluate("tensor():[log(10)]"), createAttribute("keepdims", 0)); + assertEval("reducelogsumexp", x, evaluate("tensor():[log(exp(1)+exp(2)+exp(3)+exp(4))]"), createAttribute("keepdims", 0)); + assertEval("reducesumsquare", x, evaluate("tensor():[1*1+2*2+3*3+4*4]"), createAttribute("keepdims", 0)); + + x = evaluate("tensor(d0[1],d1[5]):[-10, -5, 0, 5, 10]"); + assertEval("reducel1", x, evaluate("tensor():[30]"), createAttribute("keepdims", 0)); + assertEval("reducel2", x, evaluate("tensor():[sqrt(10*10 + 5*5 + 5*5 + 10*10)]"), createAttribute("keepdims", 0)); + } + + @Test + public void testShape1() throws ParseException { + Tensor x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("shape", x, evaluate("tensor(d0[3]):[2,3,4]")); + } + + @Test + public void testSoftmax1() throws ParseException { + Tensor x = evaluate("tensor(d0[1],d1[3]):[-1, 0, 1]"); + assertEval("softmax", x, evaluate("tensor(d0[1],d1[3]):[0.09003058, 0.24472848, 0.66524094]")); + + x = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 7]"); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", 0)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", 1)); // 1 is default + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", -1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", -2)); + + x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", 0)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", 1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", 2)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", -1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", -2)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", -3)); + } + + @Test + public void testSqueeze1() throws ParseException { + Tensor x = evaluate("tensor(d0[1],d1[2]):[1, 2]"); + assertEval("squeeze", x, evaluate("tensor(d0[2]):[1, 2]")); + + x = evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]")); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[1],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0})); + assertEval("squeeze", x, evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2})); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0, 2})); + } + + @Test + public void testWhere9() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + Tensor condition = evaluate("tensor(d0[2],d1[2]):[0, 1, 0, 1]"); + assertEval("where", condition, x, y, evaluate("tensor(d0[2],d1[2]):[5, 2, 7, 4]")); + + assertEval("where", evaluate("tensor():[0]"), x, y, y); + assertEval("where", evaluate("tensor():[1]"), x, y, x); + assertEval("where", evaluate("tensor(d0[1]):[0]"), x, y, y); + assertEval("where", evaluate("tensor(d0[1]):[1]"), x, y, x); + assertEval("where", evaluate("tensor(d0[1],d1[1]):[0]"), x, y, y); + assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x); + } + + private Tensor evaluate(String expr) throws ParseException { + return evaluate(expr, null, null, null); + } + + private Tensor evaluate(String expr, Tensor x) throws ParseException { + return evaluate(expr, x, null, null); + } + + private Tensor evaluate(String expr, Tensor x, Tensor y) throws ParseException { + return evaluate(expr, x, y, null); + } + + private Tensor evaluate(String expr, Tensor x, Tensor y, Tensor z) throws ParseException { + Context context = new MapContext(DoubleValue.NaN); + if (x != null) context.put("x", new TensorValue(x)); + if (y != null) context.put("y", new TensorValue(y)); + if (z != null) context.put("z", new TensorValue(z)); + return new RankingExpression(expr).evaluate(context).asTensor(); + } + + private Tensor evaluate(IntermediateOperation op) { + Tensor tensor = op.evaluateAsConstant(op.type().get()).asTensor(); + return renameToStandardType(op, tensor); + } + + private void assertEval(String opName, Tensor x, Tensor expected) { + assertEval(opName, x, null, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, null, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, y, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { + assertEval(opName, x, y, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { + assertEval(opName, x, y, z, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { + Context context = new MapContext(DoubleValue.NaN); + List<IntermediateOperation> inputs = createInputs(context, x, y, z); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); + optimizeAndRename(opName, op); + Tensor result = evaluate(op); + assertEquals(expected, result); + assertEquals(expected.type(), result.type()); + } + + private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) { + List<IntermediateOperation> inputs = new ArrayList<>(); + addInput(inputs, context, x, "x"); + addInput(inputs, context, y, "y"); + addInput(inputs, context, z, "z"); + return inputs; + } + + private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { + if (x == null) return; + context.put(name, new TensorValue(x)); + IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString())); + op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type))); + inputs.add(op); + } + + Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) { + IndexedTensor indexedTensor = (IndexedTensor) tensor; + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); + for (int i = 0; i < indexedTensor.size(); i++) { + builder.cellByDirectIndex(type.toDirectIndex(i), indexedTensor.get(i)); + } + return builder.build(); + } + + private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) { + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(opName, op); + graph.outputs(graph.defaultSignature()).put(opName, opName); + graph.optimize(); + return op.function().get(); + } + + private Tensor renameToStandardType(IntermediateOperation op, Tensor tensor) { + OrderedTensorType operationType = op.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo); + return func.evaluate(); + } + return tensor; + } + + static AttributeConverter createAttribute(String name, int val) { + return new Attributes().attr(name, val).build(); + } + + static AttributeConverter createAttribute(String name, float val) { + return new Attributes().attr(name, val).build(); + } + + static AttributeConverter createAttribute(String name, int [] vals) { + return new Attributes().attr(name, vals).build(); + } + + static Attributes createAttributes() { + return new Attributes(); + } + + private static class Attributes { + + Onnx.NodeProto.Builder nodeBuilder; + + Attributes() { + this.nodeBuilder = Onnx.NodeProto.newBuilder(); + } + + Attributes attr(String name, int val) { + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(INT).setI(val).build()); + return this; + } + + Attributes attr(String name, float val) { + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(FLOAT).setF(val).build()); + return this; + } + + Attributes attr(String name, int [] vals) { + Onnx.AttributeProto.Builder builder = Onnx.AttributeProto.newBuilder(); + for (int val : vals) { + builder.addInts(val); + } + nodeBuilder.addAttribute(builder.setName(name).setType(INTS).build()); + return this; + } + + AttributeConverter build() { + return AttributeConverter.convert(nodeBuilder.build()); + } + + } + +} diff --git a/searchcore/src/apps/proton/proton.cpp b/searchcore/src/apps/proton/proton.cpp index b37eb5ac0cf..f80558a1bc6 100644 --- a/searchcore/src/apps/proton/proton.cpp +++ b/searchcore/src/apps/proton/proton.cpp @@ -8,6 +8,7 @@ #include <vespa/metrics/metricmanager.h> #include <vespa/vespalib/util/signalhandler.h> #include <vespa/vespalib/util/programoptions.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/config/common/exceptions.h> #include <vespa/fastos/app.h> @@ -198,7 +199,7 @@ App::Main() LOG(info, "Sleeping 900 seconds due to proton state"); int sleepLeft = 900; while (!(SIG::INT.check() || SIG::TERM.check()) && sleepLeft > 0) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1000ms); --sleepLeft; } EV_STOPPING("proton", "shutdown after stop on io errors"); @@ -226,7 +227,7 @@ App::Main() } EV_STARTED("proton"); while (!(SIG::INT.check() || SIG::TERM.check() || (spiProton && spiProton->getNode().attemptedStopped()))) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1000ms); if (spiProton && spiProton->configUpdated()) { storage::ResumeGuard guard(spiProton->getNode().pause()); spiProton->updateConfig(); @@ -240,7 +241,7 @@ App::Main() if (spiProton) { // report down state to cluster controller. spiProton->getNode().notifyPartitionDown(0, "proton state string is " + stateString); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1000ms); } EV_STOPPING("proton", "shutdown after new stop on io errors"); return 1; diff --git a/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp b/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp index 3b3b5f412d2..dcd3dce218b 100644 --- a/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp +++ b/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp @@ -6,8 +6,10 @@ #include <vespa/fnet/frt/frt.h> #include <vespa/vespalib/util/host_name.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fastos/app.h> #include <sys/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("vespa-proton-cmd"); @@ -115,7 +117,7 @@ public: slobrok::api::MirrorAPI sbmirror(_frt->supervisor(), sbcfg); for (int timeout = 1; timeout < 20; timeout++) { if (!sbmirror.ready()) { - FastOS_Thread::Sleep(50*timeout); + std::this_thread::sleep_for(50ms*timeout); } } if (!sbmirror.ready()) { @@ -123,12 +125,9 @@ public: "ERROR: no data from service location broker\n"); exit(1); } - slobrok::api::MirrorAPI::SpecList specs = - sbmirror.lookup(rtcPattern); - slobrok::api::MirrorAPI::SpecList specs2 = - sbmirror.lookup(rtcPattern2); - slobrok::api::MirrorAPI::SpecList specs3 = - sbmirror.lookup(rtcPattern3); + slobrok::api::MirrorAPI::SpecList specs = sbmirror.lookup(rtcPattern); + slobrok::api::MirrorAPI::SpecList specs2 = sbmirror.lookup(rtcPattern2); + slobrok::api::MirrorAPI::SpecList specs3 = sbmirror.lookup(rtcPattern3); int found = 0; std::string service; @@ -167,7 +166,7 @@ public: slobrok::api::MirrorAPI sbmirror(_frt->supervisor(), sbcfg); for (int timeout = 1; timeout < 20; timeout++) { if (!sbmirror.ready()) { - FastOS_Thread::Sleep(50*timeout); + std::this_thread::sleep_for(50ms*timeout); } } if (!sbmirror.ready()) { diff --git a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp index f5aa74d85e1..5775c31b205 100644 --- a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp +++ b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp @@ -7,6 +7,7 @@ #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/vespalib/util/programoptions.h> #include <vespa/vespalib/util/xmlstream.h> +#include <vespa/vespalib/util/time.h> #include <vespa/document/config/config-documenttypes.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/document/fieldvalue/document.h> @@ -14,6 +15,7 @@ #include <vespa/config/helper/configgetter.hpp> #include <vespa/fastos/app.h> #include <iostream> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("vespa-transactionlog-inspect"); @@ -491,7 +493,7 @@ protected: return 1; } for (size_t i = 0; !callback.isEof() && (i < 60 * 60); i++ ) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp index b519e6a8dc7..839228b79b8 100644 --- a/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp +++ b/searchcore/src/tests/proton/documentdb/feedview/feedview_test.cpp @@ -24,6 +24,7 @@ #include <vespa/searchlib/attribute/attributefactory.h> #include <vespa/document/update/documentupdate.h> #include <vespa/searchlib/index/docbuilder.h> +#include <vespa/vespalib/util/time.h> #include <vespa/log/log.h> LOG_SETUP("feedview_test"); @@ -524,7 +525,7 @@ struct FixtureBase CommitTimeTracker _commitTimeTracker; SerialNum serial; std::shared_ptr<MyGidToLidChangeHandler> _gidToLidChangeHandler; - FixtureBase(TimeStamp visibilityDelay); + FixtureBase(vespalib::duration visibilityDelay); virtual ~FixtureBase(); @@ -689,7 +690,7 @@ struct FixtureBase }; -FixtureBase::FixtureBase(TimeStamp visibilityDelay) +FixtureBase::FixtureBase(vespalib::duration visibilityDelay) : _tracer(), sc(), iw(new MyIndexWriter(_tracer)), @@ -706,12 +707,12 @@ FixtureBase::FixtureBase(TimeStamp visibilityDelay) _writeServiceReal(_sharedExecutor), _writeService(_writeServiceReal), _lidReuseDelayer(_writeService, _dmsc->get()), - _commitTimeTracker(visibilityDelay), + _commitTimeTracker(vespalib::count_ns(visibilityDelay)), serial(0), _gidToLidChangeHandler(std::make_shared<MyGidToLidChangeHandler>()) { _dmsc->constructFreeList(); - _lidReuseDelayer.setImmediateCommit(visibilityDelay == 0); + _lidReuseDelayer.setImmediateCommit(visibilityDelay == vespalib::duration::zero()); } FixtureBase::~FixtureBase() { @@ -728,7 +729,7 @@ FixtureBase::populateBeforeCompactLidSpace() struct SearchableFeedViewFixture : public FixtureBase { SearchableFeedView fv; - SearchableFeedViewFixture(TimeStamp visibilityDelay = 0) : + SearchableFeedViewFixture(vespalib::duration visibilityDelay = 0ms) : FixtureBase(visibilityDelay), fv(StoreOnlyFeedView::Context(sa, sc._schema, @@ -750,7 +751,7 @@ struct SearchableFeedViewFixture : public FixtureBase struct FastAccessFeedViewFixture : public FixtureBase { FastAccessFeedView fv; - FastAccessFeedViewFixture(TimeStamp visibilityDelay = 0) : + FastAccessFeedViewFixture(vespalib::duration visibilityDelay = vespalib::duration::zero()) : FixtureBase(visibilityDelay), fv(StoreOnlyFeedView::Context(sa, sc._schema, @@ -1206,8 +1207,8 @@ TEST_F("require that commit is called if visibility delay is 0", "ack(Result(0, ))"); } -const TimeStamp LONG_DELAY(TimeStamp::Seconds(60.0)); -const TimeStamp SHORT_DELAY(TimeStamp::Seconds(0.5)); +const vespalib::duration LONG_DELAY = 60s; +const vespalib::duration SHORT_DELAY = 500ms; TEST_F("require that commit is not called when inside a commit interval", SearchableFeedViewFixture(LONG_DELAY)) @@ -1232,13 +1233,13 @@ TEST_F("require that commit is not called when inside a commit interval", TEST_F("require that commit is called when crossing a commit interval", SearchableFeedViewFixture(SHORT_DELAY)) { - FastOS_Thread::Sleep(SHORT_DELAY.ms() + 100); + std::this_thread::sleep_for(SHORT_DELAY + 100ms); DocumentContext dc = f.doc1(); f.putAndWait(dc); EXPECT_EQUAL(1u, f.miw._commitCount); EXPECT_EQUAL(1u, f.maw._commitCount); EXPECT_EQUAL(2u, f._docIdLimit.get()); - FastOS_Thread::Sleep(SHORT_DELAY.ms() + 100); + std::this_thread::sleep_for(SHORT_DELAY + 100ms); f.removeAndWait(dc); EXPECT_EQUAL(2u, f.miw._commitCount); EXPECT_EQUAL(2u, f.maw._commitCount); @@ -1257,13 +1258,13 @@ TEST_F("require that commit is not implicitly called after handover to maintenan SearchableFeedViewFixture(SHORT_DELAY)) { f._commitTimeTracker.setReplayDone(); - FastOS_Thread::Sleep(SHORT_DELAY.ms() + 100); + std::this_thread::sleep_for(SHORT_DELAY + 100ms); DocumentContext dc = f.doc1(); f.putAndWait(dc); EXPECT_EQUAL(0u, f.miw._commitCount); EXPECT_EQUAL(0u, f.maw._commitCount); EXPECT_EQUAL(0u, f._docIdLimit.get()); - FastOS_Thread::Sleep(SHORT_DELAY.ms() + 100); + std::this_thread::sleep_for(SHORT_DELAY + 100ms); f.removeAndWait(dc); EXPECT_EQUAL(0u, f.miw._commitCount); EXPECT_EQUAL(0u, f.maw._commitCount); diff --git a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp index c732da58dd7..dfac2edad61 100644 --- a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp +++ b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp @@ -34,6 +34,7 @@ #include <vespa/vespalib/util/closuretask.h> #include <vespa/vespalib/util/gate.h> #include <vespa/vespalib/util/threadstackexecutor.h> +#include <vespa/fastos/thread.h> #include <unistd.h> #include <vespa/log/log.h> @@ -785,9 +786,7 @@ MyExecutor::MyExecutor() } -MyExecutor::~MyExecutor() -{ -} +MyExecutor::~MyExecutor() = default; bool @@ -1065,7 +1064,7 @@ TEST_F("require that document pruner is active", MyFrozenBucket::UP frozen3(new MyFrozenBucket(f._mc, bucketId3)); f.setPruneConfig(DocumentDBPruneRemovedDocumentsConfig(0.2, 900.0)); for (uint32_t i = 0; i < 6; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); ASSERT_TRUE(f._executor.waitIdle(TIMEOUT_SEC)); if (f._removed.getNumUsedLids() != 10u) break; @@ -1074,7 +1073,7 @@ TEST_F("require that document pruner is active", EXPECT_EQUAL(10u, f._removed.getDocumentCount()); frozen3.reset(); for (uint32_t i = 0; i < 600; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); ASSERT_TRUE(f._executor.waitIdle(TIMEOUT_SEC)); if (f._removed.getNumUsedLids() != 10u) break; @@ -1090,7 +1089,7 @@ TEST_F("require that heartbeats are scheduled", f.startMaintenance(); f.setHeartBeatConfig(DocumentDBHeartBeatConfig(0.2)); for (uint32_t i = 0; i < 600; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); if (f._fh.getHeartBeats() != 0u) break; } @@ -1105,7 +1104,7 @@ TEST_F("require that periodic session prunings are scheduled", f.startMaintenance(); f.setGroupingSessionPruneInterval(0.2); for (uint32_t i = 0; i < 600; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); if (f._gsp.isInvoked) { break; } @@ -1234,7 +1233,7 @@ TEST_F("require that a blocked job is unblocked and executed after thaw bucket", EXPECT_FALSE(myJob2.isBlocked()); bool done1 = myJob1._latch.await(TIMEOUT_MS); EXPECT_TRUE(done1); - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); EXPECT_EQUAL(0u, myJob2._runCnt); } @@ -1246,7 +1245,7 @@ TEST_F("require that blocked jobs are not executed", MaintenanceControllerFixtur f._mc.registerJobInMasterThread(std::move(job)); f._injectDefaultJobs = false; f.startMaintenance(); - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); EXPECT_EQUAL(0u, myJob._runCnt); } diff --git a/searchcore/src/tests/proton/flushengine/flushengine_test.cpp b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp index bfd4450b1f2..c18d393b98f 100644 --- a/searchcore/src/tests/proton/flushengine/flushengine_test.cpp +++ b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp @@ -685,7 +685,7 @@ assertThatHandlersInCurrentSet(FlushEngine & engine, const std::vector<const cha { FlushEngine::FlushMetaSet current1 = engine.getCurrentlyFlushingSet(); while ((current1.size() < targets.size()) || !asserCorrectHandlers(current1, targets)) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); current1 = engine.getCurrentlyFlushingSet(); } } diff --git a/searchcore/src/tests/proton/index/indexmanager_test.cpp b/searchcore/src/tests/proton/index/indexmanager_test.cpp index 2f5d3e353db..51e12e70dda 100644 --- a/searchcore/src/tests/proton/index/indexmanager_test.cpp +++ b/searchcore/src/tests/proton/index/indexmanager_test.cpp @@ -1,7 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/document/fieldvalue/document.h> -#include <vespa/document/fieldvalue/fieldvalue.h> #include <vespa/fastos/file.h> #include <vespa/searchcore/proton/index/indexmanager.h> #include <vespa/searchcore/proton/server/executorthreadingservice.h> @@ -19,12 +17,11 @@ #include <vespa/searchlib/memoryindex/field_inverter.h> #include <vespa/searchlib/queryeval/isourceselector.h> #include <vespa/searchlib/test/index/mock_field_length_inspector.h> -#include <vespa/searchlib/util/dirtraverse.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/io/fileutil.h> -#include <vespa/vespalib/util/blockingthreadstackexecutor.h> #include <vespa/vespalib/util/threadstackexecutor.h> #include <set> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("indexmanager_test"); @@ -57,6 +54,7 @@ using vespalib::makeLambdaTask; using namespace proton; using namespace searchcorespi; using namespace searchcorespi::index; +using namespace std::chrono_literals; namespace { @@ -303,7 +301,7 @@ TEST_F(IndexManagerTest, require_that_memory_index_is_flushed) EXPECT_EQ(stat._modifiedTime, target.getLastFlushTime().timeSinceEpoch().time()); // updated serial number & flush time when nothing to flush - FastOS_Thread::Sleep(8000); + std::this_thread::sleep_for(8s); fastos::TimeStamp now = fastos::ClockSystem::now().timeSinceEpoch(); vespalib::Executor::Task::UP task; runAsMaster([&]() { task = target.initFlush(2); }); diff --git a/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp b/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp index be70bacb4b1..3152b737ea7 100644 --- a/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp +++ b/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp @@ -26,12 +26,12 @@ public: class RequestContextTest : public ::testing::Test { private: - vespalib::Clock _clock; - vespalib::Doom _doom; - MyAttributeContext _attr_ctx; - Properties _props; - RequestContext _request_ctx; - Value::UP _query_tensor; + vespalib::Clock _clock; + vespalib::Doom _doom; + MyAttributeContext _attr_ctx; + Properties _props; + RequestContext _request_ctx; + Value::UP _query_tensor; void insert_tensor_in_properties(const vespalib::string& tensor_name, const Value& tensor_value) { vespalib::nbostream stream; @@ -42,7 +42,7 @@ private: public: RequestContextTest() : _clock(), - _doom(_clock, fastos::SteadyTimeStamp()), + _doom(_clock, fastos::SteadyTimeStamp::ZERO, fastos::SteadyTimeStamp::ZERO, false), _attr_ctx(), _props(), _request_ctx(_doom, _attr_ctx, _props), diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp index d86a750794f..870be2ab409 100644 --- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp +++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp @@ -175,7 +175,7 @@ struct ProtonConfigOwner : public proton::IProtonConfigurer while (timer.elapsed().ms() < timeout) { if (getConfigured()) return true; - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return getConfigured(); } diff --git a/searchcore/src/vespa/searchcore/proton/matching/fakesearchcontext.h b/searchcore/src/vespa/searchcore/proton/matching/fakesearchcontext.h index fe9c20112f4..5a4ccc892b3 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/fakesearchcontext.h +++ b/searchcore/src/vespa/searchcore/proton/matching/fakesearchcontext.h @@ -6,6 +6,7 @@ #include <vespa/searchcorespi/index/fakeindexsearchable.h> #include <vespa/searchcorespi/index/indexcollection.h> #include <vespa/searchlib/attribute/fixedsourceselector.h> +#include <vespa/vespalib/util/doom.h> #include <algorithm> #include <map> #include <vector> diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_thread.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_thread.cpp index afdae5eec2e..a0381af29a8 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_thread.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_thread.cpp @@ -71,15 +71,14 @@ LazyValue get_score_feature(const RankProgram &rankProgram) { //----------------------------------------------------------------------------- -MatchThread::Context::Context(double rankDropLimit, MatchTools &tools, HitCollector &hits, - uint32_t num_threads) +MatchThread::Context::Context(double rankDropLimit, MatchTools &tools, HitCollector &hits, uint32_t num_threads) : matches(0), _matches_limit(tools.match_limiter().sample_hits_per_thread(num_threads)), _score_feature(get_score_feature(tools.rank_program())), _ranking(tools.rank_program()), _rankDropLimit(rankDropLimit), _hits(hits), - _softDoom(tools.getSoftDoom()) + _doom(tools.getDoom()) { } @@ -307,7 +306,7 @@ MatchThread::findMatches(MatchTools &tools) auto kept_hits = communicator.selectBest(sorted_hit_seq); select_best_timer.done(); DocumentScorer scorer(tools.rank_program(), tools.search()); - if (tools.getHardDoom().doom()) { + if (tools.getDoom().hard_doom()) { kept_hits.clear(); } uint32_t reRanked = hits.reRank(scorer, std::move(kept_hits)); @@ -330,16 +329,16 @@ MatchThread::findMatches(MatchTools &tools) } void -MatchThread::processResult(const Doom & hardDoom, +MatchThread::processResult(const Doom & doom, search::ResultSet::UP result, ResultProcessor::Context &context) { - if (hardDoom.doom()) return; + if (doom.hard_doom()) return; bool hasGrouping = (context.grouping.get() != 0); if (context.sort->hasSortData() || hasGrouping) { result->mergeWithBitOverflow(fallback_rank_value()); } - if (hardDoom.doom()) return; + if (doom.hard_doom()) return; size_t totalHits = result->getNumHits(); const search::RankedHit *hits = result->getArray(); size_t numHits = result->getArrayUsed(); @@ -347,20 +346,20 @@ MatchThread::processResult(const Doom & hardDoom, if (bits != nullptr && hits != nullptr) { bits->andNotWithT(search::RankedHitIterator(hits, numHits)); } - if (hardDoom.doom()) return; + if (doom.hard_doom()) return; if (hasGrouping) { search::grouping::GroupingManager man(*context.grouping); man.groupUnordered(hits, numHits, bits); } - if (hardDoom.doom()) return; + if (doom.hard_doom()) return; size_t sortLimit = hasGrouping ? numHits : context.result->maxSize(); result->sort(*context.sort->sorter, sortLimit); - if (hardDoom.doom()) return; + if (doom.hard_doom()) return; if (hasGrouping) { search::grouping::GroupingManager man(*context.grouping); man.groupInRelevanceOrder(hits, numHits); } - if (hardDoom.doom()) return; + if (doom.hard_doom()) return; PartialResult &pr = *context.result; pr.totalHits(totalHits); size_t maxHits = std::min(numHits, pr.maxSize()); @@ -432,19 +431,19 @@ MatchThread::run() MatchTools::UP matchTools = matchToolsFactory.createMatchTools(); search::ResultSet::UP result = findMatches(*matchTools); match_time_s = match_time.elapsed().sec(); - resultContext = resultProcessor.createThreadContext(matchTools->getHardDoom(), thread_id, _distributionKey); + resultContext = resultProcessor.createThreadContext(matchTools->getDoom(), thread_id, _distributionKey); { trace->addEvent(5, "Wait for result processing token"); WaitTimer get_token_timer(wait_time_s); QueryLimiter::Token::UP processToken( - matchTools->getQueryLimiter().getToken(matchTools->getHardDoom(), + matchTools->getQueryLimiter().getToken(matchTools->getDoom(), scheduler.total_size(thread_id), result->getNumHits(), resultContext->sort->hasSortData(), resultContext->grouping.get() != 0)); get_token_timer.done(); trace->addEvent(5, "Start result processing"); - processResult(matchTools->getHardDoom(), std::move(result), *resultContext); + processResult(matchTools->getDoom(), std::move(result), *resultContext); } total_time_s = total_time.elapsed().sec(); thread_stats.active_time(total_time_s - wait_time_s).wait_time(wait_time_s); diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_thread.h b/searchcore/src/vespa/searchcore/proton/matching/match_thread.h index dca77f35019..7ecbfef634e 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_thread.h +++ b/searchcore/src/vespa/searchcore/proton/matching/match_thread.h @@ -73,16 +73,16 @@ private: void addHit(uint32_t docId) { _hits.addHit(docId, search::zero_rank_value); } bool isBelowLimit() const { return matches < _matches_limit; } bool isAtLimit() const { return matches == _matches_limit; } - bool atSoftDoom() const { return _softDoom.doom(); } - fastos::TimeStamp timeLeft() const { return _softDoom.left(); } - uint32_t matches; + bool atSoftDoom() const { return _doom.soft_doom(); } + fastos::TimeStamp timeLeft() const { return _doom.soft_left(); } + uint32_t matches; private: - uint32_t _matches_limit; - LazyValue _score_feature; - RankProgram &_ranking; - double _rankDropLimit; - HitCollector &_hits; - const Doom &_softDoom; + uint32_t _matches_limit; + LazyValue _score_feature; + RankProgram &_ranking; + double _rankDropLimit; + HitCollector &_hits; + const Doom &_doom; }; double estimate_match_frequency(uint32_t matches, uint32_t searchedSoFar) __attribute__((noinline)); @@ -106,7 +106,7 @@ private: search::ResultSet::UP findMatches(MatchTools &tools); - void processResult(const Doom & hardDoom, search::ResultSet::UP result, ResultProcessor::Context &context); + void processResult(const Doom & doom, search::ResultSet::UP result, ResultProcessor::Context &context); bool isFirstThread() const { return thread_id == 0; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp index 67c1fb25b64..05b70cfb8bf 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp @@ -89,8 +89,7 @@ MatchTools::setup(search::fef::RankProgram::UP rank_program, double termwise_lim } MatchTools::MatchTools(QueryLimiter & queryLimiter, - const vespalib::Doom & softDoom, - const vespalib::Doom & hardDoom, + const vespalib::Doom & doom, const Query &query, MaybeMatchPhaseLimiter & match_limiter_in, const QueryEnvironment & queryEnv, @@ -98,8 +97,7 @@ MatchTools::MatchTools(QueryLimiter & queryLimiter, const RankSetup & rankSetup, const Properties & featureOverrides) : _queryLimiter(queryLimiter), - _softDoom(softDoom), - _hardDoom(hardDoom), + _doom(doom), _query(query), _match_limiter(match_limiter_in), _queryEnv(queryEnv), @@ -149,8 +147,7 @@ MatchTools::setup_dump() MatchToolsFactory:: MatchToolsFactory(QueryLimiter & queryLimiter, - const vespalib::Doom & softDoom, - const vespalib::Doom & hardDoom, + const vespalib::Doom & doom, ISearchContext & searchContext, IAttributeContext & attributeContext, vespalib::stringref queryStack, @@ -162,8 +159,7 @@ MatchToolsFactory(QueryLimiter & queryLimiter, const Properties & rankProperties, const Properties & featureOverrides) : _queryLimiter(queryLimiter), - _requestContext(softDoom, attributeContext, rankProperties), - _hardDoom(hardDoom), + _requestContext(doom, attributeContext, rankProperties), _query(), _match_limiter(), _queryEnv(indexEnv, attributeContext, rankProperties, searchContext.getIndexes()), @@ -204,7 +200,7 @@ MatchTools::UP MatchToolsFactory::createMatchTools() const { assert(_valid); - return std::make_unique<MatchTools>(_queryLimiter, _requestContext.getSoftDoom(), _hardDoom, _query, + return std::make_unique<MatchTools>(_queryLimiter, _requestContext.getDoom(), _query, *_match_limiter, _queryEnv, _mdl, _rankSetup, _featureOverrides); } diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.h b/searchcore/src/vespa/searchcore/proton/matching/match_tools.h index 777652c6b89..5cf2919198a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.h +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.h @@ -28,8 +28,7 @@ class MatchTools private: using IRequestContext = search::queryeval::IRequestContext; QueryLimiter &_queryLimiter; - const vespalib::Doom &_softDoom; - const vespalib::Doom &_hardDoom; + const vespalib::Doom &_doom; const Query &_query; MaybeMatchPhaseLimiter &_match_limiter; const QueryEnvironment &_queryEnv; @@ -46,8 +45,7 @@ public: MatchTools(const MatchTools &) = delete; MatchTools & operator = (const MatchTools &) = delete; MatchTools(QueryLimiter & queryLimiter, - const vespalib::Doom & softDoom, - const vespalib::Doom & hardDoom, + const vespalib::Doom & doom, const Query &query, MaybeMatchPhaseLimiter &match_limiter_in, const QueryEnvironment &queryEnv, @@ -55,8 +53,7 @@ public: const search::fef::RankSetup &rankSetup, const search::fef::Properties &featureOverrides); ~MatchTools(); - const vespalib::Doom &getSoftDoom() const { return _softDoom; } - const vespalib::Doom &getHardDoom() const { return _hardDoom; } + const vespalib::Doom &getDoom() const { return _doom; } QueryLimiter & getQueryLimiter() { return _queryLimiter; } MaybeMatchPhaseLimiter &match_limiter() { return _match_limiter; } bool has_second_phase_rank() const; @@ -87,13 +84,12 @@ private: vespalib::string _operation; }; -class MatchToolsFactory : public vespalib::noncopyable +class MatchToolsFactory { private: using IAttributeFunctor = search::attribute::IAttributeFunctor; QueryLimiter & _queryLimiter; RequestContext _requestContext; - const vespalib::Doom _hardDoom; Query _query; MaybeMatchPhaseLimiter::UP _match_limiter; QueryEnvironment _queryEnv; @@ -111,7 +107,6 @@ public: MatchToolsFactory(QueryLimiter & queryLimiter, const vespalib::Doom & softDoom, - const vespalib::Doom & hardDoom, ISearchContext &searchContext, search::attribute::IAttributeContext &attributeContext, vespalib::stringref queryStack, @@ -135,6 +130,7 @@ public: std::unique_ptr<AttributeOperationTask> createOnSummaryTask() const; const Query & query() const { return _query; } + const RequestContext & getRequestContext() const { return _requestContext; } }; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index 2764bfb96df..426bb353826 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -128,19 +128,22 @@ Matcher::create_match_tools_factory(const search::engine::Request &request, ISea { const Properties & rankProperties = request.propertiesMap.rankProperties(); bool softTimeoutEnabled = Enabled::lookup(rankProperties, _rankSetup->getSoftTimeoutEnabled()); + bool hasFactorOverride = Factor::isPresent(rankProperties); double factor = softTimeoutEnabled - ? Factor::lookup(rankProperties, _stats.softDoomFactor()) + ? ( hasFactorOverride + ? Factor::lookup(rankProperties, _stats.softDoomFactor()) + : _stats.softDoomFactor()) : 0.95; int64_t safeLeft = request.getTimeLeft() * factor; fastos::SteadyTimeStamp safeDoom(_clock.getTimeNSAssumeRunning() + safeLeft); if (softTimeoutEnabled) { - LOG(debug, "Soft-timeout computed factor=%1.3f, used factor=%1.3f, softTimeout=%" PRId64, - _stats.softDoomFactor(), factor, safeLeft); + LOG(debug, "Soft-timeout computed factor=%1.3f, used factor=%1.3f, userSupplied=%d, softTimeout=%" PRId64, + _stats.softDoomFactor(), factor, hasFactorOverride, safeLeft); } - return std::make_unique<MatchToolsFactory>(_queryLimiter, vespalib::Doom(_clock, safeDoom), - vespalib::Doom(_clock, request.getTimeOfDoom()), searchContext, - attrContext, request.getStackRef(), request.location, _viewResolver, - metaStore, _indexEnv, *_rankSetup, rankProperties, feature_overrides); + vespalib::Doom doom(_clock, safeDoom, request.getTimeOfDoom(), hasFactorOverride); + return std::make_unique<MatchToolsFactory>(_queryLimiter, doom, searchContext, attrContext, request.getStackRef(), + request.location, _viewResolver, metaStore, _indexEnv, *_rankSetup, + rankProperties, feature_overrides); } SearchReply::UP @@ -187,6 +190,7 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl SearchReply::UP reply = std::make_unique<SearchReply>(); size_t covered = 0; uint32_t numActiveLids = 0; + bool isDoomExplicit = false; { // we want to measure full set-up and tear-down time as part of // collateral time GroupingContext groupingContext(_clock, request.getTimeOfDoom(), @@ -212,6 +216,7 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl } MatchToolsFactory::UP mtf = create_match_tools_factory(request, searchContext, attrContext, metaStore, *feature_overrides); + isDoomExplicit = mtf->getRequestContext().getDoom().isExplicitSoftDoom(); traceQuery(6, request.trace(), mtf->query()); if (!mtf->valid()) { return reply; @@ -288,19 +293,21 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl _stats.add(my_stats); if (my_stats.softDoomed()) { double old = _stats.softDoomFactor(); - fastos::TimeStamp softLimit = (1.0 - _rankSetup->getSoftTimeoutTailCost()) * request.getTimeout(); + fastos::TimeStamp overtimeLimit = (1.0 - _rankSetup->getSoftTimeoutTailCost()) * request.getTimeout(); fastos::TimeStamp adjustedDuration = duration - my_stats.doomOvertime(); if (adjustedDuration < 0) { adjustedDuration = 0; } - bool allowedSoftTimeoutFactorAdjustment = (std::chrono::duration_cast<std::chrono::seconds>(my_clock::now() - _startTime).count() > SECONDS_BEFORE_ALLOWING_SOFT_TIMEOUT_FACTOR_ADJUSTMENT); + bool allowedSoftTimeoutFactorAdjustment = (std::chrono::duration_cast<std::chrono::seconds>(my_clock::now() - _startTime).count() > SECONDS_BEFORE_ALLOWING_SOFT_TIMEOUT_FACTOR_ADJUSTMENT) + && ! isDoomExplicit; if (allowedSoftTimeoutFactorAdjustment) { - _stats.updatesoftDoomFactor(request.getTimeout(), softLimit, adjustedDuration); + _stats.updatesoftDoomFactor(request.getTimeout(), overtimeLimit, adjustedDuration); } - LOG(info, "Triggered softtimeout factor adjustment. Coverage = %lu of %u documents. request=%1.3f, doomOvertime=%1.3f, limit=%1.3f and duration=%1.3f, rankprofile=%s" + LOG(info, "Triggered softtimeout %s. Coverage = %lu of %u documents. request=%1.3f, doomOvertime=%1.3f, overtime_limit=%1.3f and duration=%1.3f, rankprofile=%s" ", factor %sadjusted from %1.3f to %1.3f", + isDoomExplicit ? "with query override" : "factor adjustment", covered, numActiveLids, - request.getTimeout().sec(), my_stats.doomOvertime().sec(), softLimit.sec(), duration.sec(), + request.getTimeout().sec(), my_stats.doomOvertime().sec(), overtimeLimit.sec(), duration.sec(), request.ranking.c_str(), (allowedSoftTimeoutFactorAdjustment ? "" : "NOT "), old, _stats.softDoomFactor()); } } diff --git a/searchcore/src/vespa/searchcore/proton/matching/querylimiter.cpp b/searchcore/src/vespa/searchcore/proton/matching/querylimiter.cpp index 0d985496d41..5053cc5fdbe 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/querylimiter.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/querylimiter.cpp @@ -2,8 +2,7 @@ #include "querylimiter.h" #include <chrono> -namespace proton { -namespace matching { +namespace proton:: matching { QueryLimiter::LimitedToken::LimitedToken(const Doom & doom, QueryLimiter & limiter) : _limiter(limiter) @@ -20,8 +19,8 @@ void QueryLimiter::grabToken(const Doom & doom) { std::unique_lock<std::mutex> guard(_lock); - while ((_maxThreads > 0) && (_activeThreads >= _maxThreads) && !doom.doom()) { - int left = doom.left().ms(); + while ((_maxThreads > 0) && (_activeThreads >= _maxThreads) && !doom.hard_doom()) { + int left = doom.hard_left().ms(); if (left > 0) { _cond.wait_for(guard, std::chrono::milliseconds(left)); } @@ -62,13 +61,12 @@ QueryLimiter::getToken(const Doom & doom, uint32_t numDocs, uint32_t numHits, bo if (hasSorting || hasGrouping) { if (numHits > _minHits) { if (numDocs * _coverage < numHits) { - return Token::UP(new LimitedToken(doom, *this)); + return std::make_unique<LimitedToken>(doom, *this); } } } } - return Token::UP(new NoLimitToken()); + return std::make_unique<NoLimitToken>(); } } -} diff --git a/searchcore/src/vespa/searchcore/proton/matching/querylimiter.h b/searchcore/src/vespa/searchcore/proton/matching/querylimiter.h index fbe8526b051..45783959957 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/querylimiter.h +++ b/searchcore/src/vespa/searchcore/proton/matching/querylimiter.h @@ -7,8 +7,7 @@ #include <mutex> #include <condition_variable> -namespace proton { -namespace matching { +namespace proton::matching { class QueryLimiter { @@ -46,6 +45,4 @@ private: volatile uint32_t _minHits; }; -} // namespace matching -} // namespace proton - +} diff --git a/searchcore/src/vespa/searchcore/proton/matching/requestcontext.cpp b/searchcore/src/vespa/searchcore/proton/matching/requestcontext.cpp index 918e4a14649..c30854c051f 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/requestcontext.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/requestcontext.cpp @@ -14,9 +14,9 @@ namespace proton { using search::attribute::IAttributeVector; -RequestContext::RequestContext(const Doom & softDoom, IAttributeContext & attributeContext, +RequestContext::RequestContext(const Doom & doom, IAttributeContext & attributeContext, const search::fef::Properties& rank_properties) - : _softDoom(softDoom), + : _doom(doom), _attributeContext(attributeContext), _rank_properties(rank_properties) { diff --git a/searchcore/src/vespa/searchcore/proton/matching/requestcontext.h b/searchcore/src/vespa/searchcore/proton/matching/requestcontext.h index 0352e28eea2..31d3d573a20 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/requestcontext.h +++ b/searchcore/src/vespa/searchcore/proton/matching/requestcontext.h @@ -5,6 +5,7 @@ #include <vespa/eval/tensor/tensor.h> #include <vespa/searchlib/queryeval/irequestcontext.h> #include <vespa/searchcommon/attribute/iattributecontext.h> +#include <vespa/vespalib/util/doom.h> namespace search::fef { class Properties; } @@ -20,7 +21,7 @@ public: RequestContext(const Doom & softDoom, IAttributeContext & attributeContext, const search::fef::Properties& rank_properties); - const Doom & getSoftDoom() const override { return _softDoom; } + const Doom & getDoom() const override { return _doom; } const search::attribute::IAttributeVector *getAttribute(const vespalib::string &name) const override; void asyncForAttribute(const vespalib::string &name, std::unique_ptr<IAttributeFunctor> func) const override; @@ -31,9 +32,9 @@ public: private: - const Doom _softDoom; - IAttributeContext & _attributeContext; - const search::fef::Properties& _rank_properties; + const Doom _doom; + IAttributeContext & _attributeContext; + const search::fef::Properties & _rank_properties; }; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/result_processor.cpp b/searchcore/src/vespa/searchcore/proton/matching/result_processor.cpp index 445aab310d9..41052978997 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/result_processor.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/result_processor.cpp @@ -85,7 +85,7 @@ ResultProcessor::prepareThreadContextCreation(size_t num_threads) if (num_threads > 1) { _wasMerged = true; } - if (_groupingSession.get() != 0) { + if (_groupingSession) { _groupingSession->prepareThreadContextCreation(num_threads); } } diff --git a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp index 37f1664841a..8a4cb1682a6 100644 --- a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp @@ -1,10 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "disk_mem_usage_sampler.h" -#include <vespa/vespalib/util/timer.h> +#include <vespa/vespalib/util/scheduledexecutor.h> #include <vespa/vespalib/util/lambdatask.h> #include <filesystem> -#include <unistd.h> using vespalib::makeLambdaTask; @@ -32,7 +31,7 @@ DiskMemUsageSampler::setConfig(const Config &config) _filter.setConfig(config.filterConfig); _sampleInterval = config.sampleInterval; sampleUsage(); - _periodicTimer = std::make_unique<vespalib::Timer>(); + _periodicTimer = std::make_unique<vespalib::ScheduledExecutor>(); _periodicTimer->scheduleAtFixedRate(makeLambdaTask([this]() { sampleUsage(); }), _sampleInterval, _sampleInterval); diff --git a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h index 5a439e69003..2ab13f2f48a 100644 --- a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h +++ b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h @@ -4,7 +4,7 @@ #include "disk_mem_usage_filter.h" -namespace vespalib { class Timer; } +namespace vespalib { class ScheduledExecutor; } namespace proton { @@ -15,7 +15,7 @@ class DiskMemUsageSampler { DiskMemUsageFilter _filter; std::filesystem::path _path; double _sampleInterval; - std::unique_ptr<vespalib::Timer> _periodicTimer; + std::unique_ptr<vespalib::ScheduledExecutor> _periodicTimer; void sampleUsage(); void sampleDiskUsage(); diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp index a2672cc7972..893748ae49e 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp @@ -6,7 +6,7 @@ #include "i_blockable_maintenance_job.h" #include <vespa/searchcorespi/index/i_thread_service.h> #include <vespa/vespalib/util/closuretask.h> -#include <vespa/vespalib/util/timer.h> +#include <vespa/vespalib/util/scheduledexecutor.h> #include <vespa/log/log.h> LOG_SETUP(".proton.server.maintenancecontroller"); @@ -167,7 +167,7 @@ MaintenanceController::restart() if (!_started || _stopping || !_readySubDB.valid()) { return; } - _periodicTimer.reset(new vespalib::Timer()); + _periodicTimer = std::make_unique<vespalib::ScheduledExecutor>(); addJobsToPeriodicTimer(); } diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h index 24c1c18959e..3cfdeba4d34 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h +++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h @@ -8,6 +8,7 @@ #include "ibucketfreezelistener.h" #include <vespa/searchcore/proton/common/doctypename.h> #include <mutex> +#include <vespa/vespalib/util/scheduledexecutor.h> namespace vespalib { class Timer; @@ -77,7 +78,7 @@ private: MaintenanceDocumentSubDB _readySubDB; MaintenanceDocumentSubDB _remSubDB; MaintenanceDocumentSubDB _notReadySubDB; - std::unique_ptr<vespalib::Timer> _periodicTimer; + std::unique_ptr<vespalib::ScheduledExecutor> _periodicTimer; DocumentDBMaintenanceConfigSP _config; FrozenBuckets _frozenBuckets; bool _started; diff --git a/searchcore/src/vespa/searchcore/proton/server/proton.cpp b/searchcore/src/vespa/searchcore/proton/server/proton.cpp index bbc599a661e..b85a1e00574 100644 --- a/searchcore/src/vespa/searchcore/proton/server/proton.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/proton.cpp @@ -231,7 +231,7 @@ Proton::init() { assert( ! _initStarted && ! _initComplete ); _initStarted = true; - if (_threadPool.NewThread(&_clock, nullptr) == nullptr) { + if (_threadPool.NewThread(_clock.getRunnable(), nullptr) == nullptr) { throw IllegalStateException("Failed starting thread for the cheap clock"); } _protonConfigFetcher.start(); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index abfae426ad0..98b975546e7 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -348,6 +348,7 @@ "public" ], "methods": [ + "public static com.yahoo.searchlib.rankingexpression.Reference fromIdentifier(java.lang.String)", "public void <init>(java.lang.String, com.yahoo.searchlib.rankingexpression.rule.Arguments, java.lang.String)", "public com.yahoo.searchlib.rankingexpression.rule.Arguments arguments()", "public java.lang.String output()", @@ -1398,13 +1399,11 @@ "public void <init>(java.util.Map)", "public void <init>(java.util.Map, java.util.Map)", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction getFunction(java.lang.String)", - "protected final com.google.common.collect.ImmutableMap functions()", + "protected com.google.common.collect.ImmutableMap functions()", "public java.lang.String getBinding(java.lang.String)", "public com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext withBindings(java.util.Map)" ], - "fields": [ - "public final java.util.Map bindings" - ] + "fields": [] }, "com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode": { "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", @@ -1552,7 +1551,7 @@ "public void <init>()", "public void <init>(java.util.Collection)", "public void <init>(java.util.Map)", - "public void <init>(java.util.List, java.util.Map)", + "public void <init>(java.util.Collection, java.util.Map)", "public void <init>(java.util.Collection, java.util.Map, java.util.Map)", "public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map)", "public void addFunctionSerialization(java.lang.String, java.lang.String)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index 722520fea08..18f6c6f2ca2 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -265,8 +265,8 @@ public class RankingExpression implements Serializable { /** * Returns the rank-property name for a given expression name. * - * @param expressionName The expression name to mangle. - * @return The property name. + * @param expressionName the expression name to mangle. + * @return the property name. */ public static String propertyName(String expressionName) { return "rankingExpression(" + expressionName + ").rankingScript"; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java index fa2d0f1ee45..3c537b53e9d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java @@ -27,15 +27,28 @@ public class Reference extends Name { /** The output, or null if none */ private final String output; + /** True if this was created by the "fromIdentifier" method. This lets us separate 'foo()' and 'foo' */ + private final boolean isIdentifier; + + public static Reference fromIdentifier(String identifier) { + return new Reference(identifier, new Arguments(), null, true); + } + public Reference(String name, Arguments arguments, String output) { + this(name, arguments, output, false); + } + + private Reference(String name, Arguments arguments, String output, boolean isIdentifier) { super(name); Objects.requireNonNull(name, "name cannot be null"); Objects.requireNonNull(arguments, "arguments cannot be null"); this.arguments = arguments; this.output = output; - this.hashCode = Objects.hash(name(), arguments, output); + this.hashCode = Objects.hash(name(), arguments, output, isIdentifier); + this.isIdentifier = isIdentifier; } + public Arguments arguments() { return arguments; } public String output() { return output; } @@ -66,12 +79,8 @@ public class Reference extends Name { return Optional.of(simple(featureName, argument)); } - /** - * Returns whether this is a simple identifier - no arguments or output - */ - public boolean isIdentifier() { - return this.arguments.expressions().size() == 0 && output == null; - } + /** Returns true if this was created by fromIdentifier. Identifiers have no arguments or outputs. */ + public boolean isIdentifier() { return isIdentifier; } /** * A <i>simple feature reference</i> is a reference with a single identifier argument @@ -105,11 +114,11 @@ public class Reference extends Name { } public Reference withArguments(Arguments arguments) { - return new Reference(name(), arguments, output); + return new Reference(name(), arguments, output, isIdentifier && arguments.isEmpty()); } public Reference withOutput(String output) { - return new Reference(name(), arguments, output); + return new Reference(name(), arguments, output, isIdentifier && output == null); } @Override @@ -121,6 +130,7 @@ public class Reference extends Name { if (!Objects.equals(other.name(), this.name())) return false; if (!Objects.equals(other.arguments, this.arguments)) return false; if (!Objects.equals(other.output, this.output)) return false; + if (!Objects.equals(other.isIdentifier, this.isIdentifier)) return false; return true; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index f531d77762d..69304a811b1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -69,7 +69,7 @@ public class MapContext extends Context { * Sets the value of a key. The value is frozen by this. */ @Override - public void put(String key,Value value) { + public void put(String key, Value value) { bindings.put(key, value.freeze()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 39e408d27ca..382cbb7ce9a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -76,7 +76,7 @@ public abstract class Value { * @return this for convenience */ public Value freeze() { - frozen=true; + frozen = true; return this; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index d306e067d16..dd1ef263cba 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -40,8 +40,6 @@ public final class EmbracedNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - if (value instanceof ReferenceNode) - return value.toString(string, context, path, this); return value.toString(string.append('('), context, path, this).append(')'); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java index 084bfe65e06..83aabada8f0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -20,8 +20,7 @@ public class FunctionReferenceContext { private final ImmutableMap<String, ExpressionFunction> functions; /** Mapping from argument names to the expressions they resolve to */ - // TODO: Make private - public final Map<String, String> bindings = new HashMap<>(); + private final Map<String, String> bindings = new HashMap<>(); /** Create a context for a single serialization task */ public FunctionReferenceContext() { @@ -56,14 +55,12 @@ public class FunctionReferenceContext { return mapBuilder.build(); } - /** - * Returns a function or null if it isn't defined in this context - */ + /** Returns a function or null if it isn't defined in this context */ public ExpressionFunction getFunction(String name) { return functions.get(name); } - protected final ImmutableMap<String, ExpressionFunction> functions() { return functions; } + protected ImmutableMap<String, ExpressionFunction> functions() { return functions; } - /** Returns the resolution of an argument, or null if it isn't defined in this context */ + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ public String getBinding(String name) { return bindings.get(name); } /** Returns a new context with the bindings replaced by the given bindings */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 7312863fa26..8fec3603f3e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayDeque; import java.util.Deque; import java.util.List; +import java.util.Optional; /** * A node referring either to a value in the context or to a named ranking expression function. @@ -25,7 +26,7 @@ public final class ReferenceNode extends CompositeNode { /* Creates a node with a simple identifier reference */ public ReferenceNode(String name) { - this(name, null, null); + this.reference = Reference.fromIdentifier(name); } public ReferenceNode(String name, List<? extends ExpressionNode> arguments, String output) { @@ -67,7 +68,7 @@ public final class ReferenceNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - // A reference to a function argument? + // A reference to an identifier (function argument or bound variable)? if (reference.isIdentifier() && context.getBinding(getName()) != null) { // a bound identifier: replace by the value it is bound to return string.append(context.getBinding(getName())); @@ -89,6 +90,7 @@ public final class ReferenceNode extends CompositeNode { return string.append("rankingExpression(").append(instance.getName()).append(')'); } + // Not resolved in this context: output as-is return reference.toString(string, context, path, parent); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 4acc1a85490..d7807caa2b6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -8,7 +8,6 @@ import com.yahoo.tensor.TensorType; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; -import java.util.List; import java.util.Map; /** @@ -37,7 +36,7 @@ public class SerializationContext extends FunctionReferenceContext { } /** Create a context for a single serialization task */ - public SerializationContext(List<ExpressionFunction> functions, Map<String, String> bindings) { + public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) { this(functions, bindings, new LinkedHashMap<>()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index cec8837abcd..0a67ab5534e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,8 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -16,8 +18,6 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; -import java.sql.Ref; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.LinkedHashMap; @@ -71,7 +71,9 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext<Reference> context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { + return function.type(context); + } @Override public Value evaluate(Context context) { @@ -117,9 +119,16 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.parent() != null) + outermost = outermost.parent(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent).toString(); } else { return expression.toString(); @@ -180,9 +189,17 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionToStringContext) { - ExpressionToStringContext context = (ExpressionToStringContext) c; - return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString(); + ToStringContext outermost = c; + while (outermost.parent() != null) + outermost = outermost.parent(); + + if (outermost instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext)outermost; + return expression.toString(new StringBuilder(), + new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + context.path, + context.parent) + .toString(); } else { return expression.toString(); @@ -191,23 +208,83 @@ public class TensorFunctionNode extends CompositeNode { } - /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionToStringContext implements ToStringContext { + /** + * This is used to pass a full expression serialization context through tensor functions. + * Tensor functions cannot see the full serialization context because it exposes expressions + * (which depends on the tensor module), but they need to be able to recursively add their own + * contexts (binding scopes) due to Generate binding dimension names. + * + * To be able to achieve both passing the serialization context through functions *and* allow them + * to add more context, we need to keep track of both these contexts here separately and map between + * contexts as seen in the toString methods of the function classes above. + */ + private static class ExpressionToStringContext extends SerializationContext implements ToStringContext { - final SerializationContext context; - final Deque<String> path; - final CompositeNode parent; + private final ToStringContext wrappedToStringContext; + private final SerializationContext wrappedSerializationContext; + private final Deque<String> path; + private final CompositeNode parent; public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), null, null); - public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { - this.context = context; + ExpressionToStringContext(SerializationContext wrappedSerializationContext, Deque<String> path, CompositeNode parent) { + this(wrappedSerializationContext, null, path, parent); + } + + ExpressionToStringContext(SerializationContext wrappedSerializationContext, + ToStringContext wrappedToStringContext, + Deque<String> path, + CompositeNode parent) { + this.wrappedSerializationContext = wrappedSerializationContext; + this.wrappedToStringContext = wrappedToStringContext; this.path = path; this.parent = parent; } + /** Adds the serialization of a function */ + public void addFunctionSerialization(String name, String expressionString) { + wrappedSerializationContext.addFunctionSerialization(name, expressionString); + } + + /** Adds the serialization of the an argument type to a function */ + public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { + wrappedSerializationContext.addArgumentTypeSerialization(functionName, argumentName, type); + } + + /** Adds the serialization of the return type of a function */ + public void addFunctionTypeSerialization(String functionName, TensorType type) { + wrappedSerializationContext.addFunctionTypeSerialization(functionName, type); + } + + public Map<String, String> serializedFunctions() { + return wrappedSerializationContext.serializedFunctions(); + } + + /** Returns a function or null if it isn't defined in this context */ + public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); } + + protected ImmutableMap<String, ExpressionFunction> functions() { return wrappedSerializationContext.functions(); } + + public ToStringContext parent() { return wrappedToStringContext; } + + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ + @Override + public String getBinding(String name) { + if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) + return wrappedToStringContext.getBinding(name); + else + return wrappedSerializationContext.getBinding(name); + } + + /** Returns a new context with the bindings replaced by the given bindings */ + @Override + public ExpressionToStringContext withBindings(Map<String, String> bindings) { + return new ExpressionToStringContext(new SerializationContext(wrappedSerializationContext.functions().values(), bindings), + wrappedToStringContext, path, parent); + } + } /** Turns an EvaluationContext into a Context */ diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index c7870182939..fdad824cd1b 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -275,7 +275,12 @@ ReferenceNode feature() : } { ( name = identifier() [ <LBRACE> args = args() <RBRACE> ] [ <DOT> out = outs() ] ) - { return new ReferenceNode(name, args, out); } + { + if (args == null && out == null) // know the difference between "foo" and "foo()" + return new ReferenceNode(name); + else + return new ReferenceNode(name, args, out); + } } // Rank properties are referenced by $propertyname diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 26fcec9efba..e3d3ac7b2e1 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -159,6 +159,54 @@ public class RankingExpressionTestCase { } @Test + public void testFunctionInTensorSerialization() throws ParseException { + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("scalarFunction", List.of(), new RankingExpression("5"))); + functions.add(new ExpressionFunction("tensorFunction", List.of(), new RankingExpression("tensor(x[3]):[1, 2, 3]"))); + + // Getting a value from a tensor supplied by a function, inside a tensor generate function + assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction)[x])"), + "tensor(x[3])(tensorFunction[x])", + functions, false); + + // Getting a value from a tensor supplied by a function, where the value index is supplied by a function, inside a tensor generate function, short form + assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction)[rankingExpression(scalarFunction)])"), + "tensor(x[3])(tensorFunction[scalarFunction()])", + functions, false); + + // 'scalarFunction' is interpreted as a label here since it is the short form of a mapped dimension + assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction){scalarFunction})"), + "tensor(x[3])(tensorFunction{scalarFunction})", + functions, false); + + // Getting a value from a tensor supplied by a function, where the value index is supplied by a function, inside a tensor generate function, long form + assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction){x:rankingExpression(scalarFunction)})"), + "tensor(x[3])(tensorFunction{x:scalarFunction()})", + functions, false); + + // 'scalarFunction' without parentheses is interpreted as a label instead of a reference to the function + assertSerialization(List.of("tensor(x[3])(rankingExpression(tensorFunction){x:scalarFunction})"), + "tensor(x[3])(tensorFunction{x:scalarFunction})", + functions, false); + + // Accessing a function in a dynamic tensor, short form + assertSerialization(List.of("tensor(x[2]):{{x:0}:rankingExpression(scalarFunction),{x:1}:rankingExpression(scalarFunction)}"), + "tensor(x[2]):[scalarFunction(), scalarFunction()]]", + functions, false); + + // Accessing a function in a dynamic tensor, long form + assertSerialization(List.of("tensor(x{}):{{x:foo}:rankingExpression(scalarFunction),{x:bar}:rankingExpression(scalarFunction)}"), + "tensor(x{}):{{x:foo}:scalarFunction(), {x:bar}:scalarFunction()}", + functions, false); + + // Shadowing + assertSerialization(List.of("tensor(scalarFunction[1])(rankingExpression(tensorFunction){x:scalarFunction + rankingExpression(scalarFunction)})"), + "tensor(scalarFunction[1])(tensorFunction{x: scalarFunction + scalarFunction()})", + functions, false); + + } + + @Test public void testBug3464208() throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); functions.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); @@ -245,7 +293,7 @@ public class RankingExpressionTestCase { @Test public void testNonCanonicalLegalStrings() throws ParseException { - assertParse("a * b + c * d", "a* (b) + \nc*d"); + assertParse("a * (b) + c * d", "a* (b) + \nc*d"); } @Test @@ -323,7 +371,8 @@ public class RankingExpressionTestCase { } } - private void assertSerialization(List<String> expectedSerialization, String expressionString, + private void assertSerialization(List<String> expectedSerialization, + String expressionString, List<ExpressionFunction> functions) { assertSerialization(expectedSerialization, expressionString, functions, false); } @@ -331,13 +380,13 @@ public class RankingExpressionTestCase { List<ExpressionFunction> functions, boolean print) { try { if (print) - System.out.println("Parsing expression '" + expressionString + "'."); + System.out.println("Parsing expression '" + expressionString + "':"); RankingExpression expression = new RankingExpression(expressionString); Map<String, String> rankProperties = expression.getRankProperties(functions); if (print) { for (String key : rankProperties.keySet()) - System.out.println("Property '" + key + "': " + rankProperties.get(key)); + System.out.println(key + ": " + rankProperties.get(key)); } for (int i = 0; i < expectedSerialization.size();) { String val = expectedSerialization.get(i++); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 99047aeb79d..38f152d728c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -316,6 +316,12 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {x:0}:0, {x:1}:1, {x:2}:2 }", "range(x[3])"); tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])"); tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)"); + tester.assertEvaluates("tensor(x[2]):[0.0, 2.0]", + "tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a)," + + "{y:1}:((0+1)+a)}{y:0}," + + "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a)," + + "{y:1}:((1+1)+a)}{y:1}" + + "}"); // tensor value tester.assertEvaluates("3.0", "tensor0{x:1}", "{ {x:0}:1, {x:1}:3 }"); @@ -367,6 +373,9 @@ public class EvaluationTestCase { tester.assertEvaluates("tensor(j[2]):[6, 5]", "tensor(j[2])(tensor0{key:bar,i:2-j})", "tensor(key{},i[5]):{{key:foo,i:0}:1,{key:foo,i:1}:2,{key:foo,i:2}:2,{key:bar,i:0}:4,{key:bar,i:1}:5,{key:bar,i:2}:6}"); + tester.assertEvaluates("5.5", + "sum(tensor(d0[1])(tensor0{x:mykey}))", + "tensor(x{}):{{x:mykey}:5.5}"); // tensor result dimensions are given from argument dimensions, not the resulting values tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }"); diff --git a/searchlib/src/tests/postinglistbm/stress_runner.cpp b/searchlib/src/tests/postinglistbm/stress_runner.cpp index 53b683cd7fd..100a4fcd70d 100644 --- a/searchlib/src/tests/postinglistbm/stress_runner.cpp +++ b/searchlib/src/tests/postinglistbm/stress_runner.cpp @@ -8,9 +8,11 @@ #include <vespa/searchlib/test/fakedata/fakeword.h> #include <vespa/searchlib/test/fakedata/fakewordset.h> #include <vespa/searchlib/test/fakedata/fpfactory.h> +#include <vespa/vespalib/util/time.h> #include <condition_variable> #include <mutex> #include <vector> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".stress_runner"); @@ -306,7 +308,7 @@ StressMaster::run() totalTime / _loops, type.c_str()); dropPostings(); } - FastOS_Thread::Sleep(250); + std::this_thread::sleep_for(250ms); } double diff --git a/searchlib/src/tests/transactionlog/translogclient_test.cpp b/searchlib/src/tests/transactionlog/translogclient_test.cpp index 8a515f749f1..c4751af5adb 100644 --- a/searchlib/src/tests/transactionlog/translogclient_test.cpp +++ b/searchlib/src/tests/transactionlog/translogclient_test.cpp @@ -248,7 +248,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor = tls.createVisitor("test1", ca); ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(5, 7) ); - for (size_t i(0); ! ca._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca._eof ); ASSERT_TRUE( ca.map().size() == 1); ASSERT_TRUE( ca.hasSerial(7) ); @@ -257,7 +257,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor1 = tls.createVisitor("test1", ca1); ASSERT_TRUE(visitor1.get()); ASSERT_TRUE( visitor1->visit(4, 5) ); - for (size_t i(0); ! ca1._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca1._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca1._eof ); ASSERT_TRUE( ca1.map().size() == 0); @@ -265,7 +265,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor2 = tls.createVisitor("test1", ca2); ASSERT_TRUE(visitor2.get()); ASSERT_TRUE( visitor2->visit(5, 6) ); - for (size_t i(0); ! ca2._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca2._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca2._eof ); ASSERT_TRUE( ca2.map().size() == 0); @@ -273,7 +273,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor3 = tls.createVisitor("test1", ca3); ASSERT_TRUE(visitor3.get()); ASSERT_TRUE( visitor3->visit(5, 1000) ); - for (size_t i(0); ! ca3._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca3._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca3._eof ); ASSERT_TRUE( ca3.map().size() == 1); ASSERT_TRUE( ca3.hasSerial(7) ); @@ -437,7 +437,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c TransLogClient::Visitor::UP visitor = tls.createVisitor(name, ca); ASSERT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(0, 1) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( ca.hasSerial(1) ); @@ -447,7 +447,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c visitor = tls.createVisitor(name, ca); ASSERT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(1, 2) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( ! ca.hasSerial(1) ); @@ -458,7 +458,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c visitor = tls.createVisitor(name, ca); EXPECT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(0, 3) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( ca.hasSerial(1) ); @@ -469,7 +469,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c visitor = tls.createVisitor(name, ca); ASSERT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(2, 3) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( !ca.hasSerial(1) ); @@ -575,7 +575,7 @@ assertVisitStats(TransLogClient &tls, const vespalib::string &domain, ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(visitStart, visitEnd) ); for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } ASSERT_TRUE(ca._eof); EXPECT_EQUAL(expFirstSerial, ca._firstSerial); @@ -623,7 +623,7 @@ void Test::testMany() TransLogClient::Visitor::UP visitor = tls.createVisitor("many", ca); ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(2, TOTAL_NUM_ENTRIES) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca._eof ); EXPECT_EQUAL(ca._count, TOTAL_NUM_ENTRIES); EXPECT_EQUAL(ca._value, TOTAL_NUM_ENTRIES); @@ -644,7 +644,7 @@ void Test::testMany() TransLogClient::Visitor::UP visitor = tls.createVisitor("many", ca); ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(2, TOTAL_NUM_ENTRIES) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca._eof ); EXPECT_EQUAL(ca._count, TOTAL_NUM_ENTRIES); EXPECT_EQUAL(ca._value, TOTAL_NUM_ENTRIES); diff --git a/searchlib/src/tests/transactionlogstress/translogstress.cpp b/searchlib/src/tests/transactionlogstress/translogstress.cpp index a047c5e1657..2ec193cfe45 100644 --- a/searchlib/src/tests/transactionlogstress/translogstress.cpp +++ b/searchlib/src/tests/transactionlogstress/translogstress.cpp @@ -11,8 +11,11 @@ #include <iostream> #include <stdexcept> #include <sstream> +#include <thread> #include <vespa/log/log.h> +#include <vespa/vespalib/util/time.h> + LOG_SETUP("translogstress"); using document::ByteBuffer; @@ -267,7 +270,7 @@ FeederThread::doRun() int64_t milliSecsUsed = _timer.elapsed().ms(); if (milliSecsUsed < 1000) { //LOG(info, "FeederThread: sleep %u ms", 1000 - milliSecsUsed); - FastOS_Thread::Sleep(1000 - milliSecsUsed); + std::this_thread::sleep_for(std::chrono::milliseconds(1000 - milliSecsUsed)); } else { LOG(info, "FeederThread: max throughput"); } @@ -457,7 +460,7 @@ private: EntryGenerator _generator; std::vector<std::shared_ptr<VisitorAgent> > _visitors; std::vector<std::shared_ptr<VisitorAgent> > _rndVisitors; - uint64_t _visitorInterval; // in milliseconds + vespalib::duration _visitorInterval; // in milliseconds int64_t _pruneInterval; // in milliseconds fastos::StopWatch _pruneTimer; SerialNum _begin; @@ -481,14 +484,14 @@ ControllerThread::ControllerThread(const std::string & tlsSpec, const std::strin const EntryGenerator & generator, uint32_t numVisitors, uint64_t visitorInterval, uint64_t pruneInterval) : _tlsSpec(tlsSpec), _domain(domain), _client(tlsSpec.c_str()), _session(), - _generator(generator), _visitors(), _rndVisitors(), _visitorInterval(visitorInterval), + _generator(generator), _visitors(), _rndVisitors(), _visitorInterval(std::chrono::milliseconds(visitorInterval)), _pruneInterval(pruneInterval), _pruneTimer(), _begin(0), _end(0), _count(0) { for (uint32_t i = 0; i < numVisitors; ++i) { _visitors.push_back(std::make_shared<VisitorAgent>(tlsSpec, domain, generator, i, true)); } } -ControllerThread::~ControllerThread() {} +ControllerThread::~ControllerThread() = default; void ControllerThread::getStatus() @@ -553,7 +556,7 @@ ControllerThread::doRun() } _pruneTimer.restart(); } - FastOS_Thread::Sleep(_visitorInterval); + std::this_thread::sleep_for(_visitorInterval); } } @@ -569,7 +572,7 @@ private: uint64_t domainPartSize; size_t packetSize; - uint64_t stressTime; + std::chrono::milliseconds stressTime; uint32_t feedRate; uint32_t numVisitors; uint64_t visitorInterval; @@ -598,7 +601,7 @@ void TransLogStress::printConfig() { std::cout << "######## Config ########" << std::endl; - std::cout << "stressTime: " << _cfg.stressTime / 1000 << " s" << std::endl; + std::cout << "stressTime: " << vespalib::to_s(_cfg.stressTime) << " s" << std::endl; std::cout << "feedRate: " << _cfg.feedRate << " per/sec" << std::endl; std::cout << "numVisitors: " << _cfg.numVisitors << std::endl; std::cout << "visitorInterval: " << _cfg.visitorInterval << " ms" << std::endl; @@ -628,7 +631,7 @@ TransLogStress::Main() _cfg.domainPartSize = 8000000; // ~8MB _cfg.packetSize = 0x10000; - _cfg.stressTime = 1000 * 60; + _cfg.stressTime = std::chrono::milliseconds(1000 * 60); _cfg.feedRate = 10000; _cfg.numVisitors = 1; _cfg.visitorInterval = 1000 * 1; @@ -639,7 +642,7 @@ TransLogStress::Main() _cfg.maxStrLen = 80; _cfg.baseSeed = 100; - uint64_t sleepTime = 4000; + vespalib::duration sleepTime = 4s; int idx = 1; char opt; @@ -654,7 +657,7 @@ TransLogStress::Main() _cfg.packetSize = atol(arg); break; case 't': - _cfg.stressTime = 1000 * atol(arg); + _cfg.stressTime = std::chrono::milliseconds(1000 * atol(arg)); break; case 'f': _cfg.feedRate = atoi(arg); @@ -690,7 +693,7 @@ TransLogStress::Main() } printConfig(); - FastOS_Thread::Sleep(sleepTime); + std::this_thread::sleep_for(sleepTime); if (_argc != idx || optError) { usage(); @@ -721,13 +724,13 @@ TransLogStress::Main() FeederThread feeder(tlsSpec, domain, generator, _cfg.feedRate, _cfg.packetSize); threadPool.NewThread(&feeder); - FastOS_Thread::Sleep(sleepTime); + std::this_thread::sleep_for(sleepTime); ControllerThread controller(tlsSpec, domain, generator, _cfg.numVisitors, _cfg.visitorInterval, _cfg.pruneInterval); threadPool.NewThread(&controller); // stop feeder and controller - FastOS_Thread::Sleep(_cfg.stressTime); + std::this_thread::sleep_for(_cfg.stressTime); printConfig(); LOG(info, "Stop feeder..."); feeder.stop(); @@ -735,7 +738,7 @@ TransLogStress::Main() std::cout << "<feeder>" << std::endl; std::cout << " <from>" << feeder.getRange().from() << "</from>" << std::endl; std::cout << " <to>" << feeder.getRange().to() << "</to>" << std::endl; - std::cout << " <rate>" << 1000 * (feeder.getRange().to() - feeder.getRange().from()) / (sleepTime + _cfg.stressTime) + std::cout << " <rate>" << 1000 * (feeder.getRange().to() - feeder.getRange().from()) / vespalib::count_ms(sleepTime + _cfg.stressTime) << "</rate>" << std::endl; std::cout << "</feeder>" << std::endl; @@ -743,7 +746,7 @@ TransLogStress::Main() controller.stop(); controller.join(); - FastOS_Thread::Sleep(sleepTime); + std::this_thread::sleep_for(sleepTime); std::vector<std::shared_ptr<VisitorAgent> > & visitors = controller.getVisitors(); for (size_t i = 0; i < visitors.size(); ++i) { std::cout << "<visitor id='" << i << "'>" << std::endl; diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp index 5da5481ecbb..69d4e6a5ee9 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp @@ -174,10 +174,15 @@ SingleValueNumericAttribute<B>::clearDocs(DocId lidLow, DocId lidLimit) { assert(lidLow <= lidLimit); assert(lidLimit <= this->getNumDocs()); + uint32_t count = 0; + constexpr uint32_t commit_interval = 1000; for (DocId lid = lidLow; lid < lidLimit; ++lid) { if (!attribute::isUndefined(_data[lid])) { this->clearDoc(lid); } + if ((++count % commit_interval) == 0) { + this->commit(); + } } } diff --git a/searchlib/src/vespa/searchlib/common/sortresults.cpp b/searchlib/src/vespa/searchlib/common/sortresults.cpp index 757ba9f3f9a..729f31795c9 100644 --- a/searchlib/src/vespa/searchlib/common/sortresults.cpp +++ b/searchlib/src/vespa/searchlib/common/sortresults.cpp @@ -242,7 +242,7 @@ FastS_SortSpec::initSortData(const RankedHit *hits, uint32_t n) _sortDataArray.resize(n); document::GlobalId gid; - for (uint32_t i(0), idx(0); (i < n) && !_doom.doom(); ++i) { + for (uint32_t i(0), idx(0); (i < n) && !_doom.hard_doom(); ++i) { uint32_t len = 0; for (auto iter = _vectors.begin(); iter != _vectors.end(); ++iter) { int written(0); diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp index ce1bd69cc4c..5e7523f53c5 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp @@ -40,7 +40,7 @@ lookupDouble(const Properties &props, const vespalib::string &name, double defau { Property p = props.lookup(name); if (p.found()) { - return vespalib::locale::c::strtod(p.get().c_str(), NULL); + return vespalib::locale::c::strtod(p.get().c_str(), nullptr); } return defaultValue; } @@ -306,6 +306,10 @@ double Factor::lookup(const Properties &props, double defaultValue) { return lookupDouble(props, NAME, defaultValue); } +bool Factor::isPresent(const Properties &props) { + return props.lookup(NAME).found(); +} + } namespace matchphase { diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.h b/searchlib/src/vespa/searchlib/fef/indexproperties.h index 57aa24222a3..9fa28bfaff2 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.h +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.h @@ -237,6 +237,7 @@ namespace softtimeout { static const double DEFAULT_VALUE; static double lookup(const Properties &props); static double lookup(const Properties &props, double defaultValue); + static bool isPresent(const Properties &props); }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.cpp b/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.cpp index 7220235dc48..9af6d7024a2 100644 --- a/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.cpp @@ -4,9 +4,9 @@ namespace search::queryeval { -FakeRequestContext::FakeRequestContext(attribute::IAttributeContext * context, fastos::SteadyTimeStamp doom_in) +FakeRequestContext::FakeRequestContext(attribute::IAttributeContext * context, fastos::SteadyTimeStamp softDoom, fastos::SteadyTimeStamp hardDoom) : _clock(), - _doom(_clock, doom_in), + _doom(_clock, softDoom, hardDoom, false), _attributeContext(context), _query_tensor_name(), _query_tensor() diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.h b/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.h index 50f61a3eb22..184e0f7faf8 100644 --- a/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.h +++ b/searchlib/src/vespa/searchlib/queryeval/fake_requestcontext.h @@ -8,6 +8,7 @@ #include <vespa/searchcommon/attribute/iattributecontext.h> #include <vespa/searchlib/attribute/attributevector.h> #include <vespa/searchlib/queryeval/irequestcontext.h> +#include <vespa/vespalib/util/doom.h> #include <limits> namespace search::queryeval { @@ -15,9 +16,11 @@ namespace search::queryeval { class FakeRequestContext : public IRequestContext { public: - FakeRequestContext(attribute::IAttributeContext * context = nullptr, fastos::SteadyTimeStamp doom=fastos::SteadyTimeStamp(fastos::TimeStamp::FUTURE)); + FakeRequestContext(attribute::IAttributeContext * context = nullptr, + fastos::SteadyTimeStamp soft=fastos::SteadyTimeStamp::FUTURE, + fastos::SteadyTimeStamp hard=fastos::SteadyTimeStamp::FUTURE); ~FakeRequestContext(); - const vespalib::Doom & getSoftDoom() const override { return _doom; } + const vespalib::Doom & getDoom() const override { return _doom; } const attribute::IAttributeVector *getAttribute(const vespalib::string &name) const override { return _attributeContext ? _attributeContext->getAttribute(name) diff --git a/searchlib/src/vespa/searchlib/queryeval/irequestcontext.h b/searchlib/src/vespa/searchlib/queryeval/irequestcontext.h index 1c197021998..75025b5cbf2 100644 --- a/searchlib/src/vespa/searchlib/queryeval/irequestcontext.h +++ b/searchlib/src/vespa/searchlib/queryeval/irequestcontext.h @@ -2,11 +2,11 @@ #pragma once -#include <vespa/vespalib/util/doom.h> #include <vespa/vespalib/stllike/string.h> namespace search::attribute { class IAttributeVector; } namespace vespalib::eval { class Value; } +namespace vespalib { class Doom; } namespace search::queryeval { @@ -22,7 +22,7 @@ public: * Provides the time of soft doom for the query. Now it is time to start cleaning up and return what you have. * @return time of soft doom. */ - virtual const vespalib::Doom & getSoftDoom() const = 0; + virtual const vespalib::Doom & getDoom() const = 0; /** * Provide access to attributevectors diff --git a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.cpp index 7429553d889..edb26fdb296 100644 --- a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.cpp @@ -11,7 +11,7 @@ namespace search::queryeval { SimplePhraseBlueprint::SimplePhraseBlueprint(const FieldSpec &field, const IRequestContext & requestContext, bool expensive) : ComplexLeafBlueprint(field), - _doom(requestContext.getSoftDoom()), + _doom(requestContext.getDoom()), _field(field), _estimate(), _layout(), diff --git a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.h index c313c0b38ad..a09bc3f6c06 100644 --- a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_blueprint.h @@ -5,6 +5,7 @@ #include "searchable.h" #include "irequestcontext.h" #include <vespa/searchlib/fef/matchdatalayout.h> +#include <vespa/vespalib/util/doom.h> namespace search::fef { class TermFieldMatchData; } diff --git a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.h b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.h index c9327d678e5..d45e67ed4cb 100644 --- a/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/simple_phrase_search.h @@ -7,6 +7,7 @@ #include <vespa/searchlib/fef/matchdata.h> #include <vespa/searchlib/fef/termfieldmatchdataarray.h> #include <vespa/searchlib/fef/termfieldmatchdata.h> +#include <vespa/vespalib/util/doom.h> #include <memory> #include <vector> @@ -29,7 +30,7 @@ class SimplePhraseSearch : public AndSearch std::vector<It> _iterators; void phraseSeek(uint32_t doc_id); - bool doom() const { return ((_doom != nullptr) && _doom->doom()); } + bool doom() const { return ((_doom != nullptr) && _doom->soft_doom()); } public: /** diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp index 6d11ab1f5eb..37903bc21f5 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp @@ -2,12 +2,14 @@ #include "translogserver.h" #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/fnet/frt/supervisor.h> #include <vespa/fnet/frt/rpcrequest.h> #include <vespa/fnet/task.h> #include <vespa/fnet/transport.h> #include <fstream> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".transactionlog.server"); @@ -125,7 +127,7 @@ TransLogServer::TransLogServer(const vespalib::string &name, int listenPort, con listenOk = true; } else { LOG(warning, "Failed listening at port %s trying for %d seconds more.", listenSpec, i); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } } if ( ! listenOk ) { diff --git a/slobrok/src/tests/configure/configure.cpp b/slobrok/src/tests/configure/configure.cpp index bf41b77ab05..fa509c17d0c 100644 --- a/slobrok/src/tests/configure/configure.cpp +++ b/slobrok/src/tests/configure/configure.cpp @@ -85,7 +85,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } SpecList actual(api.lookup(pattern)); std::cerr << "Actual: " << actual.strVal() << std::endl; @@ -176,7 +176,7 @@ Test::Main() srv2Builder.slobrok[0].connectionspec = createSpec(18525); cfgCtx->reload(); - FastOS_Thread::Sleep(6000); // reconfiguration time + std::this_thread::sleep_for(6s); // reconfiguration time reg1.registerName("A"); reg2.registerName("B"); diff --git a/slobrok/src/tests/mirrorapi/mirrorapi.cpp b/slobrok/src/tests/mirrorapi/mirrorapi.cpp index b25e338533c..53e194fad2d 100644 --- a/slobrok/src/tests/mirrorapi/mirrorapi.cpp +++ b/slobrok/src/tests/mirrorapi/mirrorapi.cpp @@ -112,7 +112,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return false; } @@ -124,7 +124,7 @@ Test::Main() TEST_INIT("mirrorapi_test"); SlobrokServer mock(18501); - FastOS_Thread::Sleep(300); + std::this_thread::sleep_for(300ms); Server a("A/x/w", 18502, "tcp/localhost:18501"); Server b("B/x", 18503, "tcp/localhost:18501"); @@ -143,7 +143,7 @@ Test::Main() MirrorAPI mirror(supervisor, config::ConfigUri::createFromInstance(specBuilder)); EXPECT_TRUE(!mirror.ready()); transport.Start(&threadPool); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); a.reg(); EXPECT_TRUE(compare(mirror, "A/x/w", SpecList().add("A/x/w", "tcp/localhost:18502"))); diff --git a/slobrok/src/tests/registerapi/registerapi.cpp b/slobrok/src/tests/registerapi/registerapi.cpp index ac7e662c6f2..92f08ee41cb 100644 --- a/slobrok/src/tests/registerapi/registerapi.cpp +++ b/slobrok/src/tests/registerapi/registerapi.cpp @@ -64,7 +64,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return false; } @@ -75,7 +75,7 @@ Test::Main() TEST_INIT("registerapi_test"); SlobrokServer mock(18548); - FastOS_Thread::Sleep(300); + std::this_thread::sleep_for(300ms); cloud::config::SlobroksConfigBuilder slobrokSpecs; cloud::config::SlobroksConfig::Slobrok sb; @@ -97,7 +97,7 @@ Test::Main() EXPECT_TRUE(compare(mirror, "*/*/*", SpecList().add("A/x/w", myspec.c_str()))); for (int i = 0; i < 30; i++) { - if (reg.busy()) FastOS_Thread::Sleep(100); + if (reg.busy()) std::this_thread::sleep_for(100ms); } EXPECT_TRUE(!reg.busy()); diff --git a/slobrok/src/tests/standalone/standalone.cpp b/slobrok/src/tests/standalone/standalone.cpp index 9d3fd694ee1..65553c57530 100644 --- a/slobrok/src/tests/standalone/standalone.cpp +++ b/slobrok/src/tests/standalone/standalone.cpp @@ -132,7 +132,7 @@ TEST("standalone") { break; } fprintf(stderr, "ping failed [retry %d]\n", retry); - FastOS_Thread::Sleep(200); + std::this_thread::sleep_for(200ms); sb->SubRef(); sb = orb.GetTarget(18541); } @@ -268,7 +268,7 @@ TEST("standalone") { } } - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); // lookup 'B' should give '' req = orb.AllocRPCRequest(req); diff --git a/staging_vespalib/src/tests/clock/clock_benchmark.cpp b/staging_vespalib/src/tests/clock/clock_benchmark.cpp index c5229adea31..a9618d50682 100644 --- a/staging_vespalib/src/tests/clock/clock_benchmark.cpp +++ b/staging_vespalib/src/tests/clock/clock_benchmark.cpp @@ -1,10 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/util/clock.h> +#include <vespa/fastos/thread.h> #include <cassert> #include <vector> #include <atomic> #include <cstring> +#include <condition_variable> +#include <mutex> using vespalib::Clock; using fastos::TimeStamp; @@ -134,7 +137,7 @@ main(int , char *argv[]) TestClock nsClock(nsValue, 1.0/frequency); TestClock nsVolatileClock(nsVolatile, 1.0/frequency); TestClock nsAtomicClock(nsAtomic, 1.0/frequency); - assert(pool.NewThread(&clock, nullptr) != nullptr); + assert(pool.NewThread(clock.getRunnable(), nullptr) != nullptr); assert(pool.NewThread(&nsClock, nullptr) != nullptr); assert(pool.NewThread(&nsVolatileClock, nullptr) != nullptr); assert(pool.NewThread(&nsAtomicClock, nullptr) != nullptr); diff --git a/staging_vespalib/src/tests/clock/clock_test.cpp b/staging_vespalib/src/tests/clock/clock_test.cpp index c8c93272e85..b5650244a45 100644 --- a/staging_vespalib/src/tests/clock/clock_test.cpp +++ b/staging_vespalib/src/tests/clock/clock_test.cpp @@ -2,35 +2,27 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/clock.h> +#include <vespa/vespalib/util/time.h> +#include <vespa/fastos/thread.h> using vespalib::Clock; using fastos::TimeStamp; -class Test : public vespalib::TestApp -{ -public: - int Main() override; -}; - -int -Test::Main() -{ - TEST_INIT("clock_test"); +TEST("Test that clock is ticking forward") { Clock clock(0.050); FastOS_ThreadPool pool(0x10000); - ASSERT_TRUE(pool.NewThread(&clock, nullptr) != nullptr); + ASSERT_TRUE(pool.NewThread(clock.getRunnable(), nullptr) != nullptr); fastos::SteadyTimeStamp start = clock.getTimeNS(); - FastOS_Thread::Sleep(5000); + std::this_thread::sleep_for(5s); fastos::SteadyTimeStamp stop = clock.getTimeNS(); EXPECT_TRUE(stop > start); - FastOS_Thread::Sleep(6000); + std::this_thread::sleep_for(6s); clock.stop(); fastos::SteadyTimeStamp stop2 = clock.getTimeNS(); EXPECT_TRUE(stop2 > stop); EXPECT_TRUE((stop2 - stop)/TimeStamp::MICRO > 1000); - TEST_DONE(); } -TEST_APPHOOK(Test) +TEST_MAIN() { TEST_RUN_ALL(); }
\ No newline at end of file diff --git a/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp b/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp index e6f7bd21750..fbaa5581173 100644 --- a/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp +++ b/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp @@ -13,20 +13,20 @@ Test::Main() { TEST_INIT("shutdownguard_test"); { - ShutdownGuard farFuture(123456789); - FastOS_Thread::Sleep(20); + ShutdownGuard farFuture(1000000s); + std::this_thread::sleep_for(20ms); } EXPECT_TRUE(true); pid_t child = fork(); if (child == 0) { - ShutdownGuard soon(30); + ShutdownGuard soon(30ms); for (int i = 0; i < 1000; ++i) { - FastOS_Thread::Sleep(20); + std::this_thread::sleep_for(20ms); } exit(0); } for (int i = 0; i < 1000; ++i) { - FastOS_Thread::Sleep(20); + std::this_thread::sleep_for(20ms); int stat = 0; if (waitpid(child, &stat, WNOHANG) == child) { EXPECT_TRUE(WIFEXITED(stat)); diff --git a/staging_vespalib/src/tests/timer/timer_test.cpp b/staging_vespalib/src/tests/timer/timer_test.cpp index 309ee873b44..5472ad6e23f 100644 --- a/staging_vespalib/src/tests/timer/timer_test.cpp +++ b/staging_vespalib/src/tests/timer/timer_test.cpp @@ -1,8 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/testkit/testapp.h> -#include <vespa/vespalib/util/timer.h> -#include <vespa/vespalib/util/executor.h> +#include <vespa/vespalib/util/scheduledexecutor.h> using namespace vespalib; using vespalib::Executor; @@ -37,7 +36,7 @@ void Test::testScheduling() { vespalib::CountDownLatch latch1(3); vespalib::CountDownLatch latch2(2); - Timer timer; + ScheduledExecutor timer; timer.scheduleAtFixedRate(Task::UP(new TestTask(latch1)), 0.1, 0.2); timer.scheduleAtFixedRate(Task::UP(new TestTask(latch2)), 0.5, 0.5); EXPECT_TRUE(latch1.await(60000)); @@ -47,7 +46,7 @@ void Test::testScheduling() void Test::testReset() { vespalib::CountDownLatch latch1(2); - Timer timer; + ScheduledExecutor timer; timer.scheduleAtFixedRate(Task::UP(new TestTask(latch1)), 2.0, 3.0); timer.reset(); EXPECT_TRUE(!latch1.await(3000)); diff --git a/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt b/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt index 20d47c90453..71364a813f6 100644 --- a/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt +++ b/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt @@ -16,7 +16,7 @@ vespa_add_library(staging_vespalib_vespalib_util OBJECT document_runnable.cpp rusage.cpp shutdownguard.cpp - timer.cpp + scheduledexecutor.cpp xmlserializable.cpp xmlstream.cpp DEPENDS diff --git a/staging_vespalib/src/vespa/vespalib/util/clock.cpp b/staging_vespalib/src/vespa/vespalib/util/clock.cpp index e935a80bd6b..cd2a13029ab 100644 --- a/staging_vespalib/src/vespa/vespalib/util/clock.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/clock.cpp @@ -1,17 +1,70 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "clock.h" +#include <vespa/fastos/thread.h> +#include <mutex> +#include <condition_variable> #include <cassert> #include <chrono> namespace vespalib { +namespace clock::internal { + +class Updater : public FastOS_Runnable +{ +private: + Clock & _clock; + int _timePeriodMS; + std::mutex _lock; + std::condition_variable _cond; + bool _stop; + + + void Run(FastOS_ThreadInterface *thisThread, void *arguments) override; + +public: + Updater(Clock & clock, double timePeriod=0.100); + ~Updater(); + + void stop(); +}; + +Updater::Updater(Clock & clock, double timePeriod) + : _clock(clock), + _timePeriodMS(static_cast<uint32_t>(timePeriod*1000)), + _lock(), + _cond(), + _stop(false) +{ } + +Updater::~Updater() = default; + +void +Updater::Run(FastOS_ThreadInterface *thread, void *) +{ + _clock._running = true; + std::unique_lock<std::mutex> guard(_lock); + while ( ! thread->GetBreakFlag() && !_stop) { + _clock.setTime(); + _cond.wait_for(guard, std::chrono::milliseconds(_timePeriodMS)); + } + _clock._running = false; +} + +void +Updater::stop() +{ + std::lock_guard<std::mutex> guard(_lock); + _stop = true; + _cond.notify_all(); +} + +} + Clock::Clock(double timePeriod) : _timeNS(0u), - _timePeriodMS(static_cast<uint32_t>(timePeriod*1000)), - _lock(), - _cond(), - _stop(false), + _updater(std::make_unique<clock::internal::Updater>(*this, timePeriod)), _running(false) { setTime(); @@ -19,6 +72,9 @@ Clock::Clock(double timePeriod) : Clock::~Clock() { + if (_running) { + _updater->GetThread()->Join(); + } assert(!_running); } @@ -27,24 +83,15 @@ void Clock::setTime() const _timeNS.store(fastos::ClockSteady::now() - fastos::SteadyTimeStamp::ZERO, std::memory_order_relaxed); } -void Clock::Run(FastOS_ThreadInterface *thread, void *arguments) -{ - (void) arguments; - _running = true; - std::unique_lock<std::mutex> guard(_lock); - while ( ! thread->GetBreakFlag() && !_stop) { - setTime(); - _cond.wait_for(guard, std::chrono::milliseconds(_timePeriodMS)); - } - _running = false; -} - void Clock::stop() { - std::lock_guard<std::mutex> guard(_lock); - _stop = true; - _cond.notify_all(); + _updater->stop(); +} + +FastOS_Runnable * +Clock::getRunnable() { + return _updater.get(); } } diff --git a/staging_vespalib/src/vespa/vespalib/util/clock.h b/staging_vespalib/src/vespa/vespalib/util/clock.h index fc8e9eaff23..e9e1ebace3f 100644 --- a/staging_vespalib/src/vespa/vespalib/util/clock.h +++ b/staging_vespalib/src/vespa/vespalib/util/clock.h @@ -1,36 +1,32 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once -#include <vespa/fastos/thread.h> #include <vespa/fastos/timestamp.h> -#include <mutex> -#include <condition_variable> +#include <atomic> +#include <memory> + +class FastOS_Runnable; namespace vespalib { +namespace clock::internal { class Updater; } + /** * Clock is a clock that updates the time at defined intervals. * It is intended used where you want to check the time with low cost, but where * resolution is not that important. */ -class Clock : public FastOS_Runnable +class Clock { private: - Clock(const Clock &); - Clock & operator = (const Clock &); - - mutable std::atomic<int64_t> _timeNS; - int _timePeriodMS; - std::mutex _lock; - std::condition_variable _cond; - bool _stop; - bool _running; + mutable std::atomic<int64_t> _timeNS; + std::unique_ptr<clock::internal::Updater> _updater; + std::atomic<bool> _running; void setTime() const; - - void Run(FastOS_ThreadInterface *thisThread, void *arguments) override; - + void start(); + friend clock::internal::Updater; public: Clock(double timePeriod=0.100); ~Clock(); @@ -41,9 +37,12 @@ public: } return getTimeNSAssumeRunning(); } - fastos::SteadyTimeStamp getTimeNSAssumeRunning() const { return fastos::SteadyTimeStamp(_timeNS.load(std::memory_order_relaxed)); } + fastos::SteadyTimeStamp getTimeNSAssumeRunning() const { + return fastos::SteadyTimeStamp(_timeNS.load(std::memory_order_relaxed)); + } void stop(); + FastOS_Runnable * getRunnable(); }; } diff --git a/staging_vespalib/src/vespa/vespalib/util/doom.cpp b/staging_vespalib/src/vespa/vespalib/util/doom.cpp index df20981c584..87b24799721 100644 --- a/staging_vespalib/src/vespa/vespalib/util/doom.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/doom.cpp @@ -4,10 +4,12 @@ namespace vespalib { -Doom::Doom(const vespalib::Clock &clock, fastos::SteadyTimeStamp timeOfDoom) : - _clock(clock), - _timeOfDoom(timeOfDoom) -{ -} - -} // namespace vespalib
\ No newline at end of file +Doom::Doom(const vespalib::Clock &clock, fastos::SteadyTimeStamp softDoom, + fastos::SteadyTimeStamp hardDoom, bool explicitSoftDoom) + : _clock(clock), + _softDoom(softDoom), + _hardDoom(hardDoom), + _isExplicitSoftDoom(explicitSoftDoom) +{ } + +}
\ No newline at end of file diff --git a/staging_vespalib/src/vespa/vespalib/util/doom.h b/staging_vespalib/src/vespa/vespalib/util/doom.h index ee0c1af3177..d85c3dc9084 100644 --- a/staging_vespalib/src/vespa/vespalib/util/doom.h +++ b/staging_vespalib/src/vespa/vespalib/util/doom.h @@ -6,19 +6,24 @@ namespace vespalib { -class Doom -{ +class Doom { +public: + Doom(const vespalib::Clock &clock, fastos::SteadyTimeStamp doom) + : Doom(clock, doom, doom, false) + {} + Doom(const vespalib::Clock &clock, fastos::SteadyTimeStamp softDoom, + fastos::SteadyTimeStamp hardDoom, bool explicitSoftDoom); + + bool soft_doom() const { return (_clock.getTimeNSAssumeRunning() > _softDoom); } + bool hard_doom() const { return (_clock.getTimeNSAssumeRunning() > _hardDoom); } + fastos::TimeStamp soft_left() const { return _softDoom - _clock.getTimeNS(); } + fastos::TimeStamp hard_left() const { return _hardDoom - _clock.getTimeNS(); } + bool isExplicitSoftDoom() const { return _isExplicitSoftDoom; } private: const vespalib::Clock &_clock; - fastos::SteadyTimeStamp _timeOfDoom; - -public: - Doom(const vespalib::Clock &clock, fastos::SteadyTimeStamp timeOfDoom); - bool doom() const { - return (_clock.getTimeNSAssumeRunning() > _timeOfDoom); - } - fastos::TimeStamp left() const { return _timeOfDoom - _clock.getTimeNS(); } + fastos::SteadyTimeStamp _softDoom; + fastos::SteadyTimeStamp _hardDoom; + bool _isExplicitSoftDoom; }; -} // namespace vespalib - +} diff --git a/staging_vespalib/src/vespa/vespalib/util/timer.cpp b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp index a7acbe67965..61f9666114c 100644 --- a/staging_vespalib/src/vespa/vespalib/util/timer.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "timer.h" +#include "scheduledexecutor.h" #include <vespa/fnet/scheduler.h> #include <vespa/fnet/task.h> #include <vespa/fnet/transport.h> @@ -34,7 +34,7 @@ public: } }; -Timer::Timer() +ScheduledExecutor::ScheduledExecutor() : _threadPool(128 * 1024), _transport(new FNET_Transport()), _lock(), @@ -43,7 +43,7 @@ Timer::Timer() _transport->Start(&_threadPool); } -Timer::~Timer() +ScheduledExecutor::~ScheduledExecutor() { vespalib::LockGuard guard(_lock); _transport->ShutDown(true); @@ -53,7 +53,7 @@ Timer::~Timer() void -Timer::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, double interval) +ScheduledExecutor::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, double interval) { vespalib::LockGuard guard(_lock); TimerTaskPtr tTask(new TimerTask(_transport->GetScheduler(), std::move(task), interval)); @@ -62,7 +62,7 @@ Timer::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, doub } void -Timer::reset() +ScheduledExecutor::reset() { vespalib::LockGuard guard(_lock); _transport->ShutDown(true); diff --git a/staging_vespalib/src/vespa/vespalib/util/timer.h b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h index 0f7cde67ee4..d7e56494828 100644 --- a/staging_vespalib/src/vespa/vespalib/util/timer.h +++ b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h @@ -13,11 +13,11 @@ namespace vespalib { class TimerTask; /** - * Timer is a class capable of running Tasks at a regular + * ScheduledExecutor is a class capable of running Tasks at a regular * interval. The timer can be reset to clear all tasks currently being * scheduled. */ -class Timer +class ScheduledExecutor { private: typedef std::unique_ptr<TimerTask> TimerTaskPtr; @@ -31,13 +31,13 @@ public: /** * Create a new timer, capable of scheduling tasks at fixed intervals. */ - Timer(); + ScheduledExecutor(); /** * Destroys this timer, finishing the current task executing and then * finishing. */ - ~Timer(); + ~ScheduledExecutor(); /** * Schedule new task to be executed at specified intervals. diff --git a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp index 645ffea380d..99857107860 100644 --- a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "shutdownguard.h" #include <unistd.h> -#include <sys/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".vespalib.shutdownguard"); @@ -10,36 +10,30 @@ namespace vespalib { namespace { enum { STACK_SIZE = (1u << 16) }; - -uint64_t getTimeInMillis() { - struct timeval mytime; - gettimeofday(&mytime, 0); - uint64_t mult = 1000; - return (mytime.tv_sec * mult) + (mytime.tv_usec / mult); -} } void ShutdownGuard::Run(FastOS_ThreadInterface *, void *) { - while (_dieAtTime > getTimeInMillis()) { - FastOS_Thread::Sleep(5); + while (_dieAtTime > steady_clock::now() && ! GetThread()->GetBreakFlag()) { + std::this_thread::sleep_for(5ms); } - if (_dieAtTime != 0) { + if (_dieAtTime <= steady_clock::now()) { LOG(warning, "ShutdownGuard is now forcing an exit of the process."); _exit(EXIT_FAILURE); } } -ShutdownGuard::ShutdownGuard(uint64_t millis) : +ShutdownGuard::ShutdownGuard(duration millis) : FastOS_Runnable(), _pool(STACK_SIZE, 1), - _dieAtTime(getTimeInMillis() + millis) + _dieAtTime(steady_clock::now() + millis) { _pool.NewThread(this); } ShutdownGuard::~ShutdownGuard() { - _dieAtTime = 0; + GetThread()->SetBreakFlag(); + GetThread()->Join(); _pool.Close(); } diff --git a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h index 5a9aad5d4d4..9de9df8bbad 100644 --- a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h +++ b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h @@ -1,8 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <vespa/vespalib/util/time.h> #include <vespa/fastos/thread.h> -#include <cstdint> namespace vespalib { @@ -16,7 +16,7 @@ namespace vespalib { class ShutdownGuard : public FastOS_Runnable { FastOS_ThreadPool _pool; - volatile uint64_t _dieAtTime; + steady_time _dieAtTime; void Run(FastOS_ThreadInterface *, void *) override; @@ -25,7 +25,7 @@ public: * Construct a shutdown guard with a given lifetime. * @arg millis the number of milliseconds before process automatically exits **/ - ShutdownGuard(uint64_t millis); + ShutdownGuard(duration millis); /** * Destructor that dismisses the guard and collects the shutdown thread. diff --git a/storage/src/tests/common/metricstest.cpp b/storage/src/tests/common/metricstest.cpp index d1421845b81..d698cbb5e05 100644 --- a/storage/src/tests/common/metricstest.cpp +++ b/storage/src/tests/common/metricstest.cpp @@ -1,6 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/document/fieldvalue/document.h> #include <vespa/storageapi/message/persistence.h> #include <vespa/storageframework/defaultimplementation/clock/fakeclock.h> #include <vespa/storage/bucketdb/bucketmanager.h> @@ -14,6 +13,7 @@ #include <vespa/config/common/exceptions.h> #include <vespa/vespalib/stllike/hash_map.hpp> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> #include <gmock/gmock.h> #include <thread> @@ -202,7 +202,7 @@ void MetricsTest::createFakeLoad() while (uint64_t(_metricManager->getLastProcessedTime()) < _clock->getTimeInSeconds().getTime()) { - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); _metricManager->timeChangedNotification(); } } @@ -257,7 +257,7 @@ TEST_F(MetricsTest, snapshot_presenting) { uint64_t(_metricManager->getLastProcessedTime()) < _clock->getTimeInSeconds().getTime()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } LOG(debug, "5 minute snapshot should have been taken. Adding put count"); diff --git a/storage/src/tests/common/teststorageapp.cpp b/storage/src/tests/common/teststorageapp.cpp index dd89082d3e7..082af954871 100644 --- a/storage/src/tests/common/teststorageapp.cpp +++ b/storage/src/tests/common/teststorageapp.cpp @@ -7,10 +7,11 @@ #include <vespa/config-load-type.h> #include <vespa/config-fleetcontroller.h> #include <vespa/persistence/dummyimpl/dummypersistence.h> -#include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/time.h> #include <vespa/config/config.h> #include <vespa/config/helper/configgetter.hpp> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".test.servicelayerapp"); @@ -111,7 +112,7 @@ TestStorageApp::waitUntilInitialized( framework::MilliSecTime endTime( clock.getTimeInMillis() + timeout.getMillis()); while (!isInitialized()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); framework::MilliSecTime currentTime(clock.getTimeInMillis()); if (currentTime > endTime) { std::ostringstream error; diff --git a/storage/src/tests/distributor/distributortest.cpp b/storage/src/tests/distributor/distributortest.cpp index 8fa8a6bcede..d456401876e 100644 --- a/storage/src/tests/distributor/distributortest.cpp +++ b/storage/src/tests/distributor/distributortest.cpp @@ -11,9 +11,10 @@ #include <vespa/document/test/make_document_bucket.h> #include <vespa/document/test/make_bucket_space.h> #include <vespa/storage/config/config-stor-distributormanager.h> -#include <tests/common/dummystoragelink.h> #include <vespa/storage/distributor/distributor.h> #include <vespa/vespalib/text/stringtokenizer.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/vespalib/gtest/gtest.h> #include <gmock/gmock.h> @@ -383,7 +384,7 @@ TEST_F(DistributorTest, tick_processes_status_requests) { thread, "statustest", tickWaitMs, tickMaxProcessTime, ticksBeforeWait)); while (true) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); framework::TickingLockGuard guard( distributor_thread_pool().freezeCriticalTicks()); if (!distributor_status_todos().empty()) { diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp index 44cb92071a1..f907d0496e6 100644 --- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp +++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp @@ -20,8 +20,10 @@ #include <vespa/persistence/spi/test.h> #include <vespa/config/common/exceptions.h> #include <vespa/fastos/file.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/gtest/gtest.h> #include <atomic> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".filestormanagertest"); @@ -556,7 +558,7 @@ public: auto cmd = std::make_shared<api::PutCommand>(makeDocumentBucket(bucket), _doc, 100); _handler.schedule(cmd, 0); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } _threadDone = true; @@ -589,13 +591,13 @@ public: if (msg.second.get()) { uint32_t originalConfig = _config.load(); _fetchedCount++; - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); if (_config.load() != originalConfig) { _failed = true; } } else { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -634,7 +636,7 @@ TEST_F(FileStorManagerTest, handler_paused_multi_thread) { thread.start(pool); for (uint32_t i = 0; i < 50; ++i) { - FastOS_Thread::Sleep(2); + std::this_thread::sleep_for(2ms); ResumeGuard guard = filestorHandler.pause(); thread._config.fetch_add(1); uint32_t count = thread._fetchedCount; @@ -646,7 +648,7 @@ TEST_F(FileStorManagerTest, handler_paused_multi_thread) { ASSERT_FALSE(thread._failed); while (!pushthread._threadDone || !thread._threadDone) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -869,7 +871,7 @@ TEST_F(FileStorManagerTest, handler_timeout) { filestorHandler.schedule(cmd, 0); } - FastOS_Thread::Sleep(51); + std::this_thread::sleep_for(51ms); for (;;) { auto lock = filestorHandler.getNextMessage(0, stripeId); if (lock.first.get()) { @@ -944,7 +946,7 @@ TEST_F(FileStorManagerTest, priority) { // Wait until everything is done. int count = 0; while (documents.size() != top.getNumReplies() && count < 10000) { - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); count++; } ASSERT_LT(count, 10000); diff --git a/storage/src/tests/persistence/filestorage/operationabortingtest.cpp b/storage/src/tests/persistence/filestorage/operationabortingtest.cpp index 0d43f8a9020..ba344971c3b 100644 --- a/storage/src/tests/persistence/filestorage/operationabortingtest.cpp +++ b/storage/src/tests/persistence/filestorage/operationabortingtest.cpp @@ -9,6 +9,8 @@ #include <vespa/vespalib/util/thread.h> #include <vespa/vespalib/stllike/hash_set_insert.hpp> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".operationabortingtest"); @@ -53,7 +55,7 @@ public: (void) context; _queueBarrier.await(); // message abort stage with active opertion in disk queue - FastOS_Thread::Sleep(75); + std::this_thread::sleep_for(75ms); _completionBarrier.await(); // test finished return spi::Result(); diff --git a/storage/src/tests/storageserver/bucketintegritycheckertest.cpp b/storage/src/tests/storageserver/bucketintegritycheckertest.cpp index ae466f04734..8a68adf226c 100644 --- a/storage/src/tests/storageserver/bucketintegritycheckertest.cpp +++ b/storage/src/tests/storageserver/bucketintegritycheckertest.cpp @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/storage/bucketdb/bucketmanager.h> -#include <vespa/storage/persistence/filestorage/filestormanager.h> #include <vespa/storage/storageserver/bucketintegritychecker.h> #include <vespa/storageapi/message/persistence.h> #include <tests/common/testhelper.h> @@ -9,6 +8,8 @@ #include <vespa/vespalib/io/fileutil.h> #include <tests/common/teststorageapp.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> using namespace ::testing; @@ -175,13 +176,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { checker.getSchedulingOptions()._minCycleTime = framework::SecondTime(60 * 60); topLink.open(); // Waiting for system to be initialized - FastOS_Thread::Sleep(10); // Give next message chance to come + std::this_thread::sleep_for(10ms); // Give next message chance to come ASSERT_COMMAND_COUNT(0, *dummyLink); topLink.doneInit(); checker.bump(); // Should have started new run with 2 pending per disk dummyLink->waitForMessages(4, _timeout); - FastOS_Thread::Sleep(10); // Give 5th message chance to come + std::this_thread::sleep_for(10ms); // Give 5th message chance to come ASSERT_COMMAND_COUNT(4, *dummyLink); auto* cmd1 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(0).get()); EXPECT_EQ(230, cmd1->getPriority()); @@ -200,13 +201,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { // Answering a message on disk with no more buckets does not trigger new auto reply1 = std::make_shared<RepairBucketReply>(*cmd3); ASSERT_TRUE(checker.onUp(reply1)); - FastOS_Thread::Sleep(10); // Give next message chance to come + std::this_thread::sleep_for(10ms); // Give next message chance to come ASSERT_COMMAND_COUNT(4, *dummyLink); // Answering a message on disk with more buckets trigger new repair auto reply2 = std::make_shared<RepairBucketReply>(*cmd2); ASSERT_TRUE(checker.onUp(reply2)); dummyLink->waitForMessages(5, _timeout); - FastOS_Thread::Sleep(10); // Give 6th message chance to come + std::this_thread::sleep_for(10ms); // Give 6th message chance to come ASSERT_COMMAND_COUNT(5, *dummyLink); auto* cmd5 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(4).get()); ASSERT_TRUE(cmd5); @@ -217,7 +218,7 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { reply3->setResult(api::ReturnCode(api::ReturnCode::IGNORED)); ASSERT_TRUE(checker.onUp(reply3)); dummyLink->waitForMessages(6, _timeout); - FastOS_Thread::Sleep(10); // Give 7th message chance to come + std::this_thread::sleep_for(10ms); // Give 7th message chance to come ASSERT_COMMAND_COUNT(6, *dummyLink); auto* cmd6 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(5).get()); ASSERT_TRUE(cmd6); @@ -227,7 +228,7 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { auto reply4 = std::make_shared<RepairBucketReply>(*cmd4); reply3->setResult(api::ReturnCode(api::ReturnCode::BUCKET_NOT_FOUND)); ASSERT_TRUE(checker.onUp(reply4)); - FastOS_Thread::Sleep(10); // Give 7th message chance to come + std::this_thread::sleep_for(10ms); // Give 7th message chance to come ASSERT_COMMAND_COUNT(6, *dummyLink); // Send a repair reply that actually have corrected the bucket. @@ -247,13 +248,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { EXPECT_EQ(document::BucketId(16, 0x234), cmd7->getBucketId()); auto reply7 = std::make_shared<RepairBucketReply>(*cmd7); ASSERT_TRUE(checker.onUp(reply7)); - FastOS_Thread::Sleep(10); // Give 8th message chance to come + std::this_thread::sleep_for(10ms); // Give 8th message chance to come ASSERT_COMMAND_COUNT(7, *dummyLink); // Still not time for next iteration dummyLink->reset(); _node->getClock().setAbsoluteTimeInSeconds(getDate("week1 sun 00:59:59")); - FastOS_Thread::Sleep(10); // Give new run chance to start + std::this_thread::sleep_for(10ms); // Give new run chance to start ASSERT_COMMAND_COUNT(0, *dummyLink); // Pass time until next cycle should start diff --git a/storage/src/tests/storageserver/communicationmanagertest.cpp b/storage/src/tests/storageserver/communicationmanagertest.cpp index caee6e6ab91..6657a9f1600 100644 --- a/storage/src/tests/storageserver/communicationmanagertest.cpp +++ b/storage/src/tests/storageserver/communicationmanagertest.cpp @@ -15,6 +15,8 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/documentapi/messagebus/messages/removedocumentmessage.h> #include <vespa/documentapi/messagebus/messages/getdocumentreply.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/vespalib/gtest/gtest.h> using document::test::makeDocumentBucket; @@ -65,7 +67,7 @@ TEST_F(CommunicationManagerTest, simple) { distributor.open(); storage.open(); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); // Send a message through from distributor to storage auto cmd = std::make_shared<api::GetCommand>( diff --git a/storage/src/tests/storageserver/statereportertest.cpp b/storage/src/tests/storageserver/statereportertest.cpp index c84f9311c52..dc8094275d1 100644 --- a/storage/src/tests/storageserver/statereportertest.cpp +++ b/storage/src/tests/storageserver/statereportertest.cpp @@ -11,6 +11,8 @@ #include <vespa/config/common/exceptions.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".test.statereporter"); @@ -233,7 +235,7 @@ TEST_F(StateReporterTest, report_metrics) { uint64_t(_metricManager->getLastProcessedTime()) < _clock->getTimeInSeconds().getTime()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } LOG(debug, "5 minute snapshot should have been taken. Adding put count"); diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp index c86e1671033..651686a7c6d 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp +++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp @@ -6,9 +6,11 @@ #include <vespa/storageapi/message/state.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/host_name.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fnet/frt/supervisor.h> #include <vespa/fnet/transport.h> #include <sstream> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".rpc.listener"); @@ -50,7 +52,7 @@ FNetListener::registerHandle(vespalib::stringref handle) { _slobrokRegister.registerName(handle); while (_slobrokRegister.busy()) { LOG(debug, "Waiting to register in slobrok"); - FastOS_Thread::Sleep(50); + std::this_thread::sleep_for(50ms); } _handle = handle; } diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.h b/storage/src/vespa/storage/storageserver/fnetlistener.h index 205a5af4586..e37727beb44 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.h +++ b/storage/src/vespa/storage/storageserver/fnetlistener.h @@ -5,6 +5,7 @@ #include <atomic> class FNET_Transport; +class FastOS_ThreadPool; namespace storage { diff --git a/storage/src/vespa/storage/storageserver/storagenode.cpp b/storage/src/vespa/storage/storageserver/storagenode.cpp index c5a0a031067..e962ee4b1b6 100644 --- a/storage/src/vespa/storage/storageserver/storagenode.cpp +++ b/storage/src/vespa/storage/storageserver/storagenode.cpp @@ -14,6 +14,7 @@ #include <vespa/storage/common/statusmetricconsumer.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/time.h> #include <vespa/metrics/metricmanager.h> #include <fcntl.h> @@ -568,7 +569,7 @@ StorageNode::waitUntilInitialized(uint32_t timeout) { lib::NodeState nodeState(*_component->getStateUpdater().getReportedNodeState()); if (nodeState.getState() == lib::State::UP) break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); if (clock.getTimeInMillis() >= endTime) { std::ostringstream ost; ost << "Storage server not initialized after waiting timeout of " diff --git a/storage/src/vespa/storage/tools/storage-cmd.cpp b/storage/src/vespa/storage/tools/storage-cmd.cpp index daaa890873f..8c0fcc83330 100644 --- a/storage/src/vespa/storage/tools/storage-cmd.cpp +++ b/storage/src/vespa/storage/tools/storage-cmd.cpp @@ -3,6 +3,8 @@ #include <vespa/slobrok/sbmirror.h> #include <vespa/fastos/app.h> #include <vespa/vespalib/locale/c.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("vespa-storage-cmd"); @@ -61,7 +63,7 @@ public: slobrok::api::MirrorAPI mirror(supervisor.supervisor(), sbcfg); while (!mirror.ready()) { - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } slobrok::api::MirrorAPI::SpecList list = mirror.lookup(_argv[1]); diff --git a/storageframework/src/tests/thread/tickingthreadtest.cpp b/storageframework/src/tests/thread/tickingthreadtest.cpp index 97ae08eef3d..c42a9c17283 100644 --- a/storageframework/src/tests/thread/tickingthreadtest.cpp +++ b/storageframework/src/tests/thread/tickingthreadtest.cpp @@ -6,6 +6,8 @@ #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/exception.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/time.h> +#include <thread> namespace storage::framework::defaultimplementation { @@ -35,7 +37,7 @@ struct MyApp : public TickingThread { Context& c(_context[index]); if (_doCritOverlapTest) { uint32_t oldTick = _critOverlapCounter; - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); _critOverlap |= (_critOverlapCounter != oldTick); ++_critOverlapCounter; } @@ -109,7 +111,7 @@ TEST(TickingThreadTest, test_ticks_before_wait_basic) // and verify time is in right ballpark. int totalSleepMs = 0; while (app.getTotalNonCritTicks() < 20) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); totalSleepMs++; } EXPECT_GT(totalSleepMs, 10); @@ -134,7 +136,7 @@ TEST(TickingThreadTest, test_ticks_before_wait_live_update) // (if live update is broken it will take more than an hour). int maxAttempts = 120000; // a bit more than 120 secs while (app.getTotalNonCritTicks() < ticksBeforeWaitMs && maxAttempts-->0) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } EXPECT_GT(maxAttempts, 0); @@ -158,7 +160,7 @@ TEST(TickingThreadTest, test_verbose_stopping) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); while (app.getMinCritTick() < 5) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } app._threadPool->stop(); } @@ -171,7 +173,7 @@ TEST(TickingThreadTest, test_stop_on_deletion) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); while (app.getMinCritTick() < 5) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -185,7 +187,7 @@ TEST(TickingThreadTest, test_lock_all_ticks) app1.start(testReg.getThreadPoolImpl()); app2.start(testReg.getThreadPoolImpl()); while (std::min(app1.getMinCritTick(), app2.getMinCritTick()) < 5) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } uint64_t ticks1, ticks2; { @@ -194,12 +196,12 @@ TEST(TickingThreadTest, test_lock_all_ticks) ticks2 = app2.getTotalTicks(); while (app2.getMinCritTick() < 2 * ticks2 / threadCount) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } EXPECT_EQ(ticks1, app1.getTotalTicks()); } while (app1.getMinCritTick() < 2 * ticks1 / threadCount) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -213,7 +215,7 @@ TEST(TickingThreadTest, test_lock_critical_ticks) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); while (!app.hasCritOverlap()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); ++app._critOverlapCounter; ++iterationsBeforeOverlap; } @@ -222,7 +224,7 @@ TEST(TickingThreadTest, test_lock_critical_ticks) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); for (uint64_t i=0; i<iterationsBeforeOverlap * 10; ++i) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); TickingLockGuard guard(app._threadPool->freezeCriticalTicks()); for (int j=0; j<threadCount; ++j) { ++app._context[j]._critTickCount; @@ -318,13 +320,13 @@ TEST(TickingThreadTest, test_broadcast) BroadcastApp app; app.start(testReg.getThreadPoolImpl()); app.doTask("foo"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); app.doTask("bar"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); app.doTask("baz"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); app.doTask("hmm"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } diff --git a/storageserver/src/apps/storaged/storage.cpp b/storageserver/src/apps/storaged/storage.cpp index 0748cc3cb1e..5996951e65f 100644 --- a/storageserver/src/apps/storaged/storage.cpp +++ b/storageserver/src/apps/storaged/storage.cpp @@ -198,7 +198,7 @@ int StorageApp::Main() LOG(debug, "Server was attempted stopped, shutting down"); // Create guard that will forcifully kill storage if destruction takes longer // time than given timeout. - vespalib::ShutdownGuard shutdownGuard(_maxShutdownTime); + vespalib::ShutdownGuard shutdownGuard(std::chrono::milliseconds(_maxShutdownTime)); LOG(debug, "Attempting proper shutdown"); _process.reset(); LOG(debug, "Completed controlled shutdown."); diff --git a/vdslib/src/tests/thread/taskschedulertest.cpp b/vdslib/src/tests/thread/taskschedulertest.cpp index 540de722137..1925625172c 100644 --- a/vdslib/src/tests/thread/taskschedulertest.cpp +++ b/vdslib/src/tests/thread/taskschedulertest.cpp @@ -2,6 +2,8 @@ #include <vespa/vdslib/thread/taskscheduler.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> namespace vdslib { @@ -141,13 +143,13 @@ TEST(TaskSchedulerTest, test_simple) task->registerCallsWithName("", calls); scheduler.addAbsolute(TestTask::UP(task), 50); watch.increment(49); // Not yet time to run - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); // Check that it has not run yet.. EXPECT_EQ(counter, scheduler.getTaskCounter()); watch.increment(10); // Now time is enough for it to run scheduler.waitForTaskCounterOfAtLeast(counter + 1); watch.increment(10); - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); // Check that it has not run yet.. EXPECT_EQ(counter + 1, scheduler.getTaskCounter()); watch.increment(50); diff --git a/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp b/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp index 9ec9249b671..1b8f31f0d03 100644 --- a/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp +++ b/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp @@ -1,19 +1,20 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/defaults.h> -#include <vespa/document/util/stringutil.h> #include <vespa/fnet/frt/frt.h> #include <vespa/slobrok/sbmirror.h> #include <vespa/vdslib/distribution/distribution.h> #include <vespa/vdslib/state/clusterstate.h> #include <vespa/vespalib/util/programoptions.h> #include <vespa/vespaclient/clusterlist/clusterlist.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/text/lowercase.h> #include <vespa/config-stor-distribution.h> #include <vespa/config/helper/configgetter.hpp> #include <vespa/fastos/app.h> #include <sstream> #include <iostream> +#include <thread> #include <sys/time.h> #include <vespa/log/log.h> @@ -282,7 +283,7 @@ struct StateApp : public FastOS_Application { } warnTime *= 4; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } if (!slobrok->ready()) { std::cerr << "Slobrok not ready.\n"; diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 59474021de2..cea58d565c2 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1985,6 +1985,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2075,6 +2076,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2271,6 +2273,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double, double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2437,22 +2440,25 @@ "public static java.util.function.DoubleUnaryOperator atan()", "public static java.util.function.DoubleUnaryOperator ceil()", "public static java.util.function.DoubleUnaryOperator cos()", - "public static java.util.function.DoubleUnaryOperator elu()", "public static java.util.function.DoubleUnaryOperator exp()", "public static java.util.function.DoubleUnaryOperator floor()", "public static java.util.function.DoubleUnaryOperator log()", "public static java.util.function.DoubleUnaryOperator neg()", "public static java.util.function.DoubleUnaryOperator reciprocal()", - "public static java.util.function.DoubleUnaryOperator relu()", "public static java.util.function.DoubleUnaryOperator rsqrt()", - "public static java.util.function.DoubleUnaryOperator selu()", - "public static java.util.function.DoubleUnaryOperator leakyrelu()", "public static java.util.function.DoubleUnaryOperator sin()", "public static java.util.function.DoubleUnaryOperator sigmoid()", "public static java.util.function.DoubleUnaryOperator sqrt()", "public static java.util.function.DoubleUnaryOperator square()", "public static java.util.function.DoubleUnaryOperator tan()", "public static java.util.function.DoubleUnaryOperator tanh()", + "public static java.util.function.DoubleUnaryOperator elu()", + "public static java.util.function.DoubleUnaryOperator elu(double)", + "public static java.util.function.DoubleUnaryOperator leakyrelu()", + "public static java.util.function.DoubleUnaryOperator leakyrelu(double)", + "public static java.util.function.DoubleUnaryOperator relu()", + "public static java.util.function.DoubleUnaryOperator selu()", + "public static java.util.function.DoubleUnaryOperator selu(double, double)", "public static java.util.function.Function random()", "public static java.util.function.Function equal(java.util.List)", "public static java.util.function.Function sum(java.util.List)" @@ -2495,6 +2501,21 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ToStringContext$EmptyStringContext": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.tensor.functions.ToStringContext" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()", + "public java.lang.String getBinding(java.lang.String)", + "public com.yahoo.tensor.functions.ToStringContext parent()" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ToStringContext": { "superClass": "java.lang.Object", "interfaces": [], @@ -2504,7 +2525,9 @@ "abstract" ], "methods": [ - "public static com.yahoo.tensor.functions.ToStringContext empty()" + "public static com.yahoo.tensor.functions.ToStringContext empty()", + "public abstract java.lang.String getBinding(java.lang.String)", + "public abstract com.yahoo.tensor.functions.ToStringContext parent()" ], "fields": [] }, @@ -2526,7 +2549,8 @@ "public java.util.Optional dimension()", "public java.util.Optional label()", "public java.util.Optional index()", - "public java.lang.String toString()" + "public java.lang.String toString()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" ], "fields": [] }, @@ -2544,7 +2568,6 @@ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", - "public java.lang.String toString()", "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)" ], "fields": [] diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 416940a60eb..0a496cda5d9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -83,7 +83,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens String contentToString(ToStringContext context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; - return "{" + cells.values().iterator().next() + "}"; + return "{" + cells.values().iterator().next().toString(context) + "}"; } StringBuilder b = new StringBuilder("{"); @@ -124,7 +124,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens String contentToString(ToStringContext context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; - return "{" + cells.get(0) + "}"; + return "{" + cells.get(0).toString(context) + "}"; } IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e5095178be7..fa3d70a4ddf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -91,7 +91,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); - GenerateContext generateContext = new GenerateContext(type, context); + GenerateEvaluationContext generateContext = new GenerateEvaluationContext(type, context); for (int i = 0; i < indexes.size(); i++) { indexes.next(); builder.cell(generateContext.apply(indexes), indexes.indexesForReading()); @@ -113,7 +113,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM if (freeGenerator != null) return freeGenerator.toString(); else - return boundGenerator.toString(context); + return boundGenerator.toString(new GenerateToStringContext(context)); } /** @@ -121,19 +121,18 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM * This returns all the current index values as variables and falls back to delivering from the given * evaluation context. */ - private class GenerateContext implements EvaluationContext<NAMETYPE> { + private class GenerateEvaluationContext implements EvaluationContext<NAMETYPE> { private final TensorType type; private final EvaluationContext<NAMETYPE> context; private IndexedTensor.Indexes indexes; - GenerateContext(TensorType type, EvaluationContext<NAMETYPE> context) { + GenerateEvaluationContext(TensorType type, EvaluationContext<NAMETYPE> context) { this.type = type; this.context = context; } - @SuppressWarnings("unchecked") double apply(IndexedTensor.Indexes indexes) { if (freeGenerator != null) { return freeGenerator.apply(indexes.toList()); @@ -173,4 +172,26 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } + /** A context which adds the bindings of the generate dimension names to the given context. */ + private class GenerateToStringContext implements ToStringContext { + + private final ToStringContext context; + + public GenerateToStringContext(ToStringContext context) { + this.context = context; + } + + @Override + public String getBinding(String identifier) { + if (type.dimension(identifier).isPresent()) + return identifier; // dimension names are bound but not substituted in the generate context + else + return context.getBinding(identifier); + } + + @Override + public ToStringContext parent() { return context; } + + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index 07b3658fb58..ec579a90e4f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -16,8 +16,6 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati @Override Double apply(EvaluationContext<NAMETYPE> context); - default String toString(ToStringContext context) { - return toString(); - } + default String toString(ToStringContext context) { return toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index e8e329cd75c..d9204e24d68 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -38,16 +38,12 @@ public class ScalarFunctions { public static DoubleUnaryOperator atan() { return new Atan(); } public static DoubleUnaryOperator ceil() { return new Ceil(); } public static DoubleUnaryOperator cos() { return new Cos(); } - public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } public static DoubleUnaryOperator log() { return new Log(); } public static DoubleUnaryOperator neg() { return new Neg(); } public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } - public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } - public static DoubleUnaryOperator selu() { return new Selu(); } - public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } @@ -55,6 +51,14 @@ public class ScalarFunctions { public static DoubleUnaryOperator tan() { return new Tan(); } public static DoubleUnaryOperator tanh() { return new Tanh(); } + public static DoubleUnaryOperator elu() { return new Elu(); } + public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); } + public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } + public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); } + public static DoubleUnaryOperator relu() { return new Relu(); } + public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); } + public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } @@ -191,10 +195,17 @@ public class ScalarFunctions { } public static class Elu implements DoubleUnaryOperator { + private final double alpha; + public Elu() { + this(1.0); + } + public Elu(double alpha) { + this.alpha = alpha; + } @Override - public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } + public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; } @Override - public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; } + public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; } } public static class Exp implements DoubleUnaryOperator { @@ -241,8 +252,15 @@ public class ScalarFunctions { public static class Selu implements DoubleUnaryOperator { // See https://arxiv.org/abs/1706.02515 - private static final double scale = 1.0507009873554804934193349852946; - private static final double alpha = 1.6732632423543772848170429916717; + private final double scale; // 1.0507009873554804934193349852946; + private final double alpha; // 1.6732632423543772848170429916717; + public Selu() { + this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717); + } + public Selu(double scale, double alpha) { + this.scale = scale; + this.alpha = alpha; + } @Override public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); } @Override @@ -250,10 +268,17 @@ public class ScalarFunctions { } public static class LeakyRelu implements DoubleUnaryOperator { + private final double alpha; + public LeakyRelu() { + this(0.01); + } + public LeakyRelu(double alpha) { + this.alpha = alpha; + } @Override - public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); } + public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); } @Override - public String toString() { return "f(a)(max(0.01*a, a))"; } + public String toString() { return "f(a)(max(" + alpha + " * a, a))"; } } public static class Sin implements DoubleUnaryOperator { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java index cb7f376c365..634ba4fe6ab 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -3,13 +3,30 @@ package com.yahoo.tensor.functions; /** * A context which is passed down to all nested functions when returning a string representation. - * The default implementation is empty as this library does not in itself have any need for a - * context. * * @author bratseth */ public interface ToStringContext { - static ToStringContext empty() { return new ToStringContext() {}; } + static ToStringContext empty() { return new EmptyStringContext(); } + + /** Returns the name an identifier is bound to, or null if not bound in this context */ + String getBinding(String name); + + /** + * Returns the parent context of this (the context we're in scope of when this is created), + * or null if this is the root. + */ + ToStringContext parent(); + + class EmptyStringContext implements ToStringContext { + + @Override + public String getBinding(String name) { return null; } + + @Override + public ToStringContext parent() { return null; } + + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java index cb14711c0dd..37a54807673 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java @@ -47,7 +47,7 @@ public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); - return new Value<NAMETYPE>(arguments.get(0), cellAddress); + return new Value<>(arguments.get(0), cellAddress); } @Override @@ -78,20 +78,17 @@ public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY @Override public String toString(ToStringContext context) { - return toString(); - } - - @Override - public String toString() { + StringBuilder b = new StringBuilder(argument.toString(context)); if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) { if (cellAddress.get(0).index().isPresent()) - return "[" + cellAddress.get(0).index().get() + "]"; + b.append("[").append(cellAddress.get(0).index().get().toString(context)).append("]"); else - return "{" + cellAddress.get(0).label() + "}"; + b.append("{").append(cellAddress.get(0).label().get()).append("}"); } else { - return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}"; + b.append("{").append(cellAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); } + return b.toString(); } public static class DimensionValue<NAMETYPE extends Name> { @@ -109,11 +106,11 @@ public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY } public DimensionValue(String dimension, int index) { - this(Optional.of(dimension), null, new ConstantScalarFunction<>(index)); + this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); } public DimensionValue(int index) { - this(Optional.empty(), null, new ConstantScalarFunction<>(index)); + this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); } public DimensionValue(String label) { @@ -156,30 +153,37 @@ public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY @Override public String toString() { + return toString(null); + } + + public String toString(ToStringContext context) { StringBuilder b = new StringBuilder(); dimension.ifPresent(d -> b.append(d).append(":")); if (label != null) b.append(label); else - b.append(index); + b.append(index.toString(context)); return b.toString(); } } - private static class ConstantScalarFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { + private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { - private final Double value; + private final int value; - public ConstantScalarFunction(int value) { - this.value = (double)value; + public ConstantIntegerFunction(int value) { + this.value = value; } @Override public Double apply(EvaluationContext<NAMETYPE> context) { - return value; + return (double)value; } + @Override + public String toString() { return String.valueOf(value); } + } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java index 7127abde016..227fbffbaa8 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java @@ -63,4 +63,13 @@ public class ValueTestCase { } } + @Test + public void testToString() { + Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); + assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]", + new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>(2))) + .toString()); + } + } diff --git a/vespalib/src/tests/delegatelist/delegatelist.cpp b/vespalib/src/tests/delegatelist/delegatelist.cpp index ba1a2049794..070864dd85a 100644 --- a/vespalib/src/tests/delegatelist/delegatelist.cpp +++ b/vespalib/src/tests/delegatelist/delegatelist.cpp @@ -780,11 +780,11 @@ Test::testWaitSnapshots() ASSERT_TRUE(pool.NewThread(&a1, 0) != 0); s1.reset(new DL::Snapshot(dl)); // create snap 1 a1.doIt(cmd_wait_snap(&dl)); // wait for snaps - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); EXPECT_TRUE(a1.getState() == Actor::STATE_BUSY); // still waiting... s2.reset(new DL::Snapshot(dl)); // create snap 2 s1.reset(); // destroy snap 1 - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); EXPECT_TRUE(a1.getState() == Actor::STATE_IDLE); // wait done! a1.doIt(cmd_exit()); a1.waitState(Actor::STATE_DONE); diff --git a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp index c43d0ec1c29..7567e8426ae 100644 --- a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp +++ b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp @@ -5,7 +5,12 @@ #include <vespa/vespalib/util/inline.h> #include <vespa/fastos/timestamp.h> -using namespace vespalib; +using vespalib::RightArrayHeap; +using vespalib::RightHeap; +using vespalib::LeftArrayHeap; +using vespalib::LeftHeap; +using vespalib::LeftStdHeap; +using vespalib::make_string; template <typename H> struct IsRight { enum { VALUE = 0 }; }; template <> struct IsRight<RightHeap> { enum { VALUE = 1 }; }; diff --git a/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp b/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp index 7ca4a2eff39..5641d751f34 100644 --- a/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp +++ b/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp @@ -47,7 +47,7 @@ TEST_MT_FF("require that signals can be counted and cancelled", 2, Signal, size_ if (thread_id == 0) { for (size_t i = 0; i < f2; ++i) { f1.send(); - if (i % 128 == 0) { FastOS_Thread::Sleep(1); } + if (i % 128 == 0) { std::this_thread::sleep_for(1ms); } } TEST_BARRIER(); f1.cancel(); diff --git a/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp b/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp index 5b6df6eef4e..d67c417b71a 100644 --- a/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp +++ b/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp @@ -65,7 +65,7 @@ TEST("estimate cost of thread bundle fork/join") { if (time < minTime) { minTime = time; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } fprintf(stderr, "strategy: %s, threads: %zu, fork: %zu, iter: %zu, time: %g, unit: %g\n", strategy_name[strategy].c_str(), threads, fork, iter, minTime, diff --git a/vespalib/src/tests/thread/thread_test.cpp b/vespalib/src/tests/thread/thread_test.cpp index 025a33fa221..bcd38190c7e 100644 --- a/vespalib/src/tests/thread/thread_test.cpp +++ b/vespalib/src/tests/thread/thread_test.cpp @@ -32,7 +32,7 @@ TEST("normal operation") { { Thread thread(agent); thread.start(); - FastOS_Thread::Sleep(20); + std::this_thread::sleep_for(20ms); thread.stop().join(); } EXPECT_TRUE(agent.started); diff --git a/vespalib/src/vespa/vespalib/testkit/test_kit.h b/vespalib/src/vespa/vespalib/testkit/test_kit.h index 7e6b07d71df..17746c5b0fc 100644 --- a/vespalib/src/vespa/vespalib/testkit/test_kit.h +++ b/vespalib/src/vespa/vespalib/testkit/test_kit.h @@ -10,3 +10,4 @@ #include "test_hook.h" #include "test_state_guard.h" #include "time_bomb.h" +#include <vespa/vespalib/util/time.h> diff --git a/vespalib/src/vespa/vespalib/util/thread.cpp b/vespalib/src/vespa/vespalib/util/thread.cpp index 2d0118645ab..4eb436458a2 100644 --- a/vespalib/src/vespa/vespalib/util/thread.cpp +++ b/vespalib/src/vespa/vespalib/util/thread.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "thread.h" +#include <thread> namespace vespalib { @@ -87,7 +88,7 @@ Thread::currentThread() void Thread::sleep(size_t ms) { - FastOS_Thread::Sleep(ms); + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); } } // namespace vespalib diff --git a/vespamalloc/src/tests/allocfree/allocfree.cpp b/vespamalloc/src/tests/allocfree/allocfree.cpp index 80513579a2f..86050d4aee9 100644 --- a/vespamalloc/src/tests/allocfree/allocfree.cpp +++ b/vespamalloc/src/tests/allocfree/allocfree.cpp @@ -89,7 +89,7 @@ int Test::Main() { for (; duration > 0; --duration) { LOG(info, "%d seconds left...", duration); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } pool.Close(); size_t numFreeOperations(0); diff --git a/vespamalloc/src/tests/allocfree/linklist.cpp b/vespamalloc/src/tests/allocfree/linklist.cpp index 11a8d1ddd11..74af380458a 100644 --- a/vespamalloc/src/tests/allocfree/linklist.cpp +++ b/vespamalloc/src/tests/allocfree/linklist.cpp @@ -163,7 +163,7 @@ int Test::Main() { for (; duration > 0; --duration) { LOG(info, "%d seconds left...", duration); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } pool.Close(); fprintf(stderr, "Did (%lu + %lu) = %lu linkIn operations\n", diff --git a/zkfacade/abi-spec.json b/zkfacade/abi-spec.json index efe6fbdaa08..25b652b7312 100644 --- a/zkfacade/abi-spec.json +++ b/zkfacade/abi-spec.json @@ -68,7 +68,6 @@ "methods": [ "public static com.yahoo.vespa.curator.Curator create(java.lang.String)", "public static com.yahoo.vespa.curator.Curator create(java.lang.String, java.util.Optional)", - "public void <init>(com.yahoo.cloud.config.ConfigserverConfig)", "public void <init>(com.yahoo.cloud.config.ConfigserverConfig, com.yahoo.vespa.zookeeper.VespaZooKeeperServer)", "protected void <init>(java.lang.String, java.lang.String, java.util.function.Function)", "public java.lang.String connectionSpec()", diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java index b76bad5b97b..9d74306d3d5 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java @@ -3,9 +3,12 @@ package com.yahoo.vespa.curator; import com.google.inject.Inject; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.io.IOUtils; import com.yahoo.net.HostName; import com.yahoo.path.Path; +import com.yahoo.text.Utf8; import com.yahoo.vespa.curator.recipes.CuratorCounter; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.zookeeper.VespaZooKeeperServer; import org.apache.curator.RetryPolicy; import org.apache.curator.framework.CuratorFramework; @@ -20,7 +23,9 @@ import org.apache.curator.framework.recipes.locks.InterProcessLock; import org.apache.curator.framework.recipes.locks.InterProcessMutex; import org.apache.curator.retry.ExponentialBackoffRetry; import org.apache.zookeeper.KeeperException; +import org.apache.zookeeper.client.ZKClientConfig; import org.apache.zookeeper.data.Stat; +import org.apache.zookeeper.server.quorum.QuorumPeerConfig; import java.io.File; import java.time.Duration; @@ -63,28 +68,21 @@ public class Curator implements AutoCloseable { /** Creates a curator instance from a comma-separated string of ZooKeeper host:port strings */ public static Curator create(String connectionSpec, Optional<File> clientConfigFile) { - return new Curator(connectionSpec, connectionSpec); - } - - // For testing - public Curator(ConfigserverConfig configserverConfig) { - this(configserverConfig, createConnectionSpec(configserverConfig)); + return new Curator(connectionSpec, connectionSpec, clientConfigFile); } // Depend on ZooKeeperServer to make sure it is started first // TODO: Move zookeeperserver config out of configserverconfig (requires update of controller services.xml as well) @Inject public Curator(ConfigserverConfig configserverConfig, VespaZooKeeperServer server) { - this(configserverConfig, createConnectionSpec(configserverConfig)); + this(configserverConfig, Optional.empty()); } - private Curator(ConfigserverConfig configserverConfig, String zooKeeperEnsembleConnectionSpec) { - this((configserverConfig.zookeeperLocalhostAffinity()) ? - createConnectionSpecForLocalhost(configserverConfig) : zooKeeperEnsembleConnectionSpec, - zooKeeperEnsembleConnectionSpec); + Curator(ConfigserverConfig configserverConfig, Optional<File> clientConfigFile) { + this(createConnectionSpec(configserverConfig), createEnsembleConnectionSpec(configserverConfig), clientConfigFile); } - private Curator(String connectionSpec, String zooKeeperEnsembleConnectionSpec) { + private Curator(String connectionSpec, String zooKeeperEnsembleConnectionSpec, Optional<File> clientConfigFile) { this(connectionSpec, zooKeeperEnsembleConnectionSpec, (retryPolicy) -> CuratorFrameworkFactory @@ -93,7 +91,7 @@ public class Curator implements AutoCloseable { .sessionTimeoutMs(ZK_SESSION_TIMEOUT) .connectionTimeoutMs(ZK_CONNECTION_TIMEOUT) .connectString(connectionSpec) - .zookeeperFactory(new VespaZooKeeperFactory()) + .zookeeperFactory(new VespaZooKeeperFactory(createClientConfig(clientConfigFile))) .dontUseContainerParents() // TODO: Remove when we know ZooKeeper 3.5 works fine, consider waiting until Vespa 8 .build()); } @@ -123,7 +121,29 @@ public class Curator implements AutoCloseable { this.zooKeeperEnsembleCount = zooKeeperEnsembleConnectionSpec.split(",").length; } - private static String createConnectionSpec(ConfigserverConfig config) { + private static String createConnectionSpec(ConfigserverConfig configserverConfig) { + return configserverConfig.zookeeperLocalhostAffinity() + ? createConnectionSpecForLocalhost(configserverConfig) + : createEnsembleConnectionSpec(configserverConfig); + } + + private static ZKClientConfig createClientConfig(Optional<File> file) { + boolean useSecureClient = Boolean.parseBoolean(getEnvironmentVariable("VESPA_USE_TLS_FOR_ZOOKEEPER_CLIENT").orElse("false")); + String config = "zookeeper.client.secure=" + useSecureClient + "\n"; + + File clientConfigFile = + file.orElseGet(() -> new File(Defaults.getDefaults().underVespaHome("conf/zookeeper/zookeeper-client.cfg"))); + clientConfigFile.getParentFile().mkdirs(); + IOUtils.writeFile(clientConfigFile, Utf8.toBytes(config)); + + try { + return new ZKClientConfig(clientConfigFile); + } catch (QuorumPeerConfig.ConfigException e) { + throw new RuntimeException("Unable to create ZooKeeper client config file " + file); + } + } + + private static String createEnsembleConnectionSpec(ConfigserverConfig config) { StringBuilder connectionSpec = new StringBuilder(); for (int i = 0; i < config.zookeeperserver().size(); i++) { if (connectionSpec.length() > 0) { @@ -405,4 +425,10 @@ public class Curator implements AutoCloseable { * TODO: Move method out of this class. */ public int zooKeeperEnsembleCount() { return zooKeeperEnsembleCount; } + + private static Optional<String> getEnvironmentVariable(String variableName) { + return Optional.ofNullable(System.getenv().get(variableName)) + .filter(var -> !var.isEmpty()); + } + } diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java index 7c08168c536..84e2cb65a1a 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java @@ -4,19 +4,24 @@ package com.yahoo.vespa.curator; import org.apache.curator.utils.ZookeeperFactory; import org.apache.zookeeper.Watcher; import org.apache.zookeeper.ZooKeeper; +import org.apache.zookeeper.client.ZKClientConfig; /** * A ZooKeeper factory for creating a ZooKeeper client * * @author hmusum */ -// TODO: add constructor that takes feature flag so that we can write ZooKeeper client config and start -// ZooKeeper client with that config class VespaZooKeeperFactory implements ZookeeperFactory { + private final ZKClientConfig zkClientConfig; + + VespaZooKeeperFactory(ZKClientConfig zkClientConfig) { + this.zkClientConfig = zkClientConfig; + } + @Override public ZooKeeper newZooKeeper(String connectString, int sessionTimeout, Watcher watcher, boolean canBeReadOnly) throws Exception { - return new ZooKeeper(connectString, sessionTimeout, watcher); + return new ZooKeeper(connectString, sessionTimeout, watcher, zkClientConfig); } } diff --git a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java index 9b7e0250f2f..6b85953a1ff 100644 --- a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java +++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java @@ -8,7 +8,6 @@ import static org.junit.Assert.assertEquals; /** * @author Ulf Lilleengen - * @date 19.08.13 */ public class CuratorCounterTest { diff --git a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java index a8342dfe5bc..4cd2c708d1a 100644 --- a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java +++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java @@ -9,6 +9,8 @@ import org.junit.Before; import org.junit.Test; import java.io.IOException; +import java.nio.file.Files; +import java.util.Optional; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; @@ -52,7 +54,7 @@ public class CuratorTest { } @Test - public void require_curator_is_created_from_config() { + public void require_curator_is_created_from_config() throws IOException { try (Curator curator = createCurator(createTestConfig())) { assertThat(curator.zooKeeperEnsembleConnectionSpec(), is(spec1 + "," + spec2)); assertThat(curator.zooKeeperEnsembleCount(), is(2)); @@ -60,7 +62,7 @@ public class CuratorTest { } @Test - public void require_that_server_count_is_correct() { + public void require_that_server_count_is_correct() throws IOException { ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); builder.zookeeperserver(createZKBuilder(localhost, port1)); try (Curator curator = createCurator(new ConfigserverConfig(builder))) { @@ -98,8 +100,8 @@ public class CuratorTest { return zkBuilder; } - private Curator createCurator(ConfigserverConfig configserverConfig) { - return new Curator(configserverConfig); + private Curator createCurator(ConfigserverConfig configserverConfig) throws IOException { + return new Curator(configserverConfig, Optional.of(Files.createTempFile("zookeeper-client", "cfg").toFile())); } private static class PortAllocator { |