diff options
149 files changed, 1305 insertions, 550 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/ConfiguredFilebasedSslProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/ConfiguredFilebasedSslProvider.java index 4f84a01ff94..4a331718985 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/ConfiguredFilebasedSslProvider.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/ConfiguredFilebasedSslProvider.java @@ -8,6 +8,7 @@ import com.yahoo.jdisc.http.ssl.impl.ConfiguredSslContextFactoryProvider; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.vespa.model.container.component.SimpleComponent; +import java.util.List; import java.util.Optional; import static com.yahoo.component.ComponentSpecification.fromString; @@ -16,6 +17,7 @@ import static com.yahoo.component.ComponentSpecification.fromString; * Configure SSL using file references * * @author mortent + * @author bjorncs */ public class ConfiguredFilebasedSslProvider extends SimpleComponent implements ConnectorConfig.Producer { public static final String COMPONENT_ID_PREFIX = "configured-ssl-provider@"; @@ -26,8 +28,16 @@ public class ConfiguredFilebasedSslProvider extends SimpleComponent implements C private final String certificatePath; private final String caCertificatePath; private final ConnectorConfig.Ssl.ClientAuth.Enum clientAuthentication; + private final List<String> cipherSuites; + private final List<String> protocolVersions; - public ConfiguredFilebasedSslProvider(String servername, String privateKeyPath, String certificatePath, String caCertificatePath, String clientAuthentication) { + public ConfiguredFilebasedSslProvider(String servername, + String privateKeyPath, + String certificatePath, + String caCertificatePath, + String clientAuthentication, + List<String> cipherSuites, + List<String> protocolVersions) { super(new ComponentModel( new BundleInstantiationSpecification(new ComponentId(COMPONENT_ID_PREFIX+servername), fromString(COMPONENT_CLASS), @@ -36,15 +46,21 @@ public class ConfiguredFilebasedSslProvider extends SimpleComponent implements C this.certificatePath = certificatePath; this.caCertificatePath = caCertificatePath; this.clientAuthentication = mapToConfigEnum(clientAuthentication); + this.cipherSuites = cipherSuites; + this.protocolVersions = protocolVersions; } @Override public void getConfig(ConnectorConfig.Builder builder) { - builder.ssl.enabled(true); - builder.ssl.privateKeyFile(privateKeyPath); - builder.ssl.certificateFile(certificatePath); - builder.ssl.caCertificateFile(Optional.ofNullable(caCertificatePath).orElse("")); - builder.ssl.clientAuth(clientAuthentication); + builder.ssl( + new ConnectorConfig.Ssl.Builder() + .enabled(true) + .privateKeyFile(privateKeyPath) + .certificateFile(certificatePath) + .caCertificateFile(Optional.ofNullable(caCertificatePath).orElse("")) + .clientAuth(clientAuthentication) + .enabledCipherSuites(cipherSuites) + .enabledProtocols(protocolVersions)); } public SimpleComponent getComponent() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java index 1b457b1250a..499268929b7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/xml/JettyConnectorBuilder.java @@ -8,13 +8,17 @@ import com.yahoo.text.XML; import com.yahoo.vespa.model.builder.xml.dom.VespaDomBuilder; import com.yahoo.vespa.model.container.component.SimpleComponent; import com.yahoo.vespa.model.container.http.ConnectorFactory; -import com.yahoo.vespa.model.container.http.ssl.CustomSslProvider; import com.yahoo.vespa.model.container.http.ssl.ConfiguredFilebasedSslProvider; +import com.yahoo.vespa.model.container.http.ssl.CustomSslProvider; import com.yahoo.vespa.model.container.http.ssl.DefaultSslProvider; import org.w3c.dom.Element; +import java.util.Arrays; +import java.util.List; import java.util.Optional; +import static java.util.stream.Collectors.toList; + /** * @author Einar M R Rosenvinge * @author mortent @@ -39,12 +43,16 @@ public class JettyConnectorBuilder extends VespaDomBuilder.DomConfigProducerBuil String certificateFile = XML.getValue(XML.getChild(sslConfigurator, "certificate-file")); Optional<String> caCertificateFile = XmlHelper.getOptionalChildValue(sslConfigurator, "ca-certificates-file"); Optional<String> clientAuthentication = XmlHelper.getOptionalChildValue(sslConfigurator, "client-authentication"); + List<String> cipherSuites = extractOptionalCommaSeparatedList(sslConfigurator, "cipher-suites"); + List<String> protocols = extractOptionalCommaSeparatedList(sslConfigurator, "protocols"); return new ConfiguredFilebasedSslProvider( serverName, privateKeyFile, certificateFile, caCertificateFile.orElse(null), - clientAuthentication.orElse(null)); + clientAuthentication.orElse(null), + cipherSuites, + protocols); } else if (sslProviderConfigurator != null) { String className = sslProviderConfigurator.getAttribute("class"); String bundle = sslProviderConfigurator.getAttribute("bundle"); @@ -53,4 +61,14 @@ public class JettyConnectorBuilder extends VespaDomBuilder.DomConfigProducerBuil return new DefaultSslProvider(serverName); } } + + private static List<String> extractOptionalCommaSeparatedList(Element sslElement, String listElementName) { + return XmlHelper.getOptionalChildValue(sslElement, listElementName) + .map(element -> + Arrays.stream(element.split(",")) + .filter(listEntry -> !listEntry.isBlank()) + .map(String::trim) + .collect(toList())) + .orElse(List.of()); + } } diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 142abb5c63b..a8228a233b3 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -95,7 +95,9 @@ Ssl = element ssl { element private-key-file { string } & element certificate-file { string } & element ca-certificates-file { string }? & - element client-authentication { string "disabled" | string "want" | string "need" }? + element client-authentication { string "disabled" | string "want" | string "need" }? & + element cipher-suites { string }? & + element protocols { string }? } SslProvider = element ssl-provider { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java index 929e520f984..4679377ce94 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JettyContainerModelBuilderTest.java @@ -152,6 +152,14 @@ public class JettyContainerModelBuilderTest extends ContainerModelBuilderTestBas " <client-authentication>need</client-authentication>", " </ssl>", " </server>", + " <server port='9003' id='with-ciphers-and-protocols'>", + " <ssl>", + " <private-key-file>/foo/key</private-key-file>", + " <certificate-file>/foo/cert</certificate-file>", + " <cipher-suites>TLS_AES_128_GCM_SHA256,TLS_AES_256_GCM_SHA384</cipher-suites>", + " <protocols>TLSv1.3</protocols>", + " </ssl>", + " </server>", " </http>", nodesXml, "", @@ -179,6 +187,13 @@ public class JettyContainerModelBuilderTest extends ContainerModelBuilderTestBas assertThat(needClientAuth.ssl().caCertificateFile(), is(equalTo(""))); assertThat(needClientAuth.ssl().clientAuth(), is(equalTo(ConnectorConfig.Ssl.ClientAuth.Enum.NEED_AUTH))); + ConnectorConfig withCiphersAndProtocols = root.getConfig(ConnectorConfig.class, "default/http/jdisc-jetty/with-ciphers-and-protocols/configured-ssl-provider@with-ciphers-and-protocols"); + assertTrue(withCiphersAndProtocols.ssl().enabled()); + assertThat(withCiphersAndProtocols.ssl().privateKeyFile(), is(equalTo("/foo/key"))); + assertThat(withCiphersAndProtocols.ssl().certificateFile(), is(equalTo("/foo/cert"))); + assertThat(withCiphersAndProtocols.ssl().enabledCipherSuites(), is(equalTo(List.of("TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384")))); + assertThat(withCiphersAndProtocols.ssl().enabledProtocols(), is(equalTo(List.of("TLSv1.3")))); + ContainerCluster cluster = (ContainerCluster) root.getChildren().get("default"); List<ConnectorFactory> connectorFactories = cluster.getChildrenByTypeRecursive(ConnectorFactory.class); connectorFactories.forEach(connectorFactory -> assertChildComponentExists(connectorFactory, ConfiguredFilebasedSslProvider.COMPONENT_CLASS)); diff --git a/config-model/src/test/schema-test-files/services.xml b/config-model/src/test/schema-test-files/services.xml index 2bbd98f72ac..1bf42650123 100644 --- a/config-model/src/test/schema-test-files/services.xml +++ b/config-model/src/test/schema-test-files/services.xml @@ -119,6 +119,13 @@ <certificate-file>/foo/cert</certificate-file> <ca-certificates-file>/foo/cacerts</ca-certificates-file> <client-authentication>want</client-authentication> + <cipher-suites> + TLS_AES_128_GCM_SHA256, + TLS_AES_256_GCM_SHA384, + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + </cipher-suites> + <protocols>TLSv1.2,TLSv1.3</protocols> </ssl> </server> <server port="4083" id="sslProvider"> diff --git a/config/src/tests/api/api.cpp b/config/src/tests/api/api.cpp index 4db66761444..0af2b848ea5 100644 --- a/config/src/tests/api/api.cpp +++ b/config/src/tests/api/api.cpp @@ -32,7 +32,7 @@ TEST_MT_FFF("require that source may be unable to serve config temporarily", 2, ASSERT_TRUE(cfg.get() != NULL); ASSERT_EQUAL("myfoo", cfg->myField); } else { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); f3.myField = "myfoo"; f2.addBuilder("myid", &f3); f1->reload(); diff --git a/config/src/tests/configfetcher/configfetcher.cpp b/config/src/tests/configfetcher/configfetcher.cpp index 607ab0a29a5..be25e913980 100644 --- a/config/src/tests/configfetcher/configfetcher.cpp +++ b/config/src/tests/configfetcher/configfetcher.cpp @@ -69,7 +69,7 @@ TEST("requireThatConfigUpdatesArePerformed") { while (!cb._configured && timer.elapsed().ms() < 20000.0) { if (cb._configured) break; - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } ASSERT_TRUE(cb._configured); ASSERT_TRUE(cb._config); diff --git a/config/src/tests/configretriever/configretriever.cpp b/config/src/tests/configretriever/configretriever.cpp index fc921a324af..87f189ad7d3 100644 --- a/config/src/tests/configretriever/configretriever.cpp +++ b/config/src/tests/configretriever/configretriever.cpp @@ -251,7 +251,7 @@ public: if (configured) { return true; } - FastOS_Thread::Sleep(200); + std::this_thread::sleep_for(200ms); } return configured; } diff --git a/config/src/tests/frt/frt.cpp b/config/src/tests/frt/frt.cpp index ba8279a1999..28dea82bfe7 100644 --- a/config/src/tests/frt/frt.cpp +++ b/config/src/tests/frt/frt.cpp @@ -49,7 +49,7 @@ namespace { while (timer.elapsed().ms() < timeoutInMillis) { if (notified) break; - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return notified; } @@ -260,7 +260,7 @@ TEST_FF("require that request is config task is scheduled", SourceFixture(), FRT f1.conn.scheduler.CheckTasks(); if (f2.result.notified) break; - FastOS_Thread::Sleep(500); + std::this_thread::sleep_for(500ms); } ASSERT_TRUE(f2.result.notified); f2.src.close(); diff --git a/config/src/tests/trace/trace.cpp b/config/src/tests/trace/trace.cpp index 33e25fa7ba2..fdb40d40893 100644 --- a/config/src/tests/trace/trace.cpp +++ b/config/src/tests/trace/trace.cpp @@ -2,7 +2,6 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/config/common/trace.h> -#include <vespa/vespalib/trace/tracenode.h> using namespace config; @@ -11,9 +10,9 @@ using namespace vespalib::slime; struct FixedClock : public Clock { - FixedClock() : currentTime(0) { } - int64_t currentTime; - int64_t currentTimeMillis() const override { return currentTime; } + FixedClock() : _currentTime(duration::zero()) { } + vespalib::system_time _currentTime; + vespalib::system_time currentTime() const override { return _currentTime; } }; TEST("that trace can be serialized and deserialized") { @@ -38,7 +37,7 @@ TEST("that trace can be serialized and deserialized") { } TEST_F("that trace level is taken into account", FixedClock) { - f1.currentTime = 3; + f1._currentTime = vespalib::system_time(3ms); Trace trace(4, f1); trace.trace(4, "foo"); trace.trace(5, "bar"); @@ -58,11 +57,13 @@ TEST("that trace can be copied") { EXPECT_EQUAL(trace.toString(), trace2.toString()); } +constexpr vespalib::system_time epoch(duration::zero()); + TEST("ensure that system clock is used by default") { Trace trace(2); trace.trace(1, "foo"); TraceNode child(trace.getRoot().getChild(0)); - EXPECT_TRUE(child.getTimestamp() > 0); + EXPECT_TRUE(child.getTimestamp() > epoch); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/config/src/vespa/config/common/trace.cpp b/config/src/vespa/config/common/trace.cpp index e6183a9fec1..d1bb154eda9 100644 --- a/config/src/vespa/config/common/trace.cpp +++ b/config/src/vespa/config/common/trace.cpp @@ -2,7 +2,6 @@ #include "trace.h" #include <vespa/vespalib/trace/slime_trace_serializer.h> #include <vespa/vespalib/trace/slime_trace_deserializer.h> -#include <vespa/fastos/timestamp.h> using namespace vespalib; using namespace vespalib::slime; @@ -11,8 +10,8 @@ namespace config { struct SystemClock : public Clock { - int64_t currentTimeMillis() const override { - return fastos::ClockSystem::now().timeSinceEpoch().ms(); + vespalib::system_time currentTime() const override { + return vespalib::system_clock::now(); } }; @@ -73,7 +72,7 @@ void Trace::trace(uint32_t level, const vespalib::string & message) { if (shouldTrace(level)) { - _root.addChild(message, _clock.currentTimeMillis()); + _root.addChild(message, _clock.currentTime()); } } diff --git a/config/src/vespa/config/common/trace.h b/config/src/vespa/config/common/trace.h index 9cfd2b1f88e..772cdb6f31e 100644 --- a/config/src/vespa/config/common/trace.h +++ b/config/src/vespa/config/common/trace.h @@ -12,7 +12,7 @@ namespace config { * Clock interface for acquiring time. */ struct Clock { - virtual int64_t currentTimeMillis() const = 0; + virtual vespalib::system_time currentTime() const = 0; virtual ~Clock() {} }; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java index 14cba858857..26bf189dd3d 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java @@ -33,7 +33,8 @@ enum PathGroup { "/zone/v2/{*}"), /** Paths used for creating and reading user resources. */ - user("/application/v4/user", + user(Optional.of("/api"), + "/application/v4/user", "/athenz/v1/{*}"), /** Paths used for creating tenants with proper access control. */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java index 26c4bf6292a..d10a4879bf5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiHandler.java @@ -37,6 +37,7 @@ import java.util.logging.Logger; public class AthenzApiHandler extends LoggingRequestHandler { private final static Logger log = Logger.getLogger(AthenzApiHandler.class.getName()); + private static final String OPTIONAL_PREFIX = "/api"; private final AthenzFacade athenz; private final AthenzDomain sandboxDomain; @@ -69,7 +70,7 @@ public class AthenzApiHandler extends LoggingRequestHandler { } private HttpResponse get(HttpRequest request) { - Path path = new Path(request.getUri()); + Path path = new Path(request.getUri(), OPTIONAL_PREFIX); if (path.matches("/athenz/v1")) return root(request); if (path.matches("/athenz/v1/domains")) return domainList(request); if (path.matches("/athenz/v1/properties")) return properties(); @@ -79,7 +80,7 @@ public class AthenzApiHandler extends LoggingRequestHandler { } private HttpResponse post(HttpRequest request) { - Path path = new Path(request.getUri()); + Path path = new Path(request.getUri(), OPTIONAL_PREFIX); if (path.matches("/athenz/v1/user")) return signup(request); return ErrorResponse.notFoundError(String.format("No '%s' handler at '%s'", request.getMethod(), request.getUri().getPath())); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java index d37df2cc313..c263054c808 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/ControllerContainerTest.java @@ -127,6 +127,7 @@ public class ControllerContainerTest { " </handler>\n" + " <handler id='com.yahoo.vespa.hosted.controller.restapi.athenz.AthenzApiHandler'>\n" + " <binding>http://*/athenz/v1/*</binding>\n" + + " <binding>http://*/api/athenz/v1/*</binding>\n" + " </handler>\n" + " <handler id='com.yahoo.vespa.hosted.controller.restapi.zone.v1.ZoneApiHandler'>\n" + " <binding>http://*/zone/v1</binding>\n" + diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java index c90dcbf7e2b..34ee160c449 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/athenz/AthenzApiTest.java @@ -48,7 +48,7 @@ public class AthenzApiTest extends ControllerContainerTest { new File("property-list.json")); // POST user signup - tester.assertResponse(authenticatedRequest("http://localhost:8080/athenz/v1/user", "", Request.Method.POST), + tester.assertResponse(authenticatedRequest("http://localhost:8080/api/athenz/v1/user", "", Request.Method.POST), "{\"message\":\"User 'bob' added to admin role of 'vespa.vespa.tenants.sandbox'\"}"); } diff --git a/documentapi/src/tests/policies/policies_test.cpp b/documentapi/src/tests/policies/policies_test.cpp index 93c5d51fef5..0c3a39186a7 100644 --- a/documentapi/src/tests/policies/policies_test.cpp +++ b/documentapi/src/tests/policies/policies_test.cpp @@ -284,7 +284,7 @@ Test::assertMirrorReady(const slobrok::api::IMirrorAPI &mirror) if (mirror.ready()) { return; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } ASSERT_TRUE(false); } @@ -297,7 +297,7 @@ Test::assertMirrorContains(const slobrok::api::IMirrorAPI &mirror, const string if (mirror.lookup(pattern).size() == numEntries) { return; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } ASSERT_TRUE(false); } diff --git a/documentapi/src/tests/policies/testframe.cpp b/documentapi/src/tests/policies/testframe.cpp index 4cdc5d4ba14..1ca449816d9 100644 --- a/documentapi/src/tests/policies/testframe.cpp +++ b/documentapi/src/tests/policies/testframe.cpp @@ -8,6 +8,8 @@ #include <vespa/messagebus/testlib/simpleprotocol.h> #include <vespa/messagebus/testlib/simplereply.h> #include <vespa/messagebus/network/rpcnetworkparams.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".testframe"); @@ -297,7 +299,7 @@ TestFrame::waitSlobrok(const string &pattern, uint32_t cnt) if (res.size() == cnt) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } LOG(error, "Slobrok failed to resolve '%s' to %d recipients in time.", pattern.c_str(), cnt); return false; diff --git a/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp b/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp index 18dd525b066..e82a184d8b2 100644 --- a/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp +++ b/documentapi/src/vespa/documentapi/messagebus/policies/externslobrokpolicy.cpp @@ -1,12 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "externslobrokpolicy.h" -#include <vespa/vespalib/text/stringtokenizer.h> #include <vespa/messagebus/routing/routingcontext.h> +#include <vespa/vespalib/text/stringtokenizer.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fnet/frt/frt.h> #include <vespa/slobrok/sbmirror.h> #include <vespa/fnet/transport.h> -#include <vespa/fastos/thread.h> +#include <thread> using slobrok::api::IMirrorAPI; using slobrok::api::MirrorAPI; @@ -82,7 +83,7 @@ ExternSlobrokPolicy::lookup(mbus::RoutingContext& context, const string& pattern if (_firstTry) { int count = 0; while (entries.empty() && count < 100) { - FastOS_Thread::Sleep(50); + std::this_thread::sleep_for(50ms); entries = mirror.lookup(pattern); count++; } diff --git a/fastos/src/tests/processtest.cpp b/fastos/src/tests/processtest.cpp index a6729dbb783..5a78eff1d36 100644 --- a/fastos/src/tests/processtest.cpp +++ b/fastos/src/tests/processtest.cpp @@ -3,6 +3,8 @@ #include <vespa/fastos/process.h> #include <vespa/fastos/timestamp.h> +using namespace std::chrono_literals; + class MyListener : public FastOS_ProcessRedirectListener { private: @@ -119,7 +121,7 @@ public: xproc->WriteStdin(nullptr, 0); } - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } if(i == 10) diff --git a/fastos/src/tests/thread_bounce_test.cpp b/fastos/src/tests/thread_bounce_test.cpp index f7bb7ee1260..84506938455 100644 --- a/fastos/src/tests/thread_bounce_test.cpp +++ b/fastos/src/tests/thread_bounce_test.cpp @@ -43,7 +43,7 @@ class Thread_Bounce_Test : public ThreadTestBase int left = static_cast<int>(checkTime.elapsed().ms()); while (left < 1000) { - FastOS_Thread::Sleep(1000 - left); + std::this_thread::sleep_for(std::chrono::milliseconds(1000 - left)); left = static_cast<int>(checkTime.elapsed().ms()); } diff --git a/fastos/src/tests/thread_mutex_test.cpp b/fastos/src/tests/thread_mutex_test.cpp index d49cf37163d..6d3f8c3c5f0 100644 --- a/fastos/src/tests/thread_mutex_test.cpp +++ b/fastos/src/tests/thread_mutex_test.cpp @@ -132,7 +132,7 @@ class Thread_Mutex_Test : public ThreadTestBase { bool lockrc; - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); for(int i=0; i<5; i++) { @@ -145,7 +145,7 @@ class Thread_Mutex_Test : public ThreadTestBase } } - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); lockrc = mtx.try_lock(); Progress(lockrc, "We should get the mutex lock now (%s)", diff --git a/fastos/src/tests/thread_sleep_test.cpp b/fastos/src/tests/thread_sleep_test.cpp index 7fd3412b7c3..209b7d3f880 100644 --- a/fastos/src/tests/thread_sleep_test.cpp +++ b/fastos/src/tests/thread_sleep_test.cpp @@ -20,7 +20,7 @@ class Thread_Sleep_Test : public ThreadTestBase Progress(rc, "Creating Thread"); Progress(true, "Sleeping 3 seconds"); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); } Progress(true, "Closing threadpool..."); diff --git a/fastos/src/tests/thread_stats_test.cpp b/fastos/src/tests/thread_stats_test.cpp index 3633c12bcaa..a9d304d411f 100644 --- a/fastos/src/tests/thread_stats_test.cpp +++ b/fastos/src/tests/thread_stats_test.cpp @@ -31,7 +31,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -44,7 +44,7 @@ class Thread_Stats_Test : public ThreadTestBase job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -57,7 +57,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread->SetBreakFlag(); job[1].ownThread->SetBreakFlag(); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads); @@ -72,7 +72,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].code = WAIT_FOR_BREAK_FLAG; job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 1, "Inactive threads = %d", inactiveThreads); @@ -84,7 +84,7 @@ class Thread_Stats_Test : public ThreadTestBase job[1].code = WAIT_FOR_BREAK_FLAG; job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1])); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -97,7 +97,7 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread->SetBreakFlag(); job[1].ownThread->SetBreakFlag(); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads); diff --git a/fastos/src/tests/thread_test_base.hpp b/fastos/src/tests/thread_test_base.hpp index 7966e95b369..c4f7ed76ea7 100644 --- a/fastos/src/tests/thread_test_base.hpp +++ b/fastos/src/tests/thread_test_base.hpp @@ -3,6 +3,7 @@ #pragma once #include <chrono> +#include <thread> static volatile int64_t number; #define INCREASE_NUMBER_AMOUNT 10000 @@ -47,7 +48,7 @@ public: } } - FastOS_Thread::Sleep(500); + std::this_thread::sleep_for(500ms); if(threadsFinished) break; @@ -88,7 +89,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) Progress(true, "Thread printing message: [%s]", job->message); job->result = strlen(job->message); - FastOS_Thread::Sleep(3000); + std::this_thread::sleep_for(3s); break; } @@ -109,7 +110,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) number = number + 2; if(i == sleepOn) - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } guard = std::unique_lock<std::mutex>(); @@ -123,7 +124,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) { for(;;) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); if(thread->GetBreakFlag()) { @@ -192,7 +193,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) case WAIT2SEC_AND_SIGNALCOND: { - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); job->condition->notify_one(); job->result = 1; break; @@ -202,7 +203,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) { { std::lock_guard<std::mutex> guard(*job->mutex); - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); } job->result = 1; break; @@ -210,7 +211,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) case WAIT_2_SEC: { - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); job->result = 1; break; } diff --git a/fastos/src/tests/threadtest.cpp b/fastos/src/tests/threadtest.cpp index 9507bb1e5d7..0a8a0d2bf02 100644 --- a/fastos/src/tests/threadtest.cpp +++ b/fastos/src/tests/threadtest.cpp @@ -43,7 +43,7 @@ class ThreadTest : public ThreadTestBase if(waitingThreads == numWait) break; - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } } @@ -336,7 +336,7 @@ class ThreadTest : public ThreadTestBase // Threads are not guaranteed to have entered sleep yet, // as this test only tests for result code // Wait another second to be sure. - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } void SignalTest () diff --git a/fastos/src/vespa/fastos/thread.h b/fastos/src/vespa/fastos/thread.h index 12866c71b2c..257acbc92d3 100644 --- a/fastos/src/vespa/fastos/thread.h +++ b/fastos/src/vespa/fastos/thread.h @@ -347,15 +347,7 @@ public: /** * Destructor. */ - virtual ~FastOS_ThreadInterface (){} - - /** - * Sleep for x milliseconds. Attempting to sleep for <1 milliseconds - * will result in failure. - * @param ms Number of milliseconds to sleep. - * @return Boolean success/failure - */ - static bool Sleep(int ms); + virtual ~FastOS_ThreadInterface () {} /** * Instruct a thread to exit. This could be used in conjunction with diff --git a/fastos/src/vespa/fastos/timestamp.cpp b/fastos/src/vespa/fastos/timestamp.cpp index deceaee4c65..977af69049c 100644 --- a/fastos/src/vespa/fastos/timestamp.cpp +++ b/fastos/src/vespa/fastos/timestamp.cpp @@ -76,7 +76,6 @@ ClockSteady::now() const SteadyTimeStamp SteadyTimeStamp::ZERO; const SteadyTimeStamp SteadyTimeStamp::FUTURE(TimeStamp::FUTURE); const UTCTimeStamp UTCTimeStamp::ZERO; -const UTCTimeStamp UTCTimeStamp::FUTURE(TimeStamp::FUTURE); UTCTimeStamp SteadyTimeStamp::toUTC() const { diff --git a/fastos/src/vespa/fastos/timestamp.h b/fastos/src/vespa/fastos/timestamp.h index 79d6ef5eed6..f1a40272938 100644 --- a/fastos/src/vespa/fastos/timestamp.h +++ b/fastos/src/vespa/fastos/timestamp.h @@ -61,7 +61,6 @@ inline TimeStamp operator *(double a, TimeStamp b) { return TimeStamp(static_cas class UTCTimeStamp { public: static const UTCTimeStamp ZERO; - static const UTCTimeStamp FUTURE; UTCTimeStamp() : _timeStamp() { } explicit UTCTimeStamp(TimeStamp timeStamp) : _timeStamp(timeStamp) { } @@ -92,7 +91,7 @@ public: friend bool operator >= (UTCTimeStamp a, UTCTimeStamp b) { return a._timeStamp >= b._timeStamp; } - TimeStamp timeSinceEpoch() const { return _timeStamp - ZERO._timeStamp; } + TimeStamp time_since_epoch() const { return _timeStamp - ZERO._timeStamp; } std::string toString() const { return _timeStamp.toString(); }; private: TimeStamp _timeStamp; diff --git a/fastos/src/vespa/fastos/unix_process.cpp b/fastos/src/vespa/fastos/unix_process.cpp index 86d285059b8..4d4197f5354 100644 --- a/fastos/src/vespa/fastos/unix_process.cpp +++ b/fastos/src/vespa/fastos/unix_process.cpp @@ -40,6 +40,8 @@ extern char **environ; #endif +using namespace std::chrono_literals; + static pid_t safe_fork () { pid_t pid; @@ -1629,7 +1631,7 @@ FastOS_UNIX_ProcessStarter::Wait(FastOS_UNIX_Process *process, } } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return rc; diff --git a/fastos/src/vespa/fastos/unix_thread.cpp b/fastos/src/vespa/fastos/unix_thread.cpp index 5218bde2630..9e48727deb3 100644 --- a/fastos/src/vespa/fastos/unix_thread.cpp +++ b/fastos/src/vespa/fastos/unix_thread.cpp @@ -83,18 +83,6 @@ FastOS_UNIX_Thread::~FastOS_UNIX_Thread() } } -bool FastOS_UNIX_Thread::Sleep (int ms) -{ - bool rc=false; - - if (ms > 0) { - usleep(ms*1000); - rc = true; - } - - return rc; -} - FastOS_ThreadId FastOS_UNIX_Thread::GetThreadId () { return _handle; diff --git a/fastos/src/vespa/fastos/unix_thread.h b/fastos/src/vespa/fastos/unix_thread.h index c6e0b040fc7..35df3f5745f 100644 --- a/fastos/src/vespa/fastos/unix_thread.h +++ b/fastos/src/vespa/fastos/unix_thread.h @@ -36,7 +36,6 @@ public: ~FastOS_UNIX_Thread(); - static bool Sleep (int ms); FastOS_ThreadId GetThreadId () override; static bool CompareThreadIds (FastOS_ThreadId a, FastOS_ThreadId b); static FastOS_ThreadId GetCurrentThreadId (); diff --git a/fnet/src/examples/frt/rpc/CMakeLists.txt b/fnet/src/examples/frt/rpc/CMakeLists.txt index 914525a2a57..109284fa222 100644 --- a/fnet/src/examples/frt/rpc/CMakeLists.txt +++ b/fnet/src/examples/frt/rpc/CMakeLists.txt @@ -48,7 +48,7 @@ vespa_add_executable(fnet_rpc_callback_client_app vespa_add_executable(fnet_rpc_invoke_app SOURCES rpc_invoke.cpp - OUTPUT_NAME vespa-rpc-invoke + OUTPUT_NAME vespa-rpc-invoke-bin INSTALL bin DEPENDS fnet diff --git a/fnet/src/examples/timeout/timeout.cpp b/fnet/src/examples/timeout/timeout.cpp index 1d6ecc11909..23dfbeb9070 100644 --- a/fnet/src/examples/timeout/timeout.cpp +++ b/fnet/src/examples/timeout/timeout.cpp @@ -2,7 +2,8 @@ #include <vespa/fnet/fnet.h> #include <vespa/fastos/app.h> -#include <chrono> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("timeout"); @@ -55,7 +56,7 @@ MyApp::Main() transport.Start(&pool); // stable-state operation - FastOS_Thread::Sleep(500); + std::this_thread::sleep_for(500ms); FNET_Packet *packet; FNET_Context context; @@ -64,7 +65,7 @@ MyApp::Main() t = clock::now(); timeout.Schedule(2.0); // timeout in 2 seconds - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); timeout.Unschedule(); // cancel timeout ms = (clock::now() - t); diff --git a/fnet/src/tests/examples/examples_test.cpp b/fnet/src/tests/examples/examples_test.cpp index 61a2408e39d..c704c58abc9 100644 --- a/fnet/src/tests/examples/examples_test.cpp +++ b/fnet/src/tests/examples/examples_test.cpp @@ -75,7 +75,7 @@ TEST("usage") { EXPECT_FALSE(runProc(proc, done)); } { - SlaveProc proc("exec ../../examples/frt/rpc/vespa-rpc-invoke"); + SlaveProc proc("exec ../../examples/frt/rpc/vespa-rpc-invoke-bin"); EXPECT_FALSE(runProc(proc, done)); } { @@ -197,7 +197,7 @@ TEST_MT_F("rpc invoke", 2, std::atomic<bool>()) { EXPECT_TRUE(runProc(proc, f1)); } else { TEST_BARRIER(); - EXPECT_TRUE(runProc(vespalib::make_string("exec ../../examples/frt/rpc/vespa-rpc-invoke tcp/localhost:%d frt.rpc.echo " + EXPECT_TRUE(runProc(vespalib::make_string("exec ../../examples/frt/rpc/vespa-rpc-invoke-bin tcp/localhost:%d frt.rpc.echo " "b:1 h:2 i:4 l:8 f:0.5 d:0.25 s:foo", PORT0).c_str())); f1 = true; diff --git a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp index 43a61cd9bcd..95dbe672909 100644 --- a/fnet/src/tests/frt/rpc/detach_return_invoke.cpp +++ b/fnet/src/tests/frt/rpc/detach_return_invoke.cpp @@ -54,7 +54,7 @@ TEST("detach return invoke") { if (receptor.req != 0) { break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } req->SubRef(); target->SubRef(); diff --git a/fnet/src/vespa/fnet/frt/invoker.h b/fnet/src/vespa/fnet/frt/invoker.h index 64adf66688e..0838ef84dd3 100644 --- a/fnet/src/vespa/fnet/frt/invoker.h +++ b/fnet/src/vespa/fnet/frt/invoker.h @@ -5,7 +5,6 @@ #include "rpcrequest.h" #include <vespa/fnet/task.h> #include <vespa/fnet/ipackethandler.h> -#include <vespa/fastos/thread.h> #include <mutex> #include <condition_variable> diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java index 7a683b74656..140feb75026 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java @@ -23,6 +23,7 @@ import org.eclipse.jetty.server.Handler; import org.eclipse.jetty.server.Server; import org.eclipse.jetty.server.ServerConnectionStatistics; import org.eclipse.jetty.server.ServerConnector; +import org.eclipse.jetty.server.SslConnectionFactory; import org.eclipse.jetty.server.handler.AbstractHandlerContainer; import org.eclipse.jetty.server.handler.HandlerCollection; import org.eclipse.jetty.server.handler.StatisticsHandler; @@ -316,6 +317,7 @@ public class JettyHttpServer extends AbstractServerProvider { public void start() { try { server.start(); + logEffectiveSslConfiguration(); } catch (final Exception e) { if (e instanceof IOException && e.getCause() instanceof BindException) { throw new RuntimeException("Failed to start server due to BindExecption. ListenPorts = " + listenedPorts.toString(), e.getCause()); @@ -324,6 +326,22 @@ public class JettyHttpServer extends AbstractServerProvider { } } + private void logEffectiveSslConfiguration() { + if (!server.isStarted()) throw new IllegalStateException(); + for (Connector connector : server.getConnectors()) { + ServerConnector serverConnector = (ServerConnector) connector; + int localPort = serverConnector.getLocalPort(); + var sslConnectionFactory = serverConnector.getConnectionFactory(SslConnectionFactory.class); + if (sslConnectionFactory != null) { + var sslContextFactory = sslConnectionFactory.getSslContextFactory(); + log.info(String.format("Enabled SSL cipher suites for port '%d': %s", + localPort, Arrays.toString(sslContextFactory.getSelectedCipherSuites()))); + log.info(String.format("Enabled SSL protocols for port '%d': %s", + localPort, Arrays.toString(sslContextFactory.getSelectedProtocols()))); + } + } + } + @Override public void close() { try { diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/ConfiguredSslContextFactoryProvider.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/ConfiguredSslContextFactoryProvider.java index b2e7ba1be67..90848f1dfd4 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/ConfiguredSslContextFactoryProvider.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/ConfiguredSslContextFactoryProvider.java @@ -70,12 +70,12 @@ public class ConfiguredSslContextFactoryProvider implements SslContextFactoryPro List<String> protocols = !sslConfig.enabledProtocols().isEmpty() ? sslConfig.enabledProtocols() - : new ArrayList<>(TlsContext.ALLOWED_PROTOCOLS); + : new ArrayList<>(TlsContext.getAllowedProtocols(sslContext)); setEnabledProtocols(factory, sslContext, protocols); List<String> ciphers = !sslConfig.enabledCipherSuites().isEmpty() ? sslConfig.enabledCipherSuites() - : new ArrayList<>(TlsContext.ALLOWED_CIPHER_SUITES); + : new ArrayList<>(TlsContext.getAllowedCipherSuites(sslContext)); setEnabledCipherSuites(factory, sslContext, ciphers); return factory; diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/JDiscSslContextFactory.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/JDiscSslContextFactory.java deleted file mode 100644 index 4d3bb4a280a..00000000000 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/ssl/impl/JDiscSslContextFactory.java +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.jdisc.http.ssl.impl; - -import org.eclipse.jetty.util.resource.Resource; -import org.eclipse.jetty.util.security.CertificateUtils; -import org.eclipse.jetty.util.ssl.SslContextFactory; - -import java.security.KeyStore; -import java.util.Objects; - -/** - * A modified {@link SslContextFactory} that allows passwordless truststore in combination with password protected keystore. - * - * @author bjorncs - */ -class JDiscSslContextFactory extends SslContextFactory.Server { - - private String trustStorePassword; - - @Override - public void setTrustStorePassword(String password) { - super.setTrustStorePassword(password); - this.trustStorePassword = password; - } - - - // Overriden to stop Jetty from using the keystore password if no truststore password is specified. - @Override - protected KeyStore loadTrustStore(Resource resource) throws Exception { - return CertificateUtils.getKeyStore( - resource != null ? resource : getKeyStoreResource(), - Objects.toString(getTrustStoreType(), getKeyStoreType()), - Objects.toString(getTrustStoreProvider(), getKeyStoreProvider()), - trustStorePassword); - } -} diff --git a/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp b/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp index 40c54980c11..cd1ad7e6eed 100644 --- a/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp +++ b/jrt_test/src/tests/mandatory-methods/extract-reflection.cpp @@ -2,6 +2,8 @@ #include <vespa/fastos/app.h> #include <vespa/fnet/frt/frt.h> +#include <vespa/vespalib/util/time.h> +#include <thread> class RPCInfo : public FastOS_Application { @@ -85,7 +87,7 @@ public: if (info->GetErrorCode() != FRTE_RPC_CONNECTION) { break; } - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); target->SubRef(); target = supervisor.GetTarget(_argv[1]); } diff --git a/messagebus/src/tests/context/context.cpp b/messagebus/src/tests/context/context.cpp index de9dd1b83a6..c71357d09ad 100644 --- a/messagebus/src/tests/context/context.cpp +++ b/messagebus/src/tests/context/context.cpp @@ -77,7 +77,7 @@ Test::Main() if (queue.size() == 3) { break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } EXPECT_EQUAL(queue.size(), 3u); { diff --git a/messagebus/src/tests/loadbalance/loadbalance.cpp b/messagebus/src/tests/loadbalance/loadbalance.cpp index 2f510d98ff1..05ea6d78871 100644 --- a/messagebus/src/tests/loadbalance/loadbalance.cpp +++ b/messagebus/src/tests/loadbalance/loadbalance.cpp @@ -78,7 +78,7 @@ Test::Main() if (queue.size() == msgCnt) { break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } EXPECT_TRUE(queue.size() == msgCnt); EXPECT_TRUE(h1.cnt == msgCnt / 3); diff --git a/messagebus/src/tests/messagebus/messagebus.cpp b/messagebus/src/tests/messagebus/messagebus.cpp index 7434941a900..d9c6e438523 100644 --- a/messagebus/src/tests/messagebus/messagebus.cpp +++ b/messagebus/src/tests/messagebus/messagebus.cpp @@ -43,7 +43,7 @@ struct Base { if (queue.size() == size) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } @@ -270,7 +270,7 @@ Test::testSendToCol() } } client->waitQueueSize(300); - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); client->waitQueueSize(300); while (client->queue.size() > 0) { Routable::UP reply = client->queue.dequeue(0); @@ -347,7 +347,7 @@ Test::testSendToAnyThenCol() } } client->waitQueueSize(300); - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); client->waitQueueSize(300); while (client->queue.size() > 0) { Routable::UP reply = client->queue.dequeue(0); diff --git a/messagebus/src/tests/messageordering/messageordering.cpp b/messagebus/src/tests/messageordering/messageordering.cpp index 520c3d3dea3..481b8bbd270 100644 --- a/messagebus/src/tests/messageordering/messageordering.cpp +++ b/messagebus/src/tests/messageordering/messageordering.cpp @@ -167,7 +167,7 @@ Test::Main() const int messageCount = 5000; for (int i = 0; i < messageCount; ++i) { vespalib::string str(vespalib::make_string("%d", i)); - //FastOS_Thread::Sleep(1); + //std::this_thread::sleep_for(1ms); auto msg = std::make_unique<SimpleMessage>(str, true, commonMessageId); msg->getTrace().setLevel(9); //LOG(debug, "Sending message %p for %d", msg.get(), i); diff --git a/messagebus/src/tests/serviceaddress/serviceaddress.cpp b/messagebus/src/tests/serviceaddress/serviceaddress.cpp index ac43cec3c02..441da5a80ac 100644 --- a/messagebus/src/tests/serviceaddress/serviceaddress.cpp +++ b/messagebus/src/tests/serviceaddress/serviceaddress.cpp @@ -82,7 +82,7 @@ Test::waitSlobrok(RPCNetwork &network, const string &pattern, size_t num) if (res.size() == num) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } diff --git a/messagebus/src/tests/slobrok/slobrok.cpp b/messagebus/src/tests/slobrok/slobrok.cpp index 7e0718283a6..439ee0b23b5 100644 --- a/messagebus/src/tests/slobrok/slobrok.cpp +++ b/messagebus/src/tests/slobrok/slobrok.cpp @@ -51,7 +51,7 @@ compare(const IMirrorAPI &api, const string &pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return false; } diff --git a/messagebus/src/tests/sourcesession/sourcesession.cpp b/messagebus/src/tests/sourcesession/sourcesession.cpp index 5177cf0e799..de04715a060 100644 --- a/messagebus/src/tests/sourcesession/sourcesession.cpp +++ b/messagebus/src/tests/sourcesession/sourcesession.cpp @@ -35,7 +35,7 @@ struct DelayedHandler : public IMessageHandler // this will block the transport thread in the server messagebus, // but that should be ok, as we only want to test the timing in the // client messagebus... - FastOS_Thread::Sleep(delay); + std::this_thread::sleep_for(std::chrono::milliseconds(delay)); session->acknowledge(std::move(msg)); } }; @@ -59,7 +59,7 @@ bool waitQueueSize(RoutableQueue &queue, uint32_t size) { if (queue.size() == size) { return true; } - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } return false; } @@ -99,7 +99,7 @@ Test::testSequencing() EXPECT_TRUE(ss->send(Message::UP(new SimpleMessage("foo", true, 2)), "dst").isAccepted()); EXPECT_TRUE(ss->send(Message::UP(new SimpleMessage("foo", true, 1)), "dst").isAccepted()); EXPECT_TRUE(waitQueueSize(dstQ, 2)); - FastOS_Thread::Sleep(250); + std::this_thread::sleep_for(250ms); EXPECT_TRUE(waitQueueSize(dstQ, 2)); EXPECT_TRUE(waitQueueSize(srcQ, 0)); ds->acknowledge(Message::UP((Message*)dstQ.dequeue(0).release())); diff --git a/messagebus/src/tests/throttling/throttling.cpp b/messagebus/src/tests/throttling/throttling.cpp index 5d3525e8ba6..76bba89b72d 100644 --- a/messagebus/src/tests/throttling/throttling.cpp +++ b/messagebus/src/tests/throttling/throttling.cpp @@ -51,7 +51,7 @@ bool waitQueueSize(RoutableQueue &queue, uint32_t size) if (queue.size() == size) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } @@ -62,7 +62,7 @@ bool waitPending(SourceSession& session, uint32_t size) if (session.getPendingCount() == size) { return true; } - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } return false; } diff --git a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp index 5ae6b07c3fa..280250a5119 100644 --- a/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp +++ b/messagebus/src/vespa/messagebus/network/rpcnetwork.cpp @@ -17,6 +17,7 @@ #include <vespa/fnet/scheduler.h> #include <vespa/fnet/transport.h> #include <vespa/fnet/frt/supervisor.h> +#include <vespa/fastos/thread.h> #include <thread> #include <vespa/log/log.h> diff --git a/messagebus/src/vespa/messagebus/testlib/testserver.cpp b/messagebus/src/vespa/messagebus/testlib/testserver.cpp index bbd23d52c0b..e0c6d6a756d 100644 --- a/messagebus/src/vespa/messagebus/testlib/testserver.cpp +++ b/messagebus/src/vespa/messagebus/testlib/testserver.cpp @@ -4,6 +4,8 @@ #include "slobrok.h" #include "slobrokstate.h" #include <vespa/vespalib/component/vtag.h> +#include <vespa/vespalib/util/time.h> +#include <thread> namespace mbus { @@ -59,7 +61,7 @@ TestServer::waitState(const SlobrokState &slobrokState) if (done) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } diff --git a/messagebus_test/src/tests/error/cpp-client.cpp b/messagebus_test/src/tests/error/cpp-client.cpp index f186be68d01..147052e0701 100644 --- a/messagebus_test/src/tests/error/cpp-client.cpp +++ b/messagebus_test/src/tests/error/cpp-client.cpp @@ -7,6 +7,8 @@ #include <vespa/messagebus/rpcmessagebus.h> #include <vespa/messagebus/network/rpcnetworkparams.h> #include <vespa/messagebus/testlib/receptor.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -45,7 +47,7 @@ App::Main() break; } } - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } if (reply.get() == 0) { fprintf(stderr, "CPP-CLIENT: no reply\n"); diff --git a/messagebus_test/src/tests/error/cpp-server.cpp b/messagebus_test/src/tests/error/cpp-server.cpp index d5200ed20c1..383f703317e 100644 --- a/messagebus_test/src/tests/error/cpp-server.cpp +++ b/messagebus_test/src/tests/error/cpp-server.cpp @@ -6,6 +6,8 @@ #include <vespa/messagebus/network/rpcnetworkparams.h> #include <vespa/messagebus/emptyreply.h> #include <vespa/messagebus/errorcode.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -55,7 +57,7 @@ App::Main() "file:routing.cfg"); Server server(mb.getMessageBus()); while (true) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/messagebus_test/src/tests/speed/cpp-client.cpp b/messagebus_test/src/tests/speed/cpp-client.cpp index 43d030b519b..ff00128037a 100644 --- a/messagebus_test/src/tests/speed/cpp-client.cpp +++ b/messagebus_test/src/tests/speed/cpp-client.cpp @@ -7,6 +7,8 @@ #include <vespa/messagebus/testlib/simplemessage.h> #include <vespa/messagebus/testlib/simpleprotocol.h> #include <vespa/messagebus/testlib/simplereply.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/timestamp.h> #include <vespa/fastos/app.h> @@ -100,7 +102,7 @@ App::Main() Client client(mb.getMessageBus(), SourceSessionParams().setTimeout(30s)); // let the system 'warm up' - FastOS_Thread::Sleep(5000); + std::this_thread::sleep_for(5s); // inject messages into the feedback loop for (uint32_t i = 0; i < 1024; ++i) { @@ -108,7 +110,7 @@ App::Main() } // let the system 'warm up' - FastOS_Thread::Sleep(5000); + std::this_thread::sleep_for(5s); fastos::StopWatch stopWatch; uint32_t okBefore = 0; @@ -117,7 +119,7 @@ App::Main() uint32_t failAfter = 0; client.sample(okBefore, failBefore); - FastOS_Thread::Sleep(10000); // Benchmark time + std::this_thread::sleep_for(10s); // Benchmark time fastos::TimeStamp elapsed = stopWatch.elapsed(); client.sample(okAfter, failAfter); double time = elapsed.ms(); diff --git a/messagebus_test/src/tests/speed/cpp-server.cpp b/messagebus_test/src/tests/speed/cpp-server.cpp index 82b884c46f2..a1aa5a5029c 100644 --- a/messagebus_test/src/tests/speed/cpp-server.cpp +++ b/messagebus_test/src/tests/speed/cpp-server.cpp @@ -6,6 +6,8 @@ #include <vespa/messagebus/testlib/simpleprotocol.h> #include <vespa/messagebus/rpcmessagebus.h> #include <vespa/messagebus/network/rpcnetworkparams.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -62,7 +64,7 @@ App::Main() "file:routing.cfg"); Server server(mb.getMessageBus()); while (true) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/messagebus_test/src/tests/trace/cpp-server.cpp b/messagebus_test/src/tests/trace/cpp-server.cpp index d6db86070b1..75f4ee3a002 100644 --- a/messagebus_test/src/tests/trace/cpp-server.cpp +++ b/messagebus_test/src/tests/trace/cpp-server.cpp @@ -5,6 +5,8 @@ #include <vespa/messagebus/rpcmessagebus.h> #include <vespa/messagebus/network/rpcnetworkparams.h> #include <vespa/messagebus/emptyreply.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/fastos/app.h> using namespace mbus; @@ -73,7 +75,7 @@ App::Main() "file:routing.cfg"); Server server(mb.getMessageBus(), _argv[1]); while (true) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/messagebus_test/src/tests/trace/trace.cpp b/messagebus_test/src/tests/trace/trace.cpp index a804bef6785..48c8d4afab0 100644 --- a/messagebus_test/src/tests/trace/trace.cpp +++ b/messagebus_test/src/tests/trace/trace.cpp @@ -36,7 +36,7 @@ waitSlobrok(RPCMessageBus &mbus, const std::string &pattern) if (res.size() > 0) { return true; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } @@ -112,7 +112,7 @@ Test::Main() } } std::cout << "Attempt " << i << " got errors, retrying in 1 second.." << std::endl; - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } EXPECT_TRUE(!reply->hasErrors()); diff --git a/metrics/src/tests/metricmanagertest.cpp b/metrics/src/tests/metricmanagertest.cpp index 1d954a641b6..6407bb73ecb 100644 --- a/metrics/src/tests/metricmanagertest.cpp +++ b/metrics/src/tests/metricmanagertest.cpp @@ -10,8 +10,10 @@ #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/xmlstream.h> -#include <vespa/log/log.h> +#include <vespa/vespalib/util/time.h> +#include <thread> +#include <vespa/log/log.h> LOG_SETUP(".test.metricmanager"); namespace metrics { @@ -386,7 +388,7 @@ bool waitForTimeProcessed(const MetricManager& mm, while (time(0) < lastchance) { if (mm.getLastProcessedTime() >= processtime) return true; mm.timeChangedNotification(); - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } return false; } diff --git a/metrics/src/tests/stresstest.cpp b/metrics/src/tests/stresstest.cpp index 4a6d2f4a2ea..f3e709b4b04 100644 --- a/metrics/src/tests/stresstest.cpp +++ b/metrics/src/tests/stresstest.cpp @@ -4,6 +4,8 @@ #include <vespa/metrics/metricmanager.h> #include <vespa/metrics/metrics.h> #include <vespa/metrics/summetric.hpp> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/log/log.h> @@ -39,7 +41,7 @@ InnerMetricSet::InnerMetricSet(const char* name, const LoadTypeSet& lt, MetricSe _valueSum.addMetricToSum(_value1); _valueSum.addMetricToSum(_value2); } -InnerMetricSet::~InnerMetricSet() { } +InnerMetricSet::~InnerMetricSet() = default; MetricSet* InnerMetricSet::clone(std::vector<Metric::UP> &ownerList, CopyType copyType, @@ -133,11 +135,10 @@ TEST(StressTest, test_stress) FastOS_ThreadPool threadPool(256 * 1024); std::vector<Hammer::UP> hammers; for (uint32_t i=0; i<10; ++i) { - hammers.push_back(Hammer::UP( - new Hammer(metrics, loadTypes, threadPool))); + hammers.push_back(std::make_unique<Hammer>(metrics, loadTypes, threadPool)); } LOG(info, "Waiting to let loadgivers hammer a while"); - FastOS_Thread::Sleep(5 * 1000); + std::this_thread::sleep_for(5s); LOG(info, "Removing loadgivers"); hammers.clear(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 6c583d960bd..14aa3ebf84e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -70,7 +70,7 @@ public class IntermediateGraph { return operations; } - void optimize() { + public void optimize() { renameDimensions(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 280fe354149..55f5d979ea8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -3,11 +3,13 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.operations.Gemm; +import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; @@ -21,6 +23,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; @@ -36,24 +39,37 @@ import java.util.stream.Collectors; */ class GraphImporter { + private static final Value eluAlpha = DoubleValue.frozen(1.0); + private static final Value seluAlpha = DoubleValue.frozen(1.6732632423543772848170429916717); + private static final Value seluGamma = DoubleValue.frozen(1.0507009873554804934193349852946); + private static final Value leakyReluAlpha = DoubleValue.frozen(0.01); + private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, IntermediateGraph graph) { + String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); + return mapOperation(type, inputs, modelName, nodeName, attributes); + } - switch (node.getOpType().toLowerCase()) { + static IntermediateOperation mapOperation(String opType, + List<IntermediateOperation> inputs, + String modelName, + String nodeName, + AttributeConverter attributes) { + switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); - case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); @@ -63,23 +79,31 @@ class GraphImporter { case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); - case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); - case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); - case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); + case "max": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.max); + case "min": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.min); + case "mean": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); + case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null); + case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt()); + case "reducelogsum":return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, null, ScalarFunctions.log()); + case "reducelogsumexp": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.exp(), ScalarFunctions.log()); + case "reducemax": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.max); case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg); + case "reducemin": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.min); + case "reduceprod": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.prod); + case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reducesumsquare": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), null); case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu(attributes.get("gamma").orElse(seluGamma).asDouble(), attributes.get("alpha").orElse(seluAlpha).asDouble())); + case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu(attributes.get("alpha").orElse(leakyReluAlpha).asDouble())); case "shape": return new Shape(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); @@ -90,7 +114,7 @@ class GraphImporter { } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); - op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + op.warning("Operation '" + opType + "' is currently not implemented"); return op; } @@ -260,5 +284,4 @@ class GraphImporter { "Either no explicit name given or no single output name."); } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java new file mode 100644 index 00000000000..ea6bb2eaf99 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java @@ -0,0 +1,76 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class ConcatReduce extends IntermediateOperation { + + private final static String tmpDimensionName = "__concat_reduce_tmp_dimension_name__"; + private final Reduce.Aggregator aggregator; + + public ConcatReduce(String modelName, String nodeName, List<IntermediateOperation> inputs, Reduce.Aggregator aggregator) { + super(modelName, nodeName, inputs); + this.aggregator = aggregator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(inputs.size())) return null; + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(inputs.size())) return null; + + TensorFunction result = inputs.get(0).function().get(); + for (int i = 1; i < inputs.size(); ++i) { + TensorFunction b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName); + } + return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! allInputTypesPresent(inputs.size())) return; + + OrderedTensorType a = inputs.get(0).type().get(); + for (int i = 1; i < inputs.size(); ++i) { + OrderedTensorType b = inputs.get(i).type().get(); + + OrderedTensorType largest = largestInput(a, b); + OrderedTensorType smallest = smallestInput(a, b); + + int sizeDifference = largest.rank() - smallest.rank(); + for (int j = 0; j < smallest.rank(); ++j) { + String bDim = smallest.dimensions().get(j).name(); + String aDim = largest.dimensions().get(j + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); + } + a = b; + } + } + + private OrderedTensorType largestInput(OrderedTensorType a, OrderedTensorType b) { + return a.rank() >= b.rank() ? a : b; + } + + private OrderedTensorType smallestInput(OrderedTensorType a, OrderedTensorType b) { + return a.rank() < b.rank() ? a : b; + } + + @Override + public ConcatReduce withInputs(List<IntermediateOperation> inputs) { + return new ConcatReduce(modelName(), name(), inputs, aggregator); + } + + @Override + public String operationName() { return "ConcatReduce"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index f091ae165d1..3fba8680332 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -92,7 +92,7 @@ public class Gemm extends IntermediateOperation { return null; } - String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose! + String joinDimension = aType.dimensions().get(1 - transposeA).name(); TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index bd302afa5c7..efd6f9d3339 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -199,7 +199,9 @@ public abstract class IntermediateOperation { String constantName = "constant(" + vespaName() + ")"; Value result = context.get(constantName); if (result == DoubleValue.NaN) { - if (inputs.size() == 0) { + if (constantValue != null) { + result = constantValue; + } else if (inputs.size() == 0) { if (getConstantValue().isEmpty()) { throw new IllegalArgumentException("Error in evaluating constant for " + name); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java index ded76db60fe..5785621eed3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java @@ -28,6 +28,9 @@ public class OnnxConcat extends IntermediateOperation { if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; OrderedTensorType aType = inputs.get(0).type().get(); + if (concatDimensionIndex < 0) { + concatDimensionIndex = aType.dimensions().size() + concatDimensionIndex; + } long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L); for (int i = 1; i < inputs.size(); ++i) { @@ -92,7 +95,7 @@ public class OnnxConcat extends IntermediateOperation { public void renameDimensions(DimensionRenamer renamer) { super.renameDimensions(renamer); concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); - } + } @Override public OnnxConcat withInputs(List<IntermediateOperation> inputs) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java index 1b2d9ac090e..7af051484f5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.function.DoubleUnaryOperator; /** * ONNX Reduce[Sum/Mean/etc] operation @@ -24,6 +25,8 @@ public class Reduce extends IntermediateOperation { private final AttributeMap attributeMap; private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator; + private final DoubleUnaryOperator preOperator; + private final DoubleUnaryOperator postOperator; private List<String> reduceDimensions; @@ -31,11 +34,23 @@ public class Reduce extends IntermediateOperation { List<IntermediateOperation> inputs, AttributeMap attributeMap, com.yahoo.tensor.functions.Reduce.Aggregator aggregator) { + this(modelName, nodeName, inputs, attributeMap, aggregator, null, null); + } + + public Reduce(String modelName, String nodeName, + List<IntermediateOperation> inputs, + AttributeMap attributeMap, + com.yahoo.tensor.functions.Reduce.Aggregator aggregator, + DoubleUnaryOperator preOperator, + DoubleUnaryOperator postOperator) { super(modelName, nodeName, inputs); this.attributeMap = attributeMap; this.aggregator = aggregator; + this.preOperator = preOperator; + this.postOperator = postOperator; } + @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; @@ -48,7 +63,7 @@ public class Reduce extends IntermediateOperation { for (Value i : attributeMap.getList("axes").get()) { int dimensionIndex = (int) i.asDouble(); if (dimensionIndex < 0) { - dimensionIndex = inputType.dimensions().size() - dimensionIndex; + dimensionIndex = inputType.dimensions().size() + dimensionIndex; } reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); } @@ -61,6 +76,9 @@ public class Reduce extends IntermediateOperation { if ( ! allInputTypesPresent(1)) return null; TensorFunction inputFunction = inputs.get(0).function().get(); + if (preOperator != null) { + inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator); + } TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions @@ -74,6 +92,9 @@ public class Reduce extends IntermediateOperation { new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); } + if (postOperator != null) { + output = new com.yahoo.tensor.functions.Map(output, postOperator); + } return output; } @@ -93,7 +114,7 @@ public class Reduce extends IntermediateOperation { @Override public Reduce withInputs(List<IntermediateOperation> inputs) { - return new Reduce(modelName(), name(), inputs, attributeMap, aggregator); + return new Reduce(modelName(), name(), inputs, attributeMap, aggregator, preOperator, postOperator); } @Override @@ -101,7 +122,7 @@ public class Reduce extends IntermediateOperation { private boolean shouldKeepDimensions() { Optional<Value> keepDims = attributeMap.get("keepdims"); - return keepDims.isPresent() && keepDims.get().asBoolean(); + return keepDims.isEmpty() || keepDims.get().asBoolean(); // default is 1 } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c7accd00619..c88fc18e6c6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; public class Reshape extends IntermediateOperation { - public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override protected OrderedTensorType lazyGetType() { - if ( ! allInputTypesPresent(2)) return null; + if (inputs.size() == 2) { + return typeWithShapeAsInput(); + } else if (inputs.size() == 1) { + return typeWithShapeAsAttribute(); + } + throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size()); + } + private OrderedTensorType typeWithShapeAsInput() { IntermediateOperation newShape = inputs.get(1); if (newShape.getConstantValue().isEmpty()) - throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant."); + OrderedTensorType inputType = inputs.get(0).type().get(); Tensor shape = newShape.getConstantValue().get().asTensor(); + List<Integer> dimSizes = new ArrayList<>(shape.type().rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + // first pass - set 0 values, meaning that size is retained from input + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) == 0) { + if (i >= inputType.dimensions().size()) { + throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input"); + } + dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue()); + } + } + + // second pass - set any -1 value, meaning that the dimension size should be expanded to fill the tensor + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) < 0) { + int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize)); + } + } + + return buildOutputType(dimSizes); + } + + private OrderedTensorType typeWithShapeAsAttribute() { + if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0) + throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty."); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); + List<Value> shape = attributeMap.getList("shape").get(); + List<Integer> dimSizes = new ArrayList<>(shape.size()); + + for (Value v : shape) { + int size = (int) v.asDouble(); if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - OrderedTensorType.tensorSize(inputType.type()).intValue(); + int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + size = -1 * shapeSize / tensorSize; } - outputTypeBuilder.add(TensorType.Dimension.indexed( - String.format("%s_%d", vespaName(), dimensionIndex), size)); - dimensionIndex++; + dimSizes.add(size); + } + return buildOutputType(dimSizes); + } + + private OrderedTensorType buildOutputType(List<Integer> dimSizes) { + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i))); } return outputTypeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { - if ( ! allInputTypesPresent(2)) return null; - if ( ! allInputFunctionsPresent(2)) return null; + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null; + if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null; OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); - return reshape(inputFunction, inputType.type(), type.type()); + return reshape(inputFunction, inputType, type); } @Override @@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation { @Override public Reshape withInputs(List<IntermediateOperation> inputs) { - return new Reshape(modelName(), name(), inputs); + return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) + public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, @@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation { // the new shape. We have to introduce temporary dimension names and rename back if dimension names // in the new and old tensor type overlap. + // Todo: change this to use tensor generate when available + List<String> from = new ArrayList<>(); List<String> to = new ArrayList<>(); boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType); if (dimensionNamesOverlap) { - TensorType.Builder builder = new TensorType.Builder(outputType.valueType()); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType()); for (int i = 0; i < outputType.rank(); ++i) { TensorType.Dimension dim = outputType.dimensions().get(i); from.add(dim.name()); to.add("temp_" + dim.name()); - builder.dimension(dim.withName("temp_" + dim.name())); + builder.add(dim.withName("temp_" + dim.name())); } outputType = builder.build(); } ExpressionNode unrollFrom = unrollTensorExpression(inputType); ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo)); + ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo)); - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build(); Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); @@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation { return result; } - private static boolean dimensionNamesOverlap(TensorType a, TensorType b) { - return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent()); + private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { + return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } - private static ExpressionNode unrollTensorExpression(TensorType type) { + private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 032ffb88a46..83086926316 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -2,8 +2,13 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Map; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import java.util.ArrayList; import java.util.List; /** @@ -13,8 +18,11 @@ import java.util.List; */ public class Softmax extends IntermediateOperation { - public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -28,18 +36,30 @@ public class Softmax extends IntermediateOperation { if ( ! allInputFunctionsPresent(1)) return null; OrderedTensorType inputType = inputs.get(0).type().get(); - String dimension = inputType.dimensions().get(0).name(); - if (inputType.rank() == 2) { - dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension + + int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension + if (attributeMap.get("axis").isPresent()) { + axis = (int)attributeMap.get("axis").get().asDouble(); + } + if (axis < 0) { + axis = inputType.rank() + axis; } + List<String> reduceDimensions = new ArrayList<>(); + for (int i = axis; i < inputType.rank(); ++i) { + reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension + } + + TensorFunction input = inputs.get(0).function().get(); + TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); - TensorFunction inputFunction = inputs.get(0).function().get(); - return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); + return div; } @Override public Softmax withInputs(List<IntermediateOperation> inputs) { - return new Softmax(modelName(), name(), inputs); + return new Softmax(modelName(), name(), inputs, attributeMap); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 4f656d86929..0d2ba0cc714 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -64,7 +64,7 @@ class GraphImporter { case "identity": return new Identity(modelName, nodeName, inputs); case "placeholder": return new Argument(modelName, nodeName, nodeType); case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "shape": return new Shape(modelName, nodeName, inputs); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); @@ -113,7 +113,7 @@ class GraphImporter { case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); // state ops case "variable": return new Constant(modelName, nodeName, nodeType); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java new file mode 100644 index 00000000000..6954abe5157 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -0,0 +1,460 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*; +import static onnx.Onnx.AttributeProto.AttributeType.FLOAT; +import static onnx.Onnx.AttributeProto.AttributeType.INT; +import static onnx.Onnx.AttributeProto.AttributeType.INTS; +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for ONNX operators. The number on the test reflects the minimum + * opset number for the operations tested. + * + * @author lesters + */ +public class OnnxOperationsTestCase { + + private static final String modelName = "test_model"; + + @Test + public void testElementwiseOperators7() throws ParseException { + Tensor x = evaluate("tensor(d0[7]):[-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]"); + assertEval("acos", x, evaluate("acos(x)", x)); + assertEval("asin", x, evaluate("asin(x)", x)); + assertEval("atan", x, evaluate("atan(x)", x)); + assertEval("cos", x, evaluate("cos(x)", x)); + assertEval("sin", x, evaluate("sin(x)", x)); + assertEval("tan", x, evaluate("tan(x)", x)); + assertEval("tanh", x, evaluate("tanh(x)", x)); + assertEval("neg", x, evaluate("-x", x)); + assertEval("sigmoid", x, evaluate("sigmoid(x)", x)); + assertEval("exp", x, evaluate("exp(x)", x)); + assertEval("floor", x, evaluate("floor(x)", x)); + assertEval("ceil", x, evaluate("ceil(x)", x)); + assertEval("abs", x, evaluate("abs(x)", x)); + + assertEval("relu", x, evaluate("max(0, x)", x)); + assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 1.0 * (exp(a)-1), a)))", x)); + assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 0.5 * (exp(a)-1), a)))", x), createAttribute("alpha", 0.5f)); + assertEval("selu", x, evaluate("map(x, f(a)(1.050700987 * if(a >= 0, a, 1.673263242 * (exp(a) - 1))))", x)); + assertEval("selu", x, evaluate("map(x, f(a)(1.0 * if(a >= 0, a, 1.5 * (exp(a) - 1))))", x), createAttributes().attr("gamma", 1.0f).attr("alpha", 1.5f).build()); + assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x)); + assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f)); + + x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]"); + assertEval("log", x, evaluate("log(x)", x)); + assertEval("sqrt", x, evaluate("sqrt(x)", x)); + assertEval("reciprocal", x, evaluate("map(x, f(a)(1.0 / a))", x)); + } + + @Test + public void testJoinOperators7() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[3, 4]"); + Tensor y = evaluate("tensor(d0[2]):[1, 2]"); + assertEval("add", x, y, evaluate("tensor(d0[2]):[4, 6]")); + assertEval("sub", x, y, evaluate("tensor(d0[2]):[2, 2]")); + assertEval("mul", x, y, evaluate("tensor(d0[2]):[3, 8]")); + assertEval("div", x, y, evaluate("tensor(d0[2]):[3, 2]")); + assertEval("greater", x, y, evaluate("tensor(d0[2]):[1, 1]")); + assertEval("less", x, y, evaluate("tensor(d0[2]):[0, 0]")); + assertEval("equal", x, y, evaluate("tensor(d0[2]):[0, 0]")); + assertEval("pow", x, y, evaluate("tensor(d0[2]):[3, 16]")); + + x = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + y = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + assertEval("add", x, y, evaluate("x + y", x, y)); + assertEval("sub", x, y, evaluate("x - y", x, y)); + assertEval("mul", x, y, evaluate("x * y", x, y)); + assertEval("div", x, y, evaluate("x / y", x, y)); + assertEval("greater", x, y, evaluate("join(x, y, f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(x, y, f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(x, y, f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(x, y, f(a,b)(pow(a,b)))", x, y)); + + // broadcasting + x = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + y = evaluate("random(d0[4]) + 1"); + assertEval("add", x, y, evaluate("x + rename(y, d0, d2)", x, y)); + assertEval("sub", x, y, evaluate("x - rename(y, d0, d2)", x, y)); + assertEval("mul", x, y, evaluate("x * rename(y, d0, d2)", x, y)); + assertEval("div", x, y, evaluate("x / rename(y, d0, d2)", x, y)); + assertEval("greater", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y)); + } + + @Test + public void testConcatReduce8() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[3, 4]"); + Tensor y = evaluate("tensor(d0[2]):[1, 2]"); + Tensor z = evaluate("tensor(d0[2]):[5, 6]"); + assertEval("max", x, y, z, evaluate("tensor(d0[2]):[5, 6]")); + assertEval("min", x, y, z, evaluate("tensor(d0[2]):[1, 2]")); + assertEval("mean", x, y, z, evaluate("tensor(d0[2]):[3, 4]")); + + x = evaluate("random(d0[2],d1[3],d2[4])"); + y = evaluate("random(d0[2],d1[3],d2[4])"); + z = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), max, tmp)", x, y, z)); + assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), min, tmp)", x, y, z)); + assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), avg, tmp)", x, y, z)); + + // broadcasting + x = evaluate("random(d0[2],d1[3],d2[4])"); + y = evaluate("random(d0[3],d1[4])"); + z = evaluate("random(d0[4])"); + assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), max, tmp)", x, y, z)); + assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), min, tmp)", x, y, z)); + assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), avg, tmp)", x, y, z)); + } + + @Test + public void testConcat4() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[1, 2]"); + Tensor y = evaluate("tensor(d0[2]):[3, 4]"); + Tensor expected = evaluate("tensor(d0[4]):[1,2,3,4]"); + assertEval("concat", x, y, expected, createAttribute("axis", 0)); + assertEval("concat", x, y, expected, createAttribute("axis", -1)); + + x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", 0)); + assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", 1)); + assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", -1)); + assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", -2)); + + x = evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 5, 6, 7, 8]"); + y = evaluate("tensor(d0[2],d1[2],d2[2]):[9,10,11,12,13,14,15,16]"); + assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", 0)); + assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", 1)); + assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", 2)); + assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", -1)); + assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", -2)); + assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", -3)); + } + + @Test + public void testGemm7() throws ParseException { + Tensor a = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + Tensor b = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + Tensor c = evaluate("tensor(d0[2],d1[2]):[0.1, 0.2, 0.3, 0.4]"); + + assertEval("gemm", a, b, evaluate("tensor(d0[2],d1[2]):[19, 22, 43, 50]")); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.3, 50.4]")); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[38.1, 44.2, 86.3, 100.4]"), createAttribute("alpha", 2.0f)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.2, 22.4, 43.6, 50.8]"), createAttribute("beta", 2.0f)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[26.1, 30.2, 38.3, 44.4]"), createAttribute("transA", 1)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[17.1, 23.2, 39.3, 53.4]"), createAttribute("transB", 1)); + + // unidictional broadcasting for c + c = evaluate("tensor(d0[2]):[0.1, 0.2]"); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.1, 50.2]")); + } + + @Test + public void testIdentity1() throws ParseException { + Tensor x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("identity", x, x); + } + + @Test + public void testMatMul1() throws ParseException { + Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]"); + Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]")); + } + + @Test + public void testReshape5() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + Tensor y = evaluate("tensor(d0[1]):[4]"); + assertEval("reshape", x, y, evaluate("tensor(d0[4]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[3]):[2,1,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,-1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,0]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[0,-1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + x = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[2]):[3,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]")); + + y = evaluate("tensor(d0[4]):[3,2,-1,1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]")); + } + + @Test + public void testReduceOperators1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]")); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("axes", new int[] {0,1})); + assertEval("reducesum", x, evaluate("tensor():[10]"), createAttribute("keepdims", 0)); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("keepdims", 1)); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[]{0})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[4, 6]"), createAttributes().attr("axes", new int[]{0}).attr("keepdims", 0).build()); + assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {1})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[]{1}).attr("keepdims", 0).build()); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[] {-2})); + assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {-1})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[] {-1}).attr("keepdims", 0).build()); + + assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[1]):[24]")); + assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[2]):[3, 8]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemin", x, evaluate("tensor(d0[1],d1[1]):[1]")); + assertEval("reducemin", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemax", x, evaluate("tensor(d0[1],d1[1]):[4]")); + assertEval("reducemax", x, evaluate("tensor(d0[1],d1[2]):[3, 4]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemean", x, evaluate("tensor():[2.5]"), createAttribute("keepdims", 0)); + assertEval("reducemean", x, evaluate("tensor(d0[2]):[2, 3]"), createAttributes().attr("axes", new int[] {0}).attr("keepdims", 0).build()); + + assertEval("reducelogsum", x, evaluate("tensor():[log(10)]"), createAttribute("keepdims", 0)); + assertEval("reducelogsumexp", x, evaluate("tensor():[log(exp(1)+exp(2)+exp(3)+exp(4))]"), createAttribute("keepdims", 0)); + assertEval("reducesumsquare", x, evaluate("tensor():[1*1+2*2+3*3+4*4]"), createAttribute("keepdims", 0)); + + x = evaluate("tensor(d0[1],d1[5]):[-10, -5, 0, 5, 10]"); + assertEval("reducel1", x, evaluate("tensor():[30]"), createAttribute("keepdims", 0)); + assertEval("reducel2", x, evaluate("tensor():[sqrt(10*10 + 5*5 + 5*5 + 10*10)]"), createAttribute("keepdims", 0)); + } + + @Test + public void testShape1() throws ParseException { + Tensor x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("shape", x, evaluate("tensor(d0[3]):[2,3,4]")); + } + + @Test + public void testSoftmax1() throws ParseException { + Tensor x = evaluate("tensor(d0[1],d1[3]):[-1, 0, 1]"); + assertEval("softmax", x, evaluate("tensor(d0[1],d1[3]):[0.09003058, 0.24472848, 0.66524094]")); + + x = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 7]"); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", 0)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", 1)); // 1 is default + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", -1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", -2)); + + x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", 0)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", 1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", 2)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", -1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", -2)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", -3)); + } + + @Test + public void testSqueeze1() throws ParseException { + Tensor x = evaluate("tensor(d0[1],d1[2]):[1, 2]"); + assertEval("squeeze", x, evaluate("tensor(d0[2]):[1, 2]")); + + x = evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]")); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[1],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0})); + assertEval("squeeze", x, evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2})); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0, 2})); + } + + @Test + public void testWhere9() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + Tensor condition = evaluate("tensor(d0[2],d1[2]):[0, 1, 0, 1]"); + assertEval("where", condition, x, y, evaluate("tensor(d0[2],d1[2]):[5, 2, 7, 4]")); + + assertEval("where", evaluate("tensor():[0]"), x, y, y); + assertEval("where", evaluate("tensor():[1]"), x, y, x); + assertEval("where", evaluate("tensor(d0[1]):[0]"), x, y, y); + assertEval("where", evaluate("tensor(d0[1]):[1]"), x, y, x); + assertEval("where", evaluate("tensor(d0[1],d1[1]):[0]"), x, y, y); + assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x); + } + + private Tensor evaluate(String expr) throws ParseException { + return evaluate(expr, null, null, null); + } + + private Tensor evaluate(String expr, Tensor x) throws ParseException { + return evaluate(expr, x, null, null); + } + + private Tensor evaluate(String expr, Tensor x, Tensor y) throws ParseException { + return evaluate(expr, x, y, null); + } + + private Tensor evaluate(String expr, Tensor x, Tensor y, Tensor z) throws ParseException { + Context context = new MapContext(DoubleValue.NaN); + if (x != null) context.put("x", new TensorValue(x)); + if (y != null) context.put("y", new TensorValue(y)); + if (z != null) context.put("z", new TensorValue(z)); + return new RankingExpression(expr).evaluate(context).asTensor(); + } + + private Tensor evaluate(IntermediateOperation op) { + Tensor tensor = op.evaluateAsConstant(op.type().get()).asTensor(); + return renameToStandardType(op, tensor); + } + + private void assertEval(String opName, Tensor x, Tensor expected) { + assertEval(opName, x, null, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, null, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, y, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { + assertEval(opName, x, y, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { + assertEval(opName, x, y, z, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { + Context context = new MapContext(DoubleValue.NaN); + List<IntermediateOperation> inputs = createInputs(context, x, y, z); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); + optimizeAndRename(opName, op); + Tensor result = evaluate(op); + assertEquals(expected, result); + assertEquals(expected.type(), result.type()); + } + + private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) { + List<IntermediateOperation> inputs = new ArrayList<>(); + addInput(inputs, context, x, "x"); + addInput(inputs, context, y, "y"); + addInput(inputs, context, z, "z"); + return inputs; + } + + private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { + if (x == null) return; + context.put(name, new TensorValue(x)); + IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString())); + op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type))); + inputs.add(op); + } + + Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) { + IndexedTensor indexedTensor = (IndexedTensor) tensor; + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); + for (int i = 0; i < indexedTensor.size(); i++) { + builder.cellByDirectIndex(type.toDirectIndex(i), indexedTensor.get(i)); + } + return builder.build(); + } + + private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) { + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(opName, op); + graph.outputs(graph.defaultSignature()).put(opName, opName); + graph.optimize(); + return op.function().get(); + } + + private Tensor renameToStandardType(IntermediateOperation op, Tensor tensor) { + OrderedTensorType operationType = op.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo); + return func.evaluate(); + } + return tensor; + } + + static AttributeConverter createAttribute(String name, int val) { + return new Attributes().attr(name, val).build(); + } + + static AttributeConverter createAttribute(String name, float val) { + return new Attributes().attr(name, val).build(); + } + + static AttributeConverter createAttribute(String name, int [] vals) { + return new Attributes().attr(name, vals).build(); + } + + static Attributes createAttributes() { + return new Attributes(); + } + + private static class Attributes { + + Onnx.NodeProto.Builder nodeBuilder; + + Attributes() { + this.nodeBuilder = Onnx.NodeProto.newBuilder(); + } + + Attributes attr(String name, int val) { + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(INT).setI(val).build()); + return this; + } + + Attributes attr(String name, float val) { + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(FLOAT).setF(val).build()); + return this; + } + + Attributes attr(String name, int [] vals) { + Onnx.AttributeProto.Builder builder = Onnx.AttributeProto.newBuilder(); + for (int val : vals) { + builder.addInts(val); + } + nodeBuilder.addAttribute(builder.setName(name).setType(INTS).build()); + return this; + } + + AttributeConverter build() { + return AttributeConverter.convert(nodeBuilder.build()); + } + + } + +} diff --git a/searchcore/src/apps/proton/proton.cpp b/searchcore/src/apps/proton/proton.cpp index b37eb5ac0cf..f80558a1bc6 100644 --- a/searchcore/src/apps/proton/proton.cpp +++ b/searchcore/src/apps/proton/proton.cpp @@ -8,6 +8,7 @@ #include <vespa/metrics/metricmanager.h> #include <vespa/vespalib/util/signalhandler.h> #include <vespa/vespalib/util/programoptions.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/config/common/exceptions.h> #include <vespa/fastos/app.h> @@ -198,7 +199,7 @@ App::Main() LOG(info, "Sleeping 900 seconds due to proton state"); int sleepLeft = 900; while (!(SIG::INT.check() || SIG::TERM.check()) && sleepLeft > 0) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1000ms); --sleepLeft; } EV_STOPPING("proton", "shutdown after stop on io errors"); @@ -226,7 +227,7 @@ App::Main() } EV_STARTED("proton"); while (!(SIG::INT.check() || SIG::TERM.check() || (spiProton && spiProton->getNode().attemptedStopped()))) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1000ms); if (spiProton && spiProton->configUpdated()) { storage::ResumeGuard guard(spiProton->getNode().pause()); spiProton->updateConfig(); @@ -240,7 +241,7 @@ App::Main() if (spiProton) { // report down state to cluster controller. spiProton->getNode().notifyPartitionDown(0, "proton state string is " + stateString); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1000ms); } EV_STOPPING("proton", "shutdown after new stop on io errors"); return 1; diff --git a/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp b/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp index 3b3b5f412d2..dcd3dce218b 100644 --- a/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp +++ b/searchcore/src/apps/vespa-proton-cmd/vespa-proton-cmd.cpp @@ -6,8 +6,10 @@ #include <vespa/fnet/frt/frt.h> #include <vespa/vespalib/util/host_name.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fastos/app.h> #include <sys/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("vespa-proton-cmd"); @@ -115,7 +117,7 @@ public: slobrok::api::MirrorAPI sbmirror(_frt->supervisor(), sbcfg); for (int timeout = 1; timeout < 20; timeout++) { if (!sbmirror.ready()) { - FastOS_Thread::Sleep(50*timeout); + std::this_thread::sleep_for(50ms*timeout); } } if (!sbmirror.ready()) { @@ -123,12 +125,9 @@ public: "ERROR: no data from service location broker\n"); exit(1); } - slobrok::api::MirrorAPI::SpecList specs = - sbmirror.lookup(rtcPattern); - slobrok::api::MirrorAPI::SpecList specs2 = - sbmirror.lookup(rtcPattern2); - slobrok::api::MirrorAPI::SpecList specs3 = - sbmirror.lookup(rtcPattern3); + slobrok::api::MirrorAPI::SpecList specs = sbmirror.lookup(rtcPattern); + slobrok::api::MirrorAPI::SpecList specs2 = sbmirror.lookup(rtcPattern2); + slobrok::api::MirrorAPI::SpecList specs3 = sbmirror.lookup(rtcPattern3); int found = 0; std::string service; @@ -167,7 +166,7 @@ public: slobrok::api::MirrorAPI sbmirror(_frt->supervisor(), sbcfg); for (int timeout = 1; timeout < 20; timeout++) { if (!sbmirror.ready()) { - FastOS_Thread::Sleep(50*timeout); + std::this_thread::sleep_for(50ms*timeout); } } if (!sbmirror.ready()) { diff --git a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp index f5aa74d85e1..5775c31b205 100644 --- a/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp +++ b/searchcore/src/apps/vespa-transactionlog-inspect/vespa-transactionlog-inspect.cpp @@ -7,6 +7,7 @@ #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/vespalib/util/programoptions.h> #include <vespa/vespalib/util/xmlstream.h> +#include <vespa/vespalib/util/time.h> #include <vespa/document/config/config-documenttypes.h> #include <vespa/document/repo/documenttyperepo.h> #include <vespa/document/fieldvalue/document.h> @@ -14,6 +15,7 @@ #include <vespa/config/helper/configgetter.hpp> #include <vespa/fastos/app.h> #include <iostream> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("vespa-transactionlog-inspect"); @@ -491,7 +493,7 @@ protected: return 1; } for (size_t i = 0; !callback.isEof() && (i < 60 * 60); i++ ) { - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } return 0; } diff --git a/searchcore/src/tests/proton/attribute/attributeflush_test.cpp b/searchcore/src/tests/proton/attribute/attributeflush_test.cpp index c41ac3a3a1c..ea3d2c335fb 100644 --- a/searchcore/src/tests/proton/attribute/attributeflush_test.cpp +++ b/searchcore/src/tests/proton/attribute/attributeflush_test.cpp @@ -516,21 +516,21 @@ Test::requireThatLastFlushTimeIsReported() EXPECT_EQUAL(fastos::UTCTimeStamp::ZERO, ft->getLastFlushTime()); ft->initFlush(200)->run(); EXPECT_TRUE(FastOS_File::Stat("flush/a9/snapshot-200", &stat)); - EXPECT_EQUAL(stat._modifiedTime, ft->getLastFlushTime().timeSinceEpoch().time()); + EXPECT_EQUAL(stat._modifiedTime, ft->getLastFlushTime().time_since_epoch().time()); } { // snapshot flushed AttributeManagerFixture amf(f); AttributeManager &am = amf._m; amf.addAttribute("a9"); IFlushTarget::SP ft = am.getFlushable("a9"); - EXPECT_EQUAL(stat._modifiedTime, ft->getLastFlushTime().timeSinceEpoch().time()); + EXPECT_EQUAL(stat._modifiedTime, ft->getLastFlushTime().time_since_epoch().time()); { // updated flush time after nothing to flush std::this_thread::sleep_for(8000ms); - fastos::TimeStamp now = fastos::ClockSystem::now().timeSinceEpoch(); + fastos::TimeStamp now = fastos::ClockSystem::now().time_since_epoch(); Executor::Task::UP task = ft->initFlush(200); EXPECT_TRUE(task.get() == NULL); - EXPECT_LESS(stat._modifiedTime, ft->getLastFlushTime().timeSinceEpoch().time()); - EXPECT_APPROX(now.time(), ft->getLastFlushTime().timeSinceEpoch().time(), 8); + EXPECT_LESS(stat._modifiedTime, ft->getLastFlushTime().time_since_epoch().time()); + EXPECT_APPROX(now.time(), ft->getLastFlushTime().time_since_epoch().time(), 8); } } } diff --git a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp index fdd53d629ad..dfac2edad61 100644 --- a/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp +++ b/searchcore/src/tests/proton/documentdb/maintenancecontroller/maintenancecontroller_test.cpp @@ -1064,7 +1064,7 @@ TEST_F("require that document pruner is active", MyFrozenBucket::UP frozen3(new MyFrozenBucket(f._mc, bucketId3)); f.setPruneConfig(DocumentDBPruneRemovedDocumentsConfig(0.2, 900.0)); for (uint32_t i = 0; i < 6; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); ASSERT_TRUE(f._executor.waitIdle(TIMEOUT_SEC)); if (f._removed.getNumUsedLids() != 10u) break; @@ -1073,7 +1073,7 @@ TEST_F("require that document pruner is active", EXPECT_EQUAL(10u, f._removed.getDocumentCount()); frozen3.reset(); for (uint32_t i = 0; i < 600; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); ASSERT_TRUE(f._executor.waitIdle(TIMEOUT_SEC)); if (f._removed.getNumUsedLids() != 10u) break; @@ -1089,7 +1089,7 @@ TEST_F("require that heartbeats are scheduled", f.startMaintenance(); f.setHeartBeatConfig(DocumentDBHeartBeatConfig(0.2)); for (uint32_t i = 0; i < 600; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); if (f._fh.getHeartBeats() != 0u) break; } @@ -1104,7 +1104,7 @@ TEST_F("require that periodic session prunings are scheduled", f.startMaintenance(); f.setGroupingSessionPruneInterval(0.2); for (uint32_t i = 0; i < 600; ++i) { - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); if (f._gsp.isInvoked) { break; } @@ -1233,7 +1233,7 @@ TEST_F("require that a blocked job is unblocked and executed after thaw bucket", EXPECT_FALSE(myJob2.isBlocked()); bool done1 = myJob1._latch.await(TIMEOUT_MS); EXPECT_TRUE(done1); - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); EXPECT_EQUAL(0u, myJob2._runCnt); } @@ -1245,7 +1245,7 @@ TEST_F("require that blocked jobs are not executed", MaintenanceControllerFixtur f._mc.registerJobInMasterThread(std::move(job)); f._injectDefaultJobs = false; f.startMaintenance(); - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); EXPECT_EQUAL(0u, myJob._runCnt); } diff --git a/searchcore/src/tests/proton/flushengine/flushengine_test.cpp b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp index bfd4450b1f2..c18d393b98f 100644 --- a/searchcore/src/tests/proton/flushengine/flushengine_test.cpp +++ b/searchcore/src/tests/proton/flushengine/flushengine_test.cpp @@ -685,7 +685,7 @@ assertThatHandlersInCurrentSet(FlushEngine & engine, const std::vector<const cha { FlushEngine::FlushMetaSet current1 = engine.getCurrentlyFlushingSet(); while ((current1.size() < targets.size()) || !asserCorrectHandlers(current1, targets)) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); current1 = engine.getCurrentlyFlushingSet(); } } diff --git a/searchcore/src/tests/proton/index/indexmanager_test.cpp b/searchcore/src/tests/proton/index/indexmanager_test.cpp index 51e12e70dda..1941e82a2db 100644 --- a/searchcore/src/tests/proton/index/indexmanager_test.cpp +++ b/searchcore/src/tests/proton/index/indexmanager_test.cpp @@ -280,7 +280,7 @@ TEST_F(IndexManagerTest, require_that_memory_index_is_flushed) runAsMaster([&]() { flushTask = target.initFlush(1); }); flushTask->run(); EXPECT_TRUE(FastOS_File::Stat("test_data/index.flush.1", &stat)); - EXPECT_EQ(stat._modifiedTime, target.getLastFlushTime().timeSinceEpoch().time()); + EXPECT_EQ(stat._modifiedTime, target.getLastFlushTime().time_since_epoch().time()); sources = get_source_collection(); EXPECT_EQ(2u, sources->getSourceCount()); @@ -298,17 +298,17 @@ TEST_F(IndexManagerTest, require_that_memory_index_is_flushed) { // verify last flush time when loading disk index resetIndexManager(); IndexFlushTarget target(_index_manager->getMaintainer()); - EXPECT_EQ(stat._modifiedTime, target.getLastFlushTime().timeSinceEpoch().time()); + EXPECT_EQ(stat._modifiedTime, target.getLastFlushTime().time_since_epoch().time()); // updated serial number & flush time when nothing to flush std::this_thread::sleep_for(8s); - fastos::TimeStamp now = fastos::ClockSystem::now().timeSinceEpoch(); + fastos::TimeStamp now = fastos::ClockSystem::now().time_since_epoch(); vespalib::Executor::Task::UP task; runAsMaster([&]() { task = target.initFlush(2); }); EXPECT_TRUE(task.get() == nullptr); EXPECT_EQ(2u, target.getFlushedSerialNum()); - EXPECT_LT(stat._modifiedTime, target.getLastFlushTime().timeSinceEpoch().time()); - EXPECT_NEAR(now.time(), target.getLastFlushTime().timeSinceEpoch().time(), 8); + EXPECT_LT(stat._modifiedTime, target.getLastFlushTime().time_since_epoch().time()); + EXPECT_NEAR(now.time(), target.getLastFlushTime().time_since_epoch().time(), 8); } } diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp index d86a750794f..870be2ab409 100644 --- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp +++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp @@ -175,7 +175,7 @@ struct ProtonConfigOwner : public proton::IProtonConfigurer while (timer.elapsed().ms() < timeout) { if (getConfigured()) return true; - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return getConfigured(); } diff --git a/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp b/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp index 66d1c27a5b5..0e0e96aa871 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp @@ -211,7 +211,7 @@ FlushableAttribute::internalInitFlush(SerialNum currentSerial) if (syncToken <= getFlushedSerialNum()) { writer->setLastFlushTime(fastos::ClockSystem::now()); LOG(debug,"No attribute vector to flush. Update flush time to current: lastFlushTime(%f)", - getLastFlushTime().timeSinceEpoch().sec()); + getLastFlushTime().time_since_epoch().sec()); return Task::UP(); } return std::make_unique<Flusher>(*this, syncToken, *writer); diff --git a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp index fcca1c2a737..eaf76ebdc8f 100644 --- a/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp +++ b/searchcore/src/vespa/searchcore/proton/common/attribute_updater.cpp @@ -22,7 +22,6 @@ #include <vespa/searchlib/attribute/changevector.hpp> #include <vespa/searchlib/attribute/predicate_attribute.h> #include <vespa/searchlib/attribute/reference_attribute.h> -#include <vespa/searchlib/common/base.h> #include <vespa/searchlib/tensor/tensor_attribute.h> #include <vespa/vespalib/util/stringfmt.h> #include <sstream> diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreflushtarget.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreflushtarget.cpp index 2cadedd0f59..71274bf444e 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreflushtarget.cpp +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/documentmetastoreflushtarget.cpp @@ -223,7 +223,7 @@ DocumentMetaStoreFlushTarget::initFlush(SerialNum currentSerial) if (syncToken <= getFlushedSerialNum()) { writer->setLastFlushTime(fastos::ClockSystem::now()); LOG(debug, "No document meta store to flush. Update flush time to current: lastFlushTime(%f)", - getLastFlushTime().timeSinceEpoch().sec()); + getLastFlushTime().time_since_epoch().sec()); return Task::UP(); } return std::make_unique<Flusher>(*this, syncToken, *writer); diff --git a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp index 37f1664841a..8a4cb1682a6 100644 --- a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.cpp @@ -1,10 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "disk_mem_usage_sampler.h" -#include <vespa/vespalib/util/timer.h> +#include <vespa/vespalib/util/scheduledexecutor.h> #include <vespa/vespalib/util/lambdatask.h> #include <filesystem> -#include <unistd.h> using vespalib::makeLambdaTask; @@ -32,7 +31,7 @@ DiskMemUsageSampler::setConfig(const Config &config) _filter.setConfig(config.filterConfig); _sampleInterval = config.sampleInterval; sampleUsage(); - _periodicTimer = std::make_unique<vespalib::Timer>(); + _periodicTimer = std::make_unique<vespalib::ScheduledExecutor>(); _periodicTimer->scheduleAtFixedRate(makeLambdaTask([this]() { sampleUsage(); }), _sampleInterval, _sampleInterval); diff --git a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h index 5a439e69003..2ab13f2f48a 100644 --- a/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h +++ b/searchcore/src/vespa/searchcore/proton/server/disk_mem_usage_sampler.h @@ -4,7 +4,7 @@ #include "disk_mem_usage_filter.h" -namespace vespalib { class Timer; } +namespace vespalib { class ScheduledExecutor; } namespace proton { @@ -15,7 +15,7 @@ class DiskMemUsageSampler { DiskMemUsageFilter _filter; std::filesystem::path _path; double _sampleInterval; - std::unique_ptr<vespalib::Timer> _periodicTimer; + std::unique_ptr<vespalib::ScheduledExecutor> _periodicTimer; void sampleUsage(); void sampleDiskUsage(); diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp index a2672cc7972..893748ae49e 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp @@ -6,7 +6,7 @@ #include "i_blockable_maintenance_job.h" #include <vespa/searchcorespi/index/i_thread_service.h> #include <vespa/vespalib/util/closuretask.h> -#include <vespa/vespalib/util/timer.h> +#include <vespa/vespalib/util/scheduledexecutor.h> #include <vespa/log/log.h> LOG_SETUP(".proton.server.maintenancecontroller"); @@ -167,7 +167,7 @@ MaintenanceController::restart() if (!_started || _stopping || !_readySubDB.valid()) { return; } - _periodicTimer.reset(new vespalib::Timer()); + _periodicTimer = std::make_unique<vespalib::ScheduledExecutor>(); addJobsToPeriodicTimer(); } diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h index 24c1c18959e..3cfdeba4d34 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h +++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h @@ -8,6 +8,7 @@ #include "ibucketfreezelistener.h" #include <vespa/searchcore/proton/common/doctypename.h> #include <mutex> +#include <vespa/vespalib/util/scheduledexecutor.h> namespace vespalib { class Timer; @@ -77,7 +78,7 @@ private: MaintenanceDocumentSubDB _readySubDB; MaintenanceDocumentSubDB _remSubDB; MaintenanceDocumentSubDB _notReadySubDB; - std::unique_ptr<vespalib::Timer> _periodicTimer; + std::unique_ptr<vespalib::ScheduledExecutor> _periodicTimer; DocumentDBMaintenanceConfigSP _config; FrozenBuckets _frozenBuckets; bool _started; diff --git a/searchcore/src/vespa/searchcore/proton/server/memoryflush.cpp b/searchcore/src/vespa/searchcore/proton/server/memoryflush.cpp index 8ea9e095385..86e818e0fe3 100644 --- a/searchcore/src/vespa/searchcore/proton/server/memoryflush.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/memoryflush.cpp @@ -140,7 +140,7 @@ MemoryFlush::getFlushTargets(const FlushContext::List &targetList, config.maxGlobalMemory, config.maxGlobalTlsSize, config.globalDiskBloatFactor, config.maxMemoryGain, config.diskBloatFactor, config.maxTimeGain.sec(), - _startTime.timeSinceEpoch().sec()); + _startTime.time_since_epoch().sec()); for (size_t i(0), m(targetList.size()); i < m; i++) { const IFlushTarget & target(*targetList[i]->getTarget()); const IFlushHandler & handler(*targetList[i]->getHandler()); @@ -183,8 +183,8 @@ MemoryFlush::getFlushTargets(const FlushContext::List &targetList, target.getFlushedSerialNum(), localLastSerial, serialDiff, - lastFlushTime.timeSinceEpoch().sec(), - now.timeSinceEpoch().sec(), + lastFlushTime.time_since_epoch().sec(), + now.time_since_epoch().sec(), timeDiff.sec(), getOrderName(order).c_str()); } diff --git a/searchcore/src/vespa/searchcore/proton/server/pruneremoveddocumentsjob.cpp b/searchcore/src/vespa/searchcore/proton/server/pruneremoveddocumentsjob.cpp index b5a430b3bcb..8bf96b4d370 100644 --- a/searchcore/src/vespa/searchcore/proton/server/pruneremoveddocumentsjob.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/pruneremoveddocumentsjob.cpp @@ -67,7 +67,7 @@ PruneRemovedDocumentsJob::run() uint64_t tshz = 1000000; fastos::UTCTimeStamp now = fastos::ClockSystem::now(); const Timestamp ageLimit(static_cast<Timestamp::Type> - ((now.timeSinceEpoch().sec() - _cfgAgeLimit) * tshz)); + ((now.time_since_epoch().sec() - _cfgAgeLimit) * tshz)); DocId lid(_nextLid); const DocId olid(lid); const DocId docIdLimit(_metaStore.getCommittedDocIdLimit()); diff --git a/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp b/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp index da3058e9bbd..e2ab5f5ff1f 100644 --- a/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp +++ b/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp @@ -944,7 +944,7 @@ IndexMaintainer::initFlush(SerialNum serialNum, searchcorespi::FlushStats * stat _lastFlushTime = fastos::ClockSystem::now(); LOG(debug, "No memory index to flush. Update serial number and flush time to current: " "flushSerialNum(%" PRIu64 "), lastFlushTime(%f)", - _flush_serial_num, _lastFlushTime.timeSinceEpoch().sec()); + _flush_serial_num, _lastFlushTime.time_since_epoch().sec()); return FlushTask::UP(); } SerialNum realSerialNum = args.flush_serial_num; diff --git a/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp b/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp index 4e2450dccc9..9ec5cc8857b 100644 --- a/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp +++ b/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp @@ -764,7 +764,7 @@ TEST("requireThatFlushTimeIsAvailableAfterFlush") { MyTlSyncer tlSyncer; LogDataStore store(executor, testDir.getDir(), config, GrowStrategy(), TuneFileSummary(), fileHeaderContext, tlSyncer, nullptr); - EXPECT_EQUAL(0, store.getLastFlushTime().timeSinceEpoch().time()); + EXPECT_EQUAL(0, store.getLastFlushTime().time_since_epoch().time()); uint64_t flushToken = store.initFlush(5); EXPECT_EQUAL(5u, flushToken); store.flush(flushToken); diff --git a/searchlib/src/tests/postinglistbm/stress_runner.cpp b/searchlib/src/tests/postinglistbm/stress_runner.cpp index 53b683cd7fd..100a4fcd70d 100644 --- a/searchlib/src/tests/postinglistbm/stress_runner.cpp +++ b/searchlib/src/tests/postinglistbm/stress_runner.cpp @@ -8,9 +8,11 @@ #include <vespa/searchlib/test/fakedata/fakeword.h> #include <vespa/searchlib/test/fakedata/fakewordset.h> #include <vespa/searchlib/test/fakedata/fpfactory.h> +#include <vespa/vespalib/util/time.h> #include <condition_variable> #include <mutex> #include <vector> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".stress_runner"); @@ -306,7 +308,7 @@ StressMaster::run() totalTime / _loops, type.c_str()); dropPostings(); } - FastOS_Thread::Sleep(250); + std::this_thread::sleep_for(250ms); } double diff --git a/searchlib/src/tests/transactionlog/translogclient_test.cpp b/searchlib/src/tests/transactionlog/translogclient_test.cpp index 8a515f749f1..c4751af5adb 100644 --- a/searchlib/src/tests/transactionlog/translogclient_test.cpp +++ b/searchlib/src/tests/transactionlog/translogclient_test.cpp @@ -248,7 +248,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor = tls.createVisitor("test1", ca); ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(5, 7) ); - for (size_t i(0); ! ca._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca._eof ); ASSERT_TRUE( ca.map().size() == 1); ASSERT_TRUE( ca.hasSerial(7) ); @@ -257,7 +257,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor1 = tls.createVisitor("test1", ca1); ASSERT_TRUE(visitor1.get()); ASSERT_TRUE( visitor1->visit(4, 5) ); - for (size_t i(0); ! ca1._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca1._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca1._eof ); ASSERT_TRUE( ca1.map().size() == 0); @@ -265,7 +265,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor2 = tls.createVisitor("test1", ca2); ASSERT_TRUE(visitor2.get()); ASSERT_TRUE( visitor2->visit(5, 6) ); - for (size_t i(0); ! ca2._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca2._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca2._eof ); ASSERT_TRUE( ca2.map().size() == 0); @@ -273,7 +273,7 @@ bool Test::partialUpdateTest() TransLogClient::Visitor::UP visitor3 = tls.createVisitor("test1", ca3); ASSERT_TRUE(visitor3.get()); ASSERT_TRUE( visitor3->visit(5, 1000) ); - for (size_t i(0); ! ca3._eof && (i < 1000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca3._eof && (i < 1000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca3._eof ); ASSERT_TRUE( ca3.map().size() == 1); ASSERT_TRUE( ca3.hasSerial(7) ); @@ -437,7 +437,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c TransLogClient::Visitor::UP visitor = tls.createVisitor(name, ca); ASSERT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(0, 1) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( ca.hasSerial(1) ); @@ -447,7 +447,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c visitor = tls.createVisitor(name, ca); ASSERT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(1, 2) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( ! ca.hasSerial(1) ); @@ -458,7 +458,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c visitor = tls.createVisitor(name, ca); EXPECT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(0, 3) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( ca.hasSerial(1) ); @@ -469,7 +469,7 @@ bool Test::visitDomainTest(TransLogClient & tls, TransLogClient::Session * s1, c visitor = tls.createVisitor(name, ca); ASSERT_TRUE(visitor.get()); EXPECT_TRUE( visitor->visit(2, 3) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } EXPECT_TRUE( ca._eof ); EXPECT_TRUE( ! ca.hasSerial(0) ); EXPECT_TRUE( !ca.hasSerial(1) ); @@ -575,7 +575,7 @@ assertVisitStats(TransLogClient &tls, const vespalib::string &domain, ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(visitStart, visitEnd) ); for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } ASSERT_TRUE(ca._eof); EXPECT_EQUAL(expFirstSerial, ca._firstSerial); @@ -623,7 +623,7 @@ void Test::testMany() TransLogClient::Visitor::UP visitor = tls.createVisitor("many", ca); ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(2, TOTAL_NUM_ENTRIES) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca._eof ); EXPECT_EQUAL(ca._count, TOTAL_NUM_ENTRIES); EXPECT_EQUAL(ca._value, TOTAL_NUM_ENTRIES); @@ -644,7 +644,7 @@ void Test::testMany() TransLogClient::Visitor::UP visitor = tls.createVisitor("many", ca); ASSERT_TRUE(visitor.get()); ASSERT_TRUE( visitor->visit(2, TOTAL_NUM_ENTRIES) ); - for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { FastOS_Thread::Sleep(10); } + for (size_t i(0); ! ca._eof && (i < 60000); i++ ) { std::this_thread::sleep_for(10ms); } ASSERT_TRUE( ca._eof ); EXPECT_EQUAL(ca._count, TOTAL_NUM_ENTRIES); EXPECT_EQUAL(ca._value, TOTAL_NUM_ENTRIES); diff --git a/searchlib/src/tests/transactionlogstress/translogstress.cpp b/searchlib/src/tests/transactionlogstress/translogstress.cpp index a047c5e1657..2ec193cfe45 100644 --- a/searchlib/src/tests/transactionlogstress/translogstress.cpp +++ b/searchlib/src/tests/transactionlogstress/translogstress.cpp @@ -11,8 +11,11 @@ #include <iostream> #include <stdexcept> #include <sstream> +#include <thread> #include <vespa/log/log.h> +#include <vespa/vespalib/util/time.h> + LOG_SETUP("translogstress"); using document::ByteBuffer; @@ -267,7 +270,7 @@ FeederThread::doRun() int64_t milliSecsUsed = _timer.elapsed().ms(); if (milliSecsUsed < 1000) { //LOG(info, "FeederThread: sleep %u ms", 1000 - milliSecsUsed); - FastOS_Thread::Sleep(1000 - milliSecsUsed); + std::this_thread::sleep_for(std::chrono::milliseconds(1000 - milliSecsUsed)); } else { LOG(info, "FeederThread: max throughput"); } @@ -457,7 +460,7 @@ private: EntryGenerator _generator; std::vector<std::shared_ptr<VisitorAgent> > _visitors; std::vector<std::shared_ptr<VisitorAgent> > _rndVisitors; - uint64_t _visitorInterval; // in milliseconds + vespalib::duration _visitorInterval; // in milliseconds int64_t _pruneInterval; // in milliseconds fastos::StopWatch _pruneTimer; SerialNum _begin; @@ -481,14 +484,14 @@ ControllerThread::ControllerThread(const std::string & tlsSpec, const std::strin const EntryGenerator & generator, uint32_t numVisitors, uint64_t visitorInterval, uint64_t pruneInterval) : _tlsSpec(tlsSpec), _domain(domain), _client(tlsSpec.c_str()), _session(), - _generator(generator), _visitors(), _rndVisitors(), _visitorInterval(visitorInterval), + _generator(generator), _visitors(), _rndVisitors(), _visitorInterval(std::chrono::milliseconds(visitorInterval)), _pruneInterval(pruneInterval), _pruneTimer(), _begin(0), _end(0), _count(0) { for (uint32_t i = 0; i < numVisitors; ++i) { _visitors.push_back(std::make_shared<VisitorAgent>(tlsSpec, domain, generator, i, true)); } } -ControllerThread::~ControllerThread() {} +ControllerThread::~ControllerThread() = default; void ControllerThread::getStatus() @@ -553,7 +556,7 @@ ControllerThread::doRun() } _pruneTimer.restart(); } - FastOS_Thread::Sleep(_visitorInterval); + std::this_thread::sleep_for(_visitorInterval); } } @@ -569,7 +572,7 @@ private: uint64_t domainPartSize; size_t packetSize; - uint64_t stressTime; + std::chrono::milliseconds stressTime; uint32_t feedRate; uint32_t numVisitors; uint64_t visitorInterval; @@ -598,7 +601,7 @@ void TransLogStress::printConfig() { std::cout << "######## Config ########" << std::endl; - std::cout << "stressTime: " << _cfg.stressTime / 1000 << " s" << std::endl; + std::cout << "stressTime: " << vespalib::to_s(_cfg.stressTime) << " s" << std::endl; std::cout << "feedRate: " << _cfg.feedRate << " per/sec" << std::endl; std::cout << "numVisitors: " << _cfg.numVisitors << std::endl; std::cout << "visitorInterval: " << _cfg.visitorInterval << " ms" << std::endl; @@ -628,7 +631,7 @@ TransLogStress::Main() _cfg.domainPartSize = 8000000; // ~8MB _cfg.packetSize = 0x10000; - _cfg.stressTime = 1000 * 60; + _cfg.stressTime = std::chrono::milliseconds(1000 * 60); _cfg.feedRate = 10000; _cfg.numVisitors = 1; _cfg.visitorInterval = 1000 * 1; @@ -639,7 +642,7 @@ TransLogStress::Main() _cfg.maxStrLen = 80; _cfg.baseSeed = 100; - uint64_t sleepTime = 4000; + vespalib::duration sleepTime = 4s; int idx = 1; char opt; @@ -654,7 +657,7 @@ TransLogStress::Main() _cfg.packetSize = atol(arg); break; case 't': - _cfg.stressTime = 1000 * atol(arg); + _cfg.stressTime = std::chrono::milliseconds(1000 * atol(arg)); break; case 'f': _cfg.feedRate = atoi(arg); @@ -690,7 +693,7 @@ TransLogStress::Main() } printConfig(); - FastOS_Thread::Sleep(sleepTime); + std::this_thread::sleep_for(sleepTime); if (_argc != idx || optError) { usage(); @@ -721,13 +724,13 @@ TransLogStress::Main() FeederThread feeder(tlsSpec, domain, generator, _cfg.feedRate, _cfg.packetSize); threadPool.NewThread(&feeder); - FastOS_Thread::Sleep(sleepTime); + std::this_thread::sleep_for(sleepTime); ControllerThread controller(tlsSpec, domain, generator, _cfg.numVisitors, _cfg.visitorInterval, _cfg.pruneInterval); threadPool.NewThread(&controller); // stop feeder and controller - FastOS_Thread::Sleep(_cfg.stressTime); + std::this_thread::sleep_for(_cfg.stressTime); printConfig(); LOG(info, "Stop feeder..."); feeder.stop(); @@ -735,7 +738,7 @@ TransLogStress::Main() std::cout << "<feeder>" << std::endl; std::cout << " <from>" << feeder.getRange().from() << "</from>" << std::endl; std::cout << " <to>" << feeder.getRange().to() << "</to>" << std::endl; - std::cout << " <rate>" << 1000 * (feeder.getRange().to() - feeder.getRange().from()) / (sleepTime + _cfg.stressTime) + std::cout << " <rate>" << 1000 * (feeder.getRange().to() - feeder.getRange().from()) / vespalib::count_ms(sleepTime + _cfg.stressTime) << "</rate>" << std::endl; std::cout << "</feeder>" << std::endl; @@ -743,7 +746,7 @@ TransLogStress::Main() controller.stop(); controller.join(); - FastOS_Thread::Sleep(sleepTime); + std::this_thread::sleep_for(sleepTime); std::vector<std::shared_ptr<VisitorAgent> > & visitors = controller.getVisitors(); for (size_t i = 0; i < visitors.size(); ++i) { std::cout << "<visitor id='" << i << "'>" << std::endl; diff --git a/searchlib/src/vespa/searchlib/docstore/logdatastore.cpp b/searchlib/src/vespa/searchlib/docstore/logdatastore.cpp index 19b945ddd9c..95c02e57dd1 100644 --- a/searchlib/src/vespa/searchlib/docstore/logdatastore.cpp +++ b/searchlib/src/vespa/searchlib/docstore/logdatastore.cpp @@ -656,7 +656,7 @@ LogDataStore::createWritableFile(FileId fileId, SerialNum serialNum, NameId name FileChunk::UP LogDataStore::createWritableFile(FileId fileId, SerialNum serialNum) { - return createWritableFile(fileId, serialNum, NameId(fastos::ClockSystem::now().timeSinceEpoch())); + return createWritableFile(fileId, serialNum, NameId(fastos::ClockSystem::now().time_since_epoch())); } namespace { diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp index 6d11ab1f5eb..37903bc21f5 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp @@ -2,12 +2,14 @@ #include "translogserver.h" #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/io/fileutil.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/fnet/frt/supervisor.h> #include <vespa/fnet/frt/rpcrequest.h> #include <vespa/fnet/task.h> #include <vespa/fnet/transport.h> #include <fstream> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".transactionlog.server"); @@ -125,7 +127,7 @@ TransLogServer::TransLogServer(const vespalib::string &name, int listenPort, con listenOk = true; } else { LOG(warning, "Failed listening at port %s trying for %d seconds more.", listenSpec, i); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } } if ( ! listenOk ) { diff --git a/slobrok/src/tests/configure/configure.cpp b/slobrok/src/tests/configure/configure.cpp index bf41b77ab05..fa509c17d0c 100644 --- a/slobrok/src/tests/configure/configure.cpp +++ b/slobrok/src/tests/configure/configure.cpp @@ -85,7 +85,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } SpecList actual(api.lookup(pattern)); std::cerr << "Actual: " << actual.strVal() << std::endl; @@ -176,7 +176,7 @@ Test::Main() srv2Builder.slobrok[0].connectionspec = createSpec(18525); cfgCtx->reload(); - FastOS_Thread::Sleep(6000); // reconfiguration time + std::this_thread::sleep_for(6s); // reconfiguration time reg1.registerName("A"); reg2.registerName("B"); diff --git a/slobrok/src/tests/mirrorapi/mirrorapi.cpp b/slobrok/src/tests/mirrorapi/mirrorapi.cpp index b25e338533c..53e194fad2d 100644 --- a/slobrok/src/tests/mirrorapi/mirrorapi.cpp +++ b/slobrok/src/tests/mirrorapi/mirrorapi.cpp @@ -112,7 +112,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return false; } @@ -124,7 +124,7 @@ Test::Main() TEST_INIT("mirrorapi_test"); SlobrokServer mock(18501); - FastOS_Thread::Sleep(300); + std::this_thread::sleep_for(300ms); Server a("A/x/w", 18502, "tcp/localhost:18501"); Server b("B/x", 18503, "tcp/localhost:18501"); @@ -143,7 +143,7 @@ Test::Main() MirrorAPI mirror(supervisor, config::ConfigUri::createFromInstance(specBuilder)); EXPECT_TRUE(!mirror.ready()); transport.Start(&threadPool); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); a.reg(); EXPECT_TRUE(compare(mirror, "A/x/w", SpecList().add("A/x/w", "tcp/localhost:18502"))); diff --git a/slobrok/src/tests/registerapi/registerapi.cpp b/slobrok/src/tests/registerapi/registerapi.cpp index ac7e662c6f2..92f08ee41cb 100644 --- a/slobrok/src/tests/registerapi/registerapi.cpp +++ b/slobrok/src/tests/registerapi/registerapi.cpp @@ -64,7 +64,7 @@ compare(MirrorAPI &api, const char *pattern, SpecList expect) if (actual == expect) { return true; } - FastOS_Thread::Sleep(100); + std::this_thread::sleep_for(100ms); } return false; } @@ -75,7 +75,7 @@ Test::Main() TEST_INIT("registerapi_test"); SlobrokServer mock(18548); - FastOS_Thread::Sleep(300); + std::this_thread::sleep_for(300ms); cloud::config::SlobroksConfigBuilder slobrokSpecs; cloud::config::SlobroksConfig::Slobrok sb; @@ -97,7 +97,7 @@ Test::Main() EXPECT_TRUE(compare(mirror, "*/*/*", SpecList().add("A/x/w", myspec.c_str()))); for (int i = 0; i < 30; i++) { - if (reg.busy()) FastOS_Thread::Sleep(100); + if (reg.busy()) std::this_thread::sleep_for(100ms); } EXPECT_TRUE(!reg.busy()); diff --git a/slobrok/src/tests/standalone/standalone.cpp b/slobrok/src/tests/standalone/standalone.cpp index 9d3fd694ee1..65553c57530 100644 --- a/slobrok/src/tests/standalone/standalone.cpp +++ b/slobrok/src/tests/standalone/standalone.cpp @@ -132,7 +132,7 @@ TEST("standalone") { break; } fprintf(stderr, "ping failed [retry %d]\n", retry); - FastOS_Thread::Sleep(200); + std::this_thread::sleep_for(200ms); sb->SubRef(); sb = orb.GetTarget(18541); } @@ -268,7 +268,7 @@ TEST("standalone") { } } - FastOS_Thread::Sleep(2000); + std::this_thread::sleep_for(2s); // lookup 'B' should give '' req = orb.AllocRPCRequest(req); diff --git a/staging_vespalib/src/tests/clock/clock_test.cpp b/staging_vespalib/src/tests/clock/clock_test.cpp index bf7e3773055..b5650244a45 100644 --- a/staging_vespalib/src/tests/clock/clock_test.cpp +++ b/staging_vespalib/src/tests/clock/clock_test.cpp @@ -2,36 +2,27 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/clock.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fastos/thread.h> using vespalib::Clock; using fastos::TimeStamp; -class Test : public vespalib::TestApp -{ -public: - int Main() override; -}; - -int -Test::Main() -{ - TEST_INIT("clock_test"); +TEST("Test that clock is ticking forward") { Clock clock(0.050); FastOS_ThreadPool pool(0x10000); ASSERT_TRUE(pool.NewThread(clock.getRunnable(), nullptr) != nullptr); fastos::SteadyTimeStamp start = clock.getTimeNS(); - FastOS_Thread::Sleep(5000); + std::this_thread::sleep_for(5s); fastos::SteadyTimeStamp stop = clock.getTimeNS(); EXPECT_TRUE(stop > start); - FastOS_Thread::Sleep(6000); + std::this_thread::sleep_for(6s); clock.stop(); fastos::SteadyTimeStamp stop2 = clock.getTimeNS(); EXPECT_TRUE(stop2 > stop); EXPECT_TRUE((stop2 - stop)/TimeStamp::MICRO > 1000); - TEST_DONE(); } -TEST_APPHOOK(Test) +TEST_MAIN() { TEST_RUN_ALL(); }
\ No newline at end of file diff --git a/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp b/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp index e6f7bd21750..fbaa5581173 100644 --- a/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp +++ b/staging_vespalib/src/tests/shutdownguard/shutdownguard_test.cpp @@ -13,20 +13,20 @@ Test::Main() { TEST_INIT("shutdownguard_test"); { - ShutdownGuard farFuture(123456789); - FastOS_Thread::Sleep(20); + ShutdownGuard farFuture(1000000s); + std::this_thread::sleep_for(20ms); } EXPECT_TRUE(true); pid_t child = fork(); if (child == 0) { - ShutdownGuard soon(30); + ShutdownGuard soon(30ms); for (int i = 0; i < 1000; ++i) { - FastOS_Thread::Sleep(20); + std::this_thread::sleep_for(20ms); } exit(0); } for (int i = 0; i < 1000; ++i) { - FastOS_Thread::Sleep(20); + std::this_thread::sleep_for(20ms); int stat = 0; if (waitpid(child, &stat, WNOHANG) == child) { EXPECT_TRUE(WIFEXITED(stat)); diff --git a/staging_vespalib/src/tests/timer/timer_test.cpp b/staging_vespalib/src/tests/timer/timer_test.cpp index 309ee873b44..5472ad6e23f 100644 --- a/staging_vespalib/src/tests/timer/timer_test.cpp +++ b/staging_vespalib/src/tests/timer/timer_test.cpp @@ -1,8 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/testkit/testapp.h> -#include <vespa/vespalib/util/timer.h> -#include <vespa/vespalib/util/executor.h> +#include <vespa/vespalib/util/scheduledexecutor.h> using namespace vespalib; using vespalib::Executor; @@ -37,7 +36,7 @@ void Test::testScheduling() { vespalib::CountDownLatch latch1(3); vespalib::CountDownLatch latch2(2); - Timer timer; + ScheduledExecutor timer; timer.scheduleAtFixedRate(Task::UP(new TestTask(latch1)), 0.1, 0.2); timer.scheduleAtFixedRate(Task::UP(new TestTask(latch2)), 0.5, 0.5); EXPECT_TRUE(latch1.await(60000)); @@ -47,7 +46,7 @@ void Test::testScheduling() void Test::testReset() { vespalib::CountDownLatch latch1(2); - Timer timer; + ScheduledExecutor timer; timer.scheduleAtFixedRate(Task::UP(new TestTask(latch1)), 2.0, 3.0); timer.reset(); EXPECT_TRUE(!latch1.await(3000)); diff --git a/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt b/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt index 20d47c90453..71364a813f6 100644 --- a/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt +++ b/staging_vespalib/src/vespa/vespalib/util/CMakeLists.txt @@ -16,7 +16,7 @@ vespa_add_library(staging_vespalib_vespalib_util OBJECT document_runnable.cpp rusage.cpp shutdownguard.cpp - timer.cpp + scheduledexecutor.cpp xmlserializable.cpp xmlstream.cpp DEPENDS diff --git a/staging_vespalib/src/vespa/vespalib/util/timer.cpp b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp index a7acbe67965..61f9666114c 100644 --- a/staging_vespalib/src/vespa/vespalib/util/timer.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.cpp @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "timer.h" +#include "scheduledexecutor.h" #include <vespa/fnet/scheduler.h> #include <vespa/fnet/task.h> #include <vespa/fnet/transport.h> @@ -34,7 +34,7 @@ public: } }; -Timer::Timer() +ScheduledExecutor::ScheduledExecutor() : _threadPool(128 * 1024), _transport(new FNET_Transport()), _lock(), @@ -43,7 +43,7 @@ Timer::Timer() _transport->Start(&_threadPool); } -Timer::~Timer() +ScheduledExecutor::~ScheduledExecutor() { vespalib::LockGuard guard(_lock); _transport->ShutDown(true); @@ -53,7 +53,7 @@ Timer::~Timer() void -Timer::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, double interval) +ScheduledExecutor::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, double interval) { vespalib::LockGuard guard(_lock); TimerTaskPtr tTask(new TimerTask(_transport->GetScheduler(), std::move(task), interval)); @@ -62,7 +62,7 @@ Timer::scheduleAtFixedRate(vespalib::Executor::Task::UP task, double delay, doub } void -Timer::reset() +ScheduledExecutor::reset() { vespalib::LockGuard guard(_lock); _transport->ShutDown(true); diff --git a/staging_vespalib/src/vespa/vespalib/util/timer.h b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h index 0f7cde67ee4..d7e56494828 100644 --- a/staging_vespalib/src/vespa/vespalib/util/timer.h +++ b/staging_vespalib/src/vespa/vespalib/util/scheduledexecutor.h @@ -13,11 +13,11 @@ namespace vespalib { class TimerTask; /** - * Timer is a class capable of running Tasks at a regular + * ScheduledExecutor is a class capable of running Tasks at a regular * interval. The timer can be reset to clear all tasks currently being * scheduled. */ -class Timer +class ScheduledExecutor { private: typedef std::unique_ptr<TimerTask> TimerTaskPtr; @@ -31,13 +31,13 @@ public: /** * Create a new timer, capable of scheduling tasks at fixed intervals. */ - Timer(); + ScheduledExecutor(); /** * Destroys this timer, finishing the current task executing and then * finishing. */ - ~Timer(); + ~ScheduledExecutor(); /** * Schedule new task to be executed at specified intervals. diff --git a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp index 645ffea380d..99857107860 100644 --- a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.cpp @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "shutdownguard.h" #include <unistd.h> -#include <sys/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".vespalib.shutdownguard"); @@ -10,36 +10,30 @@ namespace vespalib { namespace { enum { STACK_SIZE = (1u << 16) }; - -uint64_t getTimeInMillis() { - struct timeval mytime; - gettimeofday(&mytime, 0); - uint64_t mult = 1000; - return (mytime.tv_sec * mult) + (mytime.tv_usec / mult); -} } void ShutdownGuard::Run(FastOS_ThreadInterface *, void *) { - while (_dieAtTime > getTimeInMillis()) { - FastOS_Thread::Sleep(5); + while (_dieAtTime > steady_clock::now() && ! GetThread()->GetBreakFlag()) { + std::this_thread::sleep_for(5ms); } - if (_dieAtTime != 0) { + if (_dieAtTime <= steady_clock::now()) { LOG(warning, "ShutdownGuard is now forcing an exit of the process."); _exit(EXIT_FAILURE); } } -ShutdownGuard::ShutdownGuard(uint64_t millis) : +ShutdownGuard::ShutdownGuard(duration millis) : FastOS_Runnable(), _pool(STACK_SIZE, 1), - _dieAtTime(getTimeInMillis() + millis) + _dieAtTime(steady_clock::now() + millis) { _pool.NewThread(this); } ShutdownGuard::~ShutdownGuard() { - _dieAtTime = 0; + GetThread()->SetBreakFlag(); + GetThread()->Join(); _pool.Close(); } diff --git a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h index 5a9aad5d4d4..9de9df8bbad 100644 --- a/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h +++ b/staging_vespalib/src/vespa/vespalib/util/shutdownguard.h @@ -1,8 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <vespa/vespalib/util/time.h> #include <vespa/fastos/thread.h> -#include <cstdint> namespace vespalib { @@ -16,7 +16,7 @@ namespace vespalib { class ShutdownGuard : public FastOS_Runnable { FastOS_ThreadPool _pool; - volatile uint64_t _dieAtTime; + steady_time _dieAtTime; void Run(FastOS_ThreadInterface *, void *) override; @@ -25,7 +25,7 @@ public: * Construct a shutdown guard with a given lifetime. * @arg millis the number of milliseconds before process automatically exits **/ - ShutdownGuard(uint64_t millis); + ShutdownGuard(duration millis); /** * Destructor that dismisses the guard and collects the shutdown thread. diff --git a/storage/src/tests/common/metricstest.cpp b/storage/src/tests/common/metricstest.cpp index d1421845b81..d698cbb5e05 100644 --- a/storage/src/tests/common/metricstest.cpp +++ b/storage/src/tests/common/metricstest.cpp @@ -1,6 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/document/fieldvalue/document.h> #include <vespa/storageapi/message/persistence.h> #include <vespa/storageframework/defaultimplementation/clock/fakeclock.h> #include <vespa/storage/bucketdb/bucketmanager.h> @@ -14,6 +13,7 @@ #include <vespa/config/common/exceptions.h> #include <vespa/vespalib/stllike/hash_map.hpp> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> #include <gmock/gmock.h> #include <thread> @@ -202,7 +202,7 @@ void MetricsTest::createFakeLoad() while (uint64_t(_metricManager->getLastProcessedTime()) < _clock->getTimeInSeconds().getTime()) { - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); _metricManager->timeChangedNotification(); } } @@ -257,7 +257,7 @@ TEST_F(MetricsTest, snapshot_presenting) { uint64_t(_metricManager->getLastProcessedTime()) < _clock->getTimeInSeconds().getTime()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } LOG(debug, "5 minute snapshot should have been taken. Adding put count"); diff --git a/storage/src/tests/common/teststorageapp.cpp b/storage/src/tests/common/teststorageapp.cpp index dd89082d3e7..082af954871 100644 --- a/storage/src/tests/common/teststorageapp.cpp +++ b/storage/src/tests/common/teststorageapp.cpp @@ -7,10 +7,11 @@ #include <vespa/config-load-type.h> #include <vespa/config-fleetcontroller.h> #include <vespa/persistence/dummyimpl/dummypersistence.h> -#include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/time.h> #include <vespa/config/config.h> #include <vespa/config/helper/configgetter.hpp> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".test.servicelayerapp"); @@ -111,7 +112,7 @@ TestStorageApp::waitUntilInitialized( framework::MilliSecTime endTime( clock.getTimeInMillis() + timeout.getMillis()); while (!isInitialized()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); framework::MilliSecTime currentTime(clock.getTimeInMillis()); if (currentTime > endTime) { std::ostringstream error; diff --git a/storage/src/tests/distributor/distributortest.cpp b/storage/src/tests/distributor/distributortest.cpp index 8fa8a6bcede..d456401876e 100644 --- a/storage/src/tests/distributor/distributortest.cpp +++ b/storage/src/tests/distributor/distributortest.cpp @@ -11,9 +11,10 @@ #include <vespa/document/test/make_document_bucket.h> #include <vespa/document/test/make_bucket_space.h> #include <vespa/storage/config/config-stor-distributormanager.h> -#include <tests/common/dummystoragelink.h> #include <vespa/storage/distributor/distributor.h> #include <vespa/vespalib/text/stringtokenizer.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/vespalib/gtest/gtest.h> #include <gmock/gmock.h> @@ -383,7 +384,7 @@ TEST_F(DistributorTest, tick_processes_status_requests) { thread, "statustest", tickWaitMs, tickMaxProcessTime, ticksBeforeWait)); while (true) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); framework::TickingLockGuard guard( distributor_thread_pool().freezeCriticalTicks()); if (!distributor_status_todos().empty()) { diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp index 44cb92071a1..f907d0496e6 100644 --- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp +++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp @@ -20,8 +20,10 @@ #include <vespa/persistence/spi/test.h> #include <vespa/config/common/exceptions.h> #include <vespa/fastos/file.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/gtest/gtest.h> #include <atomic> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".filestormanagertest"); @@ -556,7 +558,7 @@ public: auto cmd = std::make_shared<api::PutCommand>(makeDocumentBucket(bucket), _doc, 100); _handler.schedule(cmd, 0); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } _threadDone = true; @@ -589,13 +591,13 @@ public: if (msg.second.get()) { uint32_t originalConfig = _config.load(); _fetchedCount++; - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); if (_config.load() != originalConfig) { _failed = true; } } else { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -634,7 +636,7 @@ TEST_F(FileStorManagerTest, handler_paused_multi_thread) { thread.start(pool); for (uint32_t i = 0; i < 50; ++i) { - FastOS_Thread::Sleep(2); + std::this_thread::sleep_for(2ms); ResumeGuard guard = filestorHandler.pause(); thread._config.fetch_add(1); uint32_t count = thread._fetchedCount; @@ -646,7 +648,7 @@ TEST_F(FileStorManagerTest, handler_paused_multi_thread) { ASSERT_FALSE(thread._failed); while (!pushthread._threadDone || !thread._threadDone) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -869,7 +871,7 @@ TEST_F(FileStorManagerTest, handler_timeout) { filestorHandler.schedule(cmd, 0); } - FastOS_Thread::Sleep(51); + std::this_thread::sleep_for(51ms); for (;;) { auto lock = filestorHandler.getNextMessage(0, stripeId); if (lock.first.get()) { @@ -944,7 +946,7 @@ TEST_F(FileStorManagerTest, priority) { // Wait until everything is done. int count = 0; while (documents.size() != top.getNumReplies() && count < 10000) { - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); count++; } ASSERT_LT(count, 10000); diff --git a/storage/src/tests/persistence/filestorage/operationabortingtest.cpp b/storage/src/tests/persistence/filestorage/operationabortingtest.cpp index 0d43f8a9020..ba344971c3b 100644 --- a/storage/src/tests/persistence/filestorage/operationabortingtest.cpp +++ b/storage/src/tests/persistence/filestorage/operationabortingtest.cpp @@ -9,6 +9,8 @@ #include <vespa/vespalib/util/thread.h> #include <vespa/vespalib/stllike/hash_set_insert.hpp> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".operationabortingtest"); @@ -53,7 +55,7 @@ public: (void) context; _queueBarrier.await(); // message abort stage with active opertion in disk queue - FastOS_Thread::Sleep(75); + std::this_thread::sleep_for(75ms); _completionBarrier.await(); // test finished return spi::Result(); diff --git a/storage/src/tests/storageserver/bucketintegritycheckertest.cpp b/storage/src/tests/storageserver/bucketintegritycheckertest.cpp index ae466f04734..8a68adf226c 100644 --- a/storage/src/tests/storageserver/bucketintegritycheckertest.cpp +++ b/storage/src/tests/storageserver/bucketintegritycheckertest.cpp @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/storage/bucketdb/bucketmanager.h> -#include <vespa/storage/persistence/filestorage/filestormanager.h> #include <vespa/storage/storageserver/bucketintegritychecker.h> #include <vespa/storageapi/message/persistence.h> #include <tests/common/testhelper.h> @@ -9,6 +8,8 @@ #include <vespa/vespalib/io/fileutil.h> #include <tests/common/teststorageapp.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> using namespace ::testing; @@ -175,13 +176,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { checker.getSchedulingOptions()._minCycleTime = framework::SecondTime(60 * 60); topLink.open(); // Waiting for system to be initialized - FastOS_Thread::Sleep(10); // Give next message chance to come + std::this_thread::sleep_for(10ms); // Give next message chance to come ASSERT_COMMAND_COUNT(0, *dummyLink); topLink.doneInit(); checker.bump(); // Should have started new run with 2 pending per disk dummyLink->waitForMessages(4, _timeout); - FastOS_Thread::Sleep(10); // Give 5th message chance to come + std::this_thread::sleep_for(10ms); // Give 5th message chance to come ASSERT_COMMAND_COUNT(4, *dummyLink); auto* cmd1 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(0).get()); EXPECT_EQ(230, cmd1->getPriority()); @@ -200,13 +201,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { // Answering a message on disk with no more buckets does not trigger new auto reply1 = std::make_shared<RepairBucketReply>(*cmd3); ASSERT_TRUE(checker.onUp(reply1)); - FastOS_Thread::Sleep(10); // Give next message chance to come + std::this_thread::sleep_for(10ms); // Give next message chance to come ASSERT_COMMAND_COUNT(4, *dummyLink); // Answering a message on disk with more buckets trigger new repair auto reply2 = std::make_shared<RepairBucketReply>(*cmd2); ASSERT_TRUE(checker.onUp(reply2)); dummyLink->waitForMessages(5, _timeout); - FastOS_Thread::Sleep(10); // Give 6th message chance to come + std::this_thread::sleep_for(10ms); // Give 6th message chance to come ASSERT_COMMAND_COUNT(5, *dummyLink); auto* cmd5 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(4).get()); ASSERT_TRUE(cmd5); @@ -217,7 +218,7 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { reply3->setResult(api::ReturnCode(api::ReturnCode::IGNORED)); ASSERT_TRUE(checker.onUp(reply3)); dummyLink->waitForMessages(6, _timeout); - FastOS_Thread::Sleep(10); // Give 7th message chance to come + std::this_thread::sleep_for(10ms); // Give 7th message chance to come ASSERT_COMMAND_COUNT(6, *dummyLink); auto* cmd6 = dynamic_cast<RepairBucketCommand*>(dummyLink->getCommand(5).get()); ASSERT_TRUE(cmd6); @@ -227,7 +228,7 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { auto reply4 = std::make_shared<RepairBucketReply>(*cmd4); reply3->setResult(api::ReturnCode(api::ReturnCode::BUCKET_NOT_FOUND)); ASSERT_TRUE(checker.onUp(reply4)); - FastOS_Thread::Sleep(10); // Give 7th message chance to come + std::this_thread::sleep_for(10ms); // Give 7th message chance to come ASSERT_COMMAND_COUNT(6, *dummyLink); // Send a repair reply that actually have corrected the bucket. @@ -247,13 +248,13 @@ TEST_F(BucketIntegrityCheckerTest, basic_functionality) { EXPECT_EQ(document::BucketId(16, 0x234), cmd7->getBucketId()); auto reply7 = std::make_shared<RepairBucketReply>(*cmd7); ASSERT_TRUE(checker.onUp(reply7)); - FastOS_Thread::Sleep(10); // Give 8th message chance to come + std::this_thread::sleep_for(10ms); // Give 8th message chance to come ASSERT_COMMAND_COUNT(7, *dummyLink); // Still not time for next iteration dummyLink->reset(); _node->getClock().setAbsoluteTimeInSeconds(getDate("week1 sun 00:59:59")); - FastOS_Thread::Sleep(10); // Give new run chance to start + std::this_thread::sleep_for(10ms); // Give new run chance to start ASSERT_COMMAND_COUNT(0, *dummyLink); // Pass time until next cycle should start diff --git a/storage/src/tests/storageserver/communicationmanagertest.cpp b/storage/src/tests/storageserver/communicationmanagertest.cpp index caee6e6ab91..6657a9f1600 100644 --- a/storage/src/tests/storageserver/communicationmanagertest.cpp +++ b/storage/src/tests/storageserver/communicationmanagertest.cpp @@ -15,6 +15,8 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/documentapi/messagebus/messages/removedocumentmessage.h> #include <vespa/documentapi/messagebus/messages/getdocumentreply.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/vespalib/gtest/gtest.h> using document::test::makeDocumentBucket; @@ -65,7 +67,7 @@ TEST_F(CommunicationManagerTest, simple) { distributor.open(); storage.open(); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); // Send a message through from distributor to storage auto cmd = std::make_shared<api::GetCommand>( diff --git a/storage/src/tests/storageserver/statereportertest.cpp b/storage/src/tests/storageserver/statereportertest.cpp index c84f9311c52..dc8094275d1 100644 --- a/storage/src/tests/storageserver/statereportertest.cpp +++ b/storage/src/tests/storageserver/statereportertest.cpp @@ -11,6 +11,8 @@ #include <vespa/config/common/exceptions.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".test.statereporter"); @@ -233,7 +235,7 @@ TEST_F(StateReporterTest, report_metrics) { uint64_t(_metricManager->getLastProcessedTime()) < _clock->getTimeInSeconds().getTime()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } LOG(debug, "5 minute snapshot should have been taken. Adding put count"); diff --git a/storage/src/tests/visiting/memory_bounded_trace_test.cpp b/storage/src/tests/visiting/memory_bounded_trace_test.cpp index 0cfd3dad4c3..124730cf093 100644 --- a/storage/src/tests/visiting/memory_bounded_trace_test.cpp +++ b/storage/src/tests/visiting/memory_bounded_trace_test.cpp @@ -7,6 +7,8 @@ using namespace ::testing; namespace storage { +constexpr vespalib::system_time epoch(vespalib::duration::zero()); + TEST(MemoryBoundedTraceTest, no_memory_reported_used_when_empty) { MemoryBoundedTrace trace(100); EXPECT_EQ(0, trace.getApproxMemoryUsed()); @@ -14,7 +16,7 @@ TEST(MemoryBoundedTraceTest, no_memory_reported_used_when_empty) { TEST(MemoryBoundedTraceTest, memory_used_is_string_length_for_leaf_node) { MemoryBoundedTrace trace(100); - EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", 0))); + EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", epoch))); EXPECT_EQ(11, trace.getApproxMemoryUsed()); } @@ -29,7 +31,7 @@ TEST(MemoryBoundedTraceTest, memory_used_is_accumulated_recursively_for_non_leaf TEST(MemoryBoundedTraceTest, trace_nodes_can_be_moved_and_implicitly_cleared) { MemoryBoundedTrace trace(100); - EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", 0))); + EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", epoch))); mbus::TraceNode target; trace.moveTraceTo(target); EXPECT_EQ(1, target.getNumChildren()); @@ -49,7 +51,7 @@ TEST(MemoryBoundedTraceTest, trace_nodes_can_be_moved_and_implicitly_cleared) { */ TEST(MemoryBoundedTraceTest, moved_trace_tree_is_marked_as_strict) { MemoryBoundedTrace trace(100); - EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", 0))); + EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", epoch))); mbus::TraceNode target; trace.moveTraceTo(target); EXPECT_EQ(1, target.getNumChildren()); @@ -60,11 +62,11 @@ TEST(MemoryBoundedTraceTest, can_not_add_more_nodes_when_memory_used_exceeds_upp // Note: we allow one complete node tree to exceed the bounds, but as soon // as the bound is exceeded no further nodes can be added. MemoryBoundedTrace trace(10); - EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", 0))); + EXPECT_TRUE(trace.add(mbus::TraceNode("hello world", epoch))); EXPECT_EQ(11, trace.getApproxMemoryUsed()); EXPECT_FALSE(trace.add(mbus::TraceNode("the quick red fox runs across " - "the freeway", 0))); + "the freeway", epoch))); EXPECT_EQ(11, trace.getApproxMemoryUsed()); mbus::TraceNode target; @@ -77,8 +79,8 @@ TEST(MemoryBoundedTraceTest, can_not_add_more_nodes_when_memory_used_exceeds_upp TEST(MemoryBoundedTraceTest, moved_tree_includes_stats_node_when_nodes_omitted) { MemoryBoundedTrace trace(5); - EXPECT_TRUE(trace.add(mbus::TraceNode("abcdef", 0))); - EXPECT_FALSE(trace.add(mbus::TraceNode("ghijkjlmn", 0))); + EXPECT_TRUE(trace.add(mbus::TraceNode("abcdef", epoch))); + EXPECT_FALSE(trace.add(mbus::TraceNode("ghijkjlmn", epoch))); mbus::TraceNode target; trace.moveTraceTo(target); diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp index c86e1671033..651686a7c6d 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp +++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp @@ -6,9 +6,11 @@ #include <vespa/storageapi/message/state.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/host_name.h> +#include <vespa/vespalib/util/time.h> #include <vespa/fnet/frt/supervisor.h> #include <vespa/fnet/transport.h> #include <sstream> +#include <thread> #include <vespa/log/log.h> LOG_SETUP(".rpc.listener"); @@ -50,7 +52,7 @@ FNetListener::registerHandle(vespalib::stringref handle) { _slobrokRegister.registerName(handle); while (_slobrokRegister.busy()) { LOG(debug, "Waiting to register in slobrok"); - FastOS_Thread::Sleep(50); + std::this_thread::sleep_for(50ms); } _handle = handle; } diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.h b/storage/src/vespa/storage/storageserver/fnetlistener.h index 205a5af4586..e37727beb44 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.h +++ b/storage/src/vespa/storage/storageserver/fnetlistener.h @@ -5,6 +5,7 @@ #include <atomic> class FNET_Transport; +class FastOS_ThreadPool; namespace storage { diff --git a/storage/src/vespa/storage/storageserver/storagenode.cpp b/storage/src/vespa/storage/storageserver/storagenode.cpp index c5a0a031067..e962ee4b1b6 100644 --- a/storage/src/vespa/storage/storageserver/storagenode.cpp +++ b/storage/src/vespa/storage/storageserver/storagenode.cpp @@ -14,6 +14,7 @@ #include <vespa/storage/common/statusmetricconsumer.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/time.h> #include <vespa/metrics/metricmanager.h> #include <fcntl.h> @@ -568,7 +569,7 @@ StorageNode::waitUntilInitialized(uint32_t timeout) { lib::NodeState nodeState(*_component->getStateUpdater().getReportedNodeState()); if (nodeState.getState() == lib::State::UP) break; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); if (clock.getTimeInMillis() >= endTime) { std::ostringstream ost; ost << "Storage server not initialized after waiting timeout of " diff --git a/storage/src/vespa/storage/tools/storage-cmd.cpp b/storage/src/vespa/storage/tools/storage-cmd.cpp index daaa890873f..8c0fcc83330 100644 --- a/storage/src/vespa/storage/tools/storage-cmd.cpp +++ b/storage/src/vespa/storage/tools/storage-cmd.cpp @@ -3,6 +3,8 @@ #include <vespa/slobrok/sbmirror.h> #include <vespa/fastos/app.h> #include <vespa/vespalib/locale/c.h> +#include <vespa/vespalib/util/time.h> +#include <thread> #include <vespa/log/log.h> LOG_SETUP("vespa-storage-cmd"); @@ -61,7 +63,7 @@ public: slobrok::api::MirrorAPI mirror(supervisor.supervisor(), sbcfg); while (!mirror.ready()) { - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } slobrok::api::MirrorAPI::SpecList list = mirror.lookup(_argv[1]); diff --git a/storageframework/src/tests/thread/tickingthreadtest.cpp b/storageframework/src/tests/thread/tickingthreadtest.cpp index 97ae08eef3d..c42a9c17283 100644 --- a/storageframework/src/tests/thread/tickingthreadtest.cpp +++ b/storageframework/src/tests/thread/tickingthreadtest.cpp @@ -6,6 +6,8 @@ #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/exception.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/time.h> +#include <thread> namespace storage::framework::defaultimplementation { @@ -35,7 +37,7 @@ struct MyApp : public TickingThread { Context& c(_context[index]); if (_doCritOverlapTest) { uint32_t oldTick = _critOverlapCounter; - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); _critOverlap |= (_critOverlapCounter != oldTick); ++_critOverlapCounter; } @@ -109,7 +111,7 @@ TEST(TickingThreadTest, test_ticks_before_wait_basic) // and verify time is in right ballpark. int totalSleepMs = 0; while (app.getTotalNonCritTicks() < 20) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); totalSleepMs++; } EXPECT_GT(totalSleepMs, 10); @@ -134,7 +136,7 @@ TEST(TickingThreadTest, test_ticks_before_wait_live_update) // (if live update is broken it will take more than an hour). int maxAttempts = 120000; // a bit more than 120 secs while (app.getTotalNonCritTicks() < ticksBeforeWaitMs && maxAttempts-->0) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } EXPECT_GT(maxAttempts, 0); @@ -158,7 +160,7 @@ TEST(TickingThreadTest, test_verbose_stopping) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); while (app.getMinCritTick() < 5) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } app._threadPool->stop(); } @@ -171,7 +173,7 @@ TEST(TickingThreadTest, test_stop_on_deletion) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); while (app.getMinCritTick() < 5) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -185,7 +187,7 @@ TEST(TickingThreadTest, test_lock_all_ticks) app1.start(testReg.getThreadPoolImpl()); app2.start(testReg.getThreadPoolImpl()); while (std::min(app1.getMinCritTick(), app2.getMinCritTick()) < 5) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } uint64_t ticks1, ticks2; { @@ -194,12 +196,12 @@ TEST(TickingThreadTest, test_lock_all_ticks) ticks2 = app2.getTotalTicks(); while (app2.getMinCritTick() < 2 * ticks2 / threadCount) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } EXPECT_EQ(ticks1, app1.getTotalTicks()); } while (app1.getMinCritTick() < 2 * ticks1 / threadCount) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } @@ -213,7 +215,7 @@ TEST(TickingThreadTest, test_lock_critical_ticks) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); while (!app.hasCritOverlap()) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); ++app._critOverlapCounter; ++iterationsBeforeOverlap; } @@ -222,7 +224,7 @@ TEST(TickingThreadTest, test_lock_critical_ticks) MyApp app(threadCount, true); app.start(testReg.getThreadPoolImpl()); for (uint64_t i=0; i<iterationsBeforeOverlap * 10; ++i) { - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); TickingLockGuard guard(app._threadPool->freezeCriticalTicks()); for (int j=0; j<threadCount; ++j) { ++app._context[j]._critTickCount; @@ -318,13 +320,13 @@ TEST(TickingThreadTest, test_broadcast) BroadcastApp app; app.start(testReg.getThreadPoolImpl()); app.doTask("foo"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); app.doTask("bar"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); app.doTask("baz"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); app.doTask("hmm"); - FastOS_Thread::Sleep(1); + std::this_thread::sleep_for(1ms); } } diff --git a/storageserver/src/apps/storaged/storage.cpp b/storageserver/src/apps/storaged/storage.cpp index 0748cc3cb1e..5996951e65f 100644 --- a/storageserver/src/apps/storaged/storage.cpp +++ b/storageserver/src/apps/storaged/storage.cpp @@ -198,7 +198,7 @@ int StorageApp::Main() LOG(debug, "Server was attempted stopped, shutting down"); // Create guard that will forcifully kill storage if destruction takes longer // time than given timeout. - vespalib::ShutdownGuard shutdownGuard(_maxShutdownTime); + vespalib::ShutdownGuard shutdownGuard(std::chrono::milliseconds(_maxShutdownTime)); LOG(debug, "Attempting proper shutdown"); _process.reset(); LOG(debug, "Completed controlled shutdown."); diff --git a/vdslib/src/tests/thread/taskschedulertest.cpp b/vdslib/src/tests/thread/taskschedulertest.cpp index 540de722137..1925625172c 100644 --- a/vdslib/src/tests/thread/taskschedulertest.cpp +++ b/vdslib/src/tests/thread/taskschedulertest.cpp @@ -2,6 +2,8 @@ #include <vespa/vdslib/thread/taskscheduler.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/time.h> +#include <thread> namespace vdslib { @@ -141,13 +143,13 @@ TEST(TaskSchedulerTest, test_simple) task->registerCallsWithName("", calls); scheduler.addAbsolute(TestTask::UP(task), 50); watch.increment(49); // Not yet time to run - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); // Check that it has not run yet.. EXPECT_EQ(counter, scheduler.getTaskCounter()); watch.increment(10); // Now time is enough for it to run scheduler.waitForTaskCounterOfAtLeast(counter + 1); watch.increment(10); - FastOS_Thread::Sleep(5); + std::this_thread::sleep_for(5ms); // Check that it has not run yet.. EXPECT_EQ(counter + 1, scheduler.getTaskCounter()); watch.increment(50); diff --git a/vespabase/CMakeLists.txt b/vespabase/CMakeLists.txt index a7247d882f9..8faf58dd070 100644 --- a/vespabase/CMakeLists.txt +++ b/vespabase/CMakeLists.txt @@ -5,6 +5,7 @@ vespa_install_script(src/start-cbinaries.sh vespa-config-status bin) vespa_install_script(src/start-cbinaries.sh vespa-doclocator bin) vespa_install_script(src/start-cbinaries.sh vespa-model-inspect bin) vespa_install_script(src/start-cbinaries.sh vespa-proton-cmd bin) +vespa_install_script(src/start-cbinaries.sh vespa-rpc-invoke bin) vespa_install_script(src/start-cbinaries.sh vespa-sentinel-cmd bin) vespa_install_script(src/start-cbinaries.sh vespa-route bin) vespa_install_script(src/start-cbinaries.sh vespa-transactionlog-inspect bin) diff --git a/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp b/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp index 9ec9249b671..1b8f31f0d03 100644 --- a/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp +++ b/vespaclient/src/vespa/vespaclient/vdsstates/statesapp.cpp @@ -1,19 +1,20 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/defaults.h> -#include <vespa/document/util/stringutil.h> #include <vespa/fnet/frt/frt.h> #include <vespa/slobrok/sbmirror.h> #include <vespa/vdslib/distribution/distribution.h> #include <vespa/vdslib/state/clusterstate.h> #include <vespa/vespalib/util/programoptions.h> #include <vespa/vespaclient/clusterlist/clusterlist.h> +#include <vespa/vespalib/util/time.h> #include <vespa/vespalib/text/lowercase.h> #include <vespa/config-stor-distribution.h> #include <vespa/config/helper/configgetter.hpp> #include <vespa/fastos/app.h> #include <sstream> #include <iostream> +#include <thread> #include <sys/time.h> #include <vespa/log/log.h> @@ -282,7 +283,7 @@ struct StateApp : public FastOS_Application { } warnTime *= 4; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } if (!slobrok->ready()) { std::cerr << "Slobrok not ready.\n"; diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 63cae3904e3..cea58d565c2 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1985,6 +1985,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2075,6 +2076,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2271,6 +2273,7 @@ ], "methods": [ "public void <init>()", + "public void <init>(double, double)", "public double applyAsDouble(double)", "public java.lang.String toString()" ], @@ -2437,22 +2440,25 @@ "public static java.util.function.DoubleUnaryOperator atan()", "public static java.util.function.DoubleUnaryOperator ceil()", "public static java.util.function.DoubleUnaryOperator cos()", - "public static java.util.function.DoubleUnaryOperator elu()", "public static java.util.function.DoubleUnaryOperator exp()", "public static java.util.function.DoubleUnaryOperator floor()", "public static java.util.function.DoubleUnaryOperator log()", "public static java.util.function.DoubleUnaryOperator neg()", "public static java.util.function.DoubleUnaryOperator reciprocal()", - "public static java.util.function.DoubleUnaryOperator relu()", "public static java.util.function.DoubleUnaryOperator rsqrt()", - "public static java.util.function.DoubleUnaryOperator selu()", - "public static java.util.function.DoubleUnaryOperator leakyrelu()", "public static java.util.function.DoubleUnaryOperator sin()", "public static java.util.function.DoubleUnaryOperator sigmoid()", "public static java.util.function.DoubleUnaryOperator sqrt()", "public static java.util.function.DoubleUnaryOperator square()", "public static java.util.function.DoubleUnaryOperator tan()", "public static java.util.function.DoubleUnaryOperator tanh()", + "public static java.util.function.DoubleUnaryOperator elu()", + "public static java.util.function.DoubleUnaryOperator elu(double)", + "public static java.util.function.DoubleUnaryOperator leakyrelu()", + "public static java.util.function.DoubleUnaryOperator leakyrelu(double)", + "public static java.util.function.DoubleUnaryOperator relu()", + "public static java.util.function.DoubleUnaryOperator selu()", + "public static java.util.function.DoubleUnaryOperator selu(double, double)", "public static java.util.function.Function random()", "public static java.util.function.Function equal(java.util.List)", "public static java.util.function.Function sum(java.util.List)" diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index e8e329cd75c..d9204e24d68 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -38,16 +38,12 @@ public class ScalarFunctions { public static DoubleUnaryOperator atan() { return new Atan(); } public static DoubleUnaryOperator ceil() { return new Ceil(); } public static DoubleUnaryOperator cos() { return new Cos(); } - public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } public static DoubleUnaryOperator log() { return new Log(); } public static DoubleUnaryOperator neg() { return new Neg(); } public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } - public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } - public static DoubleUnaryOperator selu() { return new Selu(); } - public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } @@ -55,6 +51,14 @@ public class ScalarFunctions { public static DoubleUnaryOperator tan() { return new Tan(); } public static DoubleUnaryOperator tanh() { return new Tanh(); } + public static DoubleUnaryOperator elu() { return new Elu(); } + public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); } + public static DoubleUnaryOperator leakyrelu() { return new LeakyRelu(); } + public static DoubleUnaryOperator leakyrelu(double alpha) { return new LeakyRelu(alpha); } + public static DoubleUnaryOperator relu() { return new Relu(); } + public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator selu(double scale, double alpha) { return new Selu(scale, alpha); } + public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } @@ -191,10 +195,17 @@ public class ScalarFunctions { } public static class Elu implements DoubleUnaryOperator { + private final double alpha; + public Elu() { + this(1.0); + } + public Elu(double alpha) { + this.alpha = alpha; + } @Override - public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } + public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; } @Override - public String toString() { return "f(a)(if(a < 0, exp(a)-1, a))"; } + public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; } } public static class Exp implements DoubleUnaryOperator { @@ -241,8 +252,15 @@ public class ScalarFunctions { public static class Selu implements DoubleUnaryOperator { // See https://arxiv.org/abs/1706.02515 - private static final double scale = 1.0507009873554804934193349852946; - private static final double alpha = 1.6732632423543772848170429916717; + private final double scale; // 1.0507009873554804934193349852946; + private final double alpha; // 1.6732632423543772848170429916717; + public Selu() { + this(1.0507009873554804934193349852946, 1.6732632423543772848170429916717); + } + public Selu(double scale, double alpha) { + this.scale = scale; + this.alpha = alpha; + } @Override public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); } @Override @@ -250,10 +268,17 @@ public class ScalarFunctions { } public static class LeakyRelu implements DoubleUnaryOperator { + private final double alpha; + public LeakyRelu() { + this(0.01); + } + public LeakyRelu(double alpha) { + this.alpha = alpha; + } @Override - public double applyAsDouble(double operand) { return Math.max(0.01 * operand, operand); } + public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); } @Override - public String toString() { return "f(a)(max(0.01*a, a))"; } + public String toString() { return "f(a)(max(" + alpha + " * a, a))"; } } public static class Sin implements DoubleUnaryOperator { diff --git a/vespalib/src/tests/delegatelist/delegatelist.cpp b/vespalib/src/tests/delegatelist/delegatelist.cpp index ba1a2049794..070864dd85a 100644 --- a/vespalib/src/tests/delegatelist/delegatelist.cpp +++ b/vespalib/src/tests/delegatelist/delegatelist.cpp @@ -780,11 +780,11 @@ Test::testWaitSnapshots() ASSERT_TRUE(pool.NewThread(&a1, 0) != 0); s1.reset(new DL::Snapshot(dl)); // create snap 1 a1.doIt(cmd_wait_snap(&dl)); // wait for snaps - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); EXPECT_TRUE(a1.getState() == Actor::STATE_BUSY); // still waiting... s2.reset(new DL::Snapshot(dl)); // create snap 2 s1.reset(); // destroy snap 1 - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); EXPECT_TRUE(a1.getState() == Actor::STATE_IDLE); // wait done! a1.doIt(cmd_exit()); a1.waitState(Actor::STATE_DONE); diff --git a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp index c43d0ec1c29..7567e8426ae 100644 --- a/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp +++ b/vespalib/src/tests/left_right_heap/left_right_heap_bench.cpp @@ -5,7 +5,12 @@ #include <vespa/vespalib/util/inline.h> #include <vespa/fastos/timestamp.h> -using namespace vespalib; +using vespalib::RightArrayHeap; +using vespalib::RightHeap; +using vespalib::LeftArrayHeap; +using vespalib::LeftHeap; +using vespalib::LeftStdHeap; +using vespalib::make_string; template <typename H> struct IsRight { enum { VALUE = 0 }; }; template <> struct IsRight<RightHeap> { enum { VALUE = 1 }; }; diff --git a/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp b/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp index 7ca4a2eff39..5641d751f34 100644 --- a/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp +++ b/vespalib/src/tests/simple_thread_bundle/simple_thread_bundle_test.cpp @@ -47,7 +47,7 @@ TEST_MT_FF("require that signals can be counted and cancelled", 2, Signal, size_ if (thread_id == 0) { for (size_t i = 0; i < f2; ++i) { f1.send(); - if (i % 128 == 0) { FastOS_Thread::Sleep(1); } + if (i % 128 == 0) { std::this_thread::sleep_for(1ms); } } TEST_BARRIER(); f1.cancel(); diff --git a/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp b/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp index 5b6df6eef4e..d67c417b71a 100644 --- a/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp +++ b/vespalib/src/tests/simple_thread_bundle/threading_speed_test.cpp @@ -65,7 +65,7 @@ TEST("estimate cost of thread bundle fork/join") { if (time < minTime) { minTime = time; } - FastOS_Thread::Sleep(10); + std::this_thread::sleep_for(10ms); } fprintf(stderr, "strategy: %s, threads: %zu, fork: %zu, iter: %zu, time: %g, unit: %g\n", strategy_name[strategy].c_str(), threads, fork, iter, minTime, diff --git a/vespalib/src/tests/thread/thread_test.cpp b/vespalib/src/tests/thread/thread_test.cpp index 025a33fa221..bcd38190c7e 100644 --- a/vespalib/src/tests/thread/thread_test.cpp +++ b/vespalib/src/tests/thread/thread_test.cpp @@ -32,7 +32,7 @@ TEST("normal operation") { { Thread thread(agent); thread.start(); - FastOS_Thread::Sleep(20); + std::this_thread::sleep_for(20ms); thread.stop().join(); } EXPECT_TRUE(agent.started); diff --git a/vespalib/src/tests/trace/trace.cpp b/vespalib/src/tests/trace/trace.cpp index f4ed0b15a27..f2141fbf995 100644 --- a/vespalib/src/tests/trace/trace.cpp +++ b/vespalib/src/tests/trace/trace.cpp @@ -1,54 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/log/log.h> -LOG_SETUP("trace_test"); #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/trace/trace.h> #include <vespa/vespalib/trace/tracevisitor.h> -using namespace vespalib; - -class Test : public vespalib::TestApp { -private: - void testEncodeDecode(); - void testReservedChars(); - void testConstruct(); - void testAdd(); - void testSort(); - void testStrict(); - void testTraceLevel(); - void testCompact(); - void testNormalize(); - void testTraceDump(); - void testVisiting(); - void testTimestamp(); - -public: - int Main() override; -}; - -TEST_APPHOOK(Test); +#include <vespa/log/log.h> +LOG_SETUP("trace_test"); -int -Test::Main() -{ - TEST_INIT("trace_test"); - testEncodeDecode(); - testReservedChars(); - testAdd(); - testConstruct(); - testSort(); - testStrict(); - testTraceLevel(); - testCompact(); - testNormalize(); - testTraceDump(); - testVisiting(); - testTimestamp(); - TEST_DONE(); -} +using namespace vespalib; -void -Test::testEncodeDecode() +TEST("testEncodeDecode") { EXPECT_EQUAL("()", TraceNode::decode("").encode()); EXPECT_EQUAL("()", TraceNode::decode("[xyz").encode()); @@ -134,8 +94,7 @@ Test::testEncodeDecode() } } -void -Test::testReservedChars() +TEST("testReservedChars") { TraceNode t; t.addChild("abc(){}[]\\xyz"); @@ -154,8 +113,7 @@ Test::testReservedChars() } } -void -Test::testAdd() +TEST("testAdd") { TraceNode t1 = TraceNode::decode("([x])"); TraceNode t2 = TraceNode::decode("([y])"); @@ -175,16 +133,14 @@ Test::testAdd() EXPECT_EQUAL("([y]([y])([y]([y])))", t2.encode()); } -void -Test::testStrict() +TEST("testStrict") { EXPECT_EQUAL("{}", TraceNode::decode("()").setStrict(false).encode()); EXPECT_EQUAL("{[x]}", TraceNode::decode("([x])").setStrict(false).encode()); EXPECT_EQUAL("{[x][y]}", TraceNode::decode("([x][y])").setStrict(false).encode()); } -void -Test::testTraceLevel() +TEST("testTraceLevel") { Trace t; t.setLevel(4); @@ -211,8 +167,7 @@ Test::testTraceLevel() EXPECT_EQUAL(5u, t.getRoot().getNumChildren()); } -void -Test::testCompact() +TEST("testCompact") { EXPECT_EQUAL("()", TraceNode::decode("()").compact().encode()); EXPECT_EQUAL("()", TraceNode::decode("(())").compact().encode()); @@ -242,8 +197,7 @@ Test::testCompact() EXPECT_EQUAL("({[a][b][c][d][e][f]})", TraceNode::decode("({({[a][b]})({[c][d]})({[e][f]})})").compact().encode()); } -void -Test::testSort() +TEST("testSort") { EXPECT_EQUAL("([b][a][c])", TraceNode::decode("([b][a][c])").sort().encode()); EXPECT_EQUAL("({[a][b][c]})", TraceNode::decode("({[b][a][c]})").sort().encode()); @@ -253,8 +207,7 @@ Test::testSort() EXPECT_EQUAL("({([b]){[a][c]}})", TraceNode::decode("({{[c][a]}([b])})").sort().encode()); } -void -Test::testNormalize() +TEST("testNormalize") { TraceNode t1 = TraceNode::decode("({([a][b]{[x][y]([p][q])})([c][d])([e][f])})"); TraceNode t2 = TraceNode::decode("({([a][b]{[y][x]([p][q])})([c][d])([e][f])})"); @@ -295,8 +248,7 @@ Test::testNormalize() EXPECT_EQUAL("({([c][d])([e][f])([a][b]{[x][y]([p][q])})})", t1.normalize().encode()); } -void -Test::testTraceDump() +TEST("testTraceDump") { { Trace big; @@ -363,8 +315,7 @@ struct EncoderVisitor : public TraceVisitor } }; -void -Test::testVisiting() +TEST("testVisiting") { TraceNode b1; TraceNode b2; @@ -383,27 +334,30 @@ Test::testVisiting() EXPECT_EQUAL(encoder.str, b1.encode()); } -void -Test::testTimestamp() +constexpr system_time zero(duration::zero()); +constexpr system_time as_ms(long ms) { return system_time(std::chrono::milliseconds(ms)); } + +TEST("testTimestamp") { TraceNode root; - root.addChild("foo", 1234); + root.addChild("foo", as_ms(1234)); root.addChild("bar"); - EXPECT_EQUAL(root.getTimestamp(), 0); - EXPECT_EQUAL(root.getChild(0).getTimestamp(), 1234); - EXPECT_EQUAL(root.getChild(1).getTimestamp(), 0); + EXPECT_EQUAL(root.getTimestamp(), zero); + EXPECT_EQUAL(root.getChild(0).getTimestamp(), as_ms(1234)); + EXPECT_EQUAL(root.getChild(1).getTimestamp(), zero); } -void -Test::testConstruct() +TEST("testConstruct") { - TraceNode leaf1("foo", 123); + TraceNode leaf1("foo", as_ms(123)); EXPECT_TRUE(leaf1.hasNote()); EXPECT_EQUAL("foo", leaf1.getNote()); - EXPECT_EQUAL(123, leaf1.getTimestamp()); + EXPECT_EQUAL(as_ms(123), leaf1.getTimestamp()); - TraceNode leaf2(124); + TraceNode leaf2(as_ms(124)); EXPECT_FALSE(leaf2.hasNote()); EXPECT_EQUAL("", leaf2.getNote()); - EXPECT_EQUAL(124, leaf2.getTimestamp()); + EXPECT_EQUAL(as_ms(124), leaf2.getTimestamp()); } + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/trace/trace_serialization.cpp b/vespalib/src/tests/trace/trace_serialization.cpp index c176ca9fcf9..9ba6cdb512b 100644 --- a/vespalib/src/tests/trace/trace_serialization.cpp +++ b/vespalib/src/tests/trace/trace_serialization.cpp @@ -1,11 +1,12 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/log/log.h> -LOG_SETUP("trace_test"); #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/trace/tracenode.h> #include <vespa/vespalib/trace/slime_trace_serializer.h> #include <vespa/vespalib/trace/slime_trace_deserializer.h> +#include <vespa/log/log.h> +LOG_SETUP("trace_test"); + using namespace vespalib; using namespace vespalib::slime; @@ -20,10 +21,14 @@ TEST("that a single trace node is serialized") { EXPECT_FALSE(i["payload"].valid()); } +constexpr system_time zero_system_time(duration::zero()); +constexpr system_time as_ms(long ms) { return system_time(std::chrono::milliseconds(ms)); } + + TEST("that a trace node with children is serialized") { TraceNode node; - node.addChild("foo", 1234); - node.addChild("bar", 1235); + node.addChild("foo", as_ms(1234)); + node.addChild("bar", as_ms(1235)); Slime slime; SlimeTraceSerializer serializer(slime.setObject()); node.accept(serializer); @@ -47,7 +52,7 @@ TEST("that an empty root trace node can be deserialized") { SlimeTraceDeserializer deserializer(root); TraceNode node(deserializer.deserialize()); EXPECT_FALSE(node.hasNote()); - EXPECT_EQUAL(0, node.getTimestamp()); + EXPECT_EQUAL(zero_system_time, node.getTimestamp()); } @@ -58,7 +63,7 @@ TEST("that a single trace node can be deserialized") { root.setString("payload", "hello"); SlimeTraceDeserializer deserializer(root); TraceNode node(deserializer.deserialize()); - EXPECT_EQUAL(1234, node.getTimestamp()); + EXPECT_EQUAL(as_ms(1234), node.getTimestamp()); EXPECT_TRUE(node.hasNote()); EXPECT_EQUAL("hello", node.getNote()); } @@ -95,7 +100,7 @@ TEST("that a trace node with children can be deserialized") { TEST("test serialization and deserialization") { TraceNode root; - root.addChild("foo", 45); + root.addChild("foo", as_ms(45)); root.addChild("bar"); root.addChild(TraceNode()); Slime slime; diff --git a/vespalib/src/vespa/vespalib/testkit/test_comparators.cpp b/vespalib/src/vespa/vespalib/testkit/test_comparators.cpp index 92ec51268d2..d00ad8d954c 100644 --- a/vespalib/src/vespa/vespalib/testkit/test_comparators.cpp +++ b/vespalib/src/vespa/vespalib/testkit/test_comparators.cpp @@ -2,6 +2,13 @@ #include "test_comparators.h" +namespace std::chrono { + +ostream & operator << (ostream & os, system_clock::time_point ts) { + return os << ts.time_since_epoch() << "ns"; +} + +} namespace vespalib { } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/testkit/test_comparators.h b/vespalib/src/vespa/vespalib/testkit/test_comparators.h index 5119aaebea6..161c125757b 100644 --- a/vespalib/src/vespa/vespalib/testkit/test_comparators.h +++ b/vespalib/src/vespa/vespalib/testkit/test_comparators.h @@ -13,6 +13,9 @@ ostream & operator << (ostream & os, duration<rep, period> ts) { return os << ts.count(); } +ostream & operator << (ostream & os, system_clock::time_point ts); + + } namespace vespalib { diff --git a/vespalib/src/vespa/vespalib/testkit/test_kit.h b/vespalib/src/vespa/vespalib/testkit/test_kit.h index 7e6b07d71df..17746c5b0fc 100644 --- a/vespalib/src/vespa/vespalib/testkit/test_kit.h +++ b/vespalib/src/vespa/vespalib/testkit/test_kit.h @@ -10,3 +10,4 @@ #include "test_hook.h" #include "test_state_guard.h" #include "time_bomb.h" +#include <vespa/vespalib/util/time.h> diff --git a/vespalib/src/vespa/vespalib/trace/slime_trace_deserializer.cpp b/vespalib/src/vespa/vespalib/trace/slime_trace_deserializer.cpp index 44e79de5248..eea59f2164f 100644 --- a/vespalib/src/vespa/vespalib/trace/slime_trace_deserializer.cpp +++ b/vespalib/src/vespa/vespalib/trace/slime_trace_deserializer.cpp @@ -29,7 +29,7 @@ SlimeTraceDeserializer::deserialize(const Inspector & inspector) TraceNode SlimeTraceDeserializer::deserializeTraceNode(const Inspector & inspector) { - int64_t timestamp(decodeTimestamp(inspector)); + system_time timestamp(std::chrono::milliseconds(decodeTimestamp(inspector))); if (hasPayload(inspector)) { std::string note(decodePayload(inspector)); return TraceNode(note, timestamp); diff --git a/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.cpp b/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.cpp index cdc952464dd..ccf5d50517a 100644 --- a/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.cpp +++ b/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.cpp @@ -32,7 +32,7 @@ SlimeTraceSerializer::visit(const TraceNode & node) void SlimeTraceSerializer::addTimestamp(Cursor & current, const TraceNode & node) { - current.setLong(TIMESTAMP, node.getTimestamp()); + current.setLong(TIMESTAMP, count_ms(node.getTimestamp().time_since_epoch())); } void diff --git a/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.h b/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.h index d2ca00ba81a..1d5b5638f48 100644 --- a/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.h +++ b/vespalib/src/vespa/vespalib/trace/slime_trace_serializer.h @@ -22,8 +22,8 @@ public: static const Memory PAYLOAD; static const Memory CHILDREN; private: - void addTimestamp(slime::Cursor & current, const TraceNode & node); - void addPayload(slime::Cursor & current, const TraceNode & node); + static void addTimestamp(slime::Cursor & current, const TraceNode & node); + static void addPayload(slime::Cursor & current, const TraceNode & node); void addChildrenCursors(slime::Cursor & current, const TraceNode & node); void addChildrenCursorsToStack(slime::Cursor & childrenArray, const TraceNode & node); std::stack<slime::Cursor *> _cursors; diff --git a/vespalib/src/vespa/vespalib/trace/tracenode.cpp b/vespalib/src/vespa/vespalib/trace/tracenode.cpp index bce184312f4..d34b45025d3 100644 --- a/vespalib/src/vespa/vespalib/trace/tracenode.cpp +++ b/vespalib/src/vespa/vespalib/trace/tracenode.cpp @@ -46,7 +46,7 @@ TraceNode::TraceNode() : _hasNote(false), _note(""), _children(), - _timestamp(0) + _timestamp() { } TraceNode::TraceNode(const TraceNode &rhs) : @@ -65,7 +65,7 @@ TraceNode & TraceNode::operator =(const TraceNode &) = default; TraceNode::~TraceNode() = default; -TraceNode::TraceNode(const string ¬e, int64_t timestamp) : +TraceNode::TraceNode(const string ¬e, system_time timestamp) : _parent(nullptr), _strict(true), _hasNote(true), @@ -74,7 +74,7 @@ TraceNode::TraceNode(const string ¬e, int64_t timestamp) : _timestamp(timestamp) { } -TraceNode::TraceNode(int64_t timestamp) : +TraceNode::TraceNode(system_time timestamp) : _parent(nullptr), _strict(true), _hasNote(false), @@ -109,7 +109,7 @@ TraceNode::clear() _hasNote = false; _note.clear(); _children.clear(); - _timestamp = 0; + _timestamp = system_time(); return *this; } @@ -177,11 +177,11 @@ TraceNode::normalize() TraceNode & TraceNode::addChild(const string ¬e) { - return addChild(TraceNode(note, 0)); + return addChild(TraceNode(note, system_time())); } TraceNode & -TraceNode::addChild(const string ¬e, int64_t timestamp) +TraceNode::addChild(const string ¬e, system_time timestamp) { return addChild(TraceNode(note, timestamp)); } @@ -245,8 +245,7 @@ TraceNode::encode() const string ret = ""; if (_hasNote) { ret.append("["); - for (uint32_t i = 0, len = _note.size(); i < len; ++i) { - char c = _note[i]; + for (char c : _note) { if (c == '\\' || c == ']') { ret.append("\\"); } @@ -296,7 +295,7 @@ TraceNode::decode(const string &str) node = &node->_children.back(); node->setStrict(c == '('); } else if (c == ')' || c == '}') { - if (node == NULL) { + if (node == nullptr) { LOG(warning, "Unexpected closing brace in trace '%s' at " "position %d.", str.c_str(), i); return TraceNode(); diff --git a/vespalib/src/vespa/vespalib/trace/tracenode.h b/vespalib/src/vespa/vespalib/trace/tracenode.h index b44881b99b4..63e7bcd6dc0 100644 --- a/vespalib/src/vespa/vespalib/trace/tracenode.h +++ b/vespalib/src/vespa/vespalib/trace/tracenode.h @@ -2,6 +2,7 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/util/time.h> #include <vector> namespace vespalib { @@ -27,7 +28,7 @@ private: bool _hasNote; string _note; std::vector<TraceNode> _children; - int64_t _timestamp; + system_time _timestamp; public: /** @@ -40,13 +41,13 @@ public: * @param note The note for this node. * @param timestamp The timestamp to give to node. */ - explicit TraceNode(const string ¬e, int64_t timestamp); + explicit TraceNode(const string ¬e, system_time timestamp); /** * Create a leaf node with no note and a time stamp. * @param timestamp The timestamp to give to node. */ - explicit TraceNode(int64_t timestamp); + explicit TraceNode(system_time timestamp); TraceNode & operator =(const TraceNode &); TraceNode(TraceNode &&) noexcept; @@ -104,7 +105,7 @@ public: * * @return True if this has no parent. */ - bool isRoot() const { return _parent == NULL; } + bool isRoot() const { return _parent == nullptr; } /** * Check whether or not this is a leaf node. @@ -155,7 +156,7 @@ public: * * @return The timestamp. */ - int64_t getTimestamp() const { return _timestamp; } + system_time getTimestamp() const { return _timestamp; } /** @@ -189,7 +190,7 @@ public: * @param timestamp The timestamp to give this child. * @return This, to allow chaining. */ - TraceNode &addChild(const string ¬e, int64_t timestamp); + TraceNode &addChild(const string ¬e, system_time timestamp); /** * Adds a child node to this. diff --git a/vespalib/src/vespa/vespalib/util/thread.cpp b/vespalib/src/vespa/vespalib/util/thread.cpp index 2d0118645ab..4eb436458a2 100644 --- a/vespalib/src/vespa/vespalib/util/thread.cpp +++ b/vespalib/src/vespa/vespalib/util/thread.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "thread.h" +#include <thread> namespace vespalib { @@ -87,7 +88,7 @@ Thread::currentThread() void Thread::sleep(size_t ms) { - FastOS_Thread::Sleep(ms); + std::this_thread::sleep_for(std::chrono::milliseconds(ms)); } } // namespace vespalib diff --git a/vespamalloc/src/tests/allocfree/allocfree.cpp b/vespamalloc/src/tests/allocfree/allocfree.cpp index 80513579a2f..86050d4aee9 100644 --- a/vespamalloc/src/tests/allocfree/allocfree.cpp +++ b/vespamalloc/src/tests/allocfree/allocfree.cpp @@ -89,7 +89,7 @@ int Test::Main() { for (; duration > 0; --duration) { LOG(info, "%d seconds left...", duration); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } pool.Close(); size_t numFreeOperations(0); diff --git a/vespamalloc/src/tests/allocfree/linklist.cpp b/vespamalloc/src/tests/allocfree/linklist.cpp index 11a8d1ddd11..74af380458a 100644 --- a/vespamalloc/src/tests/allocfree/linklist.cpp +++ b/vespamalloc/src/tests/allocfree/linklist.cpp @@ -163,7 +163,7 @@ int Test::Main() { for (; duration > 0; --duration) { LOG(info, "%d seconds left...", duration); - FastOS_Thread::Sleep(1000); + std::this_thread::sleep_for(1s); } pool.Close(); fprintf(stderr, "Did (%lu + %lu) = %lu linkIn operations\n", diff --git a/zkfacade/abi-spec.json b/zkfacade/abi-spec.json index efe6fbdaa08..25b652b7312 100644 --- a/zkfacade/abi-spec.json +++ b/zkfacade/abi-spec.json @@ -68,7 +68,6 @@ "methods": [ "public static com.yahoo.vespa.curator.Curator create(java.lang.String)", "public static com.yahoo.vespa.curator.Curator create(java.lang.String, java.util.Optional)", - "public void <init>(com.yahoo.cloud.config.ConfigserverConfig)", "public void <init>(com.yahoo.cloud.config.ConfigserverConfig, com.yahoo.vespa.zookeeper.VespaZooKeeperServer)", "protected void <init>(java.lang.String, java.lang.String, java.util.function.Function)", "public java.lang.String connectionSpec()", diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java index b76bad5b97b..9d74306d3d5 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java @@ -3,9 +3,12 @@ package com.yahoo.vespa.curator; import com.google.inject.Inject; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.io.IOUtils; import com.yahoo.net.HostName; import com.yahoo.path.Path; +import com.yahoo.text.Utf8; import com.yahoo.vespa.curator.recipes.CuratorCounter; +import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.zookeeper.VespaZooKeeperServer; import org.apache.curator.RetryPolicy; import org.apache.curator.framework.CuratorFramework; @@ -20,7 +23,9 @@ import org.apache.curator.framework.recipes.locks.InterProcessLock; import org.apache.curator.framework.recipes.locks.InterProcessMutex; import org.apache.curator.retry.ExponentialBackoffRetry; import org.apache.zookeeper.KeeperException; +import org.apache.zookeeper.client.ZKClientConfig; import org.apache.zookeeper.data.Stat; +import org.apache.zookeeper.server.quorum.QuorumPeerConfig; import java.io.File; import java.time.Duration; @@ -63,28 +68,21 @@ public class Curator implements AutoCloseable { /** Creates a curator instance from a comma-separated string of ZooKeeper host:port strings */ public static Curator create(String connectionSpec, Optional<File> clientConfigFile) { - return new Curator(connectionSpec, connectionSpec); - } - - // For testing - public Curator(ConfigserverConfig configserverConfig) { - this(configserverConfig, createConnectionSpec(configserverConfig)); + return new Curator(connectionSpec, connectionSpec, clientConfigFile); } // Depend on ZooKeeperServer to make sure it is started first // TODO: Move zookeeperserver config out of configserverconfig (requires update of controller services.xml as well) @Inject public Curator(ConfigserverConfig configserverConfig, VespaZooKeeperServer server) { - this(configserverConfig, createConnectionSpec(configserverConfig)); + this(configserverConfig, Optional.empty()); } - private Curator(ConfigserverConfig configserverConfig, String zooKeeperEnsembleConnectionSpec) { - this((configserverConfig.zookeeperLocalhostAffinity()) ? - createConnectionSpecForLocalhost(configserverConfig) : zooKeeperEnsembleConnectionSpec, - zooKeeperEnsembleConnectionSpec); + Curator(ConfigserverConfig configserverConfig, Optional<File> clientConfigFile) { + this(createConnectionSpec(configserverConfig), createEnsembleConnectionSpec(configserverConfig), clientConfigFile); } - private Curator(String connectionSpec, String zooKeeperEnsembleConnectionSpec) { + private Curator(String connectionSpec, String zooKeeperEnsembleConnectionSpec, Optional<File> clientConfigFile) { this(connectionSpec, zooKeeperEnsembleConnectionSpec, (retryPolicy) -> CuratorFrameworkFactory @@ -93,7 +91,7 @@ public class Curator implements AutoCloseable { .sessionTimeoutMs(ZK_SESSION_TIMEOUT) .connectionTimeoutMs(ZK_CONNECTION_TIMEOUT) .connectString(connectionSpec) - .zookeeperFactory(new VespaZooKeeperFactory()) + .zookeeperFactory(new VespaZooKeeperFactory(createClientConfig(clientConfigFile))) .dontUseContainerParents() // TODO: Remove when we know ZooKeeper 3.5 works fine, consider waiting until Vespa 8 .build()); } @@ -123,7 +121,29 @@ public class Curator implements AutoCloseable { this.zooKeeperEnsembleCount = zooKeeperEnsembleConnectionSpec.split(",").length; } - private static String createConnectionSpec(ConfigserverConfig config) { + private static String createConnectionSpec(ConfigserverConfig configserverConfig) { + return configserverConfig.zookeeperLocalhostAffinity() + ? createConnectionSpecForLocalhost(configserverConfig) + : createEnsembleConnectionSpec(configserverConfig); + } + + private static ZKClientConfig createClientConfig(Optional<File> file) { + boolean useSecureClient = Boolean.parseBoolean(getEnvironmentVariable("VESPA_USE_TLS_FOR_ZOOKEEPER_CLIENT").orElse("false")); + String config = "zookeeper.client.secure=" + useSecureClient + "\n"; + + File clientConfigFile = + file.orElseGet(() -> new File(Defaults.getDefaults().underVespaHome("conf/zookeeper/zookeeper-client.cfg"))); + clientConfigFile.getParentFile().mkdirs(); + IOUtils.writeFile(clientConfigFile, Utf8.toBytes(config)); + + try { + return new ZKClientConfig(clientConfigFile); + } catch (QuorumPeerConfig.ConfigException e) { + throw new RuntimeException("Unable to create ZooKeeper client config file " + file); + } + } + + private static String createEnsembleConnectionSpec(ConfigserverConfig config) { StringBuilder connectionSpec = new StringBuilder(); for (int i = 0; i < config.zookeeperserver().size(); i++) { if (connectionSpec.length() > 0) { @@ -405,4 +425,10 @@ public class Curator implements AutoCloseable { * TODO: Move method out of this class. */ public int zooKeeperEnsembleCount() { return zooKeeperEnsembleCount; } + + private static Optional<String> getEnvironmentVariable(String variableName) { + return Optional.ofNullable(System.getenv().get(variableName)) + .filter(var -> !var.isEmpty()); + } + } diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java index 7c08168c536..84e2cb65a1a 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/VespaZooKeeperFactory.java @@ -4,19 +4,24 @@ package com.yahoo.vespa.curator; import org.apache.curator.utils.ZookeeperFactory; import org.apache.zookeeper.Watcher; import org.apache.zookeeper.ZooKeeper; +import org.apache.zookeeper.client.ZKClientConfig; /** * A ZooKeeper factory for creating a ZooKeeper client * * @author hmusum */ -// TODO: add constructor that takes feature flag so that we can write ZooKeeper client config and start -// ZooKeeper client with that config class VespaZooKeeperFactory implements ZookeeperFactory { + private final ZKClientConfig zkClientConfig; + + VespaZooKeeperFactory(ZKClientConfig zkClientConfig) { + this.zkClientConfig = zkClientConfig; + } + @Override public ZooKeeper newZooKeeper(String connectString, int sessionTimeout, Watcher watcher, boolean canBeReadOnly) throws Exception { - return new ZooKeeper(connectString, sessionTimeout, watcher); + return new ZooKeeper(connectString, sessionTimeout, watcher, zkClientConfig); } } diff --git a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java index 9b7e0250f2f..6b85953a1ff 100644 --- a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java +++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorCounterTest.java @@ -8,7 +8,6 @@ import static org.junit.Assert.assertEquals; /** * @author Ulf Lilleengen - * @date 19.08.13 */ public class CuratorCounterTest { diff --git a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java index a8342dfe5bc..4cd2c708d1a 100644 --- a/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java +++ b/zkfacade/src/test/java/com/yahoo/vespa/curator/CuratorTest.java @@ -9,6 +9,8 @@ import org.junit.Before; import org.junit.Test; import java.io.IOException; +import java.nio.file.Files; +import java.util.Optional; import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertThat; @@ -52,7 +54,7 @@ public class CuratorTest { } @Test - public void require_curator_is_created_from_config() { + public void require_curator_is_created_from_config() throws IOException { try (Curator curator = createCurator(createTestConfig())) { assertThat(curator.zooKeeperEnsembleConnectionSpec(), is(spec1 + "," + spec2)); assertThat(curator.zooKeeperEnsembleCount(), is(2)); @@ -60,7 +62,7 @@ public class CuratorTest { } @Test - public void require_that_server_count_is_correct() { + public void require_that_server_count_is_correct() throws IOException { ConfigserverConfig.Builder builder = new ConfigserverConfig.Builder(); builder.zookeeperserver(createZKBuilder(localhost, port1)); try (Curator curator = createCurator(new ConfigserverConfig(builder))) { @@ -98,8 +100,8 @@ public class CuratorTest { return zkBuilder; } - private Curator createCurator(ConfigserverConfig configserverConfig) { - return new Curator(configserverConfig); + private Curator createCurator(ConfigserverConfig configserverConfig) throws IOException { + return new Curator(configserverConfig, Optional.of(Files.createTempFile("zookeeper-client", "cfg").toFile())); } private static class PortAllocator { |