summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2019-12-05 05:32:02 -0800
committerGitHub <noreply@github.com>2019-12-05 05:32:02 -0800
commit15d09c5b5a02bd3c7cf859401e4a613edeebd901 (patch)
tree440e04af619d9f83279fd593ca272dba772ddb5e
parent17a8ab2b1d0600e9651b1aa0748edba8f48707bf (diff)
parent1cfdd7f98bf1a33e7edc6cfb646ee53d5654c7f4 (diff)
Merge branch 'master' into balder/use-duration-in-messagebus-and-storageapi-rebased-1
-rw-r--r--config/src/tests/api/api.cpp2
-rw-r--r--config/src/tests/configfetcher/configfetcher.cpp2
-rw-r--r--config/src/tests/configretriever/configretriever.cpp2
-rw-r--r--config/src/tests/frt/frt.cpp4
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java11
-rw-r--r--container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java19
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java6
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java3
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java5
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java2
-rw-r--r--documentapi/src/tests/policies/policies_test.cpp4
-rw-r--r--documentapi/src/tests/policies/testframe.cpp4
-rw-r--r--documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp7
-rw-r--r--fastos/src/tests/processtest.cpp4
-rw-r--r--fastos/src/tests/thread_bounce_test.cpp2
-rw-r--r--fastos/src/tests/thread_mutex_test.cpp4
-rw-r--r--fastos/src/tests/thread_sleep_test.cpp2
-rw-r--r--fastos/src/tests/thread_stats_test.cpp12
-rw-r--r--fastos/src/tests/thread_test_base.hpp15
-rw-r--r--fastos/src/tests/threadtest.cpp4
-rw-r--r--fastos/src/vespa/fastos/thread.h10
-rw-r--r--fastos/src/vespa/fastos/unix_process.cpp4
-rw-r--r--fastos/src/vespa/fastos/unix_thread.cpp12
-rw-r--r--fastos/src/vespa/fastos/unix_thread.h1
-rw-r--r--fnet/src/examples/timeout/timeout.cpp7
-rw-r--r--fnet/src/tests/frt/rpc/detach_return_invoke.cpp2
-rw-r--r--fnet/src/vespa/fnet/frt/invoker.h1
-rw-r--r--jrt_test/src/tests/mandatory-methods/extract-reflection.cpp4
-rw-r--r--messagebus/src/tests/context/context.cpp2
-rw-r--r--messagebus/src/tests/loadbalance/loadbalance.cpp2
-rw-r--r--messagebus/src/tests/messagebus/messagebus.cpp6
-rw-r--r--messagebus/src/tests/messageordering/messageordering.cpp2
-rw-r--r--messagebus/src/tests/serviceaddress/serviceaddress.cpp2
-rw-r--r--messagebus/src/tests/slobrok/slobrok.cpp2
-rw-r--r--messagebus/src/tests/sourcesession/sourcesession.cpp6
-rw-r--r--messagebus/src/tests/throttling/throttling.cpp4
-rw-r--r--messagebus/src/vespa/messagebus/network/rpcnetwork.cpp1
-rw-r--r--messagebus/src/vespa/messagebus/testlib/testserver.cpp2
-rw-r--r--messagebus_test/src/tests/error/cpp-client.cpp4
-rw-r--r--messagebus_test/src/tests/error/cpp-server.cpp4
-rw-r--r--messagebus_test/src/tests/speed/cpp-client.cpp8
-rw-r--r--messagebus_test/src/tests/speed/cpp-server.cpp4
-rw-r--r--messagebus_test/src/tests/trace/cpp-server.cpp4
-rw-r--r--messagebus_test/src/tests/trace/trace.cpp4
-rw-r--r--metrics/src/tests/metricmanagertest.cpp6
-rw-r--r--metrics/src/tests/stresstest.cpp9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java49
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java76
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java27
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java103
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java34
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java460
-rw-r--r--searchcore/src/apps/proton/proton.cpp7
-rw-r--r--searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp15
-rw-r--r--searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp4
-rw-r--r--searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp12
-rw-r--r--searchcore/src/tests/proton/flushengine/flushengine_test.cpp2
-rw-r--r--searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp2
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp5
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h4
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp4
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h3
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java6
-rw-r--r--searchlib/src/tests/postinglistbm/stress_runner.cpp4
-rw-r--r--searchlib/src/tests/transactionlog/translogclient_test.cpp22
-rw-r--r--searchlib/src/tests/transactionlogstress/translogstress.cpp33
-rw-r--r--searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp4
-rw-r--r--slobrok/src/tests/configure/configure.cpp4
-rw-r--r--slobrok/src/tests/mirrorapi/mirrorapi.cpp6
-rw-r--r--slobrok/src/tests/registerapi/registerapi.cpp6
-rw-r--r--slobrok/src/tests/standalone/standalone.cpp4
-rw-r--r--staging_vespalib/src/tests/clock/clock_test.cpp19
-rw-r--r--staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp10
-rw-r--r--staging_vespalib/src/tests/timer/timer_test.cpp7
-rw-r--r--staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt2
-rw-r--r--staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp (renamed from staging_vespalib/src/vespa/vespalib/util/timer.cpp)10
-rw-r--r--staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h (renamed from staging_vespalib/src/vespa/vespalib/util/timer.h)8
-rw-r--r--staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp22
-rw-r--r--staging_vespalib/src/vespa/vespalib/util/shutdownguard.h6
-rw-r--r--storage/src/tests/common/metricstest.cpp6
-rw-r--r--storage/src/tests/common/teststorageapp.cpp5
-rw-r--r--storage/src/tests/distributor/distributortest.cpp5
-rw-r--r--storage/src/tests/persistence/filestorage/filestormanagertest.cpp16
-rw-r--r--storage/src/tests/persistence/filestorage/operationabortingtest.cpp4
-rw-r--r--storage/src/tests/storageserver/bucketintegritycheckertest.cpp19
-rw-r--r--storage/src/tests/storageserver/communicationmanagertest.cpp4
-rw-r--r--storage/src/tests/storageserver/statereportertest.cpp4
-rw-r--r--storage/src/vespa/storage/storageserver/fnetlistener.cpp4
-rw-r--r--storage/src/vespa/storage/storageserver/fnetlistener.h1
-rw-r--r--storage/src/vespa/storage/storageserver/storagenode.cpp3
-rw-r--r--storage/src/vespa/storage/tools/storage-cmd.cpp4
-rw-r--r--storageframework/src/tests/thread/tickingthreadtest.cpp30
-rw-r--r--storageserver/src/apps/storaged/storage.cpp2
-rw-r--r--vdslib/src/tests/thread/taskschedulertest.cpp6
-rw-r--r--vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp5
-rw-r--r--vespajlib/abi-spec.json14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java45
-rw-r--r--vespalib/src/tests/delegatelist/delegatelist.cpp4
-rw-r--r--vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp7
-rw-r--r--vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp2
-rw-r--r--vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp2
-rw-r--r--vespalib/src/tests/thread/thread_test.cpp2
-rw-r--r--vespalib/src/vespa/vespalib/testkit/test_kit.h1
-rw-r--r--vespalib/src/vespa/vespalib/util/thread.cpp3
-rw-r--r--vespamalloc/src/tests/allocfree/allocfree.cpp2
-rw-r--r--vespamalloc/src/tests/allocfree/linklist.cpp2
-rw-r--r--zkfacade/abi-spec.json1
-rw-r--r--zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java54
-rw-r--r--zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java11
-rw-r--r--zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java1
-rw-r--r--zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java10
119 files changed, 1134 insertions, 387 deletions
diff --git a/config/src/tests/api/api.cpp b/config/src/tests/api/api.cpp
index 4db66761444..0af2b848ea5 100644
--- a/config/src/tests/api/api.cpp
+++ b/config/src/tests/api/api.cpp
@@ -32,7 +32,7 @@ TEST_MT_FFF("require that source may be unable to serve config temporarily", 2,
ASSERT_TRUE(cfg.get() != NULL);
ASSERT_EQUAL("myfoo", cfg->myField);
} else {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
f3.myField = "myfoo";
f2.addBuilder("myid", &f3);
f1->reload();
diff --git a/config/src/tests/configfetcher/configfetcher.cpp b/config/src/tests/configfetcher/configfetcher.cpp
index 607ab0a29a5..be25e913980 100644
--- a/config/src/tests/configfetcher/configfetcher.cpp
+++ b/config/src/tests/configfetcher/configfetcher.cpp
@@ -69,7 +69,7 @@ TEST("requireThatConfigUpdatesArePerformed") {
while (!cb._configured && timer.elapsed().ms() < 20000.0) {
if (cb._configured)
break;
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
ASSERT_TRUE(cb._configured);
ASSERT_TRUE(cb._config);
diff --git a/config/src/tests/configretriever/configretriever.cpp b/config/src/tests/configretriever/configretriever.cpp
index fc921a324af..87f189ad7d3 100644
--- a/config/src/tests/configretriever/configretriever.cpp
+++ b/config/src/tests/configretriever/configretriever.cpp
@@ -251,7 +251,7 @@ public:
if (configured) {
return true;
}
- FastOS_Thread::Sleep(200);
+ std::this_thread::sleep_for(200ms);
}
return configured;
}
diff --git a/config/src/tests/frt/frt.cpp b/config/src/tests/frt/frt.cpp
index ba8279a1999..28dea82bfe7 100644
--- a/config/src/tests/frt/frt.cpp
+++ b/config/src/tests/frt/frt.cpp
@@ -49,7 +49,7 @@ namespace {
while (timer.elapsed().ms() < timeoutInMillis) {
if (notified)
break;
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
return notified;
}
@@ -260,7 +260,7 @@ TEST_FF("require that request is config task is scheduled", SourceFixture(), FRT
f1.conn.scheduler.CheckTasks();
if (f2.result.notified)
break;
- FastOS_Thread::Sleep(500);
+ std::this_thread::sleep_for(500ms);
}
ASSERT_TRUE(f2.result.notified);
f2.src.close();
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index 3dabf9bc649..395d8853603 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -226,9 +226,8 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
}
public static QueryProfileType getArgumentType() { return argumentType; }
-
/** The aliases of query properties */
- private static Map<String,CompoundName> propertyAliases;
+ private static Map<String, CompoundName> propertyAliases;
static {
Map<String,CompoundName> propertyAliasesBuilder = new HashMap<>();
addAliases(Query.getArgumentType(), propertyAliasesBuilder);
diff --git a/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java b/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java
index 499cb634295..6e36881ae63 100644
--- a/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/federation/FederationSearcher.java
@@ -115,7 +115,8 @@ public class FederationSearcher extends ForkingSearcher {
this(searchChainResolver, false, PropagateSourceProperties.ALL, null);
}
- private FederationSearcher(SearchChainResolver searchChainResolver, boolean strictSearchchain,
+ private FederationSearcher(SearchChainResolver searchChainResolver,
+ boolean strictSearchchain,
PropagateSourceProperties.Enum propagateSourceProperties,
TargetSelector targetSelector) {
this.searchChainResolver = searchChainResolver;
@@ -295,9 +296,11 @@ public class FederationSearcher extends ForkingSearcher {
}
}
- private Object getSourceOrProviderProperty(Query query, CompoundName propertyName,
- String sourceName, String providerName,
- Object defaultValue) {
+ private Object getSourceOrProviderProperty(Query query,
+ CompoundName propertyName,
+ String sourceName,
+ String providerName,
+ Object defaultValue) {
Object result = getProperty(query, new SourceKey(sourceName, propertyName.toString()));
if (result == null)
result = getProperty(query, new ProviderKey(providerName, propertyName.toString()));
diff --git a/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java b/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java
index 7ade9a0eaf9..8ccbe39cc5a 100644
--- a/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java
+++ b/container-search/src/main/java/com/yahoo/search/federation/selection/FederationTarget.java
@@ -1,12 +1,11 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.federation.selection;
-import java.util.Optional;
import com.yahoo.component.chain.Chain;
import com.yahoo.search.Searcher;
import com.yahoo.search.searchchain.model.federation.FederationOptions;
-import static com.google.common.base.Preconditions.checkNotNull;
+import java.util.Objects;
/**
* Represents a search chain that the federation searcher should send a query to,
@@ -22,11 +21,8 @@ public final class FederationTarget<T> {
private final T customData;
public FederationTarget(Chain<Searcher> chain, FederationOptions federationOptions, T customData) {
- checkNotNull(chain);
- checkNotNull(federationOptions);
-
- this.chain = chain;
- this.federationOptions = federationOptions;
+ this.chain = Objects.requireNonNull(chain, "chain cannot be null");
+ this.federationOptions = Objects.requireNonNull(federationOptions, "federationOptions cannot be null");
this.customData = customData;
}
@@ -62,9 +58,7 @@ public final class FederationTarget<T> {
@Override
public int hashCode() {
- int result = chain.hashCode();
- result = 31 * result + federationOptions.hashCode();
- return result;
+ return Objects.hash(chain, federationOptions);
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java b/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java
index ac39e986ff0..a4a82d27f8e 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/PropertyAliases.java
@@ -20,14 +20,14 @@ import java.util.Map;
public class PropertyAliases extends Properties {
/** A map from aliases to standard names */
- private final Map<String,CompoundName> aliases;
+ private final Map<String, CompoundName> aliases;
/**
* Creates an instance with a set of aliases. The given aliases will be used directly by this class.
* To make this class immutable and thread safe, relinquish ownership of the parameter map.
*/
- public PropertyAliases(Map<String,CompoundName> aliases) {
- this.aliases=aliases;
+ public PropertyAliases(Map<String, CompoundName> aliases) {
+ this.aliases = aliases;
}
/**
@@ -42,20 +42,21 @@ public class PropertyAliases extends Properties {
}
@Override
- public Map<String, Object> listProperties(CompoundName property,Map<String,String> context,
- com.yahoo.processing.request.Properties substitution) {
- return super.listProperties(unalias(property),context,substitution);
+ public Map<String, Object> listProperties(CompoundName property,
+ Map<String,String> context,
+ com.yahoo.processing.request.Properties substitution) {
+ return super.listProperties(unalias(property), context, substitution);
}
@Override
- public Object get(CompoundName name,Map<String,String> context,
+ public Object get(CompoundName name, Map<String,String> context,
com.yahoo.processing.request.Properties substitution) {
return super.get(unalias(name),context,substitution);
}
@Override
- public void set(CompoundName name,Object value,Map<String,String> context) {
- super.set(unalias(name),value,context);
+ public void set(CompoundName name, Object value, Map<String,String> context) {
+ super.set(unalias(name), value, context);
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 4cdd4488f7b..c06c84fcc36 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -38,13 +38,13 @@ public class QueryProperties extends Properties {
}
public void setParentQuery(Query query) {
- this.query=query;
+ this.query = query;
super.setParentQuery(query);
}
- @SuppressWarnings("deprecation")
@Override
- public Object get(CompoundName key, Map<String,String> context,
+ public Object get(CompoundName key,
+ Map<String,String> context,
com.yahoo.processing.request.Properties substitution) {
if (key.size() == 2 && key.first().equals(Model.MODEL)) {
Model model = query.getModel();
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
index 14cba858857..26bf189dd3d 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
@@ -33,7 +33,8 @@ enum PathGroup {
"/zone/v2/{*}"),
/** Paths used for creating and reading user resources. */
- user("/application/v4/user",
+ user(Optional.of("/api"),
+ "/application/v4/user",
"/athenz/v1/{*}"),
/** Paths used for creating tenants with proper access control. */
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java
index 26c4bf6292a..d10a4879bf5 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java
@@ -37,6 +37,7 @@ import java.util.logging.Logger;
public class AthenzApiHandler extends LoggingRequestHandler {
private final static Logger log = Logger.getLogger(AthenzApiHandler.class.getName());
+ private static final String OPTIONAL_PREFIX = "/api";
private final AthenzFacade athenz;
private final AthenzDomain sandboxDomain;
@@ -69,7 +70,7 @@ public class AthenzApiHandler extends LoggingRequestHandler {
}
private HttpResponse get(HttpRequest request) {
- Path path = new Path(request.getUri());
+ Path path = new Path(request.getUri(), OPTIONAL_PREFIX);
if (path.matches("/athenz/v1")) return root(request);
if (path.matches("/athenz/v1/domains")) return domainList(request);
if (path.matches("/athenz/v1/properties")) return properties();
@@ -79,7 +80,7 @@ public class AthenzApiHandler extends LoggingRequestHandler {
}
private HttpResponse post(HttpRequest request) {
- Path path = new Path(request.getUri());
+ Path path = new Path(request.getUri(), OPTIONAL_PREFIX);
if (path.matches("/athenz/v1/user")) return signup(request);
return ErrorResponse.notFoundError(String.format("No '%s' handler at '%s'", request.getMethod(),
request.getUri().getPath()));
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java
index d37df2cc313..c263054c808 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java
@@ -127,6 +127,7 @@ public class ControllerContainerTest {
" </handler>\n" +
" <handler id='com.yahoo.vespa.hosted.controller.restapi.athenz.AthenzApiHandler'>\n" +
" <binding>http://*/athenz/v1/*</binding>\n" +
+ " <binding>http://*/api/athenz/v1/*</binding>\n" +
" </handler>\n" +
" <handler id='com.yahoo.vespa.hosted.controller.restapi.zone.v1.ZoneApiHandler'>\n" +
" <binding>http://*/zone/v1</binding>\n" +
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java
index c90dcbf7e2b..34ee160c449 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java
@@ -48,7 +48,7 @@ public class AthenzApiTest extends ControllerContainerTest {
new File("property-list.json"));
// POST user signup
- tester.assertResponse(authenticatedRequest("http://localhost:8080/athenz/v1/user", "", Request.Method.POST),
+ tester.assertResponse(authenticatedRequest("http://localhost:8080/api/athenz/v1/user", "", Request.Method.POST),
"{\"message\":\"User 'bob' added to admin role of 'vespa.vespa.tenants.sandbox'\"}");
}
diff --git a/documentapi/src/tests/policies/policies_test.cpp b/documentapi/src/tests/policies/policies_test.cpp
index 3dbc9dd7e69..0f0b9bd4504 100644
--- a/documentapi/src/tests/policies/policies_test.cpp
+++ b/documentapi/src/tests/policies/policies_test.cpp
@@ -286,7 +286,7 @@ Test::assertMirrorReady(const slobrok::api::IMirrorAPI &mirror)
if (mirror.ready()) {
return;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
ASSERT_TRUE(false);
}
@@ -299,7 +299,7 @@ Test::assertMirrorContains(const slobrok::api::IMirrorAPI &mirror, const string
if (mirror.lookup(pattern).size() == numEntries) {
return;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
ASSERT_TRUE(false);
}
diff --git a/documentapi/src/tests/policies/testframe.cpp b/documentapi/src/tests/policies/testframe.cpp
index 9834e534a56..585885b7c4b 100644
--- a/documentapi/src/tests/policies/testframe.cpp
+++ b/documentapi/src/tests/policies/testframe.cpp
@@ -8,6 +8,8 @@
#include <vespa/messagebus/testlib/simpleprotocol.h>
#include <vespa/messagebus/testlib/simplereply.h>
#include <vespa/messagebus/network/rpcnetworkparams.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".testframe");
@@ -297,7 +299,7 @@ TestFrame::waitSlobrok(const string &pattern, uint32_t cnt)
if (res.size() == cnt) {
return true;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
LOG(error, "Slobrok failed to resolve '%s' to %d recipients in time.", pattern.c_str(), cnt);
return false;
diff --git a/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp b/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp
index 18dd525b066..e82a184d8b2 100644
--- a/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp
+++ b/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp
@@ -1,12 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "externslobrokpolicy.h"
-#include <vespa/vespalib/text/stringtokenizer.h>
#include <vespa/messagebus/routing/routingcontext.h>
+#include <vespa/vespalib/text/stringtokenizer.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/fnet/frt/frt.h>
#include <vespa/slobrok/sbmirror.h>
#include <vespa/fnet/transport.h>
-#include <vespa/fastos/thread.h>
+#include <thread>
using slobrok::api::IMirrorAPI;
using slobrok::api::MirrorAPI;
@@ -82,7 +83,7 @@ ExternSlobrokPolicy::lookup(mbus::RoutingContext& context, const string& pattern
if (_firstTry) {
int count = 0;
while (entries.empty() && count < 100) {
- FastOS_Thread::Sleep(50);
+ std::this_thread::sleep_for(50ms);
entries = mirror.lookup(pattern);
count++;
}
diff --git a/fastos/src/tests/processtest.cpp b/fastos/src/tests/processtest.cpp
index a6729dbb783..5a78eff1d36 100644
--- a/fastos/src/tests/processtest.cpp
+++ b/fastos/src/tests/processtest.cpp
@@ -3,6 +3,8 @@
#include <vespa/fastos/process.h>
#include <vespa/fastos/timestamp.h>
+using namespace std::chrono_literals;
+
class MyListener : public FastOS_ProcessRedirectListener
{
private:
@@ -119,7 +121,7 @@ public:
xproc->WriteStdin(nullptr, 0);
}
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
if(i == 10)
diff --git a/fastos/src/tests/thread_bounce_test.cpp b/fastos/src/tests/thread_bounce_test.cpp
index f7bb7ee1260..84506938455 100644
--- a/fastos/src/tests/thread_bounce_test.cpp
+++ b/fastos/src/tests/thread_bounce_test.cpp
@@ -43,7 +43,7 @@ class Thread_Bounce_Test : public ThreadTestBase
int left = static_cast<int>(checkTime.elapsed().ms());
while (left < 1000) {
- FastOS_Thread::Sleep(1000 - left);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000 - left));
left = static_cast<int>(checkTime.elapsed().ms());
}
diff --git a/fastos/src/tests/thread_mutex_test.cpp b/fastos/src/tests/thread_mutex_test.cpp
index d49cf37163d..6d3f8c3c5f0 100644
--- a/fastos/src/tests/thread_mutex_test.cpp
+++ b/fastos/src/tests/thread_mutex_test.cpp
@@ -132,7 +132,7 @@ class Thread_Mutex_Test : public ThreadTestBase
{
bool lockrc;
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
for(int i=0; i<5; i++)
{
@@ -145,7 +145,7 @@ class Thread_Mutex_Test : public ThreadTestBase
}
}
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
lockrc = mtx.try_lock();
Progress(lockrc, "We should get the mutex lock now (%s)",
diff --git a/fastos/src/tests/thread_sleep_test.cpp b/fastos/src/tests/thread_sleep_test.cpp
index 7fd3412b7c3..209b7d3f880 100644
--- a/fastos/src/tests/thread_sleep_test.cpp
+++ b/fastos/src/tests/thread_sleep_test.cpp
@@ -20,7 +20,7 @@ class Thread_Sleep_Test : public ThreadTestBase
Progress(rc, "Creating Thread");
Progress(true, "Sleeping 3 seconds");
- FastOS_Thread::Sleep(3000);
+ std::this_thread::sleep_for(3s);
}
Progress(true, "Closing threadpool...");
diff --git a/fastos/src/tests/thread_stats_test.cpp b/fastos/src/tests/thread_stats_test.cpp
index 3633c12bcaa..a9d304d411f 100644
--- a/fastos/src/tests/thread_stats_test.cpp
+++ b/fastos/src/tests/thread_stats_test.cpp
@@ -31,7 +31,7 @@ class Thread_Stats_Test : public ThreadTestBase
job[0].ownThread = pool.NewThread(this,
static_cast<void *>(&job[0]));
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
inactiveThreads = pool.GetNumInactiveThreads();
Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads);
@@ -44,7 +44,7 @@ class Thread_Stats_Test : public ThreadTestBase
job[1].ownThread = pool.NewThread(this,
static_cast<void *>(&job[1]));
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
inactiveThreads = pool.GetNumInactiveThreads();
Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads);
@@ -57,7 +57,7 @@ class Thread_Stats_Test : public ThreadTestBase
job[0].ownThread->SetBreakFlag();
job[1].ownThread->SetBreakFlag();
- FastOS_Thread::Sleep(3000);
+ std::this_thread::sleep_for(3s);
inactiveThreads = pool.GetNumInactiveThreads();
Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads);
@@ -72,7 +72,7 @@ class Thread_Stats_Test : public ThreadTestBase
job[0].code = WAIT_FOR_BREAK_FLAG;
job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0]));
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
inactiveThreads = pool.GetNumInactiveThreads();
Progress(inactiveThreads == 1, "Inactive threads = %d", inactiveThreads);
@@ -84,7 +84,7 @@ class Thread_Stats_Test : public ThreadTestBase
job[1].code = WAIT_FOR_BREAK_FLAG;
job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1]));
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
inactiveThreads = pool.GetNumInactiveThreads();
Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads);
@@ -97,7 +97,7 @@ class Thread_Stats_Test : public ThreadTestBase
job[0].ownThread->SetBreakFlag();
job[1].ownThread->SetBreakFlag();
- FastOS_Thread::Sleep(3000);
+ std::this_thread::sleep_for(3s);
inactiveThreads = pool.GetNumInactiveThreads();
Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads);
diff --git a/fastos/src/tests/thread_test_base.hpp b/fastos/src/tests/thread_test_base.hpp
index 7966e95b369..c4f7ed76ea7 100644
--- a/fastos/src/tests/thread_test_base.hpp
+++ b/fastos/src/tests/thread_test_base.hpp
@@ -3,6 +3,7 @@
#pragma once
#include <chrono>
+#include <thread>
static volatile int64_t number;
#define INCREASE_NUMBER_AMOUNT 10000
@@ -47,7 +48,7 @@ public:
}
}
- FastOS_Thread::Sleep(500);
+ std::this_thread::sleep_for(500ms);
if(threadsFinished)
break;
@@ -88,7 +89,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg)
Progress(true, "Thread printing message: [%s]", job->message);
job->result = strlen(job->message);
- FastOS_Thread::Sleep(3000);
+ std::this_thread::sleep_for(3s);
break;
}
@@ -109,7 +110,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg)
number = number + 2;
if(i == sleepOn)
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
guard = std::unique_lock<std::mutex>();
@@ -123,7 +124,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg)
{
for(;;)
{
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
if(thread->GetBreakFlag())
{
@@ -192,7 +193,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg)
case WAIT2SEC_AND_SIGNALCOND:
{
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
job->condition->notify_one();
job->result = 1;
break;
@@ -202,7 +203,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg)
{
{
std::lock_guard<std::mutex> guard(*job->mutex);
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
}
job->result = 1;
break;
@@ -210,7 +211,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg)
case WAIT_2_SEC:
{
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
job->result = 1;
break;
}
diff --git a/fastos/src/tests/threadtest.cpp b/fastos/src/tests/threadtest.cpp
index 9507bb1e5d7..0a8a0d2bf02 100644
--- a/fastos/src/tests/threadtest.cpp
+++ b/fastos/src/tests/threadtest.cpp
@@ -43,7 +43,7 @@ class ThreadTest : public ThreadTestBase
if(waitingThreads == numWait)
break;
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
}
@@ -336,7 +336,7 @@ class ThreadTest : public ThreadTestBase
// Threads are not guaranteed to have entered sleep yet,
// as this test only tests for result code
// Wait another second to be sure.
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
void SignalTest ()
diff --git a/fastos/src/vespa/fastos/thread.h b/fastos/src/vespa/fastos/thread.h
index 12866c71b2c..257acbc92d3 100644
--- a/fastos/src/vespa/fastos/thread.h
+++ b/fastos/src/vespa/fastos/thread.h
@@ -347,15 +347,7 @@ public:
/**
* Destructor.
*/
- virtual ~FastOS_ThreadInterface (){}
-
- /**
- * Sleep for x milliseconds. Attempting to sleep for <1 milliseconds
- * will result in failure.
- * @param ms Number of milliseconds to sleep.
- * @return Boolean success/failure
- */
- static bool Sleep(int ms);
+ virtual ~FastOS_ThreadInterface () {}
/**
* Instruct a thread to exit. This could be used in conjunction with
diff --git a/fastos/src/vespa/fastos/unix_process.cpp b/fastos/src/vespa/fastos/unix_process.cpp
index 86d285059b8..4d4197f5354 100644
--- a/fastos/src/vespa/fastos/unix_process.cpp
+++ b/fastos/src/vespa/fastos/unix_process.cpp
@@ -40,6 +40,8 @@ extern char **environ;
#endif
+using namespace std::chrono_literals;
+
static pid_t safe_fork ()
{
pid_t pid;
@@ -1629,7 +1631,7 @@ FastOS_UNIX_ProcessStarter::Wait(FastOS_UNIX_Process *process,
}
}
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
return rc;
diff --git a/fastos/src/vespa/fastos/unix_thread.cpp b/fastos/src/vespa/fastos/unix_thread.cpp
index 5218bde2630..9e48727deb3 100644
--- a/fastos/src/vespa/fastos/unix_thread.cpp
+++ b/fastos/src/vespa/fastos/unix_thread.cpp
@@ -83,18 +83,6 @@ FastOS_UNIX_Thread::~FastOS_UNIX_Thread()
}
}
-bool FastOS_UNIX_Thread::Sleep (int ms)
-{
- bool rc=false;
-
- if (ms > 0) {
- usleep(ms*1000);
- rc = true;
- }
-
- return rc;
-}
-
FastOS_ThreadId FastOS_UNIX_Thread::GetThreadId ()
{
return _handle;
diff --git a/fastos/src/vespa/fastos/unix_thread.h b/fastos/src/vespa/fastos/unix_thread.h
index c6e0b040fc7..35df3f5745f 100644
--- a/fastos/src/vespa/fastos/unix_thread.h
+++ b/fastos/src/vespa/fastos/unix_thread.h
@@ -36,7 +36,6 @@ public:
~FastOS_UNIX_Thread();
- static bool Sleep (int ms);
FastOS_ThreadId GetThreadId () override;
static bool CompareThreadIds (FastOS_ThreadId a, FastOS_ThreadId b);
static FastOS_ThreadId GetCurrentThreadId ();
diff --git a/fnet/src/examples/timeout/timeout.cpp b/fnet/src/examples/timeout/timeout.cpp
index 1d6ecc11909..23dfbeb9070 100644
--- a/fnet/src/examples/timeout/timeout.cpp
+++ b/fnet/src/examples/timeout/timeout.cpp
@@ -2,7 +2,8 @@
#include <vespa/fnet/fnet.h>
#include <vespa/fastos/app.h>
-#include <chrono>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP("timeout");
@@ -55,7 +56,7 @@ MyApp::Main()
transport.Start(&pool);
// stable-state operation
- FastOS_Thread::Sleep(500);
+ std::this_thread::sleep_for(500ms);
FNET_Packet *packet;
FNET_Context context;
@@ -64,7 +65,7 @@ MyApp::Main()
t = clock::now();
timeout.Schedule(2.0); // timeout in 2 seconds
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
timeout.Unschedule(); // cancel timeout
ms = (clock::now() - t);
diff --git a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp
index 43a61cd9bcd..95dbe672909 100644
--- a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp
+++ b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp
@@ -54,7 +54,7 @@ TEST("detach return invoke") {
if (receptor.req != 0) {
break;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
req->SubRef();
target->SubRef();
diff --git a/fnet/src/vespa/fnet/frt/invoker.h b/fnet/src/vespa/fnet/frt/invoker.h
index 64adf66688e..0838ef84dd3 100644
--- a/fnet/src/vespa/fnet/frt/invoker.h
+++ b/fnet/src/vespa/fnet/frt/invoker.h
@@ -5,7 +5,6 @@
#include "rpcrequest.h"
#include <vespa/fnet/task.h>
#include <vespa/fnet/ipackethandler.h>
-#include <vespa/fastos/thread.h>
#include <mutex>
#include <condition_variable>
diff --git a/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp b/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp
index 40c54980c11..cd1ad7e6eed 100644
--- a/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp
+++ b/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp
@@ -2,6 +2,8 @@
#include <vespa/fastos/app.h>
#include <vespa/fnet/frt/frt.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
class RPCInfo : public FastOS_Application
{
@@ -85,7 +87,7 @@ public:
if (info->GetErrorCode() != FRTE_RPC_CONNECTION) {
break;
}
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
target->SubRef();
target = supervisor.GetTarget(_argv[1]);
}
diff --git a/messagebus/src/tests/context/context.cpp b/messagebus/src/tests/context/context.cpp
index 12539b16b11..713e01cae73 100644
--- a/messagebus/src/tests/context/context.cpp
+++ b/messagebus/src/tests/context/context.cpp
@@ -73,7 +73,7 @@ Test::Main()
if (queue.size() == 3) {
break;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
EXPECT_EQUAL(queue.size(), 3u);
{
diff --git a/messagebus/src/tests/loadbalance/loadbalance.cpp b/messagebus/src/tests/loadbalance/loadbalance.cpp
index 2f510d98ff1..05ea6d78871 100644
--- a/messagebus/src/tests/loadbalance/loadbalance.cpp
+++ b/messagebus/src/tests/loadbalance/loadbalance.cpp
@@ -78,7 +78,7 @@ Test::Main()
if (queue.size() == msgCnt) {
break;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
EXPECT_TRUE(queue.size() == msgCnt);
EXPECT_TRUE(h1.cnt == msgCnt / 3);
diff --git a/messagebus/src/tests/messagebus/messagebus.cpp b/messagebus/src/tests/messagebus/messagebus.cpp
index cf249e6eaec..86c7bf91f2a 100644
--- a/messagebus/src/tests/messagebus/messagebus.cpp
+++ b/messagebus/src/tests/messagebus/messagebus.cpp
@@ -43,7 +43,7 @@ struct Base {
if (queue.size() == size) {
return true;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
return false;
}
@@ -270,7 +270,7 @@ Test::testSendToCol()
}
}
client->waitQueueSize(300);
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
client->waitQueueSize(300);
while (client->queue.size() > 0) {
Routable::UP reply = client->queue.dequeue();
@@ -347,7 +347,7 @@ Test::testSendToAnyThenCol()
}
}
client->waitQueueSize(300);
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
client->waitQueueSize(300);
while (client->queue.size() > 0) {
Routable::UP reply = client->queue.dequeue();
diff --git a/messagebus/src/tests/messageordering/messageordering.cpp b/messagebus/src/tests/messageordering/messageordering.cpp
index 520c3d3dea3..481b8bbd270 100644
--- a/messagebus/src/tests/messageordering/messageordering.cpp
+++ b/messagebus/src/tests/messageordering/messageordering.cpp
@@ -167,7 +167,7 @@ Test::Main()
const int messageCount = 5000;
for (int i = 0; i < messageCount; ++i) {
vespalib::string str(vespalib::make_string("%d", i));
- //FastOS_Thread::Sleep(1);
+ //std::this_thread::sleep_for(1ms);
auto msg = std::make_unique<SimpleMessage>(str, true, commonMessageId);
msg->getTrace().setLevel(9);
//LOG(debug, "Sending message %p for %d", msg.get(), i);
diff --git a/messagebus/src/tests/serviceaddress/serviceaddress.cpp b/messagebus/src/tests/serviceaddress/serviceaddress.cpp
index ac43cec3c02..441da5a80ac 100644
--- a/messagebus/src/tests/serviceaddress/serviceaddress.cpp
+++ b/messagebus/src/tests/serviceaddress/serviceaddress.cpp
@@ -82,7 +82,7 @@ Test::waitSlobrok(RPCNetwork &network, const string &pattern, size_t num)
if (res.size() == num) {
return true;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
return false;
}
diff --git a/messagebus/src/tests/slobrok/slobrok.cpp b/messagebus/src/tests/slobrok/slobrok.cpp
index 7e0718283a6..439ee0b23b5 100644
--- a/messagebus/src/tests/slobrok/slobrok.cpp
+++ b/messagebus/src/tests/slobrok/slobrok.cpp
@@ -51,7 +51,7 @@ compare(const IMirrorAPI &api, const string &pattern, SpecList expect)
if (actual == expect) {
return true;
}
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
return false;
}
diff --git a/messagebus/src/tests/sourcesession/sourcesession.cpp b/messagebus/src/tests/sourcesession/sourcesession.cpp
index c793dd435c8..3ffe0f8e19d 100644
--- a/messagebus/src/tests/sourcesession/sourcesession.cpp
+++ b/messagebus/src/tests/sourcesession/sourcesession.cpp
@@ -35,7 +35,7 @@ struct DelayedHandler : public IMessageHandler
// this will block the transport thread in the server messagebus,
// but that should be ok, as we only want to test the timing in the
// client messagebus...
- FastOS_Thread::Sleep(delay);
+ std::this_thread::sleep_for(std::chrono::milliseconds(delay));
session->acknowledge(std::move(msg));
}
};
@@ -59,7 +59,7 @@ bool waitQueueSize(RoutableQueue &queue, uint32_t size) {
if (queue.size() == size) {
return true;
}
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
return false;
}
@@ -99,7 +99,7 @@ Test::testSequencing()
EXPECT_TRUE(ss->send(Message::UP(new SimpleMessage("foo", true, 2)), "dst").isAccepted());
EXPECT_TRUE(ss->send(Message::UP(new SimpleMessage("foo", true, 1)), "dst").isAccepted());
EXPECT_TRUE(waitQueueSize(dstQ, 2));
- FastOS_Thread::Sleep(250);
+ std::this_thread::sleep_for(250ms);
EXPECT_TRUE(waitQueueSize(dstQ, 2));
EXPECT_TRUE(waitQueueSize(srcQ, 0));
ds->acknowledge(Message::UP((Message*)dstQ.dequeue().release()));
diff --git a/messagebus/src/tests/throttling/throttling.cpp b/messagebus/src/tests/throttling/throttling.cpp
index 6599604bf9a..a23e0b61550 100644
--- a/messagebus/src/tests/throttling/throttling.cpp
+++ b/messagebus/src/tests/throttling/throttling.cpp
@@ -49,7 +49,7 @@ bool waitQueueSize(RoutableQueue &queue, uint32_t size)
if (queue.size() == size) {
return true;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
return false;
}
@@ -60,7 +60,7 @@ bool waitPending(SourceSession& session, uint32_t size)
if (session.getPendingCount() == size) {
return true;
}
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
return false;
}
diff --git a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp
index 0bc7f9f3399..c6f61b383bc 100644
--- a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp
+++ b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp
@@ -17,6 +17,7 @@
#include <vespa/fnet/scheduler.h>
#include <vespa/fnet/transport.h>
#include <vespa/fnet/frt/supervisor.h>
+#include <vespa/fastos/thread.h>
#include <thread>
#include <vespa/log/log.h>
diff --git a/messagebus/src/vespa/messagebus/testlib/testserver.cpp b/messagebus/src/vespa/messagebus/testlib/testserver.cpp
index fe5baab3e40..8e7f138b886 100644
--- a/messagebus/src/vespa/messagebus/testlib/testserver.cpp
+++ b/messagebus/src/vespa/messagebus/testlib/testserver.cpp
@@ -60,7 +60,7 @@ TestServer::waitState(const SlobrokState &slobrokState)
if (done) {
return true;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
return false;
}
diff --git a/messagebus_test/src/tests/error/cpp-client.cpp b/messagebus_test/src/tests/error/cpp-client.cpp
index 833d941da32..abc7967bfe5 100644
--- a/messagebus_test/src/tests/error/cpp-client.cpp
+++ b/messagebus_test/src/tests/error/cpp-client.cpp
@@ -7,6 +7,8 @@
#include <vespa/messagebus/rpcmessagebus.h>
#include <vespa/messagebus/network/rpcnetworkparams.h>
#include <vespa/messagebus/testlib/receptor.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/fastos/app.h>
using namespace mbus;
@@ -45,7 +47,7 @@ App::Main()
break;
}
}
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
if (reply.get() == 0) {
fprintf(stderr, "CPP-CLIENT: no reply\n");
diff --git a/messagebus_test/src/tests/error/cpp-server.cpp b/messagebus_test/src/tests/error/cpp-server.cpp
index d5200ed20c1..383f703317e 100644
--- a/messagebus_test/src/tests/error/cpp-server.cpp
+++ b/messagebus_test/src/tests/error/cpp-server.cpp
@@ -6,6 +6,8 @@
#include <vespa/messagebus/network/rpcnetworkparams.h>
#include <vespa/messagebus/emptyreply.h>
#include <vespa/messagebus/errorcode.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/fastos/app.h>
using namespace mbus;
@@ -55,7 +57,7 @@ App::Main()
"file:routing.cfg");
Server server(mb.getMessageBus());
while (true) {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
return 0;
}
diff --git a/messagebus_test/src/tests/speed/cpp-client.cpp b/messagebus_test/src/tests/speed/cpp-client.cpp
index 43d030b519b..ff00128037a 100644
--- a/messagebus_test/src/tests/speed/cpp-client.cpp
+++ b/messagebus_test/src/tests/speed/cpp-client.cpp
@@ -7,6 +7,8 @@
#include <vespa/messagebus/testlib/simplemessage.h>
#include <vespa/messagebus/testlib/simpleprotocol.h>
#include <vespa/messagebus/testlib/simplereply.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/fastos/timestamp.h>
#include <vespa/fastos/app.h>
@@ -100,7 +102,7 @@ App::Main()
Client client(mb.getMessageBus(), SourceSessionParams().setTimeout(30s));
// let the system 'warm up'
- FastOS_Thread::Sleep(5000);
+ std::this_thread::sleep_for(5s);
// inject messages into the feedback loop
for (uint32_t i = 0; i < 1024; ++i) {
@@ -108,7 +110,7 @@ App::Main()
}
// let the system 'warm up'
- FastOS_Thread::Sleep(5000);
+ std::this_thread::sleep_for(5s);
fastos::StopWatch stopWatch;
uint32_t okBefore = 0;
@@ -117,7 +119,7 @@ App::Main()
uint32_t failAfter = 0;
client.sample(okBefore, failBefore);
- FastOS_Thread::Sleep(10000); // Benchmark time
+ std::this_thread::sleep_for(10s); // Benchmark time
fastos::TimeStamp elapsed = stopWatch.elapsed();
client.sample(okAfter, failAfter);
double time = elapsed.ms();
diff --git a/messagebus_test/src/tests/speed/cpp-server.cpp b/messagebus_test/src/tests/speed/cpp-server.cpp
index 82b884c46f2..a1aa5a5029c 100644
--- a/messagebus_test/src/tests/speed/cpp-server.cpp
+++ b/messagebus_test/src/tests/speed/cpp-server.cpp
@@ -6,6 +6,8 @@
#include <vespa/messagebus/testlib/simpleprotocol.h>
#include <vespa/messagebus/rpcmessagebus.h>
#include <vespa/messagebus/network/rpcnetworkparams.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/fastos/app.h>
using namespace mbus;
@@ -62,7 +64,7 @@ App::Main()
"file:routing.cfg");
Server server(mb.getMessageBus());
while (true) {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
return 0;
}
diff --git a/messagebus_test/src/tests/trace/cpp-server.cpp b/messagebus_test/src/tests/trace/cpp-server.cpp
index d6db86070b1..75f4ee3a002 100644
--- a/messagebus_test/src/tests/trace/cpp-server.cpp
+++ b/messagebus_test/src/tests/trace/cpp-server.cpp
@@ -5,6 +5,8 @@
#include <vespa/messagebus/rpcmessagebus.h>
#include <vespa/messagebus/network/rpcnetworkparams.h>
#include <vespa/messagebus/emptyreply.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/fastos/app.h>
using namespace mbus;
@@ -73,7 +75,7 @@ App::Main()
"file:routing.cfg");
Server server(mb.getMessageBus(), _argv[1]);
while (true) {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
return 0;
}
diff --git a/messagebus_test/src/tests/trace/trace.cpp b/messagebus_test/src/tests/trace/trace.cpp
index 0d4c622a0df..334f00745da 100644
--- a/messagebus_test/src/tests/trace/trace.cpp
+++ b/messagebus_test/src/tests/trace/trace.cpp
@@ -36,7 +36,7 @@ waitSlobrok(RPCMessageBus &mbus, const std::string &pattern)
if (res.size() > 0) {
return true;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
return false;
}
@@ -112,7 +112,7 @@ Test::Main()
}
}
std::cout << "Attempt " << i << " got errors, retrying in 1 second.." << std::endl;
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
EXPECT_TRUE(!reply->hasErrors());
diff --git a/metrics/src/tests/metricmanagertest.cpp b/metrics/src/tests/metricmanagertest.cpp
index 1d954a641b6..6407bb73ecb 100644
--- a/metrics/src/tests/metricmanagertest.cpp
+++ b/metrics/src/tests/metricmanagertest.cpp
@@ -10,8 +10,10 @@
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/stllike/asciistream.h>
#include <vespa/vespalib/util/xmlstream.h>
-#include <vespa/log/log.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
+#include <vespa/log/log.h>
LOG_SETUP(".test.metricmanager");
namespace metrics {
@@ -386,7 +388,7 @@ bool waitForTimeProcessed(const MetricManager& mm,
while (time(0) < lastchance) {
if (mm.getLastProcessedTime() >= processtime) return true;
mm.timeChangedNotification();
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
return false;
}
diff --git a/metrics/src/tests/stresstest.cpp b/metrics/src/tests/stresstest.cpp
index 4a6d2f4a2ea..f3e709b4b04 100644
--- a/metrics/src/tests/stresstest.cpp
+++ b/metrics/src/tests/stresstest.cpp
@@ -4,6 +4,8 @@
#include <vespa/metrics/metricmanager.h>
#include <vespa/metrics/metrics.h>
#include <vespa/metrics/summetric.hpp>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/log/log.h>
@@ -39,7 +41,7 @@ InnerMetricSet::InnerMetricSet(const char* name, const LoadTypeSet& lt, MetricSe
_valueSum.addMetricToSum(_value1);
_valueSum.addMetricToSum(_value2);
}
-InnerMetricSet::~InnerMetricSet() { }
+InnerMetricSet::~InnerMetricSet() = default;
MetricSet*
InnerMetricSet::clone(std::vector<Metric::UP> &ownerList, CopyType copyType,
@@ -133,11 +135,10 @@ TEST(StressTest, test_stress)
FastOS_ThreadPool threadPool(256 * 1024);
std::vector<Hammer::UP> hammers;
for (uint32_t i=0; i<10; ++i) {
- hammers.push_back(Hammer::UP(
- new Hammer(metrics, loadTypes, threadPool)));
+ hammers.push_back(std::make_unique<Hammer>(metrics, loadTypes, threadPool));
}
LOG(info, "Waiting to let loadgivers hammer a while");
- FastOS_Thread::Sleep(5 * 1000);
+ std::this_thread::sleep_for(5s);
LOG(info, "Removing loadgivers");
hammers.clear();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 6c583d960bd..14aa3ebf84e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -70,7 +70,7 @@ public class IntermediateGraph {
return operations;
}
- void optimize() {
+ public void optimize() {
renameDimensions();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 280fe354149..55f5d979ea8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -3,11 +3,13 @@
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.operations.Gemm;
+import ai.vespa.rankingexpression.importer.operations.ConcatReduce;
import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Softmax;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
@@ -21,6 +23,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
@@ -36,24 +39,37 @@ import java.util.stream.Collectors;
*/
class GraphImporter {
+ private static final Value eluAlpha = DoubleValue.frozen(1.0);
+ private static final Value seluAlpha = DoubleValue.frozen(1.6732632423543772848170429916717);
+ private static final Value seluGamma = DoubleValue.frozen(1.0507009873554804934193349852946);
+ private static final Value leakyReluAlpha = DoubleValue.frozen(0.01);
+
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
+ String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
+ return mapOperation(type, inputs, modelName, nodeName, attributes);
+ }
- switch (node.getOpType().toLowerCase()) {
+ static IntermediateOperation mapOperation(String opType,
+ List<IntermediateOperation> inputs,
+ String modelName,
+ String nodeName,
+ AttributeConverter attributes) {
+ switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble()));
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
@@ -63,23 +79,31 @@ class GraphImporter {
case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
- case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
+ case "max": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.max);
+ case "min": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.min);
+ case "mean": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.avg);
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum);
+ case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
+ case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null);
+ case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt());
+ case "reducelogsum":return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, null, ScalarFunctions.log());
+ case "reducelogsumexp": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.exp(), ScalarFunctions.log());
+ case "reducemax": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.max);
case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg);
+ case "reducemin": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.min);
+ case "reduceprod": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.prod);
+ case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum);
+ case "reducesumsquare": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), null);
case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu(attributes.get("gamma").orElse(seluGamma).asDouble(), attributes.get("alpha").orElse(seluAlpha).asDouble()));
+ case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu(attributes.get("alpha").orElse(leakyReluAlpha).asDouble()));
case "shape": return new Shape(modelName, nodeName, inputs);
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "softmax": return new Softmax(modelName, nodeName, inputs);
+ case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
@@ -90,7 +114,7 @@ class GraphImporter {
}
IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ op.warning("Operation '" + opType + "' is currently not implemented");
return op;
}
@@ -260,5 +284,4 @@ class GraphImporter {
"Either no explicit name given or no single output name.");
}
-
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
new file mode 100644
index 00000000000..ea6bb2eaf99
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
@@ -0,0 +1,76 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class ConcatReduce extends IntermediateOperation {
+
+ private final static String tmpDimensionName = "__concat_reduce_tmp_dimension_name__";
+ private final Reduce.Aggregator aggregator;
+
+ public ConcatReduce(String modelName, String nodeName, List<IntermediateOperation> inputs, Reduce.Aggregator aggregator) {
+ super(modelName, nodeName, inputs);
+ this.aggregator = aggregator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(inputs.size())) return null;
+ return inputs.get(0).type().get();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(inputs.size())) return null;
+
+ TensorFunction result = inputs.get(0).function().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ TensorFunction b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName);
+ }
+ return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if ( ! allInputTypesPresent(inputs.size())) return;
+
+ OrderedTensorType a = inputs.get(0).type().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ OrderedTensorType b = inputs.get(i).type().get();
+
+ OrderedTensorType largest = largestInput(a, b);
+ OrderedTensorType smallest = smallestInput(a, b);
+
+ int sizeDifference = largest.rank() - smallest.rank();
+ for (int j = 0; j < smallest.rank(); ++j) {
+ String bDim = smallest.dimensions().get(j).name();
+ String aDim = largest.dimensions().get(j + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ a = b;
+ }
+ }
+
+ private OrderedTensorType largestInput(OrderedTensorType a, OrderedTensorType b) {
+ return a.rank() >= b.rank() ? a : b;
+ }
+
+ private OrderedTensorType smallestInput(OrderedTensorType a, OrderedTensorType b) {
+ return a.rank() < b.rank() ? a : b;
+ }
+
+ @Override
+ public ConcatReduce withInputs(List<IntermediateOperation> inputs) {
+ return new ConcatReduce(modelName(), name(), inputs, aggregator);
+ }
+
+ @Override
+ public String operationName() { return "ConcatReduce"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index f091ae165d1..3fba8680332 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -92,7 +92,7 @@ public class Gemm extends IntermediateOperation {
return null;
}
- String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose!
+ String joinDimension = aType.dimensions().get(1 - transposeA).name();
TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension);
TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index bd302afa5c7..efd6f9d3339 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -199,7 +199,9 @@ public abstract class IntermediateOperation {
String constantName = "constant(" + vespaName() + ")";
Value result = context.get(constantName);
if (result == DoubleValue.NaN) {
- if (inputs.size() == 0) {
+ if (constantValue != null) {
+ result = constantValue;
+ } else if (inputs.size() == 0) {
if (getConstantValue().isEmpty()) {
throw new IllegalArgumentException("Error in evaluating constant for " + name);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
index ded76db60fe..5785621eed3 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
@@ -28,6 +28,9 @@ public class OnnxConcat extends IntermediateOperation {
if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
OrderedTensorType aType = inputs.get(0).type().get();
+ if (concatDimensionIndex < 0) {
+ concatDimensionIndex = aType.dimensions().size() + concatDimensionIndex;
+ }
long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L);
for (int i = 1; i < inputs.size(); ++i) {
@@ -92,7 +95,7 @@ public class OnnxConcat extends IntermediateOperation {
public void renameDimensions(DimensionRenamer renamer) {
super.renameDimensions(renamer);
concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
- }
+ }
@Override
public OnnxConcat withInputs(List<IntermediateOperation> inputs) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
index 1b2d9ac090e..7af051484f5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import java.util.function.DoubleUnaryOperator;
/**
* ONNX Reduce[Sum/Mean/etc] operation
@@ -24,6 +25,8 @@ public class Reduce extends IntermediateOperation {
private final AttributeMap attributeMap;
private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator;
+ private final DoubleUnaryOperator preOperator;
+ private final DoubleUnaryOperator postOperator;
private List<String> reduceDimensions;
@@ -31,11 +34,23 @@ public class Reduce extends IntermediateOperation {
List<IntermediateOperation> inputs,
AttributeMap attributeMap,
com.yahoo.tensor.functions.Reduce.Aggregator aggregator) {
+ this(modelName, nodeName, inputs, attributeMap, aggregator, null, null);
+ }
+
+ public Reduce(String modelName, String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ com.yahoo.tensor.functions.Reduce.Aggregator aggregator,
+ DoubleUnaryOperator preOperator,
+ DoubleUnaryOperator postOperator) {
super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
this.aggregator = aggregator;
+ this.preOperator = preOperator;
+ this.postOperator = postOperator;
}
+
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
@@ -48,7 +63,7 @@ public class Reduce extends IntermediateOperation {
for (Value i : attributeMap.getList("axes").get()) {
int dimensionIndex = (int) i.asDouble();
if (dimensionIndex < 0) {
- dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ dimensionIndex = inputType.dimensions().size() + dimensionIndex;
}
reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
}
@@ -61,6 +76,9 @@ public class Reduce extends IntermediateOperation {
if ( ! allInputTypesPresent(1)) return null;
TensorFunction inputFunction = inputs.get(0).function().get();
+ if (preOperator != null) {
+ inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator);
+ }
TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
@@ -74,6 +92,9 @@ public class Reduce extends IntermediateOperation {
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
}
+ if (postOperator != null) {
+ output = new com.yahoo.tensor.functions.Map(output, postOperator);
+ }
return output;
}
@@ -93,7 +114,7 @@ public class Reduce extends IntermediateOperation {
@Override
public Reduce withInputs(List<IntermediateOperation> inputs) {
- return new Reduce(modelName(), name(), inputs, attributeMap, aggregator);
+ return new Reduce(modelName(), name(), inputs, attributeMap, aggregator, preOperator, postOperator);
}
@Override
@@ -101,7 +122,7 @@ public class Reduce extends IntermediateOperation {
private boolean shouldKeepDimensions() {
Optional<Value> keepDims = attributeMap.get("keepdims");
- return keepDims.isPresent() && keepDims.get().asBoolean();
+ return keepDims.isEmpty() || keepDims.get().asBoolean(); // default is 1
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c7accd00619..c88fc18e6c6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
-import java.util.Iterator;
import java.util.List;
+import java.util.Optional;
import java.util.stream.Collectors;
public class Reshape extends IntermediateOperation {
- public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ private final AttributeMap attributeMap;
+
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
protected OrderedTensorType lazyGetType() {
- if ( ! allInputTypesPresent(2)) return null;
+ if (inputs.size() == 2) {
+ return typeWithShapeAsInput();
+ } else if (inputs.size() == 1) {
+ return typeWithShapeAsAttribute();
+ }
+ throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size());
+ }
+ private OrderedTensorType typeWithShapeAsInput() {
IntermediateOperation newShape = inputs.get(1);
if (newShape.getConstantValue().isEmpty())
- throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
+ throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant.");
+ OrderedTensorType inputType = inputs.get(0).type().get();
Tensor shape = newShape.getConstantValue().get().asTensor();
+ List<Integer> dimSizes = new ArrayList<>(shape.type().rank());
+ shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
+
+ // first pass - set 0 values, meaning that size is retained from input
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ if (dimSizes.get(i) == 0) {
+ if (i >= inputType.dimensions().size()) {
+ throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input");
+ }
+ dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue());
+ }
+ }
+
+ // second pass - set any -1 value, meaning that the dimension size should be expanded to fill the tensor
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ if (dimSizes.get(i) < 0) {
+ int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b);
+ int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue();
+ dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize));
+ }
+ }
+
+ return buildOutputType(dimSizes);
+ }
+
+ private OrderedTensorType typeWithShapeAsAttribute() {
+ if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0)
+ throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty.");
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
+ List<Value> shape = attributeMap.getList("shape").get();
+ List<Integer> dimSizes = new ArrayList<>(shape.size());
+
+ for (Value v : shape) {
+ int size = (int) v.asDouble();
if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- OrderedTensorType.tensorSize(inputType.type()).intValue();
+ int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b);
+ int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue();
+ size = -1 * shapeSize / tensorSize;
}
- outputTypeBuilder.add(TensorType.Dimension.indexed(
- String.format("%s_%d", vespaName(), dimensionIndex), size));
- dimensionIndex++;
+ dimSizes.add(size);
+ }
+ return buildOutputType(dimSizes);
+ }
+
+ private OrderedTensorType buildOutputType(List<Integer> dimSizes) {
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i)));
}
return outputTypeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
- if ( ! allInputTypesPresent(2)) return null;
- if ( ! allInputFunctionsPresent(2)) return null;
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null;
+ if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
- return reshape(inputFunction, inputType.type(), type.type());
+ return reshape(inputFunction, inputType, type);
}
@Override
@@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation {
@Override
public Reshape withInputs(List<IntermediateOperation> inputs) {
- return new Reshape(modelName(), name(), inputs);
+ return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
+ public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
@@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation {
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
+ // Todo: change this to use tensor generate when available
+
List<String> from = new ArrayList<>();
List<String> to = new ArrayList<>();
boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType);
if (dimensionNamesOverlap) {
- TensorType.Builder builder = new TensorType.Builder(outputType.valueType());
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType());
for (int i = 0; i < outputType.rank(); ++i) {
TensorType.Dimension dim = outputType.dimensions().get(i);
from.add(dim.name());
to.add("temp_" + dim.name());
- builder.dimension(dim.withName("temp_" + dim.name()));
+ builder.add(dim.withName("temp_" + dim.name()));
}
outputType = builder.build();
}
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo));
+ ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo));
- TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
+ TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build();
Generate transformTensor = new Generate(transformationType,
new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
@@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation {
return result;
}
- private static boolean dimensionNamesOverlap(TensorType a, TensorType b) {
- return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent());
+ private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
+ return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
}
- private static ExpressionNode unrollTensorExpression(TensorType type) {
+ private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
if (type.rank() == 0)
return new ConstantNode(DoubleValue.zero);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 032ffb88a46..83086926316 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -2,8 +2,13 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Map;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
+import java.util.ArrayList;
import java.util.List;
/**
@@ -13,8 +18,11 @@ import java.util.List;
*/
public class Softmax extends IntermediateOperation {
- public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ private final AttributeMap attributeMap;
+
+ public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -28,18 +36,30 @@ public class Softmax extends IntermediateOperation {
if ( ! allInputFunctionsPresent(1)) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
- String dimension = inputType.dimensions().get(0).name();
- if (inputType.rank() == 2) {
- dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension
+
+ int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension
+ if (attributeMap.get("axis").isPresent()) {
+ axis = (int)attributeMap.get("axis").get().asDouble();
+ }
+ if (axis < 0) {
+ axis = inputType.rank() + axis;
}
+ List<String> reduceDimensions = new ArrayList<>();
+ for (int i = axis; i < inputType.rank(); ++i) {
+ reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension
+ }
+
+ TensorFunction input = inputs.get(0).function().get();
+ TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
+ TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());
- TensorFunction inputFunction = inputs.get(0).function().get();
- return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension);
+ return div;
}
@Override
public Softmax withInputs(List<IntermediateOperation> inputs) {
- return new Softmax(modelName(), name(), inputs);
+ return new Softmax(modelName(), name(), inputs, attributeMap);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index 4f656d86929..0d2ba0cc714 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -64,7 +64,7 @@ class GraphImporter {
case "identity": return new Identity(modelName, nodeName, inputs);
case "placeholder": return new Argument(modelName, nodeName, nodeType);
case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
case "shape": return new Shape(modelName, nodeName, inputs);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
@@ -113,7 +113,7 @@ class GraphImporter {
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "softmax": return new Softmax(modelName, nodeName, inputs);
+ case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
// state ops
case "variable": return new Constant(modelName, nodeName, nodeType);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
new file mode 100644
index 00000000000..6954abe5157
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -0,0 +1,460 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.operations.Constant;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*;
+import static onnx.Onnx.AttributeProto.AttributeType.FLOAT;
+import static onnx.Onnx.AttributeProto.AttributeType.INT;
+import static onnx.Onnx.AttributeProto.AttributeType.INTS;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Unit tests for ONNX operators. The number on the test reflects the minimum
+ * opset number for the operations tested.
+ *
+ * @author lesters
+ */
+public class OnnxOperationsTestCase {
+
+ private static final String modelName = "test_model";
+
+ @Test
+ public void testElementwiseOperators7() throws ParseException {
+ Tensor x = evaluate("tensor(d0[7]):[-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]");
+ assertEval("acos", x, evaluate("acos(x)", x));
+ assertEval("asin", x, evaluate("asin(x)", x));
+ assertEval("atan", x, evaluate("atan(x)", x));
+ assertEval("cos", x, evaluate("cos(x)", x));
+ assertEval("sin", x, evaluate("sin(x)", x));
+ assertEval("tan", x, evaluate("tan(x)", x));
+ assertEval("tanh", x, evaluate("tanh(x)", x));
+ assertEval("neg", x, evaluate("-x", x));
+ assertEval("sigmoid", x, evaluate("sigmoid(x)", x));
+ assertEval("exp", x, evaluate("exp(x)", x));
+ assertEval("floor", x, evaluate("floor(x)", x));
+ assertEval("ceil", x, evaluate("ceil(x)", x));
+ assertEval("abs", x, evaluate("abs(x)", x));
+
+ assertEval("relu", x, evaluate("max(0, x)", x));
+ assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 1.0 * (exp(a)-1), a)))", x));
+ assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 0.5 * (exp(a)-1), a)))", x), createAttribute("alpha", 0.5f));
+ assertEval("selu", x, evaluate("map(x, f(a)(1.050700987 * if(a >= 0, a, 1.673263242 * (exp(a) - 1))))", x));
+ assertEval("selu", x, evaluate("map(x, f(a)(1.0 * if(a >= 0, a, 1.5 * (exp(a) - 1))))", x), createAttributes().attr("gamma", 1.0f).attr("alpha", 1.5f).build());
+ assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x));
+ assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f));
+
+ x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]");
+ assertEval("log", x, evaluate("log(x)", x));
+ assertEval("sqrt", x, evaluate("sqrt(x)", x));
+ assertEval("reciprocal", x, evaluate("map(x, f(a)(1.0 / a))", x));
+ }
+
+ @Test
+ public void testJoinOperators7() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[3, 4]");
+ Tensor y = evaluate("tensor(d0[2]):[1, 2]");
+ assertEval("add", x, y, evaluate("tensor(d0[2]):[4, 6]"));
+ assertEval("sub", x, y, evaluate("tensor(d0[2]):[2, 2]"));
+ assertEval("mul", x, y, evaluate("tensor(d0[2]):[3, 8]"));
+ assertEval("div", x, y, evaluate("tensor(d0[2]):[3, 2]"));
+ assertEval("greater", x, y, evaluate("tensor(d0[2]):[1, 1]"));
+ assertEval("less", x, y, evaluate("tensor(d0[2]):[0, 0]"));
+ assertEval("equal", x, y, evaluate("tensor(d0[2]):[0, 0]"));
+ assertEval("pow", x, y, evaluate("tensor(d0[2]):[3, 16]"));
+
+ x = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ y = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ assertEval("add", x, y, evaluate("x + y", x, y));
+ assertEval("sub", x, y, evaluate("x - y", x, y));
+ assertEval("mul", x, y, evaluate("x * y", x, y));
+ assertEval("div", x, y, evaluate("x / y", x, y));
+ assertEval("greater", x, y, evaluate("join(x, y, f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(x, y, f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(x, y, f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(x, y, f(a,b)(pow(a,b)))", x, y));
+
+ // broadcasting
+ x = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ y = evaluate("random(d0[4]) + 1");
+ assertEval("add", x, y, evaluate("x + rename(y, d0, d2)", x, y));
+ assertEval("sub", x, y, evaluate("x - rename(y, d0, d2)", x, y));
+ assertEval("mul", x, y, evaluate("x * rename(y, d0, d2)", x, y));
+ assertEval("div", x, y, evaluate("x / rename(y, d0, d2)", x, y));
+ assertEval("greater", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y));
+ }
+
+ @Test
+ public void testConcatReduce8() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[3, 4]");
+ Tensor y = evaluate("tensor(d0[2]):[1, 2]");
+ Tensor z = evaluate("tensor(d0[2]):[5, 6]");
+ assertEval("max", x, y, z, evaluate("tensor(d0[2]):[5, 6]"));
+ assertEval("min", x, y, z, evaluate("tensor(d0[2]):[1, 2]"));
+ assertEval("mean", x, y, z, evaluate("tensor(d0[2]):[3, 4]"));
+
+ x = evaluate("random(d0[2],d1[3],d2[4])");
+ y = evaluate("random(d0[2],d1[3],d2[4])");
+ z = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), max, tmp)", x, y, z));
+ assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), min, tmp)", x, y, z));
+ assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), avg, tmp)", x, y, z));
+
+ // broadcasting
+ x = evaluate("random(d0[2],d1[3],d2[4])");
+ y = evaluate("random(d0[3],d1[4])");
+ z = evaluate("random(d0[4])");
+ assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), max, tmp)", x, y, z));
+ assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), min, tmp)", x, y, z));
+ assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), avg, tmp)", x, y, z));
+ }
+
+ @Test
+ public void testConcat4() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[1, 2]");
+ Tensor y = evaluate("tensor(d0[2]):[3, 4]");
+ Tensor expected = evaluate("tensor(d0[4]):[1,2,3,4]");
+ assertEval("concat", x, y, expected, createAttribute("axis", 0));
+ assertEval("concat", x, y, expected, createAttribute("axis", -1));
+
+ x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+ y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
+ assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", 0));
+ assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", 1));
+ assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", -1));
+ assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", -2));
+
+ x = evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 5, 6, 7, 8]");
+ y = evaluate("tensor(d0[2],d1[2],d2[2]):[9,10,11,12,13,14,15,16]");
+ assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", 0));
+ assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", 1));
+ assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", 2));
+ assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", -1));
+ assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", -2));
+ assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", -3));
+ }
+
+ @Test
+ public void testGemm7() throws ParseException {
+ Tensor a = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+ Tensor b = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
+ Tensor c = evaluate("tensor(d0[2],d1[2]):[0.1, 0.2, 0.3, 0.4]");
+
+ assertEval("gemm", a, b, evaluate("tensor(d0[2],d1[2]):[19, 22, 43, 50]"));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.3, 50.4]"));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[38.1, 44.2, 86.3, 100.4]"), createAttribute("alpha", 2.0f));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.2, 22.4, 43.6, 50.8]"), createAttribute("beta", 2.0f));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[26.1, 30.2, 38.3, 44.4]"), createAttribute("transA", 1));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[17.1, 23.2, 39.3, 53.4]"), createAttribute("transB", 1));
+
+ // unidictional broadcasting for c
+ c = evaluate("tensor(d0[2]):[0.1, 0.2]");
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.1, 50.2]"));
+ }
+
+ @Test
+ public void testIdentity1() throws ParseException {
+ Tensor x = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("identity", x, x);
+ }
+
+ @Test
+ public void testMatMul1() throws ParseException {
+ Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]");
+ Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]"));
+ }
+
+ @Test
+ public void testReshape5() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ Tensor y = evaluate("tensor(d0[1]):[4]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[4]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[2,2]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[3]):[2,1,2]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[2,-1]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[2,0]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[0,-1]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ y = evaluate("tensor(d0[2]):[3,2]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"));
+
+ y = evaluate("tensor(d0[4]):[3,2,-1,1]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]"));
+ }
+
+ @Test
+ public void testReduceOperators1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"));
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("axes", new int[] {0,1}));
+ assertEval("reducesum", x, evaluate("tensor():[10]"), createAttribute("keepdims", 0));
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("keepdims", 1));
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[]{0}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2]):[4, 6]"), createAttributes().attr("axes", new int[]{0}).attr("keepdims", 0).build());
+ assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {1}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[]{1}).attr("keepdims", 0).build());
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[] {-2}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {-1}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[] {-1}).attr("keepdims", 0).build());
+
+ assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[1]):[24]"));
+ assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[2]):[3, 8]"), createAttribute("axes", new int[] {0}));
+
+ assertEval("reducemin", x, evaluate("tensor(d0[1],d1[1]):[1]"));
+ assertEval("reducemin", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0}));
+
+ assertEval("reducemax", x, evaluate("tensor(d0[1],d1[1]):[4]"));
+ assertEval("reducemax", x, evaluate("tensor(d0[1],d1[2]):[3, 4]"), createAttribute("axes", new int[] {0}));
+
+ assertEval("reducemean", x, evaluate("tensor():[2.5]"), createAttribute("keepdims", 0));
+ assertEval("reducemean", x, evaluate("tensor(d0[2]):[2, 3]"), createAttributes().attr("axes", new int[] {0}).attr("keepdims", 0).build());
+
+ assertEval("reducelogsum", x, evaluate("tensor():[log(10)]"), createAttribute("keepdims", 0));
+ assertEval("reducelogsumexp", x, evaluate("tensor():[log(exp(1)+exp(2)+exp(3)+exp(4))]"), createAttribute("keepdims", 0));
+ assertEval("reducesumsquare", x, evaluate("tensor():[1*1+2*2+3*3+4*4]"), createAttribute("keepdims", 0));
+
+ x = evaluate("tensor(d0[1],d1[5]):[-10, -5, 0, 5, 10]");
+ assertEval("reducel1", x, evaluate("tensor():[30]"), createAttribute("keepdims", 0));
+ assertEval("reducel2", x, evaluate("tensor():[sqrt(10*10 + 5*5 + 5*5 + 10*10)]"), createAttribute("keepdims", 0));
+ }
+
+ @Test
+ public void testShape1() throws ParseException {
+ Tensor x = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("shape", x, evaluate("tensor(d0[3]):[2,3,4]"));
+ }
+
+ @Test
+ public void testSoftmax1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[1],d1[3]):[-1, 0, 1]");
+ assertEval("softmax", x, evaluate("tensor(d0[1],d1[3]):[0.09003058, 0.24472848, 0.66524094]"));
+
+ x = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 7]");
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", 0));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", 1)); // 1 is default
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", -1));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", -2));
+
+ x = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", 0));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", 1));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", 2));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", -1));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", -2));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", -3));
+ }
+
+ @Test
+ public void testSqueeze1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[1],d1[2]):[1, 2]");
+ assertEval("squeeze", x, evaluate("tensor(d0[2]):[1, 2]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]");
+ assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"));
+ assertEval("squeeze", x, evaluate("tensor(d0[2],d1[1],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0}));
+ assertEval("squeeze", x, evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2}));
+ assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0, 2}));
+ }
+
+ @Test
+ public void testWhere9() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+ Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
+ Tensor condition = evaluate("tensor(d0[2],d1[2]):[0, 1, 0, 1]");
+ assertEval("where", condition, x, y, evaluate("tensor(d0[2],d1[2]):[5, 2, 7, 4]"));
+
+ assertEval("where", evaluate("tensor():[0]"), x, y, y);
+ assertEval("where", evaluate("tensor():[1]"), x, y, x);
+ assertEval("where", evaluate("tensor(d0[1]):[0]"), x, y, y);
+ assertEval("where", evaluate("tensor(d0[1]):[1]"), x, y, x);
+ assertEval("where", evaluate("tensor(d0[1],d1[1]):[0]"), x, y, y);
+ assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x);
+ }
+
+ private Tensor evaluate(String expr) throws ParseException {
+ return evaluate(expr, null, null, null);
+ }
+
+ private Tensor evaluate(String expr, Tensor x) throws ParseException {
+ return evaluate(expr, x, null, null);
+ }
+
+ private Tensor evaluate(String expr, Tensor x, Tensor y) throws ParseException {
+ return evaluate(expr, x, y, null);
+ }
+
+ private Tensor evaluate(String expr, Tensor x, Tensor y, Tensor z) throws ParseException {
+ Context context = new MapContext(DoubleValue.NaN);
+ if (x != null) context.put("x", new TensorValue(x));
+ if (y != null) context.put("y", new TensorValue(y));
+ if (z != null) context.put("z", new TensorValue(z));
+ return new RankingExpression(expr).evaluate(context).asTensor();
+ }
+
+ private Tensor evaluate(IntermediateOperation op) {
+ Tensor tensor = op.evaluateAsConstant(op.type().get()).asTensor();
+ return renameToStandardType(op, tensor);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected) {
+ assertEval(opName, x, null, null, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, x, null, null, expected, attr);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, x, y, null, expected, attr);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
+ assertEval(opName, x, y, null, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
+ assertEval(opName, x, y, z, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
+ Context context = new MapContext(DoubleValue.NaN);
+ List<IntermediateOperation> inputs = createInputs(context, x, y, z);
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
+ optimizeAndRename(opName, op);
+ Tensor result = evaluate(op);
+ assertEquals(expected, result);
+ assertEquals(expected.type(), result.type());
+ }
+
+ private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) {
+ List<IntermediateOperation> inputs = new ArrayList<>();
+ addInput(inputs, context, x, "x");
+ addInput(inputs, context, y, "y");
+ addInput(inputs, context, z, "z");
+ return inputs;
+ }
+
+ private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) {
+ if (x == null) return;
+ context.put(name, new TensorValue(x));
+ IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString()));
+ op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type)));
+ inputs.add(op);
+ }
+
+ Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) {
+ IndexedTensor indexedTensor = (IndexedTensor) tensor;
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
+ for (int i = 0; i < indexedTensor.size(); i++) {
+ builder.cellByDirectIndex(type.toDirectIndex(i), indexedTensor.get(i));
+ }
+ return builder.build();
+ }
+
+ private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) {
+ IntermediateGraph graph = new IntermediateGraph(modelName);
+ graph.put(opName, op);
+ graph.outputs(graph.defaultSignature()).put(opName, opName);
+ graph.optimize();
+ return op.function().get();
+ }
+
+ private Tensor renameToStandardType(IntermediateOperation op, Tensor tensor) {
+ OrderedTensorType operationType = op.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo);
+ return func.evaluate();
+ }
+ return tensor;
+ }
+
+ static AttributeConverter createAttribute(String name, int val) {
+ return new Attributes().attr(name, val).build();
+ }
+
+ static AttributeConverter createAttribute(String name, float val) {
+ return new Attributes().attr(name, val).build();
+ }
+
+ static AttributeConverter createAttribute(String name, int [] vals) {
+ return new Attributes().attr(name, vals).build();
+ }
+
+ static Attributes createAttributes() {
+ return new Attributes();
+ }
+
+ private static class Attributes {
+
+ Onnx.NodeProto.Builder nodeBuilder;
+
+ Attributes() {
+ this.nodeBuilder = Onnx.NodeProto.newBuilder();
+ }
+
+ Attributes attr(String name, int val) {
+ nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(INT).setI(val).build());
+ return this;
+ }
+
+ Attributes attr(String name, float val) {
+ nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(FLOAT).setF(val).build());
+ return this;
+ }
+
+ Attributes attr(String name, int [] vals) {
+ Onnx.AttributeProto.Builder builder = Onnx.AttributeProto.newBuilder();
+ for (int val : vals) {
+ builder.addInts(val);
+ }
+ nodeBuilder.addAttribute(builder.setName(name).setType(INTS).build());
+ return this;
+ }
+
+ AttributeConverter build() {
+ return AttributeConverter.convert(nodeBuilder.build());
+ }
+
+ }
+
+}
diff --git a/searchcore/src/apps/proton/proton.cpp b/searchcore/src/apps/proton/proton.cpp
index b37eb5ac0cf..f80558a1bc6 100644
--- a/searchcore/src/apps/proton/proton.cpp
+++ b/searchcore/src/apps/proton/proton.cpp
@@ -8,6 +8,7 @@
#include <vespa/metrics/metricmanager.h>
#include <vespa/vespalib/util/signalhandler.h>
#include <vespa/vespalib/util/programoptions.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/vespalib/io/fileutil.h>
#include <vespa/config/common/exceptions.h>
#include <vespa/fastos/app.h>
@@ -198,7 +199,7 @@ App::Main()
LOG(info, "Sleeping 900 seconds due to proton state");
int sleepLeft = 900;
while (!(SIG::INT.check() || SIG::TERM.check()) && sleepLeft > 0) {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1000ms);
--sleepLeft;
}
EV_STOPPING("proton", "shutdown after stop on io errors");
@@ -226,7 +227,7 @@ App::Main()
}
EV_STARTED("proton");
while (!(SIG::INT.check() || SIG::TERM.check() || (spiProton && spiProton->getNode().attemptedStopped()))) {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1000ms);
if (spiProton && spiProton->configUpdated()) {
storage::ResumeGuard guard(spiProton->getNode().pause());
spiProton->updateConfig();
@@ -240,7 +241,7 @@ App::Main()
if (spiProton) {
// report down state to cluster controller.
spiProton->getNode().notifyPartitionDown(0, "proton state string is " + stateString);
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1000ms);
}
EV_STOPPING("proton", "shutdown after new stop on io errors");
return 1;
diff --git a/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp b/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp
index 3b3b5f412d2..dcd3dce218b 100644
--- a/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp
+++ b/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp
@@ -6,8 +6,10 @@
#include <vespa/fnet/frt/frt.h>
#include <vespa/vespalib/util/host_name.h>
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/fastos/app.h>
#include <sys/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP("vespa-proton-cmd");
@@ -115,7 +117,7 @@ public:
slobrok::api::MirrorAPI sbmirror(_frt->supervisor(), sbcfg);
for (int timeout = 1; timeout < 20; timeout++) {
if (!sbmirror.ready()) {
- FastOS_Thread::Sleep(50*timeout);
+ std::this_thread::sleep_for(50ms*timeout);
}
}
if (!sbmirror.ready()) {
@@ -123,12 +125,9 @@ public:
"ERROR: no data from service location broker\n");
exit(1);
}
- slobrok::api::MirrorAPI::SpecList specs =
- sbmirror.lookup(rtcPattern);
- slobrok::api::MirrorAPI::SpecList specs2 =
- sbmirror.lookup(rtcPattern2);
- slobrok::api::MirrorAPI::SpecList specs3 =
- sbmirror.lookup(rtcPattern3);
+ slobrok::api::MirrorAPI::SpecList specs = sbmirror.lookup(rtcPattern);
+ slobrok::api::MirrorAPI::SpecList specs2 = sbmirror.lookup(rtcPattern2);
+ slobrok::api::MirrorAPI::SpecList specs3 = sbmirror.lookup(rtcPattern3);
int found = 0;
std::string service;
@@ -167,7 +166,7 @@ public:
slobrok::api::MirrorAPI sbmirror(_frt->supervisor(), sbcfg);
for (int timeout = 1; timeout < 20; timeout++) {
if (!sbmirror.ready()) {
- FastOS_Thread::Sleep(50*timeout);
+ std::this_thread::sleep_for(50ms*timeout);
}
}
if (!sbmirror.ready()) {
diff --git a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp
index f5aa74d85e1..5775c31b205 100644
--- a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp
+++ b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp
@@ -7,6 +7,7 @@
#include <vespa/searchlib/transactionlog/translogserver.h>
#include <vespa/vespalib/util/programoptions.h>
#include <vespa/vespalib/util/xmlstream.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/document/config/config-documenttypes.h>
#include <vespa/document/repo/documenttyperepo.h>
#include <vespa/document/fieldvalue/document.h>
@@ -14,6 +15,7 @@
#include <vespa/config/helper/configgetter.hpp>
#include <vespa/fastos/app.h>
#include <iostream>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP("vespa-transactionlog-inspect");
@@ -491,7 +493,7 @@ protected:
return 1;
}
for (size_t i = 0; !callback.isEof() && (i < 60 * 60); i++ ) {
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
return 0;
}
diff --git a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp
index fdd53d629ad..dfac2edad61 100644
--- a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp
+++ b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp
@@ -1064,7 +1064,7 @@ TEST_F("require that document pruner is active",
MyFrozenBucket::UP frozen3(new MyFrozenBucket(f._mc, bucketId3));
f.setPruneConfig(DocumentDBPruneRemovedDocumentsConfig(0.2, 900.0));
for (uint32_t i = 0; i < 6; ++i) {
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
ASSERT_TRUE(f._executor.waitIdle(TIMEOUT_SEC));
if (f._removed.getNumUsedLids() != 10u)
break;
@@ -1073,7 +1073,7 @@ TEST_F("require that document pruner is active",
EXPECT_EQUAL(10u, f._removed.getDocumentCount());
frozen3.reset();
for (uint32_t i = 0; i < 600; ++i) {
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
ASSERT_TRUE(f._executor.waitIdle(TIMEOUT_SEC));
if (f._removed.getNumUsedLids() != 10u)
break;
@@ -1089,7 +1089,7 @@ TEST_F("require that heartbeats are scheduled",
f.startMaintenance();
f.setHeartBeatConfig(DocumentDBHeartBeatConfig(0.2));
for (uint32_t i = 0; i < 600; ++i) {
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
if (f._fh.getHeartBeats() != 0u)
break;
}
@@ -1104,7 +1104,7 @@ TEST_F("require that periodic session prunings are scheduled",
f.startMaintenance();
f.setGroupingSessionPruneInterval(0.2);
for (uint32_t i = 0; i < 600; ++i) {
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
if (f._gsp.isInvoked) {
break;
}
@@ -1233,7 +1233,7 @@ TEST_F("require that a blocked job is unblocked and executed after thaw bucket",
EXPECT_FALSE(myJob2.isBlocked());
bool done1 = myJob1._latch.await(TIMEOUT_MS);
EXPECT_TRUE(done1);
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
EXPECT_EQUAL(0u, myJob2._runCnt);
}
@@ -1245,7 +1245,7 @@ TEST_F("require that blocked jobs are not executed", MaintenanceControllerFixtur
f._mc.registerJobInMasterThread(std::move(job));
f._injectDefaultJobs = false;
f.startMaintenance();
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
EXPECT_EQUAL(0u, myJob._runCnt);
}
diff --git a/searchcore/src/tests/proton/flushengine/flushengine_test.cpp b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp
index bfd4450b1f2..c18d393b98f 100644
--- a/searchcore/src/tests/proton/flushengine/flushengine_test.cpp
+++ b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp
@@ -685,7 +685,7 @@ assertThatHandlersInCurrentSet(FlushEngine & engine, const std::vector<const cha
{
FlushEngine::FlushMetaSet current1 = engine.getCurrentlyFlushingSet();
while ((current1.size() < targets.size()) || !asserCorrectHandlers(current1, targets)) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
current1 = engine.getCurrentlyFlushingSet();
}
}
diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp
index d86a750794f..870be2ab409 100644
--- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp
+++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp
@@ -175,7 +175,7 @@ struct ProtonConfigOwner : public proton::IProtonConfigurer
while (timer.elapsed().ms() < timeout) {
if (getConfigured())
return true;
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
return getConfigured();
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp
index 37f1664841a..8a4cb1682a6 100644
--- a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp
@@ -1,10 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "disk_mem_usage_sampler.h"
-#include <vespa/vespalib/util/timer.h>
+#include <vespa/vespalib/util/scheduledexecutor.h>
#include <vespa/vespalib/util/lambdatask.h>
#include <filesystem>
-#include <unistd.h>
using vespalib::makeLambdaTask;
@@ -32,7 +31,7 @@ DiskMemUsageSampler::setConfig(const Config &config)
_filter.setConfig(config.filterConfig);
_sampleInterval = config.sampleInterval;
sampleUsage();
- _periodicTimer = std::make_unique<vespalib::Timer>();
+ _periodicTimer = std::make_unique<vespalib::ScheduledExecutor>();
_periodicTimer->scheduleAtFixedRate(makeLambdaTask([this]()
{ sampleUsage(); }),
_sampleInterval, _sampleInterval);
diff --git a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h
index 5a439e69003..2ab13f2f48a 100644
--- a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h
+++ b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h
@@ -4,7 +4,7 @@
#include "disk_mem_usage_filter.h"
-namespace vespalib { class Timer; }
+namespace vespalib { class ScheduledExecutor; }
namespace proton {
@@ -15,7 +15,7 @@ class DiskMemUsageSampler {
DiskMemUsageFilter _filter;
std::filesystem::path _path;
double _sampleInterval;
- std::unique_ptr<vespalib::Timer> _periodicTimer;
+ std::unique_ptr<vespalib::ScheduledExecutor> _periodicTimer;
void sampleUsage();
void sampleDiskUsage();
diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp
index a2672cc7972..893748ae49e 100644
--- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp
@@ -6,7 +6,7 @@
#include "i_blockable_maintenance_job.h"
#include <vespa/searchcorespi/index/i_thread_service.h>
#include <vespa/vespalib/util/closuretask.h>
-#include <vespa/vespalib/util/timer.h>
+#include <vespa/vespalib/util/scheduledexecutor.h>
#include <vespa/log/log.h>
LOG_SETUP(".proton.server.maintenancecontroller");
@@ -167,7 +167,7 @@ MaintenanceController::restart()
if (!_started || _stopping || !_readySubDB.valid()) {
return;
}
- _periodicTimer.reset(new vespalib::Timer());
+ _periodicTimer = std::make_unique<vespalib::ScheduledExecutor>();
addJobsToPeriodicTimer();
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h
index 24c1c18959e..3cfdeba4d34 100644
--- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h
+++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h
@@ -8,6 +8,7 @@
#include "ibucketfreezelistener.h"
#include <vespa/searchcore/proton/common/doctypename.h>
#include <mutex>
+#include <vespa/vespalib/util/scheduledexecutor.h>
namespace vespalib {
class Timer;
@@ -77,7 +78,7 @@ private:
MaintenanceDocumentSubDB _readySubDB;
MaintenanceDocumentSubDB _remSubDB;
MaintenanceDocumentSubDB _notReadySubDB;
- std::unique_ptr<vespalib::Timer> _periodicTimer;
+ std::unique_ptr<vespalib::ScheduledExecutor> _periodicTimer;
DocumentDBMaintenanceConfigSP _config;
FrozenBuckets _frozenBuckets;
bool _started;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 376fbc85d6d..38f152d728c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -316,6 +316,12 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0}:0, {x:1}:1, {x:2}:2 }", "range(x[3])");
tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])");
tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)");
+ tester.assertEvaluates("tensor(x[2]):[0.0, 2.0]",
+ "tensor(x[2]):{{x:0}:tensor(y[2]):{{y:0}:((0+0)+a)," +
+ "{y:1}:((0+1)+a)}{y:0}," +
+ "{x:1}:tensor(y[2]):{{y:0}:((1+0)+a)," +
+ "{y:1}:((1+1)+a)}{y:1}" +
+ "}");
// tensor value
tester.assertEvaluates("3.0", "tensor0{x:1}", "{ {x:0}:1, {x:1}:3 }");
diff --git a/searchlib/src/tests/postinglistbm/stress_runner.cpp b/searchlib/src/tests/postinglistbm/stress_runner.cpp
index 53b683cd7fd..100a4fcd70d 100644
--- a/searchlib/src/tests/postinglistbm/stress_runner.cpp
+++ b/searchlib/src/tests/postinglistbm/stress_runner.cpp
@@ -8,9 +8,11 @@
#include <vespa/searchlib/test/fakedata/fakeword.h>
#include <vespa/searchlib/test/fakedata/fakewordset.h>
#include <vespa/searchlib/test/fakedata/fpfactory.h>
+#include <vespa/vespalib/util/time.h>
#include <condition_variable>
#include <mutex>
#include <vector>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".stress_runner");
@@ -306,7 +308,7 @@ StressMaster::run()
totalTime / _loops, type.c_str());
dropPostings();
}
- FastOS_Thread::Sleep(250);
+ std::this_thread::sleep_for(250ms);
}
double
diff --git a/searchlib/src/tests/transactionlog/translogclient_test.cpp b/searchlib/src/tests/transactionlog/translogclient_test.cpp
index 8a515f749f1..c4751af5adb 100644
--- a/searchlib/src/tests/transactionlog/translogclient_test.cpp
+++ b/searchlib/src/tests/transactionlog/translogclient_test.cpp
@@ -248,7 +248,7 @@ bool Test::partialUpdateTest()
TransLogClient::Visitor::UP visitor = tls.createVisitor("test1", ca);
ASSERT_TRUE(visitor.get());
ASSERT_TRUE( visitor->visit(5, 7) );
- for (size_t i(0); ! ca._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); }
ASSERT_TRUE( ca._eof );
ASSERT_TRUE( ca.map().size() == 1);
ASSERT_TRUE( ca.hasSerial(7) );
@@ -257,7 +257,7 @@ bool Test::partialUpdateTest()
TransLogClient::Visitor::UP visitor1 = tls.createVisitor("test1", ca1);
ASSERT_TRUE(visitor1.get());
ASSERT_TRUE( visitor1->visit(4, 5) );
- for (size_t i(0); ! ca1._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca1._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); }
ASSERT_TRUE( ca1._eof );
ASSERT_TRUE( ca1.map().size() == 0);
@@ -265,7 +265,7 @@ bool Test::partialUpdateTest()
TransLogClient::Visitor::UP visitor2 = tls.createVisitor("test1", ca2);
ASSERT_TRUE(visitor2.get());
ASSERT_TRUE( visitor2->visit(5, 6) );
- for (size_t i(0); ! ca2._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca2._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); }
ASSERT_TRUE( ca2._eof );
ASSERT_TRUE( ca2.map().size() == 0);
@@ -273,7 +273,7 @@ bool Test::partialUpdateTest()
TransLogClient::Visitor::UP visitor3 = tls.createVisitor("test1", ca3);
ASSERT_TRUE(visitor3.get());
ASSERT_TRUE( visitor3->visit(5, 1000) );
- for (size_t i(0); ! ca3._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca3._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); }
ASSERT_TRUE( ca3._eof );
ASSERT_TRUE( ca3.map().size() == 1);
ASSERT_TRUE( ca3.hasSerial(7) );
@@ -437,7 +437,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c
TransLogClient::Visitor::UP visitor = tls.createVisitor(name, ca);
ASSERT_TRUE(visitor.get());
EXPECT_TRUE( visitor->visit(0, 1) );
- for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); }
EXPECT_TRUE( ca._eof );
EXPECT_TRUE( ! ca.hasSerial(0) );
EXPECT_TRUE( ca.hasSerial(1) );
@@ -447,7 +447,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c
visitor = tls.createVisitor(name, ca);
ASSERT_TRUE(visitor.get());
EXPECT_TRUE( visitor->visit(1, 2) );
- for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); }
EXPECT_TRUE( ca._eof );
EXPECT_TRUE( ! ca.hasSerial(0) );
EXPECT_TRUE( ! ca.hasSerial(1) );
@@ -458,7 +458,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c
visitor = tls.createVisitor(name, ca);
EXPECT_TRUE(visitor.get());
EXPECT_TRUE( visitor->visit(0, 3) );
- for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); }
EXPECT_TRUE( ca._eof );
EXPECT_TRUE( ! ca.hasSerial(0) );
EXPECT_TRUE( ca.hasSerial(1) );
@@ -469,7 +469,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c
visitor = tls.createVisitor(name, ca);
ASSERT_TRUE(visitor.get());
EXPECT_TRUE( visitor->visit(2, 3) );
- for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); }
EXPECT_TRUE( ca._eof );
EXPECT_TRUE( ! ca.hasSerial(0) );
EXPECT_TRUE( !ca.hasSerial(1) );
@@ -575,7 +575,7 @@ assertVisitStats(TransLogClient &tls, const vespalib::string &domain,
ASSERT_TRUE(visitor.get());
ASSERT_TRUE( visitor->visit(visitStart, visitEnd) );
for (size_t i(0); ! ca._eof && (i < 60000); i++ ) {
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
ASSERT_TRUE(ca._eof);
EXPECT_EQUAL(expFirstSerial, ca._firstSerial);
@@ -623,7 +623,7 @@ void Test::testMany()
TransLogClient::Visitor::UP visitor = tls.createVisitor("many", ca);
ASSERT_TRUE(visitor.get());
ASSERT_TRUE( visitor->visit(2, TOTAL_NUM_ENTRIES) );
- for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); }
ASSERT_TRUE( ca._eof );
EXPECT_EQUAL(ca._count, TOTAL_NUM_ENTRIES);
EXPECT_EQUAL(ca._value, TOTAL_NUM_ENTRIES);
@@ -644,7 +644,7 @@ void Test::testMany()
TransLogClient::Visitor::UP visitor = tls.createVisitor("many", ca);
ASSERT_TRUE(visitor.get());
ASSERT_TRUE( visitor->visit(2, TOTAL_NUM_ENTRIES) );
- for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); }
+ for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); }
ASSERT_TRUE( ca._eof );
EXPECT_EQUAL(ca._count, TOTAL_NUM_ENTRIES);
EXPECT_EQUAL(ca._value, TOTAL_NUM_ENTRIES);
diff --git a/searchlib/src/tests/transactionlogstress/translogstress.cpp b/searchlib/src/tests/transactionlogstress/translogstress.cpp
index a047c5e1657..2ec193cfe45 100644
--- a/searchlib/src/tests/transactionlogstress/translogstress.cpp
+++ b/searchlib/src/tests/transactionlogstress/translogstress.cpp
@@ -11,8 +11,11 @@
#include <iostream>
#include <stdexcept>
#include <sstream>
+#include <thread>
#include <vespa/log/log.h>
+#include <vespa/vespalib/util/time.h>
+
LOG_SETUP("translogstress");
using document::ByteBuffer;
@@ -267,7 +270,7 @@ FeederThread::doRun()
int64_t milliSecsUsed = _timer.elapsed().ms();
if (milliSecsUsed < 1000) {
//LOG(info, "FeederThread: sleep %u ms", 1000 - milliSecsUsed);
- FastOS_Thread::Sleep(1000 - milliSecsUsed);
+ std::this_thread::sleep_for(std::chrono::milliseconds(1000 - milliSecsUsed));
} else {
LOG(info, "FeederThread: max throughput");
}
@@ -457,7 +460,7 @@ private:
EntryGenerator _generator;
std::vector<std::shared_ptr<VisitorAgent> > _visitors;
std::vector<std::shared_ptr<VisitorAgent> > _rndVisitors;
- uint64_t _visitorInterval; // in milliseconds
+ vespalib::duration _visitorInterval; // in milliseconds
int64_t _pruneInterval; // in milliseconds
fastos::StopWatch _pruneTimer;
SerialNum _begin;
@@ -481,14 +484,14 @@ ControllerThread::ControllerThread(const std::string & tlsSpec, const std::strin
const EntryGenerator & generator, uint32_t numVisitors,
uint64_t visitorInterval, uint64_t pruneInterval)
: _tlsSpec(tlsSpec), _domain(domain), _client(tlsSpec.c_str()), _session(),
- _generator(generator), _visitors(), _rndVisitors(), _visitorInterval(visitorInterval),
+ _generator(generator), _visitors(), _rndVisitors(), _visitorInterval(std::chrono::milliseconds(visitorInterval)),
_pruneInterval(pruneInterval), _pruneTimer(), _begin(0), _end(0), _count(0)
{
for (uint32_t i = 0; i < numVisitors; ++i) {
_visitors.push_back(std::make_shared<VisitorAgent>(tlsSpec, domain, generator, i, true));
}
}
-ControllerThread::~ControllerThread() {}
+ControllerThread::~ControllerThread() = default;
void
ControllerThread::getStatus()
@@ -553,7 +556,7 @@ ControllerThread::doRun()
}
_pruneTimer.restart();
}
- FastOS_Thread::Sleep(_visitorInterval);
+ std::this_thread::sleep_for(_visitorInterval);
}
}
@@ -569,7 +572,7 @@ private:
uint64_t domainPartSize;
size_t packetSize;
- uint64_t stressTime;
+ std::chrono::milliseconds stressTime;
uint32_t feedRate;
uint32_t numVisitors;
uint64_t visitorInterval;
@@ -598,7 +601,7 @@ void
TransLogStress::printConfig()
{
std::cout << "######## Config ########" << std::endl;
- std::cout << "stressTime: " << _cfg.stressTime / 1000 << " s" << std::endl;
+ std::cout << "stressTime: " << vespalib::to_s(_cfg.stressTime) << " s" << std::endl;
std::cout << "feedRate: " << _cfg.feedRate << " per/sec" << std::endl;
std::cout << "numVisitors: " << _cfg.numVisitors << std::endl;
std::cout << "visitorInterval: " << _cfg.visitorInterval << " ms" << std::endl;
@@ -628,7 +631,7 @@ TransLogStress::Main()
_cfg.domainPartSize = 8000000; // ~8MB
_cfg.packetSize = 0x10000;
- _cfg.stressTime = 1000 * 60;
+ _cfg.stressTime = std::chrono::milliseconds(1000 * 60);
_cfg.feedRate = 10000;
_cfg.numVisitors = 1;
_cfg.visitorInterval = 1000 * 1;
@@ -639,7 +642,7 @@ TransLogStress::Main()
_cfg.maxStrLen = 80;
_cfg.baseSeed = 100;
- uint64_t sleepTime = 4000;
+ vespalib::duration sleepTime = 4s;
int idx = 1;
char opt;
@@ -654,7 +657,7 @@ TransLogStress::Main()
_cfg.packetSize = atol(arg);
break;
case 't':
- _cfg.stressTime = 1000 * atol(arg);
+ _cfg.stressTime = std::chrono::milliseconds(1000 * atol(arg));
break;
case 'f':
_cfg.feedRate = atoi(arg);
@@ -690,7 +693,7 @@ TransLogStress::Main()
}
printConfig();
- FastOS_Thread::Sleep(sleepTime);
+ std::this_thread::sleep_for(sleepTime);
if (_argc != idx || optError) {
usage();
@@ -721,13 +724,13 @@ TransLogStress::Main()
FeederThread feeder(tlsSpec, domain, generator, _cfg.feedRate, _cfg.packetSize);
threadPool.NewThread(&feeder);
- FastOS_Thread::Sleep(sleepTime);
+ std::this_thread::sleep_for(sleepTime);
ControllerThread controller(tlsSpec, domain, generator, _cfg.numVisitors, _cfg.visitorInterval, _cfg.pruneInterval);
threadPool.NewThread(&controller);
// stop feeder and controller
- FastOS_Thread::Sleep(_cfg.stressTime);
+ std::this_thread::sleep_for(_cfg.stressTime);
printConfig();
LOG(info, "Stop feeder...");
feeder.stop();
@@ -735,7 +738,7 @@ TransLogStress::Main()
std::cout << "<feeder>" << std::endl;
std::cout << " <from>" << feeder.getRange().from() << "</from>" << std::endl;
std::cout << " <to>" << feeder.getRange().to() << "</to>" << std::endl;
- std::cout << " <rate>" << 1000 * (feeder.getRange().to() - feeder.getRange().from()) / (sleepTime + _cfg.stressTime)
+ std::cout << " <rate>" << 1000 * (feeder.getRange().to() - feeder.getRange().from()) / vespalib::count_ms(sleepTime + _cfg.stressTime)
<< "</rate>" << std::endl;
std::cout << "</feeder>" << std::endl;
@@ -743,7 +746,7 @@ TransLogStress::Main()
controller.stop();
controller.join();
- FastOS_Thread::Sleep(sleepTime);
+ std::this_thread::sleep_for(sleepTime);
std::vector<std::shared_ptr<VisitorAgent> > & visitors = controller.getVisitors();
for (size_t i = 0; i < visitors.size(); ++i) {
std::cout << "<visitor id='" << i << "'>" << std::endl;
diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp
index 6d11ab1f5eb..37903bc21f5 100644
--- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp
+++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp
@@ -2,12 +2,14 @@
#include "translogserver.h"
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/io/fileutil.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/vespalib/util/exceptions.h>
#include <vespa/fnet/frt/supervisor.h>
#include <vespa/fnet/frt/rpcrequest.h>
#include <vespa/fnet/task.h>
#include <vespa/fnet/transport.h>
#include <fstream>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".transactionlog.server");
@@ -125,7 +127,7 @@ TransLogServer::TransLogServer(const vespalib::string &name, int listenPort, con
listenOk = true;
} else {
LOG(warning, "Failed listening at port %s trying for %d seconds more.", listenSpec, i);
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
}
if ( ! listenOk ) {
diff --git a/slobrok/src/tests/configure/configure.cpp b/slobrok/src/tests/configure/configure.cpp
index bf41b77ab05..fa509c17d0c 100644
--- a/slobrok/src/tests/configure/configure.cpp
+++ b/slobrok/src/tests/configure/configure.cpp
@@ -85,7 +85,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect)
if (actual == expect) {
return true;
}
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
SpecList actual(api.lookup(pattern));
std::cerr << "Actual: " << actual.strVal() << std::endl;
@@ -176,7 +176,7 @@ Test::Main()
srv2Builder.slobrok[0].connectionspec = createSpec(18525);
cfgCtx->reload();
- FastOS_Thread::Sleep(6000); // reconfiguration time
+ std::this_thread::sleep_for(6s); // reconfiguration time
reg1.registerName("A");
reg2.registerName("B");
diff --git a/slobrok/src/tests/mirrorapi/mirrorapi.cpp b/slobrok/src/tests/mirrorapi/mirrorapi.cpp
index b25e338533c..53e194fad2d 100644
--- a/slobrok/src/tests/mirrorapi/mirrorapi.cpp
+++ b/slobrok/src/tests/mirrorapi/mirrorapi.cpp
@@ -112,7 +112,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect)
if (actual == expect) {
return true;
}
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
return false;
}
@@ -124,7 +124,7 @@ Test::Main()
TEST_INIT("mirrorapi_test");
SlobrokServer mock(18501);
- FastOS_Thread::Sleep(300);
+ std::this_thread::sleep_for(300ms);
Server a("A/x/w", 18502, "tcp/localhost:18501");
Server b("B/x", 18503, "tcp/localhost:18501");
@@ -143,7 +143,7 @@ Test::Main()
MirrorAPI mirror(supervisor, config::ConfigUri::createFromInstance(specBuilder));
EXPECT_TRUE(!mirror.ready());
transport.Start(&threadPool);
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
a.reg();
EXPECT_TRUE(compare(mirror, "A/x/w", SpecList().add("A/x/w", "tcp/localhost:18502")));
diff --git a/slobrok/src/tests/registerapi/registerapi.cpp b/slobrok/src/tests/registerapi/registerapi.cpp
index ac7e662c6f2..92f08ee41cb 100644
--- a/slobrok/src/tests/registerapi/registerapi.cpp
+++ b/slobrok/src/tests/registerapi/registerapi.cpp
@@ -64,7 +64,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect)
if (actual == expect) {
return true;
}
- FastOS_Thread::Sleep(100);
+ std::this_thread::sleep_for(100ms);
}
return false;
}
@@ -75,7 +75,7 @@ Test::Main()
TEST_INIT("registerapi_test");
SlobrokServer mock(18548);
- FastOS_Thread::Sleep(300);
+ std::this_thread::sleep_for(300ms);
cloud::config::SlobroksConfigBuilder slobrokSpecs;
cloud::config::SlobroksConfig::Slobrok sb;
@@ -97,7 +97,7 @@ Test::Main()
EXPECT_TRUE(compare(mirror, "*/*/*", SpecList().add("A/x/w", myspec.c_str())));
for (int i = 0; i < 30; i++) {
- if (reg.busy()) FastOS_Thread::Sleep(100);
+ if (reg.busy()) std::this_thread::sleep_for(100ms);
}
EXPECT_TRUE(!reg.busy());
diff --git a/slobrok/src/tests/standalone/standalone.cpp b/slobrok/src/tests/standalone/standalone.cpp
index 9d3fd694ee1..65553c57530 100644
--- a/slobrok/src/tests/standalone/standalone.cpp
+++ b/slobrok/src/tests/standalone/standalone.cpp
@@ -132,7 +132,7 @@ TEST("standalone") {
break;
}
fprintf(stderr, "ping failed [retry %d]\n", retry);
- FastOS_Thread::Sleep(200);
+ std::this_thread::sleep_for(200ms);
sb->SubRef();
sb = orb.GetTarget(18541);
}
@@ -268,7 +268,7 @@ TEST("standalone") {
}
}
- FastOS_Thread::Sleep(2000);
+ std::this_thread::sleep_for(2s);
// lookup 'B' should give ''
req = orb.AllocRPCRequest(req);
diff --git a/staging_vespalib/src/tests/clock/clock_test.cpp b/staging_vespalib/src/tests/clock/clock_test.cpp
index bf7e3773055..b5650244a45 100644
--- a/staging_vespalib/src/tests/clock/clock_test.cpp
+++ b/staging_vespalib/src/tests/clock/clock_test.cpp
@@ -2,36 +2,27 @@
#include <vespa/vespalib/testkit/testapp.h>
#include <vespa/vespalib/util/clock.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/fastos/thread.h>
using vespalib::Clock;
using fastos::TimeStamp;
-class Test : public vespalib::TestApp
-{
-public:
- int Main() override;
-};
-
-int
-Test::Main()
-{
- TEST_INIT("clock_test");
+TEST("Test that clock is ticking forward") {
Clock clock(0.050);
FastOS_ThreadPool pool(0x10000);
ASSERT_TRUE(pool.NewThread(clock.getRunnable(), nullptr) != nullptr);
fastos::SteadyTimeStamp start = clock.getTimeNS();
- FastOS_Thread::Sleep(5000);
+ std::this_thread::sleep_for(5s);
fastos::SteadyTimeStamp stop = clock.getTimeNS();
EXPECT_TRUE(stop > start);
- FastOS_Thread::Sleep(6000);
+ std::this_thread::sleep_for(6s);
clock.stop();
fastos::SteadyTimeStamp stop2 = clock.getTimeNS();
EXPECT_TRUE(stop2 > stop);
EXPECT_TRUE((stop2 - stop)/TimeStamp::MICRO > 1000);
- TEST_DONE();
}
-TEST_APPHOOK(Test)
+TEST_MAIN() { TEST_RUN_ALL(); } \ No newline at end of file
diff --git a/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp b/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp
index e6f7bd21750..fbaa5581173 100644
--- a/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp
+++ b/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp
@@ -13,20 +13,20 @@ Test::Main()
{
TEST_INIT("shutdownguard_test");
{
- ShutdownGuard farFuture(123456789);
- FastOS_Thread::Sleep(20);
+ ShutdownGuard farFuture(1000000s);
+ std::this_thread::sleep_for(20ms);
}
EXPECT_TRUE(true);
pid_t child = fork();
if (child == 0) {
- ShutdownGuard soon(30);
+ ShutdownGuard soon(30ms);
for (int i = 0; i < 1000; ++i) {
- FastOS_Thread::Sleep(20);
+ std::this_thread::sleep_for(20ms);
}
exit(0);
}
for (int i = 0; i < 1000; ++i) {
- FastOS_Thread::Sleep(20);
+ std::this_thread::sleep_for(20ms);
int stat = 0;
if (waitpid(child, &stat, WNOHANG) == child) {
EXPECT_TRUE(WIFEXITED(stat));
diff --git a/staging_vespalib/src/tests/timer/timer_test.cpp b/staging_vespalib/src/tests/timer/timer_test.cpp
index 309ee873b44..5472ad6e23f 100644
--- a/staging_vespalib/src/tests/timer/timer_test.cpp
+++ b/staging_vespalib/src/tests/timer/timer_test.cpp
@@ -1,8 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vespalib/testkit/testapp.h>
-#include <vespa/vespalib/util/timer.h>
-#include <vespa/vespalib/util/executor.h>
+#include <vespa/vespalib/util/scheduledexecutor.h>
using namespace vespalib;
using vespalib::Executor;
@@ -37,7 +36,7 @@ void Test::testScheduling()
{
vespalib::CountDownLatch latch1(3);
vespalib::CountDownLatch latch2(2);
- Timer timer;
+ ScheduledExecutor timer;
timer.scheduleAtFixedRate(Task::UP(new TestTask(latch1)), 0.1, 0.2);
timer.scheduleAtFixedRate(Task::UP(new TestTask(latch2)), 0.5, 0.5);
EXPECT_TRUE(latch1.await(60000));
@@ -47,7 +46,7 @@ void Test::testScheduling()
void Test::testReset()
{
vespalib::CountDownLatch latch1(2);
- Timer timer;
+ ScheduledExecutor timer;
timer.scheduleAtFixedRate(Task::UP(new TestTask(latch1)), 2.0, 3.0);
timer.reset();
EXPECT_TRUE(!latch1.await(3000));
diff --git a/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt b/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt
index 20d47c90453..71364a813f6 100644
--- a/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt
+++ b/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt
@@ -16,7 +16,7 @@ vespa_add_library(staging_vespalib_vespalib_util OBJECT
document_runnable.cpp
rusage.cpp
shutdownguard.cpp
- timer.cpp
+ scheduledexecutor.cpp
xmlserializable.cpp
xmlstream.cpp
DEPENDS
diff --git a/staging_vespalib/src/vespa/vespalib/util/timer.cpp b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp
index a7acbe67965..61f9666114c 100644
--- a/staging_vespalib/src/vespa/vespalib/util/timer.cpp
+++ b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp
@@ -1,5 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include "timer.h"
+#include "scheduledexecutor.h"
#include <vespa/fnet/scheduler.h>
#include <vespa/fnet/task.h>
#include <vespa/fnet/transport.h>
@@ -34,7 +34,7 @@ public:
}
};
-Timer::Timer()
+ScheduledExecutor::ScheduledExecutor()
: _threadPool(128 * 1024),
_transport(new FNET_Transport()),
_lock(),
@@ -43,7 +43,7 @@ Timer::Timer()
_transport->Start(&_threadPool);
}
-Timer::~Timer()
+ScheduledExecutor::~ScheduledExecutor()
{
vespalib::LockGuard guard(_lock);
_transport->ShutDown(true);
@@ -53,7 +53,7 @@ Timer::~Timer()
void
-Timer::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, double interval)
+ScheduledExecutor::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, double interval)
{
vespalib::LockGuard guard(_lock);
TimerTaskPtr tTask(new TimerTask(_transport->GetScheduler(), std::move(task), interval));
@@ -62,7 +62,7 @@ Timer::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, doub
}
void
-Timer::reset()
+ScheduledExecutor::reset()
{
vespalib::LockGuard guard(_lock);
_transport->ShutDown(true);
diff --git a/staging_vespalib/src/vespa/vespalib/util/timer.h b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h
index 0f7cde67ee4..d7e56494828 100644
--- a/staging_vespalib/src/vespa/vespalib/util/timer.h
+++ b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h
@@ -13,11 +13,11 @@ namespace vespalib {
class TimerTask;
/**
- * Timer is a class capable of running Tasks at a regular
+ * ScheduledExecutor is a class capable of running Tasks at a regular
* interval. The timer can be reset to clear all tasks currently being
* scheduled.
*/
-class Timer
+class ScheduledExecutor
{
private:
typedef std::unique_ptr<TimerTask> TimerTaskPtr;
@@ -31,13 +31,13 @@ public:
/**
* Create a new timer, capable of scheduling tasks at fixed intervals.
*/
- Timer();
+ ScheduledExecutor();
/**
* Destroys this timer, finishing the current task executing and then
* finishing.
*/
- ~Timer();
+ ~ScheduledExecutor();
/**
* Schedule new task to be executed at specified intervals.
diff --git a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp
index 645ffea380d..99857107860 100644
--- a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp
+++ b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp
@@ -1,7 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "shutdownguard.h"
#include <unistd.h>
-#include <sys/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".vespalib.shutdownguard");
@@ -10,36 +10,30 @@ namespace vespalib {
namespace {
enum { STACK_SIZE = (1u << 16) };
-
-uint64_t getTimeInMillis() {
- struct timeval mytime;
- gettimeofday(&mytime, 0);
- uint64_t mult = 1000;
- return (mytime.tv_sec * mult) + (mytime.tv_usec / mult);
-}
}
void ShutdownGuard::Run(FastOS_ThreadInterface *, void *)
{
- while (_dieAtTime > getTimeInMillis()) {
- FastOS_Thread::Sleep(5);
+ while (_dieAtTime > steady_clock::now() && ! GetThread()->GetBreakFlag()) {
+ std::this_thread::sleep_for(5ms);
}
- if (_dieAtTime != 0) {
+ if (_dieAtTime <= steady_clock::now()) {
LOG(warning, "ShutdownGuard is now forcing an exit of the process.");
_exit(EXIT_FAILURE);
}
}
-ShutdownGuard::ShutdownGuard(uint64_t millis) :
+ShutdownGuard::ShutdownGuard(duration millis) :
FastOS_Runnable(),
_pool(STACK_SIZE, 1),
- _dieAtTime(getTimeInMillis() + millis)
+ _dieAtTime(steady_clock::now() + millis)
{
_pool.NewThread(this);
}
ShutdownGuard::~ShutdownGuard()
{
- _dieAtTime = 0;
+ GetThread()->SetBreakFlag();
+ GetThread()->Join();
_pool.Close();
}
diff --git a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h
index 5a9aad5d4d4..9de9df8bbad 100644
--- a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h
+++ b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h
@@ -1,8 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
+#include <vespa/vespalib/util/time.h>
#include <vespa/fastos/thread.h>
-#include <cstdint>
namespace vespalib {
@@ -16,7 +16,7 @@ namespace vespalib {
class ShutdownGuard : public FastOS_Runnable
{
FastOS_ThreadPool _pool;
- volatile uint64_t _dieAtTime;
+ steady_time _dieAtTime;
void Run(FastOS_ThreadInterface *, void *) override;
@@ -25,7 +25,7 @@ public:
* Construct a shutdown guard with a given lifetime.
* @arg millis the number of milliseconds before process automatically exits
**/
- ShutdownGuard(uint64_t millis);
+ ShutdownGuard(duration millis);
/**
* Destructor that dismisses the guard and collects the shutdown thread.
diff --git a/storage/src/tests/common/metricstest.cpp b/storage/src/tests/common/metricstest.cpp
index d1421845b81..d698cbb5e05 100644
--- a/storage/src/tests/common/metricstest.cpp
+++ b/storage/src/tests/common/metricstest.cpp
@@ -1,6 +1,5 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include <vespa/document/fieldvalue/document.h>
#include <vespa/storageapi/message/persistence.h>
#include <vespa/storageframework/defaultimplementation/clock/fakeclock.h>
#include <vespa/storage/bucketdb/bucketmanager.h>
@@ -14,6 +13,7 @@
#include <vespa/config/common/exceptions.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/time.h>
#include <gmock/gmock.h>
#include <thread>
@@ -202,7 +202,7 @@ void MetricsTest::createFakeLoad()
while (uint64_t(_metricManager->getLastProcessedTime())
< _clock->getTimeInSeconds().getTime())
{
- FastOS_Thread::Sleep(5);
+ std::this_thread::sleep_for(5ms);
_metricManager->timeChangedNotification();
}
}
@@ -257,7 +257,7 @@ TEST_F(MetricsTest, snapshot_presenting) {
uint64_t(_metricManager->getLastProcessedTime())
< _clock->getTimeInSeconds().getTime())
{
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
LOG(debug, "5 minute snapshot should have been taken. Adding put count");
diff --git a/storage/src/tests/common/teststorageapp.cpp b/storage/src/tests/common/teststorageapp.cpp
index dd89082d3e7..082af954871 100644
--- a/storage/src/tests/common/teststorageapp.cpp
+++ b/storage/src/tests/common/teststorageapp.cpp
@@ -7,10 +7,11 @@
#include <vespa/config-load-type.h>
#include <vespa/config-fleetcontroller.h>
#include <vespa/persistence/dummyimpl/dummypersistence.h>
-#include <vespa/vespalib/io/fileutil.h>
#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/config/config.h>
#include <vespa/config/helper/configgetter.hpp>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".test.servicelayerapp");
@@ -111,7 +112,7 @@ TestStorageApp::waitUntilInitialized(
framework::MilliSecTime endTime(
clock.getTimeInMillis() + timeout.getMillis());
while (!isInitialized()) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
framework::MilliSecTime currentTime(clock.getTimeInMillis());
if (currentTime > endTime) {
std::ostringstream error;
diff --git a/storage/src/tests/distributor/distributortest.cpp b/storage/src/tests/distributor/distributortest.cpp
index 8fa8a6bcede..d456401876e 100644
--- a/storage/src/tests/distributor/distributortest.cpp
+++ b/storage/src/tests/distributor/distributortest.cpp
@@ -11,9 +11,10 @@
#include <vespa/document/test/make_document_bucket.h>
#include <vespa/document/test/make_bucket_space.h>
#include <vespa/storage/config/config-stor-distributormanager.h>
-#include <tests/common/dummystoragelink.h>
#include <vespa/storage/distributor/distributor.h>
#include <vespa/vespalib/text/stringtokenizer.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/vespalib/gtest/gtest.h>
#include <gmock/gmock.h>
@@ -383,7 +384,7 @@ TEST_F(DistributorTest, tick_processes_status_requests) {
thread, "statustest", tickWaitMs, tickMaxProcessTime, ticksBeforeWait));
while (true) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
framework::TickingLockGuard guard(
distributor_thread_pool().freezeCriticalTicks());
if (!distributor_status_todos().empty()) {
diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
index b46c0236150..64306fa7c24 100644
--- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
+++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp
@@ -20,8 +20,10 @@
#include <vespa/persistence/spi/test.h>
#include <vespa/config/common/exceptions.h>
#include <vespa/fastos/file.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <atomic>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".filestormanagertest");
@@ -556,7 +558,7 @@ public:
auto cmd = std::make_shared<api::PutCommand>(makeDocumentBucket(bucket), _doc, 100);
_handler.schedule(cmd, 0);
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
_threadDone = true;
@@ -589,13 +591,13 @@ public:
if (msg.second.get()) {
uint32_t originalConfig = _config.load();
_fetchedCount++;
- FastOS_Thread::Sleep(5);
+ std::this_thread::sleep_for(5ms);
if (_config.load() != originalConfig) {
_failed = true;
}
} else {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
@@ -634,7 +636,7 @@ TEST_F(FileStorManagerTest, handler_paused_multi_thread) {
thread.start(pool);
for (uint32_t i = 0; i < 50; ++i) {
- FastOS_Thread::Sleep(2);
+ std::this_thread::sleep_for(2ms);
ResumeGuard guard = filestorHandler.pause();
thread._config.fetch_add(1);
uint32_t count = thread._fetchedCount;
@@ -646,7 +648,7 @@ TEST_F(FileStorManagerTest, handler_paused_multi_thread) {
ASSERT_FALSE(thread._failed);
while (!pushthread._threadDone || !thread._threadDone) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
@@ -869,7 +871,7 @@ TEST_F(FileStorManagerTest, handler_timeout) {
filestorHandler.schedule(cmd, 0);
}
- FastOS_Thread::Sleep(51);
+ std::this_thread::sleep_for(51ms);
for (;;) {
auto lock = filestorHandler.getNextMessage(0, stripeId);
if (lock.first.get()) {
@@ -944,7 +946,7 @@ TEST_F(FileStorManagerTest, priority) {
// Wait until everything is done.
int count = 0;
while (documents.size() != top.getNumReplies() && count < 10000) {
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
count++;
}
ASSERT_LT(count, 10000);
diff --git a/storage/src/tests/persistence/filestorage/operationabortingtest.cpp b/storage/src/tests/persistence/filestorage/operationabortingtest.cpp
index 0d43f8a9020..ba344971c3b 100644
--- a/storage/src/tests/persistence/filestorage/operationabortingtest.cpp
+++ b/storage/src/tests/persistence/filestorage/operationabortingtest.cpp
@@ -9,6 +9,8 @@
#include <vespa/vespalib/util/thread.h>
#include <vespa/vespalib/stllike/hash_set_insert.hpp>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".operationabortingtest");
@@ -53,7 +55,7 @@ public:
(void) context;
_queueBarrier.await();
// message abort stage with active opertion in disk queue
- FastOS_Thread::Sleep(75);
+ std::this_thread::sleep_for(75ms);
_completionBarrier.await();
// test finished
return spi::Result();
diff --git a/storage/src/tests/storageserver/bucketintegritycheckertest.cpp b/storage/src/tests/storageserver/bucketintegritycheckertest.cpp
index ae466f04734..8a68adf226c 100644
--- a/storage/src/tests/storageserver/bucketintegritycheckertest.cpp
+++ b/storage/src/tests/storageserver/bucketintegritycheckertest.cpp
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/storage/bucketdb/bucketmanager.h>
-#include <vespa/storage/persistence/filestorage/filestormanager.h>
#include <vespa/storage/storageserver/bucketintegritychecker.h>
#include <vespa/storageapi/message/persistence.h>
#include <tests/common/testhelper.h>
@@ -9,6 +8,8 @@
#include <vespa/vespalib/io/fileutil.h>
#include <tests/common/teststorageapp.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
using namespace ::testing;
@@ -175,13 +176,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) {
checker.getSchedulingOptions()._minCycleTime = framework::SecondTime(60 * 60);
topLink.open();
// Waiting for system to be initialized
- FastOS_Thread::Sleep(10); // Give next message chance to come
+ std::this_thread::sleep_for(10ms); // Give next message chance to come
ASSERT_COMMAND_COUNT(0, *dummyLink);
topLink.doneInit();
checker.bump();
// Should have started new run with 2 pending per disk
dummyLink->waitForMessages(4, _timeout);
- FastOS_Thread::Sleep(10); // Give 5th message chance to come
+ std::this_thread::sleep_for(10ms); // Give 5th message chance to come
ASSERT_COMMAND_COUNT(4, *dummyLink);
auto* cmd1 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(0).get());
EXPECT_EQ(230, cmd1->getPriority());
@@ -200,13 +201,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) {
// Answering a message on disk with no more buckets does not trigger new
auto reply1 = std::make_shared<RepairBucketReply>(*cmd3);
ASSERT_TRUE(checker.onUp(reply1));
- FastOS_Thread::Sleep(10); // Give next message chance to come
+ std::this_thread::sleep_for(10ms); // Give next message chance to come
ASSERT_COMMAND_COUNT(4, *dummyLink);
// Answering a message on disk with more buckets trigger new repair
auto reply2 = std::make_shared<RepairBucketReply>(*cmd2);
ASSERT_TRUE(checker.onUp(reply2));
dummyLink->waitForMessages(5, _timeout);
- FastOS_Thread::Sleep(10); // Give 6th message chance to come
+ std::this_thread::sleep_for(10ms); // Give 6th message chance to come
ASSERT_COMMAND_COUNT(5, *dummyLink);
auto* cmd5 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(4).get());
ASSERT_TRUE(cmd5);
@@ -217,7 +218,7 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) {
reply3->setResult(api::ReturnCode(api::ReturnCode::IGNORED));
ASSERT_TRUE(checker.onUp(reply3));
dummyLink->waitForMessages(6, _timeout);
- FastOS_Thread::Sleep(10); // Give 7th message chance to come
+ std::this_thread::sleep_for(10ms); // Give 7th message chance to come
ASSERT_COMMAND_COUNT(6, *dummyLink);
auto* cmd6 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(5).get());
ASSERT_TRUE(cmd6);
@@ -227,7 +228,7 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) {
auto reply4 = std::make_shared<RepairBucketReply>(*cmd4);
reply3->setResult(api::ReturnCode(api::ReturnCode::BUCKET_NOT_FOUND));
ASSERT_TRUE(checker.onUp(reply4));
- FastOS_Thread::Sleep(10); // Give 7th message chance to come
+ std::this_thread::sleep_for(10ms); // Give 7th message chance to come
ASSERT_COMMAND_COUNT(6, *dummyLink);
// Send a repair reply that actually have corrected the bucket.
@@ -247,13 +248,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) {
EXPECT_EQ(document::BucketId(16, 0x234), cmd7->getBucketId());
auto reply7 = std::make_shared<RepairBucketReply>(*cmd7);
ASSERT_TRUE(checker.onUp(reply7));
- FastOS_Thread::Sleep(10); // Give 8th message chance to come
+ std::this_thread::sleep_for(10ms); // Give 8th message chance to come
ASSERT_COMMAND_COUNT(7, *dummyLink);
// Still not time for next iteration
dummyLink->reset();
_node->getClock().setAbsoluteTimeInSeconds(getDate("week1 sun 00:59:59"));
- FastOS_Thread::Sleep(10); // Give new run chance to start
+ std::this_thread::sleep_for(10ms); // Give new run chance to start
ASSERT_COMMAND_COUNT(0, *dummyLink);
// Pass time until next cycle should start
diff --git a/storage/src/tests/storageserver/communicationmanagertest.cpp b/storage/src/tests/storageserver/communicationmanagertest.cpp
index caee6e6ab91..6657a9f1600 100644
--- a/storage/src/tests/storageserver/communicationmanagertest.cpp
+++ b/storage/src/tests/storageserver/communicationmanagertest.cpp
@@ -15,6 +15,8 @@
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/documentapi/messagebus/messages/removedocumentmessage.h>
#include <vespa/documentapi/messagebus/messages/getdocumentreply.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/vespalib/gtest/gtest.h>
using document::test::makeDocumentBucket;
@@ -65,7 +67,7 @@ TEST_F(CommunicationManagerTest, simple) {
distributor.open();
storage.open();
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
// Send a message through from distributor to storage
auto cmd = std::make_shared<api::GetCommand>(
diff --git a/storage/src/tests/storageserver/statereportertest.cpp b/storage/src/tests/storageserver/statereportertest.cpp
index c84f9311c52..dc8094275d1 100644
--- a/storage/src/tests/storageserver/statereportertest.cpp
+++ b/storage/src/tests/storageserver/statereportertest.cpp
@@ -11,6 +11,8 @@
#include <vespa/config/common/exceptions.h>
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".test.statereporter");
@@ -233,7 +235,7 @@ TEST_F(StateReporterTest, report_metrics) {
uint64_t(_metricManager->getLastProcessedTime())
< _clock->getTimeInSeconds().getTime())
{
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
LOG(debug, "5 minute snapshot should have been taken. Adding put count");
diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp
index 45bd9c64fac..c5d7880d966 100644
--- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp
+++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp
@@ -6,9 +6,11 @@
#include <vespa/storageapi/message/state.h>
#include <vespa/vespalib/util/exceptions.h>
#include <vespa/vespalib/util/host_name.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/fnet/frt/supervisor.h>
#include <vespa/fnet/transport.h>
#include <sstream>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP(".rpc.listener");
@@ -50,7 +52,7 @@ FNetListener::registerHandle(vespalib::stringref handle) {
_slobrokRegister.registerName(handle);
while (_slobrokRegister.busy()) {
LOG(debug, "Waiting to register in slobrok");
- FastOS_Thread::Sleep(50);
+ std::this_thread::sleep_for(50ms);
}
_handle = handle;
}
diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.h b/storage/src/vespa/storage/storageserver/fnetlistener.h
index 205a5af4586..e37727beb44 100644
--- a/storage/src/vespa/storage/storageserver/fnetlistener.h
+++ b/storage/src/vespa/storage/storageserver/fnetlistener.h
@@ -5,6 +5,7 @@
#include <atomic>
class FNET_Transport;
+class FastOS_ThreadPool;
namespace storage {
diff --git a/storage/src/vespa/storage/storageserver/storagenode.cpp b/storage/src/vespa/storage/storageserver/storagenode.cpp
index c5a0a031067..e962ee4b1b6 100644
--- a/storage/src/vespa/storage/storageserver/storagenode.cpp
+++ b/storage/src/vespa/storage/storageserver/storagenode.cpp
@@ -14,6 +14,7 @@
#include <vespa/storage/common/statusmetricconsumer.h>
#include <vespa/vespalib/io/fileutil.h>
#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/metrics/metricmanager.h>
#include <fcntl.h>
@@ -568,7 +569,7 @@ StorageNode::waitUntilInitialized(uint32_t timeout) {
lib::NodeState nodeState(*_component->getStateUpdater().getReportedNodeState());
if (nodeState.getState() == lib::State::UP) break;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
if (clock.getTimeInMillis() >= endTime) {
std::ostringstream ost;
ost << "Storage server not initialized after waiting timeout of "
diff --git a/storage/src/vespa/storage/tools/storage-cmd.cpp b/storage/src/vespa/storage/tools/storage-cmd.cpp
index daaa890873f..8c0fcc83330 100644
--- a/storage/src/vespa/storage/tools/storage-cmd.cpp
+++ b/storage/src/vespa/storage/tools/storage-cmd.cpp
@@ -3,6 +3,8 @@
#include <vespa/slobrok/sbmirror.h>
#include <vespa/fastos/app.h>
#include <vespa/vespalib/locale/c.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
#include <vespa/log/log.h>
LOG_SETUP("vespa-storage-cmd");
@@ -61,7 +63,7 @@ public:
slobrok::api::MirrorAPI mirror(supervisor.supervisor(), sbcfg);
while (!mirror.ready()) {
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
slobrok::api::MirrorAPI::SpecList list = mirror.lookup(_argv[1]);
diff --git a/storageframework/src/tests/thread/tickingthreadtest.cpp b/storageframework/src/tests/thread/tickingthreadtest.cpp
index 97ae08eef3d..c42a9c17283 100644
--- a/storageframework/src/tests/thread/tickingthreadtest.cpp
+++ b/storageframework/src/tests/thread/tickingthreadtest.cpp
@@ -6,6 +6,8 @@
#include <vespa/vespalib/gtest/gtest.h>
#include <vespa/vespalib/util/exception.h>
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
namespace storage::framework::defaultimplementation {
@@ -35,7 +37,7 @@ struct MyApp : public TickingThread {
Context& c(_context[index]);
if (_doCritOverlapTest) {
uint32_t oldTick = _critOverlapCounter;
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
_critOverlap |= (_critOverlapCounter != oldTick);
++_critOverlapCounter;
}
@@ -109,7 +111,7 @@ TEST(TickingThreadTest, test_ticks_before_wait_basic)
// and verify time is in right ballpark.
int totalSleepMs = 0;
while (app.getTotalNonCritTicks() < 20) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
totalSleepMs++;
}
EXPECT_GT(totalSleepMs, 10);
@@ -134,7 +136,7 @@ TEST(TickingThreadTest, test_ticks_before_wait_live_update)
// (if live update is broken it will take more than an hour).
int maxAttempts = 120000; // a bit more than 120 secs
while (app.getTotalNonCritTicks() < ticksBeforeWaitMs && maxAttempts-->0) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
EXPECT_GT(maxAttempts, 0);
@@ -158,7 +160,7 @@ TEST(TickingThreadTest, test_verbose_stopping)
MyApp app(threadCount, true);
app.start(testReg.getThreadPoolImpl());
while (app.getMinCritTick() < 5) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
app._threadPool->stop();
}
@@ -171,7 +173,7 @@ TEST(TickingThreadTest, test_stop_on_deletion)
MyApp app(threadCount, true);
app.start(testReg.getThreadPoolImpl());
while (app.getMinCritTick() < 5) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
@@ -185,7 +187,7 @@ TEST(TickingThreadTest, test_lock_all_ticks)
app1.start(testReg.getThreadPoolImpl());
app2.start(testReg.getThreadPoolImpl());
while (std::min(app1.getMinCritTick(), app2.getMinCritTick()) < 5) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
uint64_t ticks1, ticks2;
{
@@ -194,12 +196,12 @@ TEST(TickingThreadTest, test_lock_all_ticks)
ticks2 = app2.getTotalTicks();
while (app2.getMinCritTick() < 2 * ticks2 / threadCount) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
EXPECT_EQ(ticks1, app1.getTotalTicks());
}
while (app1.getMinCritTick() < 2 * ticks1 / threadCount) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
@@ -213,7 +215,7 @@ TEST(TickingThreadTest, test_lock_critical_ticks)
MyApp app(threadCount, true);
app.start(testReg.getThreadPoolImpl());
while (!app.hasCritOverlap()) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
++app._critOverlapCounter;
++iterationsBeforeOverlap;
}
@@ -222,7 +224,7 @@ TEST(TickingThreadTest, test_lock_critical_ticks)
MyApp app(threadCount, true);
app.start(testReg.getThreadPoolImpl());
for (uint64_t i=0; i<iterationsBeforeOverlap * 10; ++i) {
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
TickingLockGuard guard(app._threadPool->freezeCriticalTicks());
for (int j=0; j<threadCount; ++j) {
++app._context[j]._critTickCount;
@@ -318,13 +320,13 @@ TEST(TickingThreadTest, test_broadcast)
BroadcastApp app;
app.start(testReg.getThreadPoolImpl());
app.doTask("foo");
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
app.doTask("bar");
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
app.doTask("baz");
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
app.doTask("hmm");
- FastOS_Thread::Sleep(1);
+ std::this_thread::sleep_for(1ms);
}
}
diff --git a/storageserver/src/apps/storaged/storage.cpp b/storageserver/src/apps/storaged/storage.cpp
index 0748cc3cb1e..5996951e65f 100644
--- a/storageserver/src/apps/storaged/storage.cpp
+++ b/storageserver/src/apps/storaged/storage.cpp
@@ -198,7 +198,7 @@ int StorageApp::Main()
LOG(debug, "Server was attempted stopped, shutting down");
// Create guard that will forcifully kill storage if destruction takes longer
// time than given timeout.
- vespalib::ShutdownGuard shutdownGuard(_maxShutdownTime);
+ vespalib::ShutdownGuard shutdownGuard(std::chrono::milliseconds(_maxShutdownTime));
LOG(debug, "Attempting proper shutdown");
_process.reset();
LOG(debug, "Completed controlled shutdown.");
diff --git a/vdslib/src/tests/thread/taskschedulertest.cpp b/vdslib/src/tests/thread/taskschedulertest.cpp
index 540de722137..1925625172c 100644
--- a/vdslib/src/tests/thread/taskschedulertest.cpp
+++ b/vdslib/src/tests/thread/taskschedulertest.cpp
@@ -2,6 +2,8 @@
#include <vespa/vdslib/thread/taskscheduler.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/time.h>
+#include <thread>
namespace vdslib {
@@ -141,13 +143,13 @@ TEST(TaskSchedulerTest, test_simple)
task->registerCallsWithName("", calls);
scheduler.addAbsolute(TestTask::UP(task), 50);
watch.increment(49); // Not yet time to run
- FastOS_Thread::Sleep(5);
+ std::this_thread::sleep_for(5ms);
// Check that it has not run yet..
EXPECT_EQ(counter, scheduler.getTaskCounter());
watch.increment(10); // Now time is enough for it to run
scheduler.waitForTaskCounterOfAtLeast(counter + 1);
watch.increment(10);
- FastOS_Thread::Sleep(5);
+ std::this_thread::sleep_for(5ms);
// Check that it has not run yet..
EXPECT_EQ(counter + 1, scheduler.getTaskCounter());
watch.increment(50);
diff --git a/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp b/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp
index 9ec9249b671..1b8f31f0d03 100644
--- a/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp
+++ b/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp
@@ -1,19 +1,20 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/defaults.h>
-#include <vespa/document/util/stringutil.h>
#include <vespa/fnet/frt/frt.h>
#include <vespa/slobrok/sbmirror.h>
#include <vespa/vdslib/distribution/distribution.h>
#include <vespa/vdslib/state/clusterstate.h>
#include <vespa/vespalib/util/programoptions.h>
#include <vespa/vespaclient/clusterlist/clusterlist.h>
+#include <vespa/vespalib/util/time.h>
#include <vespa/vespalib/text/lowercase.h>
#include <vespa/config-stor-distribution.h>
#include <vespa/config/helper/configgetter.hpp>
#include <vespa/fastos/app.h>
#include <sstream>
#include <iostream>
+#include <thread>
#include <sys/time.h>
#include <vespa/log/log.h>
@@ -282,7 +283,7 @@ struct StateApp : public FastOS_Application {
}
warnTime *= 4;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
if (!slobrok->ready()) {
std::cerr << "Slobrok not ready.\n";
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 63cae3904e3..cea58d565c2 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1985,6 +1985,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(double)",
"public double applyAsDouble(double)",
"public java.lang.String toString()"
],
@@ -2075,6 +2076,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(double)",
"public double applyAsDouble(double)",
"public java.lang.String toString()"
],
@@ -2271,6 +2273,7 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(double, double)",
"public double applyAsDouble(double)",
"public java.lang.String toString()"
],
@@ -2437,22 +2440,25 @@
"public static java.util.function.DoubleUnaryOperator atan()",
"public static java.util.function.DoubleUnaryOperator ceil()",
"public static java.util.function.DoubleUnaryOperator cos()",
- "public static java.util.function.DoubleUnaryOperator elu()",
"public static java.util.function.DoubleUnaryOperator exp()",
"public static java.util.function.DoubleUnaryOperator floor()",
"public static java.util.function.DoubleUnaryOperator log()",
"public static java.util.function.DoubleUnaryOperator neg()",
"public static java.util.function.DoubleUnaryOperator reciprocal()",
- "public static java.util.function.DoubleUnaryOperator relu()",
"public static java.util.function.DoubleUnaryOperator rsqrt()",
- "public static java.util.function.DoubleUnaryOperator selu()",
- "public static java.util.function.DoubleUnaryOperator leakyrelu()",
"public static java.util.function.DoubleUnaryOperator sin()",
"public static java.util.function.DoubleUnaryOperator sigmoid()",
"public static java.util.function.DoubleUnaryOperator sqrt()",
"public static java.util.function.DoubleUnaryOperator square()",
"public static java.util.function.DoubleUnaryOperator tan()",
"public static java.util.function.DoubleUnaryOperator tanh()",
+ "public static java.util.function.DoubleUnaryOperator elu()",
+ "public static java.util.function.DoubleUnaryOperator elu(double)",
+ "public static java.util.function.DoubleUnaryOperator leakyrelu()",
+ "public static java.util.function.DoubleUnaryOperator leakyrelu(double)",
+ "public static java.util.function.DoubleUnaryOperator relu()",
+ "public static java.util.function.DoubleUnaryOperator selu()",
+ "public static java.util.function.DoubleUnaryOperator selu(double, double)",
"public static java.util.function.Function random()",
"public static java.util.function.Function equal(java.util.List)",
"public static java.util.function.Function sum(java.util.List)"
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index e8e329cd75c..d9204e24d68 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -38,16 +38,12 @@ public class ScalarFunctions {
public static DoubleUnaryOperator atan() { return new Atan(); }
public static DoubleUnaryOperator ceil() { return new Ceil(); }
public static DoubleUnaryOperator cos() { return new Cos(); }
- public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
public static DoubleUnaryOperator log() { return new Log(); }
public static DoubleUnaryOperator neg() { return new Neg(); }
public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
- public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
- public static DoubleUnaryOperator selu() { return new Selu(); }
- public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
@@ -55,6 +51,14 @@ public class ScalarFunctions {
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
+ public static DoubleUnaryOperator elu() { return new Elu(); }
+ public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
+ public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); }
+ public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); }
+ public static DoubleUnaryOperator relu() { return new Relu(); }
+ public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); }
+
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
@@ -191,10 +195,17 @@ public class ScalarFunctions {
}
public static class Elu implements DoubleUnaryOperator {
+ private final double alpha;
+ public Elu() {
+ this(1.0);
+ }
+ public Elu(double alpha) {
+ this.alpha = alpha;
+ }
@Override
- public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
+ public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
@Override
- public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; }
+ public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
}
public static class Exp implements DoubleUnaryOperator {
@@ -241,8 +252,15 @@ public class ScalarFunctions {
public static class Selu implements DoubleUnaryOperator {
// See https://arxiv.org/abs/1706.02515
- private static final double scale = 1.0507009873554804934193349852946;
- private static final double alpha = 1.6732632423543772848170429916717;
+ private final double scale; // 1.0507009873554804934193349852946;
+ private final double alpha; // 1.6732632423543772848170429916717;
+ public Selu() {
+ this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717);
+ }
+ public Selu(double scale, double alpha) {
+ this.scale = scale;
+ this.alpha = alpha;
+ }
@Override
public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
@Override
@@ -250,10 +268,17 @@ public class ScalarFunctions {
}
public static class LeakyRelu implements DoubleUnaryOperator {
+ private final double alpha;
+ public LeakyRelu() {
+ this(0.01);
+ }
+ public LeakyRelu(double alpha) {
+ this.alpha = alpha;
+ }
@Override
- public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); }
+ public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
@Override
- public String toString() { return "f(a)(max(0.01*a, a))"; }
+ public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
}
public static class Sin implements DoubleUnaryOperator {
diff --git a/vespalib/src/tests/delegatelist/delegatelist.cpp b/vespalib/src/tests/delegatelist/delegatelist.cpp
index ba1a2049794..070864dd85a 100644
--- a/vespalib/src/tests/delegatelist/delegatelist.cpp
+++ b/vespalib/src/tests/delegatelist/delegatelist.cpp
@@ -780,11 +780,11 @@ Test::testWaitSnapshots()
ASSERT_TRUE(pool.NewThread(&a1, 0) != 0);
s1.reset(new DL::Snapshot(dl)); // create snap 1
a1.doIt(cmd_wait_snap(&dl)); // wait for snaps
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
EXPECT_TRUE(a1.getState() == Actor::STATE_BUSY); // still waiting...
s2.reset(new DL::Snapshot(dl)); // create snap 2
s1.reset(); // destroy snap 1
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
EXPECT_TRUE(a1.getState() == Actor::STATE_IDLE); // wait done!
a1.doIt(cmd_exit());
a1.waitState(Actor::STATE_DONE);
diff --git a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp
index c43d0ec1c29..7567e8426ae 100644
--- a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp
+++ b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp
@@ -5,7 +5,12 @@
#include <vespa/vespalib/util/inline.h>
#include <vespa/fastos/timestamp.h>
-using namespace vespalib;
+using vespalib::RightArrayHeap;
+using vespalib::RightHeap;
+using vespalib::LeftArrayHeap;
+using vespalib::LeftHeap;
+using vespalib::LeftStdHeap;
+using vespalib::make_string;
template <typename H> struct IsRight { enum { VALUE = 0 }; };
template <> struct IsRight<RightHeap> { enum { VALUE = 1 }; };
diff --git a/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp b/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp
index 7ca4a2eff39..5641d751f34 100644
--- a/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp
+++ b/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp
@@ -47,7 +47,7 @@ TEST_MT_FF("require that signals can be counted and cancelled", 2, Signal, size_
if (thread_id == 0) {
for (size_t i = 0; i < f2; ++i) {
f1.send();
- if (i % 128 == 0) { FastOS_Thread::Sleep(1); }
+ if (i % 128 == 0) { std::this_thread::sleep_for(1ms); }
}
TEST_BARRIER();
f1.cancel();
diff --git a/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp b/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp
index 5b6df6eef4e..d67c417b71a 100644
--- a/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp
+++ b/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp
@@ -65,7 +65,7 @@ TEST("estimate cost of thread bundle fork/join") {
if (time < minTime) {
minTime = time;
}
- FastOS_Thread::Sleep(10);
+ std::this_thread::sleep_for(10ms);
}
fprintf(stderr, "strategy: %s, threads: %zu, fork: %zu, iter: %zu, time: %g, unit: %g\n",
strategy_name[strategy].c_str(), threads, fork, iter, minTime,
diff --git a/vespalib/src/tests/thread/thread_test.cpp b/vespalib/src/tests/thread/thread_test.cpp
index 025a33fa221..bcd38190c7e 100644
--- a/vespalib/src/tests/thread/thread_test.cpp
+++ b/vespalib/src/tests/thread/thread_test.cpp
@@ -32,7 +32,7 @@ TEST("normal operation") {
{
Thread thread(agent);
thread.start();
- FastOS_Thread::Sleep(20);
+ std::this_thread::sleep_for(20ms);
thread.stop().join();
}
EXPECT_TRUE(agent.started);
diff --git a/vespalib/src/vespa/vespalib/testkit/test_kit.h b/vespalib/src/vespa/vespalib/testkit/test_kit.h
index 7e6b07d71df..17746c5b0fc 100644
--- a/vespalib/src/vespa/vespalib/testkit/test_kit.h
+++ b/vespalib/src/vespa/vespalib/testkit/test_kit.h
@@ -10,3 +10,4 @@
#include "test_hook.h"
#include "test_state_guard.h"
#include "time_bomb.h"
+#include <vespa/vespalib/util/time.h>
diff --git a/vespalib/src/vespa/vespalib/util/thread.cpp b/vespalib/src/vespa/vespalib/util/thread.cpp
index 2d0118645ab..4eb436458a2 100644
--- a/vespalib/src/vespa/vespalib/util/thread.cpp
+++ b/vespalib/src/vespa/vespalib/util/thread.cpp
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "thread.h"
+#include <thread>
namespace vespalib {
@@ -87,7 +88,7 @@ Thread::currentThread()
void
Thread::sleep(size_t ms)
{
- FastOS_Thread::Sleep(ms);
+ std::this_thread::sleep_for(std::chrono::milliseconds(ms));
}
} // namespace vespalib
diff --git a/vespamalloc/src/tests/allocfree/allocfree.cpp b/vespamalloc/src/tests/allocfree/allocfree.cpp
index 80513579a2f..86050d4aee9 100644
--- a/vespamalloc/src/tests/allocfree/allocfree.cpp
+++ b/vespamalloc/src/tests/allocfree/allocfree.cpp
@@ -89,7 +89,7 @@ int Test::Main() {
for (; duration > 0; --duration) {
LOG(info, "%d seconds left...", duration);
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
pool.Close();
size_t numFreeOperations(0);
diff --git a/vespamalloc/src/tests/allocfree/linklist.cpp b/vespamalloc/src/tests/allocfree/linklist.cpp
index 11a8d1ddd11..74af380458a 100644
--- a/vespamalloc/src/tests/allocfree/linklist.cpp
+++ b/vespamalloc/src/tests/allocfree/linklist.cpp
@@ -163,7 +163,7 @@ int Test::Main() {
for (; duration > 0; --duration) {
LOG(info, "%d seconds left...", duration);
- FastOS_Thread::Sleep(1000);
+ std::this_thread::sleep_for(1s);
}
pool.Close();
fprintf(stderr, "Did (%lu + %lu) = %lu linkIn operations\n",
diff --git a/zkfacade/abi-spec.json b/zkfacade/abi-spec.json
index efe6fbdaa08..25b652b7312 100644
--- a/zkfacade/abi-spec.json
+++ b/zkfacade/abi-spec.json
@@ -68,7 +68,6 @@
"methods": [
"public static com.yahoo.vespa.curator.Curator create(java.lang.String)",
"public static com.yahoo.vespa.curator.Curator create(java.lang.String, java.util.Optional)",
- "public void <init>(com.yahoo.cloud.config.ConfigserverConfig)",
"public void <init>(com.yahoo.cloud.config.ConfigserverConfig, com.yahoo.vespa.zookeeper.VespaZooKeeperServer)",
"protected void <init>(java.lang.String, java.lang.String, java.util.function.Function)",
"public java.lang.String connectionSpec()",
diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java
index b76bad5b97b..9d74306d3d5 100644
--- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java
+++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java
@@ -3,9 +3,12 @@ package com.yahoo.vespa.curator;
import com.google.inject.Inject;
import com.yahoo.cloud.config.ConfigserverConfig;
+import com.yahoo.io.IOUtils;
import com.yahoo.net.HostName;
import com.yahoo.path.Path;
+import com.yahoo.text.Utf8;
import com.yahoo.vespa.curator.recipes.CuratorCounter;
+import com.yahoo.vespa.defaults.Defaults;
import com.yahoo.vespa.zookeeper.VespaZooKeeperServer;
import org.apache.curator.RetryPolicy;
import org.apache.curator.framework.CuratorFramework;
@@ -20,7 +23,9 @@ import org.apache.curator.framework.recipes.locks.InterProcessLock;
import org.apache.curator.framework.recipes.locks.InterProcessMutex;
import org.apache.curator.retry.ExponentialBackoffRetry;
import org.apache.zookeeper.KeeperException;
+import org.apache.zookeeper.client.ZKClientConfig;
import org.apache.zookeeper.data.Stat;
+import org.apache.zookeeper.server.quorum.QuorumPeerConfig;
import java.io.File;
import java.time.Duration;
@@ -63,28 +68,21 @@ public class Curator implements AutoCloseable {
/** Creates a curator instance from a comma-separated string of ZooKeeper host:port strings */
public static Curator create(String connectionSpec, Optional<File> clientConfigFile) {
- return new Curator(connectionSpec, connectionSpec);
- }
-
- // For testing
- public Curator(ConfigserverConfig configserverConfig) {
- this(configserverConfig, createConnectionSpec(configserverConfig));
+ return new Curator(connectionSpec, connectionSpec, clientConfigFile);
}
// Depend on ZooKeeperServer to make sure it is started first
// TODO: Move zookeeperserver config out of configserverconfig (requires update of controller services.xml as well)
@Inject
public Curator(ConfigserverConfig configserverConfig, VespaZooKeeperServer server) {
- this(configserverConfig, createConnectionSpec(configserverConfig));
+ this(configserverConfig, Optional.empty());
}
- private Curator(ConfigserverConfig configserverConfig, String zooKeeperEnsembleConnectionSpec) {
- this((configserverConfig.zookeeperLocalhostAffinity()) ?
- createConnectionSpecForLocalhost(configserverConfig) : zooKeeperEnsembleConnectionSpec,
- zooKeeperEnsembleConnectionSpec);
+ Curator(ConfigserverConfig configserverConfig, Optional<File> clientConfigFile) {
+ this(createConnectionSpec(configserverConfig), createEnsembleConnectionSpec(configserverConfig), clientConfigFile);
}
- private Curator(String connectionSpec, String zooKeeperEnsembleConnectionSpec) {
+ private Curator(String connectionSpec, String zooKeeperEnsembleConnectionSpec, Optional<File> clientConfigFile) {
this(connectionSpec,
zooKeeperEnsembleConnectionSpec,
(retryPolicy) -> CuratorFrameworkFactory
@@ -93,7 +91,7 @@ public class Curator implements AutoCloseable {
.sessionTimeoutMs(ZK_SESSION_TIMEOUT)
.connectionTimeoutMs(ZK_CONNECTION_TIMEOUT)
.connectString(connectionSpec)
- .zookeeperFactory(new VespaZooKeeperFactory())
+ .zookeeperFactory(new VespaZooKeeperFactory(createClientConfig(clientConfigFile)))
.dontUseContainerParents() // TODO: Remove when we know ZooKeeper 3.5 works fine, consider waiting until Vespa 8
.build());
}
@@ -123,7 +121,29 @@ public class Curator implements AutoCloseable {
this.zooKeeperEnsembleCount = zooKeeperEnsembleConnectionSpec.split(",").length;
}
- private static String createConnectionSpec(ConfigserverConfig config) {
+ private static String createConnectionSpec(ConfigserverConfig configserverConfig) {
+ return configserverConfig.zookeeperLocalhostAffinity()
+ ? createConnectionSpecForLocalhost(configserverConfig)
+ : createEnsembleConnectionSpec(configserverConfig);
+ }
+
+ private static ZKClientConfig createClientConfig(Optional<File> file) {
+ boolean useSecureClient = Boolean.parseBoolean(getEnvironmentVariable("VESPA_USE_TLS_FOR_ZOOKEEPER_CLIENT").orElse("false"));
+ String config = "zookeeper.client.secure=" + useSecureClient + "\n";
+
+ File clientConfigFile =
+ file.orElseGet(() -> new File(Defaults.getDefaults().underVespaHome("conf/zookeeper/zookeeper-client.cfg")));
+ clientConfigFile.getParentFile().mkdirs();
+ IOUtils.writeFile(clientConfigFile, Utf8.toBytes(config));
+
+ try {
+ return new ZKClientConfig(clientConfigFile);
+ } catch (QuorumPeerConfig.ConfigException e) {
+ throw new RuntimeException("Unable to create ZooKeeper client config file " + file);
+ }
+ }
+
+ private static String createEnsembleConnectionSpec(ConfigserverConfig config) {
StringBuilder connectionSpec = new StringBuilder();
for (int i = 0; i < config.zookeeperserver().size(); i++) {
if (connectionSpec.length() > 0) {
@@ -405,4 +425,10 @@ public class Curator implements AutoCloseable {
* TODO: Move method out of this class.
*/
public int zooKeeperEnsembleCount() { return zooKeeperEnsembleCount; }
+
+ private static Optional<String> getEnvironmentVariable(String variableName) {
+ return Optional.ofNullable(System.getenv().get(variableName))
+ .filter(var -> !var.isEmpty());
+ }
+
}
diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java
index 7c08168c536..84e2cb65a1a 100644
--- a/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java
+++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java
@@ -4,19 +4,24 @@ package com.yahoo.vespa.curator;
import org.apache.curator.utils.ZookeeperFactory;
import org.apache.zookeeper.Watcher;
import org.apache.zookeeper.ZooKeeper;
+import org.apache.zookeeper.client.ZKClientConfig;
/**
* A ZooKeeper factory for creating a ZooKeeper client
*
* @author hmusum
*/
-// TODO: add constructor that takes feature flag so that we can write ZooKeeper client config and start
-// ZooKeeper client with that config
class VespaZooKeeperFactory implements ZookeeperFactory {
+ private final ZKClientConfig zkClientConfig;
+
+ VespaZooKeeperFactory(ZKClientConfig zkClientConfig) {
+ this.zkClientConfig = zkClientConfig;
+ }
+
@Override
public ZooKeeper newZooKeeper(String connectString, int sessionTimeout, Watcher watcher, boolean canBeReadOnly) throws Exception {
- return new ZooKeeper(connectString, sessionTimeout, watcher);
+ return new ZooKeeper(connectString, sessionTimeout, watcher, zkClientConfig);
}
}
diff --git a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java
index 9b7e0250f2f..6b85953a1ff 100644
--- a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java
+++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java
@@ -8,7 +8,6 @@ import static org.junit.Assert.assertEquals;
/**
* @author Ulf Lilleengen
- * @date 19.08.13
*/
public class CuratorCounterTest {
diff --git a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java
index a8342dfe5bc..4cd2c708d1a 100644
--- a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java
+++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java
@@ -9,6 +9,8 @@ import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Optional;
import static org.hamcrest.core.Is.is;
import static org.junit.Assert.assertThat;
@@ -52,7 +54,7 @@ public class CuratorTest {
}
@Test
- public void require_curator_is_created_from_config() {
+ public void require_curator_is_created_from_config() throws IOException {
try (Curator curator = createCurator(createTestConfig())) {
assertThat(curator.zooKeeperEnsembleConnectionSpec(), is(spec1 + "," + spec2));
assertThat(curator.zooKeeperEnsembleCount(), is(2));
@@ -60,7 +62,7 @@ public class CuratorTest {
}
@Test
- public void require_that_server_count_is_correct() {
+ public void require_that_server_count_is_correct() throws IOException {
ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder();
builder.zookeeperserver(createZKBuilder(localhost, port1));
try (Curator curator = createCurator(new ConfigserverConfig(builder))) {
@@ -98,8 +100,8 @@ public class CuratorTest {
return zkBuilder;
}
- private Curator createCurator(ConfigserverConfig configserverConfig) {
- return new Curator(configserverConfig);
+ private Curator createCurator(ConfigserverConfig configserverConfig) throws IOException {
+ return new Curator(configserverConfig, Optional.of(Files.createTempFile("zookeeper-client", "cfg").toFile()));
}
private static class PortAllocator {