diff options
163 files changed, 2589 insertions, 1097 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java b/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java index 0624028732f..a4737c9f54c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java @@ -101,12 +101,12 @@ public class ContainerDocumentApi { .collect(Collectors.toList()); // We can only use host resource for calculation if all container nodes in the cluster are homogeneous (in terms of vcpu) if (vcpus.size() != 1 || vcpus.get(0) == 0) return FALLBACK_MAX_POOL_SIZE; - return (int)Math.ceil(vcpus.get(0)); + return Math.max(2, (int)Math.ceil(vcpus.get(0))); } private static int corePoolSize(int maxPoolSize, Options options) { if (maxPoolSize == FALLBACK_MAX_POOL_SIZE) return FALLBACK_CORE_POOL_SIZE; - return (int) Math.ceil(options.feedCoreThreadPoolSizeFactor * maxPoolSize); + return Math.max(1, (int)Math.ceil(options.feedCoreThreadPoolSizeFactor * maxPoolSize)); } public static final class Options { diff --git a/config/src/tests/failover/failover.cpp b/config/src/tests/failover/failover.cpp index 17fa264fd32..0ca09b228f3 100644 --- a/config/src/tests/failover/failover.cpp +++ b/config/src/tests/failover/failover.cpp @@ -7,6 +7,7 @@ #include <vespa/fnet/frt/frt.h> #include "config-my.h" #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/log/log.h> LOG_SETUP("failover"); diff --git a/config/src/tests/frt/frt.cpp b/config/src/tests/frt/frt.cpp index cf1ff9eca37..85b9789821d 100644 --- a/config/src/tests/frt/frt.cpp +++ b/config/src/tests/frt/frt.cpp @@ -10,6 +10,7 @@ #include <vespa/config/frt/frtconfigresponsev3.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/data/slime/json_format.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/fnet/fnet.h> #include <vespa/fnet/frt/frt.h> #include <vespa/fnet/frt/error.h> diff --git a/config/src/vespa/config/common/trace.cpp b/config/src/vespa/config/common/trace.cpp index 76310d08c7d..4edc9df60c3 100644 --- a/config/src/vespa/config/common/trace.cpp +++ b/config/src/vespa/config/common/trace.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/trace/slime_trace_serializer.h> #include <vespa/vespalib/trace/slime_trace_deserializer.h> #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> using namespace vespalib; using namespace vespalib::slime; diff --git a/config/src/vespa/config/frt/frtconfigresponsev3.cpp b/config/src/vespa/config/frt/frtconfigresponsev3.cpp index 405391d99b6..b983c63c6a5 100644 --- a/config/src/vespa/config/frt/frtconfigresponsev3.cpp +++ b/config/src/vespa/config/frt/frtconfigresponsev3.cpp @@ -2,6 +2,7 @@ #include "frtconfigresponsev3.h" #include "compressioninfo.h" #include <vespa/fnet/frt/frt.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/log/log.h> LOG_SETUP(".config.frt.frtconfigresponsev3"); diff --git a/config/src/vespa/config/frt/slimeconfigrequest.cpp b/config/src/vespa/config/frt/slimeconfigrequest.cpp index 07626c1e274..696789f74c1 100644 --- a/config/src/vespa/config/frt/slimeconfigrequest.cpp +++ b/config/src/vespa/config/frt/slimeconfigrequest.cpp @@ -7,6 +7,8 @@ #include <vespa/config/common/configdefinition.h> #include <vespa/config/common/trace.h> #include <vespa/config/common/vespa_version.h> +#include <vespa/vespalib/data/simple_buffer.h> + using namespace vespalib; using namespace vespalib::slime; diff --git a/config/src/vespa/config/print/fileconfigformatter.cpp b/config/src/vespa/config/print/fileconfigformatter.cpp index 85e938dee8f..628a9daa530 100644 --- a/config/src/vespa/config/print/fileconfigformatter.cpp +++ b/config/src/vespa/config/print/fileconfigformatter.cpp @@ -4,6 +4,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <cmath> #include <vector> diff --git a/configdefinitions/src/vespa/configserver.def b/configdefinitions/src/vespa/configserver.def index cde539f25f4..7405f5f2d05 100644 --- a/configdefinitions/src/vespa/configserver.def +++ b/configdefinitions/src/vespa/configserver.def @@ -51,9 +51,7 @@ ztsUrl string default="" # Maintainers maintainerIntervalMinutes int default=30 -# TODO: Default set to a high value (1 year) => maintainer will not run, change when maintainer verified out in prod -tenantsMaintainerIntervalMinutes int default=525600 -keepUnusedFileReferencesHours int default=4 +keepUnusedFileReferencesHours int default=2 # Bootstrapping # How long bootstrapping can take before giving up (in seconds) diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java index 4e44b9cae33..9f2d6d178be 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ConfigServerMaintenance.java @@ -35,8 +35,8 @@ public class ConfigServerMaintenance extends AbstractComponent { DefaultTimes defaults = new DefaultTimes(configserverConfig); tenantsMaintainer = new TenantsMaintainer(applicationRepository, curator, flagSource, defaults.defaultInterval, Clock.systemUTC()); fileDistributionMaintainer = new FileDistributionMaintainer(applicationRepository, curator, defaults.defaultInterval, flagSource); - sessionsMaintainer = new SessionsMaintainer(applicationRepository, curator, Duration.ofMinutes(1), flagSource); - applicationPackageMaintainer = new ApplicationPackageMaintainer(applicationRepository, curator, Duration.ofMinutes(1), flagSource); + sessionsMaintainer = new SessionsMaintainer(applicationRepository, curator, Duration.ofSeconds(30), flagSource); + applicationPackageMaintainer = new ApplicationPackageMaintainer(applicationRepository, curator, Duration.ofSeconds(30), flagSource); } @Override @@ -61,8 +61,8 @@ public class ConfigServerMaintenance extends AbstractComponent { } public void runBeforeBootstrap() { - fileDistributionMaintainer.maintain(); - sessionsMaintainer.maintain(); + fileDistributionMaintainer.lockAndMaintain(); + sessionsMaintainer.lockAndMaintain(); } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 2e6180aeb81..cbfa59b26e4 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -171,11 +171,11 @@ public class SessionRepository { } public void deleteExpiredSessions(Map<ApplicationId, Long> activeSessions) { - log.log(Level.FINE, "Purging old sessions for tenant '" + tenantName + "'"); + log.log(Level.FINE, () -> "Purging old sessions for tenant '" + tenantName + "'"); try { for (LocalSession candidate : localSessionCache.getSessions()) { Instant createTime = candidate.getCreateTime(); - log.log(Level.FINE, "Candidate session for deletion: " + candidate.getSessionId() + ", created: " + createTime); + log.log(Level.FINE, () -> "Candidate session for deletion: " + candidate.getSessionId() + ", created: " + createTime); // Sessions with state other than ACTIVATE if (hasExpired(candidate) && !isActiveSession(candidate)) { @@ -196,7 +196,7 @@ public class SessionRepository { } catch (Throwable e) { log.log(Level.WARNING, "Error when purging old sessions ", e); } - log.log(Level.FINE, "Done purging old sessions"); + log.log(Level.FINE, () -> "Done purging old sessions"); } private boolean hasExpired(LocalSession candidate) { @@ -210,7 +210,7 @@ public class SessionRepository { public void deleteLocalSession(LocalSession session) { long sessionId = session.getSessionId(); try (Lock lock = lock(sessionId)) { - log.log(Level.FINE, "Deleting local session " + sessionId); + log.log(Level.FINE, () -> "Deleting local session " + sessionId); SessionStateWatcher watcher = sessionStateWatchers.remove(sessionId); if (watcher != null) watcher.close(); localSessionCache.removeSession(sessionId); @@ -274,7 +274,7 @@ public class SessionRepository { if (session == null) continue; // Internal sessions not in synch with zk, continue if (session.getStatus() == Session.Status.ACTIVATE) continue; if (sessionHasExpired(session.getCreateTime(), expiryTime, clock)) { - log.log(Level.FINE, "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it"); + log.log(Level.FINE, () -> "Remote session " + sessionId + " for " + tenantName + " has expired, deleting it"); session.delete(); deleted++; } @@ -287,7 +287,7 @@ public class SessionRepository { for (var lock : curator.getChildren(locksPath)) { Path path = locksPath.append(lock); if (zooKeeperNodeCreated(path).orElse(clock.instant()).isBefore(clock.instant().minus(expiryTime))) { - log.log(Level.FINE, "Lock " + path + " has expired, deleting it"); + log.log(Level.FINE, () -> "Lock " + path + " has expired, deleting it"); curator.delete(path); deleted++; } @@ -485,7 +485,7 @@ public class SessionRepository { long sessionId, TimeoutBudget timeoutBudget, Clock clock) { - log.log(Level.FINE, TenantRepository.logPre(tenantName) + "Creating session " + sessionId + " in ZooKeeper"); + log.log(Level.FINE, () -> TenantRepository.logPre(tenantName) + "Creating session " + sessionId + " in ZooKeeper"); SessionZooKeeperClient sessionZKClient = createSessionZooKeeperClient(sessionId); sessionZKClient.createNewSession(clock.instant()); Curator.CompletionWaiter waiter = sessionZKClient.getUploadWaiter(); @@ -605,13 +605,13 @@ public class SessionRepository { */ public Optional<LocalSession> createLocalSessionUsingDistributedApplicationPackage(long sessionId) { if (applicationRepo.hasLocalSession(sessionId)) { - log.log(Level.FINE, "Local session for session id " + sessionId + " already exists"); + log.log(Level.FINE, () -> "Local session for session id " + sessionId + " already exists"); return Optional.of(createSessionFromId(sessionId)); } SessionZooKeeperClient sessionZKClient = createSessionZooKeeperClient(sessionId); FileReference fileReference = sessionZKClient.readApplicationPackageReference(); - log.log(Level.FINE, "File reference for session id " + sessionId + ": " + fileReference); + log.log(Level.FINE, () -> "File reference for session id " + sessionId + ": " + fileReference); if (fileReference != null) { File rootDir = new File(Defaults.getDefaults().underVespaHome(componentRegistry.getConfigserverConfig().fileReferencesDir())); File sessionDir; @@ -626,7 +626,7 @@ public class SessionRepository { } ApplicationId applicationId = sessionZKClient.readApplicationId() .orElseThrow(() -> new RuntimeException("Could not find application id for session " + sessionId)); - log.log(Level.INFO, "Creating local session for tenant '" + tenantName + "' with session id " + sessionId); + log.log(Level.FINE, () -> "Creating local session for tenant '" + tenantName + "' with session id " + sessionId); LocalSession localSession = createLocalSession(sessionDir, applicationId, sessionId); addLocalSession(localSession); return Optional.of(localSession); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java index 6216e8ebfd6..ecbcb513c03 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java @@ -97,17 +97,6 @@ public class TenantRepository { */ @Inject public TenantRepository(GlobalComponentRegistry componentRegistry) { - this(componentRegistry, true); - } - - /** - * Creates a new tenant repository - * - * @param componentRegistry a {@link com.yahoo.vespa.config.server.GlobalComponentRegistry} - * @param useZooKeeperWatchForTenantChanges set to false for tests where you want to control adding and deleting - * tenants yourself - */ - public TenantRepository(GlobalComponentRegistry componentRegistry, boolean useZooKeeperWatchForTenantChanges) { this.componentRegistry = componentRegistry; ConfigserverConfig configserverConfig = componentRegistry.getConfigserverConfig(); this.bootstrapExecutor = Executors.newFixedThreadPool(configserverConfig.numParallelTenantLoaders()); @@ -124,13 +113,9 @@ public class TenantRepository { createSystemTenants(configserverConfig); curator.create(vespaPath); - if (useZooKeeperWatchForTenantChanges) { - this.directoryCache = Optional.of(curator.createDirectoryCache(tenantsPath.getAbsolute(), false, false, zkCacheExecutor)); - this.directoryCache.get().addListener(this::childEvent); - this.directoryCache.get().start(); - } else { - this.directoryCache = Optional.empty(); - } + this.directoryCache = Optional.of(curator.createDirectoryCache(tenantsPath.getAbsolute(), false, false, zkCacheExecutor)); + this.directoryCache.get().addListener(this::childEvent); + this.directoryCache.get().start(); bootstrapTenants(); notifyTenantsLoaded(); checkForRemovedApplicationsService.scheduleWithFixedDelay(this::removeUnusedApplications, diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java index cde115eec40..a1249838324 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java @@ -75,7 +75,7 @@ public class TenantApplicationsTest { .modelFactoryRegistry(createRegistry()) .reloadListener(listener) .build(); - TenantRepository tenantRepository = new TenantRepository(componentRegistry, false); + TenantRepository tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(TenantRepository.HOSTED_VESPA_TENANT); tenantRepository.addTenant(tenantName); applications = TenantApplications.create(componentRegistry, tenantName); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java index 469168cedd4..67ac0b02133 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java @@ -50,7 +50,7 @@ public class ApplicationContentHandlerTest extends ContentHandlerTestBase { @Before public void setupHandler() { - TenantRepository tenantRepository = new TenantRepository(componentRegistry, false); + TenantRepository tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(tenantName1); tenantRepository.addTenant(tenantName2); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java index f43154242fb..3a33d326c48 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java @@ -80,7 +80,7 @@ public class ApplicationHandlerTest { .provisioner(provisioner) .modelFactoryRegistry(new ModelFactoryRegistry(modelFactories)) .build(); - tenantRepository = new TenantRepository(componentRegistry, false); + tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(mytenantName); provisioner = new SessionHandlerTest.MockProvisioner(); orchestrator = new OrchestratorMock(); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java index bef6369beb7..80a0b9edba6 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ListApplicationsHandlerTest.java @@ -41,7 +41,7 @@ public class ListApplicationsHandlerTest { @Before public void setup() { - TenantRepository tenantRepository = new TenantRepository(componentRegistry, false); + TenantRepository tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(mytenant); tenantRepository.addTenant(foobar); applicationRepo = tenantRepository.getTenant(mytenant).getApplicationRepo(); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java index 135a5ef45c4..c3a7e82dff5 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionActiveHandlerTest.java @@ -72,7 +72,7 @@ public class SessionActiveHandlerTest { .curator(new MockCurator()) .modelFactoryRegistry(new ModelFactoryRegistry(List.of((modelFactory)))) .build(); - TenantRepository tenantRepository = new TenantRepository(componentRegistry, false); + TenantRepository tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(tenantName); applicationRepository = new ApplicationRepository.Builder() .withTenantRepository(tenantRepository) diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java index f0362db3b8a..d28404d8d72 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionContentHandlerTest.java @@ -45,7 +45,7 @@ public class SessionContentHandlerTest extends ContentHandlerTestBase { @Before public void setupHandler() { - tenantRepository = new TenantRepository(componentRegistry, false); + tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(tenantName); ApplicationRepository applicationRepository = new ApplicationRepository.Builder() diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java index 0f1b1543c83..513bf6352e8 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionCreateHandlerTest.java @@ -61,7 +61,7 @@ public class SessionCreateHandlerTest extends SessionHandlerTest { @Before public void setupRepo() { - TenantRepository tenantRepository = new TenantRepository(componentRegistry, false); + TenantRepository tenantRepository = new TenantRepository(componentRegistry); applicationRepository = new ApplicationRepository.Builder() .withTenantRepository(tenantRepository) .withProvisioner(new SessionHandlerTest.MockProvisioner()) diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java index 297fee94e5b..cc4f39b0789 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java @@ -65,7 +65,7 @@ public class SessionPrepareHandlerTest extends SessionHandlerTest { @Before public void setupRepo() { - tenantRepository = new TenantRepository(componentRegistry, false); + tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(tenant); applicationRepository = new ApplicationRepository.Builder() .withTenantRepository(tenantRepository) diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java index 0f087e3f006..eb06f2f7017 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java @@ -88,7 +88,7 @@ public class RpcTester implements AutoCloseable { .configDefinitionRepo(new TestConfigDefinitionRepo()) .configServerConfig(configserverConfig) .build(); - tenantRepository = new TenantRepository(componentRegistry, false); + tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(tenantName); applicationRepository = new ApplicationRepository.Builder() .withTenantRepository(tenantRepository) diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java index bb87233f979..b89b63aed46 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java @@ -65,7 +65,7 @@ public class SessionRepositoryTest { .build()) .flagSource(flagSource) .build(); - tenantRepository = new TenantRepository(globalComponentRegistry, false); + tenantRepository = new TenantRepository(globalComponentRegistry); tenantRepository.addTenant(SessionRepositoryTest.tenantName); applicationRepository = new ApplicationRepository.Builder() .withTenantRepository(tenantRepository) diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java index a31b06bbebb..8678c42eab4 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java @@ -202,7 +202,7 @@ public class TenantRepositoryTest { private static class FailingDuringBootstrapTenantRepository extends TenantRepository { public FailingDuringBootstrapTenantRepository(GlobalComponentRegistry globalComponentRegistry) { - super(globalComponentRegistry, false); + super(globalComponentRegistry); } @Override diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java index 74833da6d66..ac596198fe5 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantTest.java @@ -32,7 +32,7 @@ public class TenantTest { } private Tenant createTenant(String name) { - TenantRepository tenantRepository = new TenantRepository(componentRegistry, false); + TenantRepository tenantRepository = new TenantRepository(componentRegistry); TenantName tenantName = TenantName.from(name); tenantRepository.addTenant(tenantName); return tenantRepository.getTenant(tenantName); diff --git a/container-core/abi-spec.json b/container-core/abi-spec.json index aa2e5ccfa5f..9292a946e82 100644 --- a/container-core/abi-spec.json +++ b/container-core/abi-spec.json @@ -114,7 +114,8 @@ ], "methods": [ "public void <init>(java.util.concurrent.Executor, com.yahoo.container.core.LogHandlerConfig)", - "public com.yahoo.container.jdisc.HttpResponse handle(com.yahoo.container.jdisc.HttpRequest)" + "public com.yahoo.container.jdisc.AsyncHttpResponse handle(com.yahoo.container.jdisc.HttpRequest)", + "public bridge synthetic com.yahoo.container.jdisc.HttpResponse handle(com.yahoo.container.jdisc.HttpRequest)" ], "fields": [] }, diff --git a/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java b/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java index 13d50b9b30f..25299978ecd 100644 --- a/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java +++ b/container-core/src/main/java/com/yahoo/container/core/config/HandlersConfigurerDi.java @@ -107,9 +107,8 @@ public class HandlersConfigurerDi { super(osgiFramework); this.osgiFramework = osgiFramework; - OsgiImpl osgi = new OsgiImpl(osgiFramework); - applicationBundleLoader = new ApplicationBundleLoader(osgi, new FileAcquirerBundleInstaller(fileAcquirer)); - platformBundleLoader = new PlatformBundleLoader(osgi); + applicationBundleLoader = new ApplicationBundleLoader(this, new FileAcquirerBundleInstaller(fileAcquirer)); + platformBundleLoader = new PlatformBundleLoader(this); } diff --git a/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java b/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java index b2a156862eb..4b23eafaa9c 100644 --- a/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java +++ b/container-core/src/main/java/com/yahoo/container/handler/LogHandler.java @@ -3,13 +3,19 @@ package com.yahoo.container.handler; import com.google.inject.Inject; import com.yahoo.container.core.LogHandlerConfig; +import com.yahoo.container.jdisc.AsyncHttpResponse; +import com.yahoo.container.jdisc.ContentChannelOutputStream; import com.yahoo.container.jdisc.HttpRequest; -import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; +import com.yahoo.jdisc.handler.CompletionHandler; +import com.yahoo.jdisc.handler.ContentChannel; +import java.io.IOException; import java.io.OutputStream; +import java.nio.ByteBuffer; import java.time.Instant; import java.util.Optional; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; import java.util.logging.Level; @@ -28,25 +34,58 @@ public class LogHandler extends ThreadedHttpRequestHandler { } @Override - public HttpResponse handle(HttpRequest request) { - + public AsyncHttpResponse handle(HttpRequest request) { Instant from = Optional.ofNullable(request.getProperty("from")) .map(Long::valueOf).map(Instant::ofEpochMilli).orElse(Instant.MIN); Instant to = Optional.ofNullable(request.getProperty("to")) .map(Long::valueOf).map(Instant::ofEpochMilli).orElse(Instant.MAX); - Optional<String> hostname = Optional.ofNullable(request.getProperty("hostname")); - return new HttpResponse(200) { + return new AsyncHttpResponse(200) { @Override - public void render(OutputStream outputStream) { + public void render(OutputStream output, ContentChannel networkChannel, CompletionHandler handler) { try { - logReader.writeLogs(outputStream, from, to, hostname); + OutputStream blockingOutput = new BlockingFlushContentChannelOutputStream(networkChannel); + logReader.writeLogs(blockingOutput, from, to, hostname); + blockingOutput.close(); } catch (Throwable t) { log.log(Level.WARNING, "Failed reading logs from " + from + " to " + to, t); } + finally { + networkChannel.close(handler); + } } }; } + + + private static class BlockingFlushContentChannelOutputStream extends ContentChannelOutputStream { + + private final ContentChannel channel; + + public BlockingFlushContentChannelOutputStream(ContentChannel endpoint) { + super(endpoint); + this.channel = endpoint; + } + + @Override + public void flush() throws IOException { + super.flush(); + CountDownLatch latch = new CountDownLatch(1); + channel.write(ByteBuffer.allocate(0), // :'( + new CompletionHandler() { + @Override public void completed() { latch.countDown(); } + @Override public void failed(Throwable t) { latch.countDown(); } + }); + try { + latch.await(); + } + catch (InterruptedException e) { + throw new RuntimeException("Interrupted waiting for underlying IO to complete", e); + } + } + + } + } diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java b/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java index 9c48955bf4c..329889e70c0 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/ContentChannelOutputStream.java @@ -53,11 +53,7 @@ public class ContentChannelOutputStream extends OutputStream implements Writable public void close() throws IOException { // the endpoint is closed in a finally{} block inside AbstractHttpRequestHandler // this class should be possible to close willynilly as it is exposed to plug-ins - try { - buffer.flush(); - } catch (RuntimeException e) { - throw new IOException(Exceptions.toMessageString(e), e); - } + flush(); } /** diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java b/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java index e02ae152b3f..edd24fed515 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/HttpRequest.java @@ -517,11 +517,9 @@ public class HttpRequest { } /** - * Access an HTTP header in the request. Multi-value headers are not - * supported. + * Access an HTTP header in the request. Multi-value headers are not supported. * - * @param name - * the name of an HTTP header + * @param name the name of an HTTP header * @return the first pertinent value */ public String getHeader(String name) { @@ -530,20 +528,12 @@ public class HttpRequest { return parentRequest.headers().get(name).get(0); } - /** - * Get the host segment of the URI of this request. - * - * @return the host name from the URI - */ + /** Get the host segment of the URI of this request. */ public String getHost() { return getUri().getHost(); } - /** - * The port of the URI of this request. - * - * @return the port number of the URI - */ + /** The port of the URI of this request. */ public int getPort() { return getUri().getPort(); } diff --git a/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java b/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java index 97aa8864eae..afe57579a97 100644 --- a/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java +++ b/container-core/src/test/java/com/yahoo/container/handler/LogHandlerTest.java @@ -1,17 +1,19 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.handler; +import com.yahoo.container.jdisc.AsyncHttpResponse; import com.yahoo.container.jdisc.HttpRequest; -import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.jdisc.handler.ReadableContentChannel; +import com.yahoo.yolean.Exceptions; import org.junit.Test; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.time.Instant; import java.util.Optional; import java.util.concurrent.Executor; +import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.mock; @@ -24,20 +26,20 @@ public class LogHandlerTest { { String uri = "http://myhost.com:1111/logs?from=1000&to=2000"; - HttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET)); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - response.render(bos); + AsyncHttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET)); + ReadableContentChannel out = new ReadableContentChannel(); + new Thread(() -> Exceptions.uncheck(() -> response.render(null, out, null))).start(); String expectedResponse = "newer log"; - assertEquals(expectedResponse, bos.toString()); + assertEquals(expectedResponse, new String(out.toStream().readAllBytes(), UTF_8)); } { String uri = "http://myhost.com:1111/logs?from=0&to=1000"; - HttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET)); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - response.render(bos); + AsyncHttpResponse response = logHandler.handle(HttpRequest.createTestRequest(uri, com.yahoo.jdisc.http.HttpRequest.Method.GET)); + ReadableContentChannel out = new ReadableContentChannel(); + new Thread(() -> Exceptions.uncheck(() -> response.render(null, out, null))).start(); String expectedResponse = "older log"; - assertEquals(expectedResponse, bos.toString()); + assertEquals(expectedResponse, new String(out.toStream().readAllBytes(), UTF_8)); } } diff --git a/default_build_settings.cmake b/default_build_settings.cmake index 07a70c38d71..75399069619 100644 --- a/default_build_settings.cmake +++ b/default_build_settings.cmake @@ -79,7 +79,7 @@ endfunction() function(setup_vespa_default_build_settings_fedora_33) message("-- Setting up default build settings for fedora 33") set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS}/include" "/usr/include/openblas" PARENT_SCOPE) - set(DEFAULT_VESPA_LLVM_VERSION "10" PARENT_SCOPE) + set(DEFAULT_VESPA_LLVM_VERSION "11" PARENT_SCOPE) endfunction() function(setup_vespa_default_build_settings_ubuntu_19_10) diff --git a/dist/vespa.spec b/dist/vespa.spec index e49bd74e545..8dae9ee17bc 100644 --- a/dist/vespa.spec +++ b/dist/vespa.spec @@ -97,8 +97,8 @@ BuildRequires: gmock-devel %endif %if 0%{?fc33} BuildRequires: protobuf-devel -BuildRequires: llvm-devel >= 10.0.0 -BuildRequires: boost-devel >= 1.69 +BuildRequires: llvm-devel >= 11.0.0 +BuildRequires: boost-devel >= 1.73 BuildRequires: gtest-devel BuildRequires: gmock-devel %endif @@ -200,8 +200,8 @@ Requires: llvm-libs >= 10.0.0 %endif %if 0%{?fc33} Requires: protobuf -Requires: llvm-libs >= 10.0.0 -%define _vespa_llvm_version 10 +Requires: llvm-libs >= 11.0.0 +%define _vespa_llvm_version 11 %endif %define _extra_link_directory %{_vespa_deps_prefix}/lib64 %define _extra_include_directory %{_vespa_deps_prefix}/include;/usr/include/openblas diff --git a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp index 3f4641ed2ee..08333fa30f3 100644 --- a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp @@ -174,4 +174,20 @@ TEST("require that dense tensor cells iterator works for 2d tensor") { EXPECT_FALSE(itr.valid()); } +TEST("require that memory used count is reasonable") { + Tensor::UP full = build2DTensor(); + const DenseTensorView &full_view = dynamic_cast<const DenseTensorView &>(*full); + DenseTensorView ref_view(full_view.fast_type(), full_view.cellsRef()); + + size_t full_sz = full->count_memory_used(); + size_t view_sz = full_view.count_memory_used(); + size_t ref_sz = ref_view.count_memory_used(); + + EXPECT_EQUAL(ref_sz, sizeof(DenseTensorView)); + EXPECT_LESS(ref_sz, full_sz); + EXPECT_EQUAL(full_sz, view_sz); + EXPECT_LESS(full_sz, 10000u); + EXPECT_GREATER(full_sz, sizeof(DenseTensor<double>)); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp index 86b6abedd39..f901b7775fd 100644 --- a/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp +++ b/eval/src/tests/tensor/direct_sparse_tensor_builder/direct_sparse_tensor_builder_test.cpp @@ -99,6 +99,10 @@ TEST("Test essential object sizes") { EXPECT_EQUAL(16u, sizeof(SparseTensorAddressRef)); EXPECT_EQUAL(24u, sizeof(std::pair<SparseTensorAddressRef, double>)); EXPECT_EQUAL(32u, sizeof(vespalib::hash_node<std::pair<SparseTensorAddressRef, double>>)); + Tensor::UP tensor = buildTensor(); + size_t used = tensor->count_memory_used(); + EXPECT_GREATER(used, sizeof(SparseTensor)); + EXPECT_LESS(used, 10000u); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/tensor/onnx_wrapper/.gitattributes b/eval/src/tests/tensor/onnx_wrapper/.gitattributes new file mode 100644 index 00000000000..62e8ad1e0a0 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/.gitattributes @@ -0,0 +1 @@ +/*.onnx binary diff --git a/eval/src/tests/tensor/onnx_wrapper/dynamic.onnx b/eval/src/tests/tensor/onnx_wrapper/dynamic.onnx new file mode 100644 index 00000000000..95bbf36885a --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/dynamic.onnx @@ -0,0 +1,27 @@ + +dynamic.py:¦ +0 +query_tensor +attribute_tensormatmul"MatMul +- +bias_tensorreduce" ReduceSum* +axes@ + +matmul +reduceoutput"Adddynamic_scoringZ# +query_tensor +
+batch +Z" +attribute_tensor + + +Z+ +bias_tensor + +batch +ÿÿÿÿÿÿÿÿÿb +output +
+batch +B
\ No newline at end of file diff --git a/eval/src/tests/tensor/onnx_wrapper/dynamic.py b/eval/src/tests/tensor/onnx_wrapper/dynamic.py new file mode 100755 index 00000000000..d098324fae8 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/dynamic.py @@ -0,0 +1,39 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, ['batch', 4]) +ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1]) +BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, ['batch', -1]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, ['batch', 1]) + +nodes = [ + helper.make_node( + 'MatMul', + ['query_tensor', 'attribute_tensor'], + ['matmul'], + ), + helper.make_node( + 'ReduceSum', + ['bias_tensor'], + ['reduce'], + axes=[1] + ), + helper.make_node( + 'Add', + ['matmul', 'reduce'], + ['output'], + ), +] +graph_def = helper.make_graph( + nodes, + 'dynamic_scoring', + [ + QUERY_TENSOR, + ATTRIBUTE_TENSOR, + BIAS_TENSOR, + ], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='dynamic.py') +onnx.save(model_def, 'dynamic.onnx') diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp index 28a4a34b2e4..db2415e9969 100644 --- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp +++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp @@ -10,83 +10,224 @@ using namespace vespalib::eval; using namespace vespalib::tensor; using vespalib::make_string_short::fmt; +using TensorInfo = Onnx::TensorInfo; +using DZ = Onnx::DimSize; std::string get_source_dir() { const char *dir = getenv("SOURCE_DIRECTORY"); return (dir ? dir : "."); } std::string source_dir = get_source_dir(); -std::string vespa_dir = source_dir + "/" + "../../../../.."; -std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx"; +std::string simple_model = source_dir + "/simple.onnx"; +std::string dynamic_model = source_dir + "/dynamic.onnx"; -void dump_info(const char *ctx, const std::vector<OnnxWrapper::TensorInfo> &info) { +void dump_info(const char *ctx, const std::vector<TensorInfo> &info) { fprintf(stderr, "%s:\n", ctx); for (size_t i = 0; i < info.size(); ++i) { fprintf(stderr, " %s[%zu]: '%s' %s\n", ctx, i, info[i].name.c_str(), info[i].type_as_string().c_str()); } } -TEST(OnnxWrapperTest, onnx_model_can_be_inspected) +TEST(WirePlannerTest, element_types_must_match) { + Onnx::WirePlanner planner; + ValueType type1 = ValueType::from_spec("tensor<float>(a[5])"); + ValueType type2 = ValueType::from_spec("tensor<double>(a[5])"); + TensorInfo info1 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::FLOAT}; + TensorInfo info2 = TensorInfo{"info", {DZ(5)}, TensorInfo::ElementType::DOUBLE}; + EXPECT_TRUE(planner.bind_input_type(type1, info1)); + EXPECT_FALSE(planner.bind_input_type(type2, info1)); + EXPECT_FALSE(planner.bind_input_type(type1, info2)); + EXPECT_TRUE(planner.bind_input_type(type2, info2)); +} + +TEST(WirePlannerTest, known_dimension_sizes_must_match) { + Onnx::WirePlanner planner; + ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])"); + ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[5])"); + ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])"); + TensorInfo info = TensorInfo{"info", {DZ(5),DZ(5)}, TensorInfo::ElementType::FLOAT}; + EXPECT_FALSE(planner.bind_input_type(type1, info)); + EXPECT_FALSE(planner.bind_input_type(type2, info)); + EXPECT_TRUE(planner.bind_input_type(type3, info)); +} + +TEST(WirePlannerTest, symbolic_dimension_sizes_must_match) { + Onnx::WirePlanner planner; + ValueType type1 = ValueType::from_spec("tensor<float>(a[5])"); + ValueType type2 = ValueType::from_spec("tensor<float>(a[10])"); + TensorInfo info = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT}; + EXPECT_TRUE(planner.bind_input_type(type1, info)); // binds 'dim' to 5 + EXPECT_FALSE(planner.bind_input_type(type2, info)); + EXPECT_TRUE(planner.bind_input_type(type1, info)); +} + +TEST(WirePlannerTest, unknown_dimension_sizes_match_anything) { + Onnx::WirePlanner planner; + ValueType type1 = ValueType::from_spec("tensor<float>(a[5])"); + ValueType type2 = ValueType::from_spec("tensor<float>(a[10])"); + TensorInfo info = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT}; + EXPECT_TRUE(planner.bind_input_type(type1, info)); + EXPECT_TRUE(planner.bind_input_type(type2, info)); +} + +TEST(WirePlannerTest, all_output_dimensions_must_be_bound) { + Onnx::WirePlanner planner; + ValueType type = ValueType::from_spec("tensor<float>(a[5],b[10])"); + TensorInfo info1 = TensorInfo{"info", {DZ()}, TensorInfo::ElementType::FLOAT}; + TensorInfo info2 = TensorInfo{"info", {DZ("dim")}, TensorInfo::ElementType::FLOAT}; + TensorInfo info3 = TensorInfo{"info", {DZ("dim"),DZ()}, TensorInfo::ElementType::FLOAT}; + EXPECT_TRUE(planner.make_output_type(info1).is_error()); + EXPECT_TRUE(planner.make_output_type(info2).is_error()); + EXPECT_TRUE(planner.make_output_type(info3).is_error()); + EXPECT_TRUE(planner.bind_input_type(type, info3)); // binds 'dim' to 5 + EXPECT_TRUE(planner.make_output_type(info1).is_error()); + EXPECT_EQ(planner.make_output_type(info2).to_spec(), "tensor<float>(d0[5])"); + EXPECT_TRUE(planner.make_output_type(info3).is_error()); +} + +TEST(WirePlannerTest, dimensions_resolve_left_to_right) { + Onnx::WirePlanner planner; + ValueType type1 = ValueType::from_spec("tensor<float>(a[5],b[10])"); + ValueType type2 = ValueType::from_spec("tensor<float>(a[10],b[10])"); + ValueType type3 = ValueType::from_spec("tensor<float>(a[5],b[5])"); + TensorInfo info = TensorInfo{"info", {DZ("dim"),DZ("dim")}, TensorInfo::ElementType::FLOAT}; + EXPECT_FALSE(planner.bind_input_type(type1, info)); // binds 'dim' to 5, then fails (5 != 10) + EXPECT_FALSE(planner.bind_input_type(type2, info)); + EXPECT_TRUE(planner.bind_input_type(type3, info)); +} + +TEST(OnnxTest, simple_onnx_model_can_be_inspected) { - OnnxWrapper wrapper(simple_model, OnnxWrapper::Optimize::DISABLE); - dump_info("inputs", wrapper.inputs()); - dump_info("outputs", wrapper.outputs()); - ASSERT_EQ(wrapper.inputs().size(), 3); - ASSERT_EQ(wrapper.outputs().size(), 1); + Onnx model(simple_model, Onnx::Optimize::DISABLE); + dump_info("inputs", model.inputs()); + dump_info("outputs", model.outputs()); + ASSERT_EQ(model.inputs().size(), 3); + ASSERT_EQ(model.outputs().size(), 1); //------------------------------------------------------------------------- - EXPECT_EQ(wrapper.inputs()[0].name, "query_tensor"); - EXPECT_EQ(wrapper.inputs()[0].type_as_string(), "float[1][4]"); + EXPECT_EQ(model.inputs()[0].name, "query_tensor"); + EXPECT_EQ(model.inputs()[0].type_as_string(), "float[1][4]"); //------------------------------------------------------------------------- - EXPECT_EQ(wrapper.inputs()[1].name, "attribute_tensor"); - EXPECT_EQ(wrapper.inputs()[1].type_as_string(), "float[4][1]"); + EXPECT_EQ(model.inputs()[1].name, "attribute_tensor"); + EXPECT_EQ(model.inputs()[1].type_as_string(), "float[4][1]"); //------------------------------------------------------------------------- - EXPECT_EQ(wrapper.inputs()[2].name, "bias_tensor"); - EXPECT_EQ(wrapper.inputs()[2].type_as_string(), "float[1][1]"); + EXPECT_EQ(model.inputs()[2].name, "bias_tensor"); + EXPECT_EQ(model.inputs()[2].type_as_string(), "float[1][1]"); //------------------------------------------------------------------------- - EXPECT_EQ(wrapper.outputs()[0].name, "output"); - EXPECT_EQ(wrapper.outputs()[0].type_as_string(), "float[1][1]"); + EXPECT_EQ(model.outputs()[0].name, "output"); + EXPECT_EQ(model.outputs()[0].type_as_string(), "float[1][1]"); } -TEST(OnnxWrapperTest, onnx_model_can_be_evaluated) +TEST(OnnxTest, dynamic_onnx_model_can_be_inspected) { - OnnxWrapper wrapper(simple_model, OnnxWrapper::Optimize::ENABLE); + Onnx model(dynamic_model, Onnx::Optimize::DISABLE); + dump_info("inputs", model.inputs()); + dump_info("outputs", model.outputs()); + ASSERT_EQ(model.inputs().size(), 3); + ASSERT_EQ(model.outputs().size(), 1); + //------------------------------------------------------------------------- + EXPECT_EQ(model.inputs()[0].name, "query_tensor"); + EXPECT_EQ(model.inputs()[0].type_as_string(), "float[batch][4]"); + //------------------------------------------------------------------------- + EXPECT_EQ(model.inputs()[1].name, "attribute_tensor"); + EXPECT_EQ(model.inputs()[1].type_as_string(), "float[4][1]"); + //------------------------------------------------------------------------- + EXPECT_EQ(model.inputs()[2].name, "bias_tensor"); + EXPECT_EQ(model.inputs()[2].type_as_string(), "float[batch][]"); + //------------------------------------------------------------------------- + EXPECT_EQ(model.outputs()[0].name, "output"); + EXPECT_EQ(model.outputs()[0].type_as_string(), "float[batch][1]"); +} + +TEST(OnnxTest, simple_onnx_model_can_be_evaluated) +{ + Onnx model(simple_model, Onnx::Optimize::ENABLE); + Onnx::WirePlanner planner; ValueType query_type = ValueType::from_spec("tensor<float>(a[1],b[4])"); std::vector<float> query_values({1.0, 2.0, 3.0, 4.0}); DenseTensorView query(query_type, TypedCells(query_values)); - EXPECT_TRUE(wrapper.inputs()[0].is_compatible(query_type)); - EXPECT_FALSE(wrapper.inputs()[1].is_compatible(query_type)); - EXPECT_FALSE(wrapper.inputs()[2].is_compatible(query_type)); + EXPECT_TRUE(planner.bind_input_type(query_type, model.inputs()[0])); ValueType attribute_type = ValueType::from_spec("tensor<float>(a[4],b[1])"); std::vector<float> attribute_values({5.0, 6.0, 7.0, 8.0}); DenseTensorView attribute(attribute_type, TypedCells(attribute_values)); - EXPECT_FALSE(wrapper.inputs()[0].is_compatible(attribute_type)); - EXPECT_TRUE(wrapper.inputs()[1].is_compatible(attribute_type)); - EXPECT_FALSE(wrapper.inputs()[2].is_compatible(attribute_type)); + EXPECT_TRUE(planner.bind_input_type(attribute_type, model.inputs()[1])); ValueType bias_type = ValueType::from_spec("tensor<float>(a[1],b[1])"); std::vector<float> bias_values({9.0}); DenseTensorView bias(bias_type, TypedCells(bias_values)); - EXPECT_FALSE(wrapper.inputs()[0].is_compatible(bias_type)); - EXPECT_FALSE(wrapper.inputs()[1].is_compatible(bias_type)); - EXPECT_TRUE(wrapper.inputs()[2].is_compatible(bias_type)); - - MutableDenseTensorView output(wrapper.outputs()[0].make_compatible_type()); - EXPECT_EQ(output.fast_type().to_spec(), "tensor<float>(d0[1],d1[1])"); - - OnnxWrapper::Params params; - params.bind(0, query); - params.bind(1, attribute); - params.bind(2, bias); - auto result = wrapper.eval(params); - - EXPECT_EQ(result.num_values(), 1); - result.get(0, output); - auto cells = output.cellsRef(); + EXPECT_TRUE(planner.bind_input_type(bias_type, model.inputs()[2])); + + EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), + "tensor<float>(d0[1],d1[1])"); + + Onnx::WireInfo wire_info = planner.get_wire_info(model); + Onnx::EvalContext ctx(model, wire_info); + + const Value &output = ctx.get_result(0); + EXPECT_EQ(output.type().to_spec(), "tensor<float>(d0[1],d1[1])"); + //------------------------------------------------------------------------- + ctx.bind_param(0, query); + ctx.bind_param(1, attribute); + ctx.bind_param(2, bias); + ctx.eval(); + auto cells = static_cast<const DenseTensorView&>(output).cellsRef(); EXPECT_EQ(cells.type, ValueType::CellType::FLOAT); EXPECT_EQ(cells.size, 1); EXPECT_EQ(cells.get(0), 79.0); + //------------------------------------------------------------------------- + std::vector<float> new_bias_values({10.0}); + DenseTensorView new_bias(bias_type, TypedCells(new_bias_values)); + ctx.bind_param(2, new_bias); + ctx.eval(); + EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 80.0); + //------------------------------------------------------------------------- +} + +TEST(OnnxTest, dynamic_onnx_model_can_be_evaluated) +{ + Onnx model(dynamic_model, Onnx::Optimize::ENABLE); + Onnx::WirePlanner planner; + + ValueType query_type = ValueType::from_spec("tensor<float>(a[1],b[4])"); + std::vector<float> query_values({1.0, 2.0, 3.0, 4.0}); + DenseTensorView query(query_type, TypedCells(query_values)); + EXPECT_TRUE(planner.bind_input_type(query_type, model.inputs()[0])); + + ValueType attribute_type = ValueType::from_spec("tensor<float>(a[4],b[1])"); + std::vector<float> attribute_values({5.0, 6.0, 7.0, 8.0}); + DenseTensorView attribute(attribute_type, TypedCells(attribute_values)); + EXPECT_TRUE(planner.bind_input_type(attribute_type, model.inputs()[1])); + + ValueType bias_type = ValueType::from_spec("tensor<float>(a[1],b[2])"); + std::vector<float> bias_values({4.0, 5.0}); + DenseTensorView bias(bias_type, TypedCells(bias_values)); + EXPECT_TRUE(planner.bind_input_type(bias_type, model.inputs()[2])); + + EXPECT_EQ(planner.make_output_type(model.outputs()[0]).to_spec(), + "tensor<float>(d0[1],d1[1])"); + + Onnx::WireInfo wire_info = planner.get_wire_info(model); + Onnx::EvalContext ctx(model, wire_info); + + const Value &output = ctx.get_result(0); + EXPECT_EQ(output.type().to_spec(), "tensor<float>(d0[1],d1[1])"); + //------------------------------------------------------------------------- + ctx.bind_param(0, query); + ctx.bind_param(1, attribute); + ctx.bind_param(2, bias); + ctx.eval(); + auto cells = static_cast<const DenseTensorView&>(output).cellsRef(); + EXPECT_EQ(cells.type, ValueType::CellType::FLOAT); + EXPECT_EQ(cells.size, 1); + EXPECT_EQ(cells.get(0), 79.0); + //------------------------------------------------------------------------- + std::vector<float> new_bias_values({5.0,6.0}); + DenseTensorView new_bias(bias_type, TypedCells(new_bias_values)); + ctx.bind_param(2, new_bias); + ctx.eval(); + EXPECT_EQ(static_cast<const DenseTensorView&>(output).cellsRef().get(0), 81.0); + //------------------------------------------------------------------------- } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/tensor/onnx_wrapper/simple.onnx b/eval/src/tests/tensor/onnx_wrapper/simple.onnx new file mode 100644 index 00000000000..88ed0ef23f0 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/simple.onnx @@ -0,0 +1,23 @@ + simple.py:ã +0 +query_tensor +attribute_tensormatmul"MatMul +" +matmul +bias_tensoroutput"Addsimple_scoringZ +query_tensor + + +Z" +attribute_tensor + + +Z +bias_tensor + + +b +output + + +B
\ No newline at end of file diff --git a/eval/src/tests/tensor/onnx_wrapper/simple.py b/eval/src/tests/tensor/onnx_wrapper/simple.py new file mode 100755 index 00000000000..a3cd2425d58 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/simple.py @@ -0,0 +1,33 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +QUERY_TENSOR = helper.make_tensor_value_info('query_tensor', TensorProto.FLOAT, [1, 4]) +ATTRIBUTE_TENSOR = helper.make_tensor_value_info('attribute_tensor', TensorProto.FLOAT, [4, 1]) +BIAS_TENSOR = helper.make_tensor_value_info('bias_tensor', TensorProto.FLOAT, [1, 1]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 1]) + +nodes = [ + helper.make_node( + 'MatMul', + ['query_tensor', 'attribute_tensor'], + ['matmul'], + ), + helper.make_node( + 'Add', + ['matmul', 'bias_tensor'], + ['output'], + ), +] +graph_def = helper.make_graph( + nodes, + 'simple_scoring', + [ + QUERY_TENSOR, + ATTRIBUTE_TENSOR, + BIAS_TENSOR, + ], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='simple.py') +onnx.save(model_def, 'simple.onnx') diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 6f9bee025c9..ad182115054 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -200,7 +200,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { } assert(pass_params == PassParams::LAZY); assert(params.size() == 2); - return builder.CreateCall(params[0], {params[1], builder.getInt64(idx)}, "resolve_param"); + return builder.CreateCall(llvm::cast<llvm::FunctionType>(params[0]->getType()->getPointerElementType()), + params[0], {params[1], builder.getInt64(idx)}, "resolve_param"); } //------------------------------------------------------------------------- @@ -252,12 +253,14 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { llvm::Value *eval_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)eval_ptr), eval_funptr_t, "inject_eval"); llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)forest), builder.getVoidTy()->getPointerTo(), "inject_ctx"); if (pass_params == PassParams::ARRAY) { - push(builder.CreateCall(eval_fun, {ctx, params[0]}, "call_eval")); + push(builder.CreateCall(llvm::cast<llvm::FunctionType>(eval_fun->getType()->getPointerElementType()), + eval_fun, {ctx, params[0]}, "call_eval")); } else { assert(pass_params == PassParams::LAZY); llvm::PointerType *proxy_funptr_t = make_eval_forest_proxy_funptr_t(); llvm::Value *proxy_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)vespalib_eval_forest_proxy), proxy_funptr_t, "inject_eval_proxy"); - push(builder.CreateCall(proxy_fun, {eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params)})); + push(builder.CreateCall(llvm::cast<llvm::FunctionType>(proxy_fun->getType()->getPointerElementType()), + proxy_fun, {eval_fun, ctx, params[0], params[1], builder.getInt64(stats.num_params)})); } return true; } @@ -411,7 +414,8 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { llvm::PointerType *funptr_t = make_check_membership_funptr_t(); llvm::Value *call_fun = builder.CreateIntToPtr(builder.getInt64((uint64_t)call_ptr), funptr_t, "inject_call_addr"); llvm::Value *ctx = builder.CreateIntToPtr(builder.getInt64((uint64_t)state), builder.getVoidTy()->getPointerTo(), "inject_ctx"); - push(builder.CreateCall(call_fun, {ctx, lhs}, "call_check_membership")); + push(builder.CreateCall(llvm::cast<llvm::FunctionType>(call_fun->getType()->getPointerElementType()), + call_fun, {ctx, lhs}, "call_check_membership")); } else { // build explicit code to check all set members llvm::Value *found = builder.getFalse(); diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp index a72a24be211..9ed28d87fee 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor.cpp @@ -769,5 +769,14 @@ SimpleTensor::decode(nbostream &input) return builder.build(); } +size_t +SimpleTensor::count_memory_used() const { + size_t result = sizeof(SimpleTensor); + size_t addr_size = sizeof(Label) * _type.dimensions().size(); + size_t cell_size = sizeof(Cell) + addr_size; + result += _cells.size() * cell_size; + return result; +} + } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h index cbf1ac99e05..052d7cb70bd 100644 --- a/eval/src/vespa/eval/eval/simple_tensor.h +++ b/eval/src/vespa/eval/eval/simple_tensor.h @@ -93,6 +93,7 @@ public: static std::unique_ptr<SimpleTensor> concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension); static void encode(const SimpleTensor &tensor, nbostream &output); static std::unique_ptr<SimpleTensor> decode(nbostream &input); + size_t count_memory_used() const; }; } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor.h b/eval/src/vespa/eval/tensor/dense/dense_tensor.h index d0246fef635..4114661a074 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor.h @@ -21,6 +21,11 @@ public: // for unit tests template <typename RCT> bool operator==(const DenseTensor<RCT> &rhs) const; + + size_t count_memory_used() const override { + return sizeof(DenseTensor) + (sizeof(CT) * _cells.size()); + } + private: eval::ValueType _type; std::vector<CT> _cells; diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 93dd2dbedeb..a07a3eede77 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -18,6 +18,7 @@ public: using CellsIterator = DenseTensorCellsIterator; using Address = std::vector<eval::ValueType::Dimension::size_type>; + DenseTensorView(const DenseTensorView &rhs) : DenseTensorView(rhs._typeRef, rhs._cellsRef) {} DenseTensorView(const eval::ValueType &type_in, TypedCells cells_in) : _typeRef(type_in), _cellsRef(cells_in) @@ -43,6 +44,9 @@ public: Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; void accept(TensorVisitor &visitor) const override; + size_t count_memory_used() const override { + return sizeof(DenseTensorView); + } template <typename T> static ConstArrayRef<T> typify_cells(const eval::Value &self) { return static_cast<const DenseTensorView &>(self).cellsRef().typify<T>(); @@ -55,7 +59,6 @@ protected: : _typeRef(type_in), _cellsRef() {} - DenseTensorView(const DenseTensorView &rhs) : DenseTensorView(rhs._typeRef, rhs._cellsRef) {} void initCellsRef(TypedCells cells_in) { assert(_typeRef.cell_type() == cells_in.type); diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp index 125095ff23e..88346213901 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp @@ -18,31 +18,31 @@ namespace vespalib::tensor { namespace { -vespalib::string to_str(OnnxWrapper::TensorInfo::ElementType element_type) { - if (element_type == OnnxWrapper::TensorInfo::ElementType::FLOAT) { +vespalib::string to_str(Onnx::TensorInfo::ElementType element_type) { + if (element_type == Onnx::TensorInfo::ElementType::FLOAT) { return "float"; } - if (element_type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) { + if (element_type == Onnx::TensorInfo::ElementType::DOUBLE) { return "double"; } return "???"; } -ValueType::CellType as_cell_type(OnnxWrapper::TensorInfo::ElementType type) { - if (type == OnnxWrapper::TensorInfo::ElementType::FLOAT) { +ValueType::CellType as_cell_type(Onnx::TensorInfo::ElementType type) { + if (type == Onnx::TensorInfo::ElementType::FLOAT) { return ValueType::CellType::FLOAT; } - if (type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) { + if (type == Onnx::TensorInfo::ElementType::DOUBLE) { return ValueType::CellType::DOUBLE; } abort(); } -auto convert_optimize(OnnxWrapper::Optimize optimize) { - if (optimize == OnnxWrapper::Optimize::ENABLE) { +auto convert_optimize(Onnx::Optimize optimize) { + if (optimize == Onnx::Optimize::ENABLE) { return ORT_ENABLE_ALL; } else { - assert(optimize == OnnxWrapper::Optimize::DISABLE); + assert(optimize == Onnx::Optimize::DISABLE); return ORT_DISABLE_ALL; } } @@ -81,37 +81,77 @@ public: }; Ort::AllocatorWithDefaultOptions OnnxString::_alloc; -std::vector<size_t> make_dimensions(const std::vector<int64_t> &shape) { - std::vector<size_t> result; - for (int64_t size: shape) { - result.push_back(std::max(size, 0L)); - } +std::vector<Onnx::DimSize> make_dimensions(const Ort::TensorTypeAndShapeInfo &tensor_info) { + std::vector<const char *> symbolic_sizes(tensor_info.GetDimensionsCount(), nullptr); + tensor_info.GetSymbolicDimensions(symbolic_sizes.data(), symbolic_sizes.size()); + auto shape = tensor_info.GetShape(); + std::vector<Onnx::DimSize> result; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] > 0) { + result.emplace_back(shape[i]); + } else if (symbolic_sizes[i] != nullptr) { + result.emplace_back(vespalib::string(symbolic_sizes[i])); + } else { + result.emplace_back(); + } + } return result; } -OnnxWrapper::TensorInfo::ElementType make_element_type(ONNXTensorElementDataType element_type) { +Onnx::TensorInfo::ElementType make_element_type(ONNXTensorElementDataType element_type) { if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - return OnnxWrapper::TensorInfo::ElementType::FLOAT; + return Onnx::TensorInfo::ElementType::FLOAT; } else if (element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - return OnnxWrapper::TensorInfo::ElementType::DOUBLE; + return Onnx::TensorInfo::ElementType::DOUBLE; } else { - return OnnxWrapper::TensorInfo::ElementType::UNKNOWN; + return Onnx::TensorInfo::ElementType::UNKNOWN; } } -OnnxWrapper::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) { +Onnx::TensorInfo make_tensor_info(const OnnxString &name, const Ort::TypeInfo &type_info) { auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); - auto shape = tensor_info.GetShape(); auto element_type = tensor_info.GetElementType(); - return OnnxWrapper::TensorInfo{vespalib::string(name.get()), make_dimensions(shape), make_element_type(element_type)}; + return Onnx::TensorInfo{vespalib::string(name.get()), make_dimensions(tensor_info), make_element_type(element_type)}; } } +vespalib::string +Onnx::DimSize::as_string() const +{ + if (is_known()) { + return fmt("[%zu]", value); + } else if (is_symbolic()) { + return fmt("[%s]", name.c_str()); + } else { + return "[]"; + } +} + +vespalib::string +Onnx::TensorInfo::type_as_string() const +{ + vespalib::string res = to_str(elements); + for (const auto &dim: dimensions) { + res += dim.as_string(); + } + return res; +} + +Onnx::TensorInfo::~TensorInfo() = default; + +//----------------------------------------------------------------------------- + +Onnx::WirePlanner::~WirePlanner() = default; + bool -OnnxWrapper::TensorInfo::is_compatible(const eval::ValueType &type) const +Onnx::WirePlanner::bind_input_type(const eval::ValueType &vespa_in, const TensorInfo &onnx_in) { - if ((elements == ElementType::UNKNOWN) || dimensions.empty()) { + const auto &type = vespa_in; + const auto &name = onnx_in.name; + const auto &dimensions = onnx_in.dimensions; + const auto &elements = onnx_in.elements; + if ((elements == TensorInfo::ElementType::UNKNOWN) || dimensions.empty()) { return false; } if (type.cell_type() != as_cell_type(elements)) { @@ -121,21 +161,41 @@ OnnxWrapper::TensorInfo::is_compatible(const eval::ValueType &type) const return false; } for (size_t i = 0; i < dimensions.size(); ++i) { - if (type.dimensions()[i].size != dimensions[i]) { - return false; + if (dimensions[i].is_known()) { + if (dimensions[i].value != type.dimensions()[i].size) { + return false; + } + } else if (dimensions[i].is_symbolic()) { + auto &bound_size = _symbolic_sizes[dimensions[i].name]; + if (bound_size == 0) { + bound_size = type.dimensions()[i].size; + } else if (bound_size != type.dimensions()[i].size) { + return false; + } + } else { + _unknown_sizes[std::make_pair(name,i)] = type.dimensions()[i].size; } } return true; } eval::ValueType -OnnxWrapper::TensorInfo::make_compatible_type() const +Onnx::WirePlanner::make_output_type(const TensorInfo &onnx_out) const { - if ((elements == ElementType::UNKNOWN) || dimensions.empty()) { + const auto &dimensions = onnx_out.dimensions; + const auto &elements = onnx_out.elements; + if ((elements == TensorInfo::ElementType::UNKNOWN) || dimensions.empty()) { return ValueType::error_type(); } std::vector<ValueType::Dimension> dim_list; - for (size_t dim_size: dimensions) { + for (const auto &dim: dimensions) { + size_t dim_size = dim.value; + if (dim.is_symbolic()) { + auto pos = _symbolic_sizes.find(dim.name); + if (pos != _symbolic_sizes.end()) { + dim_size = pos->second; + } + } if ((dim_size == 0) || (dim_list.size() > 9)) { return ValueType::error_type(); } @@ -144,71 +204,131 @@ OnnxWrapper::TensorInfo::make_compatible_type() const return ValueType::tensor_type(std::move(dim_list), as_cell_type(elements)); } -vespalib::string -OnnxWrapper::TensorInfo::type_as_string() const +Onnx::WireInfo +Onnx::WirePlanner::get_wire_info(const Onnx &model) const { - vespalib::string res = to_str(elements); - for (size_t dim_size: dimensions) { - if (dim_size == 0) { - res += "[]"; - } else { - res += fmt("[%zu]", dim_size); + WireInfo info; + for (const auto &input: model.inputs()) { + size_t input_idx = 0; + std::vector<int64_t> sizes; + for (const auto &dim: input.dimensions) { + if (dim.is_known()) { + sizes.push_back(dim.value); + } else if (dim.is_symbolic()) { + const auto &pos = _symbolic_sizes.find(dim.name); + assert(pos != _symbolic_sizes.end()); + sizes.push_back(pos->second); + } else { + const auto &pos = _unknown_sizes.find(std::make_pair(input.name, input_idx)); + assert(pos != _unknown_sizes.end()); + sizes.push_back(pos->second); + } + ++input_idx; } + info.input_sizes.push_back(sizes); } - return res; + for (const auto &output: model.outputs()) { + info.output_types.push_back(make_output_type(output)); + } + return info; } -OnnxWrapper::TensorInfo::~TensorInfo() = default; +//----------------------------------------------------------------------------- -OnnxWrapper::Shared::Shared() - : _env(ORT_LOGGING_LEVEL_WARNING, "vespa-onnx-wrapper") +Ort::AllocatorWithDefaultOptions Onnx::EvalContext::_alloc; + +Onnx::EvalContext::EvalContext(const Onnx &model, const WireInfo &wire_info) + : _model(model), + _wire_info(wire_info), + _cpu_memory(Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault)), + _param_values(), + _result_values(), + _result_views() { + assert(_wire_info.input_sizes.size() == _model.inputs().size()); + assert(_wire_info.output_types.size() == _model.outputs().size()); + for (const auto &input: _wire_info.input_sizes) { + (void) input; + _param_values.push_back(Ort::Value(nullptr)); + } + std::vector<int64_t> dim_sizes; + size_t num_cells; + dim_sizes.reserve(16); + // NB: output type must be reference inside vector since the view does not copy it + for (const auto &output: _wire_info.output_types) { + num_cells = 1; + dim_sizes.clear(); + for (const auto &dim: output.dimensions()) { + dim_sizes.push_back(dim.size); + num_cells *= dim.size; + } + if (output.cell_type() == ValueType::CellType::FLOAT) { + _result_values.push_back(Ort::Value::CreateTensor<float>(_alloc, dim_sizes.data(), dim_sizes.size())); + ConstArrayRef<float> cells(_result_values.back().GetTensorMutableData<float>(), num_cells); + _result_views.emplace_back(output, TypedCells(cells)); + } else { + assert(output.cell_type() == ValueType::CellType::DOUBLE); + _result_values.push_back(Ort::Value::CreateTensor<double>(_alloc, dim_sizes.data(), dim_sizes.size())); + ConstArrayRef<double> cells(_result_values.back().GetTensorMutableData<double>(), num_cells); + _result_views.emplace_back(output, TypedCells(cells)); + } + } } +Onnx::EvalContext::~EvalContext() = default; + void -OnnxWrapper::Params::bind(size_t idx, const DenseTensorView &src) +Onnx::EvalContext::bind_param(size_t i, const eval::Value ¶m) { - assert(idx == values.size()); - std::vector<int64_t> dim_sizes; - for (const auto &dim: src.fast_type().dimensions()) { - dim_sizes.push_back(dim.size); - } - auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); - if (src.fast_type().cell_type() == ValueType::CellType::FLOAT) { + // NB: dense tensors are always (sub)classes of DenseTensorView + const auto &cells_ref = static_cast<const DenseTensorView &>(param).cellsRef(); + const auto &input_sizes = _wire_info.input_sizes; + if (cells_ref.type == ValueType::CellType::FLOAT) { // NB: create requires non-const input - auto cells = unconstify(src.cellsRef().typify<float>()); - values.push_back(Ort::Value::CreateTensor<float>(memory_info, cells.begin(), cells.size(), dim_sizes.data(), dim_sizes.size())); - } else if (src.fast_type().cell_type() == ValueType::CellType::DOUBLE) { + auto cells = unconstify(cells_ref.typify<float>()); + _param_values[i] = Ort::Value::CreateTensor<float>(_cpu_memory, cells.begin(), cells.size(), input_sizes[i].data(), input_sizes[i].size()); + } else { + assert(cells_ref.type == ValueType::CellType::DOUBLE); // NB: create requires non-const input - auto cells = unconstify(src.cellsRef().typify<double>()); - values.push_back(Ort::Value::CreateTensor<double>(memory_info, cells.begin(), cells.size(), dim_sizes.data(), dim_sizes.size())); + auto cells = unconstify(cells_ref.typify<double>()); + _param_values[i] = Ort::Value::CreateTensor<double>(_cpu_memory, cells.begin(), cells.size(), input_sizes[i].data(), input_sizes[i].size()); } } void -OnnxWrapper::Result::get(size_t idx, MutableDenseTensorView &dst) +Onnx::EvalContext::eval() { - assert(values[idx].IsTensor()); - auto meta = values[idx].GetTensorTypeAndShapeInfo(); - if (dst.fast_type().cell_type() == ValueType::CellType::FLOAT) { - assert(meta.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); - ConstArrayRef<float> cells(values[idx].GetTensorMutableData<float>(), meta.GetElementCount()); - dst.setCells(TypedCells(cells)); - } else if (dst.fast_type().cell_type() == ValueType::CellType::DOUBLE) { - assert(meta.GetElementType() == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE); - ConstArrayRef<double> cells(values[idx].GetTensorMutableData<double>(), meta.GetElementCount()); - dst.setCells(TypedCells(cells)); - } + // NB: Run requires non-const session + Ort::Session &session = const_cast<Ort::Session&>(_model._session); + Ort::RunOptions run_opts(nullptr); + session.Run(run_opts, + _model._input_name_refs.data(), _param_values.data(), _param_values.size(), + _model._output_name_refs.data(), _result_values.data(), _result_values.size()); } -OnnxWrapper::Shared & -OnnxWrapper::Shared::get() { +const eval::Value & +Onnx::EvalContext::get_result(size_t i) const +{ + return _result_views[i]; +} + +//----------------------------------------------------------------------------- + +Onnx::Shared::Shared() + : _env(ORT_LOGGING_LEVEL_WARNING, "vespa-onnx-wrapper") +{ +} + +Onnx::Shared & +Onnx::Shared::get() { static Shared shared; return shared; } +//----------------------------------------------------------------------------- + void -OnnxWrapper::extract_meta_data() +Onnx::extract_meta_data() { Ort::AllocatorWithDefaultOptions allocator; size_t num_inputs = _session.GetInputCount(); @@ -227,7 +347,7 @@ OnnxWrapper::extract_meta_data() } } -OnnxWrapper::OnnxWrapper(const vespalib::string &model_file, Optimize optimize) +Onnx::Onnx(const vespalib::string &model_file, Optimize optimize) : _shared(Shared::get()), _options(), _session(nullptr), @@ -243,17 +363,6 @@ OnnxWrapper::OnnxWrapper(const vespalib::string &model_file, Optimize optimize) extract_meta_data(); } -OnnxWrapper::~OnnxWrapper() = default; - -OnnxWrapper::Result -OnnxWrapper::eval(const Params ¶ms) const -{ - assert(params.values.size() == _inputs.size()); - Ort::RunOptions run_opts(nullptr); - // NB: Run requires non-const session - Ort::Session &session = const_cast<Ort::Session&>(_session); - return Result(session.Run(run_opts, _input_name_refs.data(), params.values.data(), _inputs.size(), - _output_name_refs.data(), _outputs.size())); -} +Onnx::~Onnx() = default; } diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h index abe1da252c7..23ddbcb8885 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h @@ -2,56 +2,101 @@ #pragma once +#include "dense_tensor_view.h" #include <onnxruntime/onnxruntime_cxx_api.h> #include <vespa/vespalib/stllike/string.h> #include <vespa/eval/eval/value_type.h> #include <vector> +#include <map> -namespace vespalib::tensor { +namespace vespalib::eval { struct Value; } -class DenseTensorView; -class MutableDenseTensorView; +namespace vespalib::tensor { /** * Wrapper around an ONNX model handeled by onnxruntime. + * + * Create an Onnx object that will load your model and extract + * information about inputs and outputs. Use an Onnx::WirePlanner to + * bind vespa value types to each of the onnx model inputs. Ask the + * wire planner about the vespa value types corresponding to each of + * the model outputs for external wiring. Use the wire planner to make + * a WireInfo object which is a simple struct indicating the concrete + * onnx and vespa types to be used when converting inputs and + * outputs. Create an Onnx::EvalContex based on the model and the wire + * plan. Bind actual vespa values to the model inputs, invoke eval and + * inspect the results. See the unit test (tests/tensor/onnx_wrapper) + * for some examples. **/ -class OnnxWrapper { +class Onnx { public: // model optimization enum class Optimize { ENABLE, DISABLE }; + // the size of a dimension + struct DimSize { + size_t value; + vespalib::string name; + DimSize() : value(0), name() {} + DimSize(size_t size) : value(size), name() {} + DimSize(const vespalib::string &symbol) : value(0), name(symbol) {} + bool is_known() const { return (value > 0); } + bool is_symbolic() const { return !name.empty(); } + vespalib::string as_string() const; + }; + // information about a single input or output tensor struct TensorInfo { enum class ElementType { FLOAT, DOUBLE, UNKNOWN }; vespalib::string name; - std::vector<size_t> dimensions; + std::vector<DimSize> dimensions; ElementType elements; - bool is_compatible(const eval::ValueType &type) const; - eval::ValueType make_compatible_type() const; vespalib::string type_as_string() const; ~TensorInfo(); }; - // used to build model parameters - class Params { - friend class OnnxWrapper; + // how the model should be wired with inputs/outputs + struct WireInfo { + std::vector<std::vector<int64_t>> input_sizes; + std::vector<eval::ValueType> output_types; + WireInfo() : input_sizes(), output_types() {} + }; + + // planning how we should wire the model based on input types + class WirePlanner { private: - std::vector<Ort::Value> values; + std::map<vespalib::string,size_t> _symbolic_sizes; + std::map<std::pair<vespalib::string,size_t>,size_t> _unknown_sizes; public: - Params() : values() {} - void bind(size_t idx, const DenseTensorView &src); + WirePlanner() : _symbolic_sizes(), _unknown_sizes() {} + ~WirePlanner(); + bool bind_input_type(const eval::ValueType &vespa_in, const TensorInfo &onnx_in); + eval::ValueType make_output_type(const TensorInfo &onnx_out) const; + WireInfo get_wire_info(const Onnx &model) const; }; - // used to inspect model results - class Result { - friend class OnnxWrapper; + // evaluation context; use one per thread and keep model/wire_info alive + // all parameter values are expected to be bound per evaluation + // output values are pre-allocated and will not change + class EvalContext { private: - std::vector<Ort::Value> values; - Result(std::vector<Ort::Value> values_in) : values(std::move(values_in)) {} + static Ort::AllocatorWithDefaultOptions _alloc; + + const Onnx &_model; + const WireInfo &_wire_info; + Ort::MemoryInfo _cpu_memory; + std::vector<Ort::Value> _param_values; + std::vector<Ort::Value> _result_values; + std::vector<DenseTensorView> _result_views; + public: - static Result make_empty() { return Result({}); } - size_t num_values() const { return values.size(); } - void get(size_t idx, MutableDenseTensorView &dst); + EvalContext(const Onnx &model, const WireInfo &wire_info); + ~EvalContext(); + size_t num_params() const { return _param_values.size(); } + size_t num_results() const { return _result_values.size(); } + void bind_param(size_t i, const eval::Value ¶m); + void eval(); + const eval::Value &get_result(size_t i) const; }; private: @@ -76,11 +121,10 @@ private: void extract_meta_data(); public: - OnnxWrapper(const vespalib::string &model_file, Optimize optimize); - ~OnnxWrapper(); + Onnx(const vespalib::string &model_file, Optimize optimize); + ~Onnx(); const std::vector<TensorInfo> &inputs() const { return _inputs; } const std::vector<TensorInfo> &outputs() const { return _outputs; } - Result eval(const Params ¶ms) const; }; } diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index d183c33f5cd..db35de6786d 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -243,6 +243,16 @@ SparseTensor::remove(const CellValues &cellAddresses) const return remover.build(); } +size_t +SparseTensor::count_memory_used() const +{ + size_t result = sizeof(SparseTensor) + _cells.getMemoryConsumption(); + for (const auto &cell : _cells) { + result += cell.first.size(); + } + return result; +} + } VESPALIB_HASH_MAP_INSTANTIATE(vespalib::tensor::SparseTensorAddressRef, double); diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index e5ea639b460..6bd181e1895 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -53,6 +53,7 @@ public: Tensor::UP clone() const override; eval::TensorSpec toSpec() const override; void accept(TensorVisitor &visitor) const override; + size_t count_memory_used() const override; }; } diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h index d822c99a6d8..bef7309c609 100644 --- a/eval/src/vespa/eval/tensor/tensor.h +++ b/eval/src/vespa/eval/tensor/tensor.h @@ -59,6 +59,7 @@ public: virtual Tensor::UP clone() const = 0; // want to remove, but needed by document virtual eval::TensorSpec toSpec() const = 0; virtual void accept(TensorVisitor &visitor) const = 0; + virtual size_t count_memory_used() const = 0; using TypeList = std::initializer_list<std::reference_wrapper<const eval::ValueType>>; static bool supported(TypeList types); diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 7c09bc4e4ab..fe73bf92063 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -54,6 +54,16 @@ WrappedSimpleTensor::accept(TensorVisitor &visitor) const } } +size_t +WrappedSimpleTensor::count_memory_used() const +{ + size_t result = sizeof(WrappedSimpleTensor); + if (_space) { + result += _space->count_memory_used(); + } + return result; +} + Tensor::UP WrappedSimpleTensor::clone() const { diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h index 12ee1237d67..6b549718a29 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h @@ -33,6 +33,7 @@ public: eval::TensorSpec toSpec() const override; double as_double() const override; void accept(TensorVisitor &visitor) const override; + size_t count_memory_used() const override; Tensor::UP clone() const override; // functions below should not be used for this implementation Tensor::UP apply(const CellFunction &) const override; diff --git a/metrics/src/tests/metricmanagertest.cpp b/metrics/src/tests/metricmanagertest.cpp index 6407bb73ecb..47515d1bc4c 100644 --- a/metrics/src/tests/metricmanagertest.cpp +++ b/metrics/src/tests/metricmanagertest.cpp @@ -11,6 +11,7 @@ #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/xmlstream.h> #include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <thread> #include <vespa/log/log.h> diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java index cbc5a44ae94..ad85235fc69 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java @@ -96,6 +96,11 @@ public class NodeList extends AbstractFilteringList<Node, NodeList> { .orElse(Version.emptyVersion))); } + /** Returns the subset of nodes that are currently on a lower version than the given version */ + public NodeList osVersionIsBefore(Version version) { + return matching(node -> node.status().osVersion().isBefore(version)); + } + /** Returns the subset of nodes that are currently on the given OS version */ public NodeList onOsVersion(Version version) { return matching(node -> node.status().osVersion().matches(version)); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java index be1190ccff4..0e3b6715ff1 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/OsUpgradeActivator.java @@ -26,8 +26,8 @@ public class OsUpgradeActivator extends NodeRepositoryMaintainer { protected boolean maintain() { for (var nodeType : NodeType.values()) { if (!nodeType.isHost()) continue; - var active = canUpgradeOsOf(nodeType); - nodeRepository().osVersions().resumeUpgradeOf(nodeType, active); + boolean resume = canUpgradeOsOf(nodeType); + nodeRepository().osVersions().resumeUpgradeOf(nodeType, resume); } return true; } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java index 1216c060181..0385e2e3df6 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/OsVersion.java @@ -43,6 +43,11 @@ public class OsVersion { return wanted.isPresent() && !current.equals(wanted); } + /** Returns whether this is before the given version */ + public boolean isBefore(Version version) { + return current.isEmpty() || current.get().isBefore(version); + } + /** Returns whether current version matches given version */ public boolean matches(Version version) { return current.isPresent() && current.get().equals(version); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java index 03d04a5f6cf..74b288d77c5 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/DelegatingUpgrader.java @@ -38,7 +38,7 @@ public class DelegatingUpgrader implements Upgrader { NodeList activeNodes = nodeRepository.list().nodeType(target.nodeType()).state(Node.State.active); int numberToUpgrade = Math.max(0, maxActiveUpgrades - activeNodes.changingOsVersionTo(target.version()).size()); NodeList nodesToUpgrade = activeNodes.not().changingOsVersionTo(target.version()) - .not().onOsVersion(target.version()) + .osVersionIsBefore(target.version()) .byIncreasingOsVersion() .first(numberToUpgrade); if (nodesToUpgrade.size() == 0) return; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java index aebf14ab13f..b4e21b22cd2 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/os/RetiringUpgrader.java @@ -46,7 +46,7 @@ public class RetiringUpgrader implements Upgrader { Instant retiredAt = target.lastRetiredAt().orElse(Instant.EPOCH); if (now.isBefore(retiredAt.plus(nodeBudget))) return; // Budget has not been spent yet - activeNodes.not().onOsVersion(target.version()) + activeNodes.osVersionIsBefore(target.version()) .not().deprovisioning() .byIncreasingOsVersion() .first(1) diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java index 914008af227..6a41e766ace 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/os/OsVersionsTest.java @@ -38,7 +38,7 @@ public class OsVersionsTest { private final ApplicationId infraApplication = ApplicationId.from("hosted-vespa", "infra", "default"); @Test - public void versions() { + public void upgrade() { var versions = new OsVersions(tester.nodeRepository(), new DelegatingUpgrader(tester.nodeRepository(), Integer.MAX_VALUE)); provisionInfraApplication(10); Supplier<List<Node>> hostNodes = () -> tester.nodeRepository().getNodes(NodeType.host); @@ -50,18 +50,28 @@ public class OsVersionsTest { assertEquals(version1, versions.targetFor(NodeType.host).get()); assertTrue("Per-node wanted OS version remains unset", hostNodes.get().stream().allMatch(node -> node.status().osVersion().wanted().isEmpty())); + // One host upgrades to a later version outside the control of orchestration + Node hostOnLaterVersion = hostNodes.get().get(0); + setCurrentVersion(List.of(hostOnLaterVersion), Version.fromString("8.1")); + // Upgrade OS again var version2 = Version.fromString("7.2"); versions.setTarget(NodeType.host, version2, Optional.empty(), false); assertEquals(version2, versions.targetFor(NodeType.host).get()); - // Target can be (de)activated + // Resume upgrade versions.resumeUpgradeOf(NodeType.host, true); - assertTrue("Target version activated", hostNodes.get().stream() - .allMatch(node -> node.status().osVersion().wanted().isPresent())); + List<Node> allHosts = hostNodes.get(); + assertTrue("Wanted version is set", allHosts.stream() + .filter(node -> !node.equals(hostOnLaterVersion)) + .allMatch(node -> node.status().osVersion().wanted().isPresent())); + assertTrue("Wanted version is not set for host on later version", + allHosts.get(0).status().osVersion().wanted().isEmpty()); + + // Halt upgrade versions.resumeUpgradeOf(NodeType.host, false); - assertTrue("Target version deactivated", hostNodes.get().stream() - .allMatch(node -> node.status().osVersion().wanted().isEmpty())); + assertTrue("Wanted version is unset", hostNodes.get().stream() + .allMatch(node -> node.status().osVersion().wanted().isEmpty())); // Downgrading fails try { diff --git a/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp b/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp index 575033ad19a..3b42a399888 100644 --- a/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp +++ b/searchcore/src/tests/proton/common/pendinglidtracker_test.cpp @@ -12,6 +12,8 @@ constexpr uint32_t LID_1 = 1u; const std::vector<uint32_t> LIDV_2_1_3({2u, LID_1, 3u}); const std::vector<uint32_t> LIDV_2_3({2u, 3u}); +namespace proton { + std::ostream & operator << (std::ostream & os, ILidCommitState::State state) { switch (state) { @@ -28,6 +30,8 @@ operator << (std::ostream & os, ILidCommitState::State state) { return os; } +} + void verifyPhase1ProduceAndNeedCommit(PendingLidTrackerBase & tracker, ILidCommitState::State expected) { EXPECT_EQUAL(ILidCommitState::State::COMPLETED, tracker.getState()); diff --git a/searchcore/src/tests/proton/docsummary/docsummary.cpp b/searchcore/src/tests/proton/docsummary/docsummary.cpp index 7d27c3b21f4..92117e174e9 100644 --- a/searchcore/src/tests/proton/docsummary/docsummary.cpp +++ b/searchcore/src/tests/proton/docsummary/docsummary.cpp @@ -31,6 +31,7 @@ #include <vespa/searchlib/transactionlog/nosyncproxy.h> #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/vespalib/encoding/base64.h> #include <vespa/config-bucketspaces.h> #include <vespa/vespalib/testkit/testapp.h> diff --git a/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp b/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp index 4a01d7ae3e8..18b3a5c5d8e 100644 --- a/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp +++ b/searchcore/src/tests/proton/documentdb/feedhandler/feedhandler_test.cpp @@ -23,7 +23,6 @@ #include <vespa/searchcore/proton/server/i_feed_handler_owner.h> #include <vespa/searchcore/proton/server/ireplayconfig.h> #include <vespa/searchcore/proton/test/dummy_feed_view.h> -#include <vespa/searchlib/common/idestructorcallback.h> #include <vespa/searchlib/index/docbuilder.h> #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/searchlib/transactionlog/translogserver.h> diff --git a/searchcore/src/tests/proton/server/feedstates_test.cpp b/searchcore/src/tests/proton/server/feedstates_test.cpp index fd1e24c1f17..15083975824 100644 --- a/searchcore/src/tests/proton/server/feedstates_test.cpp +++ b/searchcore/src/tests/proton/server/feedstates_test.cpp @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. // Unit tests for feedstates. - #include <vespa/document/base/documentid.h> #include <vespa/document/base/testdocrepo.h> #include <vespa/document/bucket/bucketid.h> @@ -102,11 +101,10 @@ struct RemoveOperationContext RemoveOperationContext::RemoveOperationContext(search::SerialNum serial) : doc_id("id:ns:doctypename::bar"), op(BucketFactory::getBucketId(doc_id), Timestamp(10), doc_id), - str(), packet() + str(), packet(std::make_unique<Packet>(0xf000)) { op.serialize(str); ConstBufferRef buf(str.data(), str.wp()); - packet = std::make_unique<Packet>(); packet->add(Packet::Entry(serial, FeedOperation::REMOVE, buf)); } RemoveOperationContext::~RemoveOperationContext() = default; diff --git a/searchcore/src/tests/proton/summaryengine/summaryengine.cpp b/searchcore/src/tests/proton/summaryengine/summaryengine.cpp index 7cdd8d767c6..7e5e3527b1d 100644 --- a/searchcore/src/tests/proton/summaryengine/summaryengine.cpp +++ b/searchcore/src/tests/proton/summaryengine/summaryengine.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/util/rawbuf.h> #include <vespa/searchlib/util/slime_output_raw_buf_adapter.h> #include <vespa/vespalib/data/databuffer.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/vespalib/util/compressor.h> #include <vespa/searchsummary/docsummary/docsumwriter.h> #include <vespa/metrics/metricset.h> diff --git a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp index 96a93f0ac16..028b5d38ae9 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentdb.cpp @@ -564,6 +564,9 @@ DocumentDB::close() // Abort any ongoing maintenance stopMaintenance(); + _visibility.commit(); + _writeService.sync(); + // The attributes in the ready sub db is also the total set of attributes. DocumentDBTaggedMetrics &metrics = getMetrics(); _metricsWireService.cleanAttributes(metrics.ready.attributes); @@ -905,6 +908,11 @@ DocumentDB::syncFeedView() return; IFeedView::SP oldFeedView(_feedView.get()); IFeedView::SP newFeedView(_subDBs.getFeedView()); + + _writeService.sync(); + _visibility.commit(); + _writeService.sync(); + _feedView.set(newFeedView); _feedHandler.setActiveFeedView(newFeedView.get()); _subDBs.createRetrievers(); @@ -980,6 +988,7 @@ void DocumentDB::stopMaintenance() { _maintenanceController.stop(); + _writeService.sync(); } void diff --git a/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp b/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp index c14ed3bb1d9..d01c25d9c1e 100644 --- a/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/feedstates.cpp @@ -82,6 +82,8 @@ public: _commitTimeTracker(100ms) { } + ~TransactionLogReplayPacketHandler() override = default; + void replay(const PutOperation &op) override { _feed_view_ptr->handlePut(FeedToken(), op); } @@ -153,6 +155,8 @@ ReplayTransactionLogState::ReplayTransactionLogState( _packet_handler(std::make_unique<TransactionLogReplayPacketHandler>(feed_view_ptr, bucketDBHandler, replay_config, config_store)) { } +ReplayTransactionLogState::~ReplayTransactionLogState() = default; + void ReplayTransactionLogState::receive(const PacketWrapper::SP &wrap, Executor &executor) { EntryHandler closure = makeClosure(&startDispatch, _packet_handler.get()); diff --git a/searchcore/src/vespa/searchcore/proton/server/feedstates.h b/searchcore/src/vespa/searchcore/proton/server/feedstates.h index bf376bb8065..2cf0ee1a4dd 100644 --- a/searchcore/src/vespa/searchcore/proton/server/feedstates.h +++ b/searchcore/src/vespa/searchcore/proton/server/feedstates.h @@ -55,6 +55,7 @@ public: IReplayConfig &replay_config, FeedConfigStore &config_store); + ~ReplayTransactionLogState() override; void handleOperation(FeedToken, FeedOperationUP op) override { throwExceptionInHandleOperation(_doc_type_name, *op); } diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp index f29e54ba725..e822b1de33e 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.cpp @@ -6,6 +6,7 @@ #include "i_blockable_maintenance_job.h" #include <vespa/searchcorespi/index/i_thread_service.h> #include <vespa/vespalib/util/closuretask.h> +#include <vespa/vespalib/util/lambdatask.h> #include <vespa/vespalib/util/scheduledexecutor.h> #include <vespa/log/log.h> @@ -15,6 +16,7 @@ using document::BucketId; using vespalib::Executor; using vespalib::makeClosure; using vespalib::makeTask; +using vespalib::makeLambdaTask; namespace proton { @@ -84,8 +86,8 @@ MaintenanceController::registerJob(Executor & executor, IMaintenanceJob::UP job) void MaintenanceController::killJobs() { - // Called by master write thread during start/reconfig - // Called by other thread during stop + // Called by master write thread + assert(_masterThread.isCurrentThread()); LOG(debug, "killJobs(): threadId=%zu", (size_t)FastOS_Thread::GetCurrentThreadId()); _periodicTimer.reset(); // No need to take _jobsLock as modification of _jobs also happens in master write thread. @@ -94,24 +96,13 @@ MaintenanceController::killJobs() } _defaultExecutor.sync(); _defaultExecutor.sync(); - if (_masterThread.isCurrentThread()) { - JobList tmpJobs = _jobs; - { - Guard guard(_jobsLock); - _jobs.clear(); - } - // Hold jobs until existing tasks have been drained - _masterThread.execute(makeTask(makeClosure(this, &MaintenanceController::performHoldJobs, tmpJobs))); - } else { - // Wait for all tasks to be finished. - // NOTE: We must sync 2 times as a task currently being executed can add a new - // task to the executor as it might not see the new value of the stopped flag. - _masterThread.sync(); - _masterThread.sync(); - // Clear jobs in master write thread, to avoid races - _masterThread.execute(makeTask(makeClosure(this, &MaintenanceController::performClearJobs))); - _masterThread.sync(); + JobList tmpJobs = _jobs; + { + Guard guard(_jobsLock); + _jobs.clear(); } + // Hold jobs until existing tasks have been drained + _masterThread.execute(makeTask(makeClosure(this, &MaintenanceController::performHoldJobs, tmpJobs))); } void @@ -123,21 +114,12 @@ MaintenanceController::performHoldJobs(JobList jobs) } void -MaintenanceController::performClearJobs() -{ - // Called by master write thread - LOG(debug, "performClearJobs(): threadId=%zu", (size_t)FastOS_Thread::GetCurrentThreadId()); - Guard guard(_jobsLock); - _jobs.clear(); -} - - -void MaintenanceController::stop() { assert(!_masterThread.isCurrentThread()); - _stopping = true; - killJobs(); + _masterThread.execute(makeLambdaTask([this]() { _stopping = true; killJobs(); })); + _masterThread.sync(); // Wait for killJobs() + _masterThread.sync(); // Wait for already scheduled maintenance jobs and performHoldJobs } void diff --git a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h index 3cfdeba4d34..ece92adebd0 100644 --- a/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h +++ b/searchcore/src/vespa/searchcore/proton/server/maintenancecontroller.h @@ -90,7 +90,6 @@ private: void addJobsToPeriodicTimer(); void restart(); void notifyThawedBucket(const document::BucketId &bucket) override; - void performClearJobs(); void performHoldJobs(JobList jobs); void registerJob(vespalib::Executor & executor, IMaintenanceJob::UP job); }; diff --git a/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp b/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp index bbd02d7efce..baba74c482c 100644 --- a/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/tlcproxy.cpp @@ -15,9 +15,8 @@ void TlcProxy::commit(search::SerialNum serialNum, search::transactionlog::Type const vespalib::nbostream &buf, DoneCallback onDone) { Packet::Entry entry(serialNum, type, vespalib::ConstBufferRef(buf.data(), buf.size())); - Packet packet; + Packet packet(entry.serializedSize()); packet.add(entry); - packet.close(); _tlsDirectWriter.commit(_domain, packet, std::move(onDone)); } diff --git a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp b/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp index a3524ae79f3..3a44af517ee 100644 --- a/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/visibilityhandler.cpp @@ -2,10 +2,9 @@ #include "visibilityhandler.h" #include <vespa/vespalib/util/isequencedtaskexecutor.h> -#include <vespa/vespalib/util/closuretask.h> +#include <vespa/vespalib/util/lambdatask.h> -using vespalib::makeTask; -using vespalib::makeClosure; +using vespalib::makeLambdaTask; namespace proton { @@ -81,8 +80,7 @@ VisibilityHandler::startCommit(const std::lock_guard<std::mutex> &unused, bool f (void) unused; SerialNum current = _serial.getSerialNum(); if ((current > _lastCommitSerialNum) || force) { - _writeService.master().execute(makeTask(makeClosure(this, - &VisibilityHandler::performCommit, force))); + _writeService.master().execute(makeLambdaTask([this, force]() { performCommit(force);})); return true; } return false; @@ -95,8 +93,10 @@ VisibilityHandler::performCommit(bool force) SerialNum current = _serial.getSerialNum(); if ((current > _lastCommitSerialNum) || force) { IFeedView::SP feedView(_feedView.get()); - feedView->forceCommit(current); - _lastCommitSerialNum = current; + if (feedView) { + feedView->forceCommit(current); + _lastCommitSerialNum = current; + } } } diff --git a/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp b/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp index 5bb36f9c828..edc3b86d9d3 100644 --- a/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp +++ b/searchcore/src/vespa/searchcore/proton/summaryengine/docsum_by_slime.cpp @@ -20,7 +20,6 @@ using vespalib::Memory; using vespalib::slime::Symbol; using vespalib::slime::BinaryFormat; using vespalib::slime::ArrayTraverser; -using vespalib::SimpleBuffer; using vespalib::DataBuffer; using vespalib::ConstBufferRef; using vespalib::compression::CompressionConfig; diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index c8bfc0f1926..4d1a8a82211 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -217,6 +217,7 @@ vespa_define_module( src/tests/sortspec src/tests/stringenum src/tests/tensor/dense_tensor_store + src/tests/tensor/direct_tensor_store src/tests/tensor/distance_functions src/tests/tensor/hnsw_index src/tests/tensor/hnsw_saver diff --git a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp index 24919fb2341..04d2dfe4d52 100644 --- a/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp +++ b/searchlib/src/tests/attribute/bitvector/bitvector_test.cpp @@ -491,7 +491,7 @@ BitVectorTest::test(BasicType bt, v->asDocumentWeightAttribute(); if (dwa != NULL) { search::IDocumentWeightAttribute::LookupResult lres = - dwa->lookup(getSearchStr<VectorType>()); + dwa->lookup(getSearchStr<VectorType>(), dwa->get_dictionary_snapshot()); typedef search::queryeval::DocumentWeightSearchIterator DWSI; typedef search::queryeval::SearchIterator SI; TermFieldMatchData md; diff --git a/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp b/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp index cf1506a9118..d8a1d03f1a8 100644 --- a/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp +++ b/searchlib/src/tests/attribute/document_weight_iterator/document_weight_iterator_test.cpp @@ -3,6 +3,7 @@ #include <vespa/searchlib/attribute/attribute.h> #include <vespa/searchlib/attribute/attributefactory.h> #include <vespa/searchlib/attribute/attributeguard.h> +#include <vespa/searchlib/attribute/attribute_read_guard.h> #include <vespa/searchlib/attribute/attributememorysavetarget.h> #include <vespa/searchlib/attribute/attributevector.h> #include <vespa/searchlib/attribute/attributevector.hpp> @@ -22,6 +23,7 @@ #include <vespa/searchlib/test/searchiteratorverifier.h> #include <vespa/searchlib/util/randomgenerator.h> #include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/test/insertion_operators.h> #include <vespa/log/log.h> LOG_SETUP("document_weight_iterator_test"); @@ -124,17 +126,17 @@ void verify_invalid_lookup(IDocumentWeightAttribute::LookupResult result) { } TEST_F("require that integer lookup works correctly", LongFixture) { - verify_valid_lookup(f1.api->lookup("111")); - verify_invalid_lookup(f1.api->lookup("222")); + verify_valid_lookup(f1.api->lookup("111", f1.api->get_dictionary_snapshot())); + verify_invalid_lookup(f1.api->lookup("222", f1.api->get_dictionary_snapshot())); } TEST_F("require string lookup works correctly", StringFixture) { - verify_valid_lookup(f1.api->lookup("foo")); - verify_invalid_lookup(f1.api->lookup("bar")); + verify_valid_lookup(f1.api->lookup("foo", f1.api->get_dictionary_snapshot())); + verify_invalid_lookup(f1.api->lookup("bar", f1.api->get_dictionary_snapshot())); } void verify_posting(const IDocumentWeightAttribute &api, const char *term) { - auto result = api.lookup(term); + auto result = api.lookup(term, api.get_dictionary_snapshot()); ASSERT_TRUE(result.posting_idx.valid()); std::vector<DocumentWeightIterator> itr_store; api.create(result.posting_idx, itr_store); @@ -168,6 +170,53 @@ TEST_F("require that string iterators are created correctly", StringFixture) { verify_posting(*f1.api, "foo"); } +TEST_F("require that dictionary snapshot works", LongFixture) +{ + auto read_guard = f1.attr->makeReadGuard(false); + auto dictionary_snapshot = f1.api->get_dictionary_snapshot(); + auto lookup1 = f1.api->lookup("111", dictionary_snapshot); + EXPECT_TRUE(lookup1.enum_idx.valid()); + f1.attr->clearDoc(1); + f1.attr->clearDoc(5); + f1.attr->clearDoc(7); + f1.attr->commit(); + auto lookup2 = f1.api->lookup("111", f1.api->get_dictionary_snapshot()); + EXPECT_FALSE(lookup2.enum_idx.valid()); + auto lookup3 = f1.api->lookup("111", dictionary_snapshot); + EXPECT_TRUE(lookup3.enum_idx.valid()); + EXPECT_EQUAL(lookup1.enum_idx.ref(), lookup3.enum_idx.ref()); +} + +TEST_F("require that collect_folded works for string", StringFixture) +{ + StringAttribute *attr = static_cast<StringAttribute *>(f1.attr.get()); + set_doc(attr, 2, "bar", 30); + attr->commit(); + set_doc(attr, 3, "FOO", 30); + attr->commit(); + auto dictionary_snapshot = f1.api->get_dictionary_snapshot(); + auto lookup1 = f1.api->lookup("foo", dictionary_snapshot); + std::vector<vespalib::string> folded; + std::function<void(vespalib::datastore::EntryRef)> save_folded = [&folded,attr](vespalib::datastore::EntryRef enum_idx) { folded.emplace_back(attr->getFromEnum(enum_idx.ref())); }; + f1.api->collect_folded(lookup1.enum_idx, dictionary_snapshot, save_folded); + std::vector<vespalib::string> expected_folded{"FOO", "foo"}; + EXPECT_EQUAL(expected_folded, folded); +} + +TEST_F("require that collect_folded works for integers", LongFixture) +{ + IntegerAttributeTemplate<int64_t> *attr = dynamic_cast<IntegerAttributeTemplate<int64_t> *>(f1.attr.get()); + set_doc(attr, 2, int64_t(112), 30); + attr->commit(); + auto dictionary_snapshot = f1.api->get_dictionary_snapshot(); + auto lookup1 = f1.api->lookup("111", dictionary_snapshot); + std::vector<int64_t> folded; + std::function<void(vespalib::datastore::EntryRef)> save_folded = [&folded,attr](vespalib::datastore::EntryRef enum_idx) { folded.emplace_back(attr->getFromEnum(enum_idx.ref())); }; + f1.api->collect_folded(lookup1.enum_idx, dictionary_snapshot, save_folded); + std::vector<int64_t> expected_folded{int64_t(111)}; + EXPECT_EQUAL(expected_folded, folded); +} + class Verifier : public search::test::SearchIteratorVerifier { public: Verifier(); @@ -176,7 +225,7 @@ public: (void) strict; const IDocumentWeightAttribute *api(_attr->asDocumentWeightAttribute()); ASSERT_TRUE(api != nullptr); - auto dict_entry = api->lookup("123"); + auto dict_entry = api->lookup("123", api->get_dictionary_snapshot()); ASSERT_TRUE(dict_entry.posting_idx.valid()); return std::make_unique<queryeval::DocumentWeightSearchIterator>(_tfmd, *api, dict_entry); } diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index cc6b8e0ce29..7a200a46ab2 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -25,7 +25,8 @@ std::string get_source_dir() { } std::string source_dir = get_source_dir(); std::string vespa_dir = source_dir + "/" + "../../../../.."; -std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx"; +std::string simple_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/simple.onnx"; +std::string dynamic_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/dynamic.onnx"; uint32_t default_docid = 1; @@ -97,4 +98,16 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); } +TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { + add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); + add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); + add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]"); + add_onnx("dynamic", dynamic_model); + compile(onnx_feature("dynamic")); + EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0)); + EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp index 9761b0da2d7..f2c02d02080 100644 --- a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp +++ b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp @@ -674,7 +674,7 @@ private: MatchParams match_params(_dummy_heap, _dummy_heap.getMinScore(), 1.0, 1); std::vector<IDocumentWeightAttribute::LookupResult> dict_entries; for (size_t i = 0; i < _num_children; ++i) { - dict_entries.push_back(_helper.dwa().lookup(vespalib::make_string("%zu", i).c_str())); + dict_entries.push_back(_helper.dwa().lookup(vespalib::make_string("%zu", i).c_str(), _helper.dwa().get_dictionary_snapshot())); } return create_wand(_use_dwa, _tfmd, match_params, _weights, dict_entries, _helper.dwa(), strict); } diff --git a/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt b/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt new file mode 100644 index 00000000000..14a70f25e3c --- /dev/null +++ b/searchlib/src/tests/tensor/direct_tensor_store/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_direct_tensor_store_test_app TEST + SOURCES + direct_tensor_store_test.cpp + DEPENDS + searchlib + GTest::GTest +) +vespa_add_test(NAME searchlib_direct_tensor_store_test_app COMMAND searchlib_direct_tensor_store_test_app) diff --git a/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp new file mode 100644 index 00000000000..1003e461676 --- /dev/null +++ b/searchlib/src/tests/tensor/direct_tensor_store/direct_tensor_store_test.cpp @@ -0,0 +1,89 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/searchlib/tensor/direct_tensor_store.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/eval/tensor/tensor.h> + +using namespace search::tensor; + +using vespalib::datastore::EntryRef; +using vespalib::eval::TensorSpec; +using vespalib::tensor::DefaultTensorEngine; +using vespalib::tensor::Tensor; + +vespalib::string tensor_spec("tensor(x{})"); + +Tensor::UP +make_tensor(const TensorSpec& spec) +{ + auto value = DefaultTensorEngine::ref().from_spec(spec); + auto* tensor = dynamic_cast<Tensor*>(value.get()); + assert(tensor != nullptr); + value.release(); + return Tensor::UP(tensor); +} + +Tensor::UP +make_tensor(double value) +{ + return make_tensor(TensorSpec(tensor_spec).add({{"x", "a"}}, value)); +} + +class DirectTensorStoreTest : public ::testing::Test { +public: + DirectTensorStore store; + + DirectTensorStoreTest() : store() {} + + virtual ~DirectTensorStoreTest() { + store.clearHoldLists(); + } + + void expect_tensor(const Tensor* exp, EntryRef ref) { + const auto* act = store.get_tensor(ref); + ASSERT_TRUE(act); + EXPECT_EQ(exp, act); + } +}; + +TEST_F(DirectTensorStoreTest, can_set_and_get_tensor) +{ + auto t = make_tensor(5); + auto* exp = t.get(); + auto ref = store.set_tensor(std::move(t)); + expect_tensor(exp, ref); +} + +TEST_F(DirectTensorStoreTest, invalid_ref_returns_nullptr) +{ + const auto* t = store.get_tensor(EntryRef()); + EXPECT_FALSE(t); +} + +TEST_F(DirectTensorStoreTest, hold_adds_entry_to_hold_list) +{ + auto ref = store.set_tensor(make_tensor(5)); + auto mem_1 = store.getMemoryUsage(); + store.holdTensor(ref); + auto mem_2 = store.getMemoryUsage(); + EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold()); +} + +TEST_F(DirectTensorStoreTest, move_allocates_new_entry_and_puts_old_entry_on_hold) +{ + auto t = make_tensor(5); + auto* exp = t.get(); + auto ref_1 = store.set_tensor(std::move(t)); + auto mem_1 = store.getMemoryUsage(); + + auto ref_2 = store.move(ref_1); + auto mem_2 = store.getMemoryUsage(); + EXPECT_NE(ref_1, ref_2); + expect_tensor(exp, ref_1); + expect_tensor(exp, ref_2); + EXPECT_GT(mem_2.allocatedBytesOnHold(), mem_1.allocatedBytesOnHold()); +} + +GTEST_MAIN_RUN_ALL_TESTS() + diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp index 9c896396de3..a5e0e1e2b6a 100644 --- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp @@ -1,9 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <sys/types.h> -#include <sys/stat.h> #include <fcntl.h> -#include <stdio.h> +#include <cstdio> #include <unistd.h> #include <chrono> #include <cstdlib> @@ -24,6 +22,7 @@ #include <vespa/vespalib/util/blockingthreadstackexecutor.h> #include <vespa/vespalib/util/generationhandler.h> #include <vespa/vespalib/util/lambdatask.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/log/log.h> LOG_SETUP("stress_hnsw_mt"); diff --git a/searchlib/src/tests/transactionlogstress/translogstress.cpp b/searchlib/src/tests/transactionlogstress/translogstress.cpp index 81a3006dbff..013ca81dcc9 100644 --- a/searchlib/src/tests/transactionlogstress/translogstress.cpp +++ b/searchlib/src/tests/transactionlogstress/translogstress.cpp @@ -8,7 +8,6 @@ #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/fastos/app.h> #include <iostream> -#include <stdexcept> #include <sstream> #include <thread> @@ -223,7 +222,6 @@ FeederThread::~FeederThread() = default; void FeederThread::commitPacket() { - _packet.close(); const vespalib::nbostream& stream = _packet.getHandle(); if (!_session->commit(ConstBufferRef(stream.data(), stream.size()))) { throw std::runtime_error(vespalib::make_string @@ -238,8 +236,9 @@ FeederThread::commitPacket() bool FeederThread::addEntry(const Packet::Entry & e) { - //LOG(info, "FeederThread: add %s", EntryPrinter::toStr(e).c_str()); - return _packet.add(e); + if (_packet.sizeBytes() > 0xf000) return false; + _packet.add(e); + return true; } void diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index b9e4bf565ef..4ab80ebce7d 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -322,6 +322,7 @@ private: std::vector<int32_t> _weights; std::vector<IDocumentWeightAttribute::LookupResult> _terms; const IDocumentWeightAttribute &_attr; + vespalib::datastore::EntryRef _dictionary_snapshot; public: DirectWeightedSetBlueprint(const FieldSpec &field, const IDocumentWeightAttribute &attr, size_t size_hint) @@ -329,7 +330,8 @@ public: _estimate(), _weights(), _terms(), - _attr(attr) + _attr(attr), + _dictionary_snapshot(_attr.get_dictionary_snapshot()) { set_allow_termwise_eval(true); _weights.reserve(size_hint); @@ -337,7 +339,7 @@ public: } void addTerm(const vespalib::string &term, int32_t weight) { - IDocumentWeightAttribute::LookupResult result = _attr.lookup(term); + IDocumentWeightAttribute::LookupResult result = _attr.lookup(term, _dictionary_snapshot); HitEstimate childEst(result.posting_size, (result.posting_size == 0)); if (!childEst.empty) { if (_estimate.empty) { @@ -394,6 +396,7 @@ private: std::vector<int32_t> _weights; std::vector<IDocumentWeightAttribute::LookupResult> _terms; const IDocumentWeightAttribute &_attr; + vespalib::datastore::EntryRef _dictionary_snapshot; public: DirectWandBlueprint(const FieldSpec &field, const IDocumentWeightAttribute &attr, uint32_t scoresToTrack, @@ -406,14 +409,16 @@ public: _scoresAdjustFrequency(queryeval::DEFAULT_PARALLEL_WAND_SCORES_ADJUST_FREQUENCY), _weights(), _terms(), - _attr(attr) + _attr(attr), + _dictionary_snapshot(_attr.get_dictionary_snapshot()) + { _weights.reserve(size_hint); _terms.reserve(size_hint); } void addTerm(const vespalib::string &term, int32_t weight) { - IDocumentWeightAttribute::LookupResult result = _attr.lookup(term); + IDocumentWeightAttribute::LookupResult result = _attr.lookup(term, _dictionary_snapshot); HitEstimate childEst(result.posting_size, (result.posting_size == 0)); if (!childEst.empty) { if (_estimate.empty) { @@ -464,6 +469,7 @@ class DirectAttributeBlueprint : public queryeval::SimpleLeafBlueprint private: vespalib::string _attrName; const IDocumentWeightAttribute &_attr; + vespalib::datastore::EntryRef _dictionary_snapshot; IDocumentWeightAttribute::LookupResult _dict_entry; public: @@ -472,7 +478,8 @@ public: : SimpleLeafBlueprint(field), _attrName(name), _attr(attr), - _dict_entry(_attr.lookup(term)) + _dictionary_snapshot(_attr.get_dictionary_snapshot()), + _dict_entry(_attr.lookup(term, _dictionary_snapshot)) { setEstimate(HitEstimate(_dict_entry.posting_size, (_dict_entry.posting_size == 0))); } diff --git a/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h b/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h index ed184a7370e..e0cfd446da5 100644 --- a/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h +++ b/searchlib/src/vespa/searchlib/attribute/i_document_weight_attribute.h @@ -4,6 +4,8 @@ #include "postinglisttraits.h" +#include <functional> + namespace search { namespace query { class Node; } @@ -17,11 +19,18 @@ struct IDocumentWeightAttribute const uint32_t posting_size; const int32_t min_weight; const int32_t max_weight; - LookupResult() : posting_idx(), posting_size(0), min_weight(0), max_weight(0) {} - LookupResult(vespalib::datastore::EntryRef posting_idx_in, uint32_t posting_size_in, int32_t min_weight_in, int32_t max_weight_in) - : posting_idx(posting_idx_in), posting_size(posting_size_in), min_weight(min_weight_in), max_weight(max_weight_in) {} + const vespalib::datastore::EntryRef enum_idx; + LookupResult() : posting_idx(), posting_size(0), min_weight(0), max_weight(0), enum_idx() {} + LookupResult(vespalib::datastore::EntryRef posting_idx_in, uint32_t posting_size_in, int32_t min_weight_in, int32_t max_weight_in, vespalib::datastore::EntryRef enum_idx_in) + : posting_idx(posting_idx_in), posting_size(posting_size_in), min_weight(min_weight_in), max_weight(max_weight_in), enum_idx(enum_idx_in) {} }; - virtual LookupResult lookup(const vespalib::string &term) const = 0; + virtual vespalib::datastore::EntryRef get_dictionary_snapshot() const = 0; + virtual LookupResult lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const = 0; + /* + * Collect enum indexes (via callback) where folded + * (e.g. lowercased) value equals the folded value for enum_idx. + */ + virtual void collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const = 0; virtual void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const = 0; virtual DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const = 0; virtual ~IDocumentWeightAttribute() {} diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h index fa962f8d469..c09366cdaea 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.h @@ -32,12 +32,14 @@ public: using EnumStoreBatchUpdater = typename EnumStore::BatchUpdater; private: - struct DocumentWeightAttributeAdapter : IDocumentWeightAttribute { + struct DocumentWeightAttributeAdapter final : IDocumentWeightAttribute { const MultiValueNumericPostingAttribute &self; DocumentWeightAttributeAdapter(const MultiValueNumericPostingAttribute &self_in) : self(self_in) {} - virtual LookupResult lookup(const vespalib::string &term) const override final; - virtual void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override final; - virtual DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override final; + vespalib::datastore::EntryRef get_dictionary_snapshot() const override; + LookupResult lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const override; + void collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const override; + void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override; + DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override; }; DocumentWeightAttributeAdapter _document_weight_attribute_adapter; diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp index 283c3da00b1..1fd1cd09bea 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp @@ -83,11 +83,18 @@ MultiValueNumericPostingAttribute<B, M>::getSearch(QueryTermSimpleUP qTerm, } template <typename B, typename M> +vespalib::datastore::EntryRef +MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::get_dictionary_snapshot() const +{ + const Dictionary &dictionary = self._enumStore.get_posting_dictionary(); + return dictionary.getFrozenView().getRoot(); +} + +template <typename B, typename M> IDocumentWeightAttribute::LookupResult -MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term) const +MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const { const Dictionary &dictionary = self._enumStore.get_posting_dictionary(); - const FrozenDictionary frozenDictionary(dictionary.getFrozenView()); DictionaryConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator()); char *end = nullptr; @@ -95,13 +102,13 @@ MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup( if (*end == '\0') { auto comp = self._enumStore.make_comparator(int_term); - dictItr.lower_bound(frozenDictionary.getRoot(), EnumIndex(), comp); + dictItr.lower_bound(dictionary_snapshot, EnumIndex(), comp); if (dictItr.valid() && !comp(EnumIndex(), dictItr.getKey())) { vespalib::datastore::EntryRef pidx(dictItr.getData()); if (pidx.valid()) { const PostingList &plist = self.getPostingList(); auto minmax = plist.getAggregated(pidx); - return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax()); + return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax(), dictItr.getKey()); } } } @@ -110,6 +117,14 @@ MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup( template <typename B, typename M> void +MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback)const +{ + (void) dictionary_snapshot; + callback(enum_idx); +} + +template <typename B, typename M> +void MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const { assert(idx.valid()); diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h index c755c5cb649..142879f4578 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.h @@ -30,12 +30,14 @@ public: using EnumStoreBatchUpdater = typename EnumStore::BatchUpdater; private: - struct DocumentWeightAttributeAdapter : IDocumentWeightAttribute { + struct DocumentWeightAttributeAdapter final : IDocumentWeightAttribute { const MultiValueStringPostingAttributeT &self; DocumentWeightAttributeAdapter(const MultiValueStringPostingAttributeT &self_in) : self(self_in) {} - virtual LookupResult lookup(const vespalib::string &term) const override final; - virtual void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override final; - virtual DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override final; + vespalib::datastore::EntryRef get_dictionary_snapshot() const override; + LookupResult lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const override; + void collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const override; + void create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const override; + DocumentWeightIterator create(vespalib::datastore::EntryRef idx) const override; }; DocumentWeightAttributeAdapter _document_weight_attribute_adapter; diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp index 7bc62169b3c..4263eacfa52 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp @@ -99,21 +99,28 @@ MultiValueStringPostingAttributeT<B, T>::getSearch(QueryTermSimpleUP qTerm, template <typename B, typename T> +vespalib::datastore::EntryRef +MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::get_dictionary_snapshot() const +{ + const Dictionary &dictionary = self._enumStore.get_posting_dictionary(); + return dictionary.getFrozenView().getRoot(); +} + +template <typename B, typename T> IDocumentWeightAttribute::LookupResult -MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term) const +MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup(const vespalib::string &term, vespalib::datastore::EntryRef dictionary_snapshot) const { const Dictionary &dictionary = self._enumStore.get_posting_dictionary(); - const FrozenDictionary frozenDictionary(dictionary.getFrozenView()); DictionaryConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator()); auto comp = self._enumStore.make_folded_comparator(term.c_str()); - dictItr.lower_bound(frozenDictionary.getRoot(), EnumIndex(), comp); + dictItr.lower_bound(dictionary_snapshot, EnumIndex(), comp); if (dictItr.valid() && !comp(EnumIndex(), dictItr.getKey())) { vespalib::datastore::EntryRef pidx(dictItr.getData()); if (pidx.valid()) { const PostingList &plist = self.getPostingList(); auto minmax = plist.getAggregated(pidx); - return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax()); + return LookupResult(pidx, plist.frozenSize(pidx), minmax.getMin(), minmax.getMax(), dictItr.getKey()); } } return LookupResult(); @@ -121,6 +128,20 @@ MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup( template <typename B, typename T> void +MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::collect_folded(vespalib::datastore::EntryRef enum_idx, vespalib::datastore::EntryRef dictionary_snapshot, const std::function<void(vespalib::datastore::EntryRef)>& callback) const +{ + const Dictionary &dictionary = self._enumStore.get_posting_dictionary(); + DictionaryConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator()); + auto comp = self._enumStore.make_folded_comparator(); + dictItr.lower_bound(dictionary_snapshot, enum_idx, comp); + while (dictItr.valid() && !comp(enum_idx, dictItr.getKey())) { + callback(dictItr.getKey()); + ++dictItr; + } +} + +template <typename B, typename T> +void MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::create(vespalib::datastore::EntryRef idx, std::vector<DocumentWeightIterator> &dst) const { assert(idx.valid()); diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp index cd53253ad0a..53f88246a0a 100644 --- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp @@ -112,13 +112,14 @@ ReferenceAttribute::buildReverseMapping(EntryRef newRef, const std::vector<Rever void ReferenceAttribute::buildReverseMapping() { - std::vector<std::pair<EntryRef, uint32_t>> indices; + using EntryPair = std::pair<EntryRef, uint32_t>; + std::vector<EntryPair, vespalib::allocator_large<EntryPair>> indices; uint32_t numDocs = _indices.size(); indices.reserve(numDocs); for (uint32_t lid = 0; lid < numDocs; ++lid) { EntryRef ref = _indices[lid]; if (ref.valid()) { - indices.push_back(std::make_pair(ref, lid)); + indices.emplace_back(ref, lid); } } std::sort(indices.begin(), indices.end()); @@ -200,8 +201,7 @@ ReferenceAttribute::onUpdateStat() std::unique_ptr<AttributeSaver> ReferenceAttribute::onInitSave(vespalib::stringref fileName) { - vespalib::GenerationHandler::Guard guard(this->getGenerationHandler(). - takeGuard()); + vespalib::GenerationHandler::Guard guard(this->getGenerationHandler().takeGuard()); return std::make_unique<ReferenceAttributeSaver> (std::move(guard), createAttributeHeader(fileName), @@ -221,8 +221,7 @@ ReferenceAttribute::onLoad() assert(attrReader.getEnumerated()); assert(!attrReader.hasIdx()); size_t numDocs(0); - uint64_t numValues(0); - numValues = attrReader.getEnumCount(); + uint64_t numValues = attrReader.getEnumCount(); numDocs = numValues; auto udatBuffer = attribute::LoadUtils::loadUDAT(*this); const GenericHeader &header = udatBuffer->getHeader(); @@ -367,13 +366,13 @@ class TargetLidPopulator : public IGidToLidMapperVisitor { ReferenceAttribute &_attr; public: - TargetLidPopulator(ReferenceAttribute &attr) + explicit TargetLidPopulator(ReferenceAttribute &attr) : IGidToLidMapperVisitor(), _attr(attr) { } - virtual ~TargetLidPopulator() override { } - virtual void visit(const document::GlobalId &gid, uint32_t lid) const override { + ~TargetLidPopulator() override = default; + void visit(const document::GlobalId &gid, uint32_t lid) const override { _attr.notifyReferencedPutNoCommit(gid, lid); } }; diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.h b/searchlib/src/vespa/searchlib/attribute/reference_attribute.h index 706abc53819..1c138abf989 100644 --- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.h +++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.h @@ -7,6 +7,7 @@ #include "reference_mappings.h" #include <vespa/vespalib/datastore/unique_store.h> #include <vespa/vespalib/util/rcuvector.h> +#include <vespa/vespalib/stllike/allocator.h> namespace search { class IGidToLidMapperFactory; } @@ -28,7 +29,7 @@ public: using GlobalId = document::GlobalId; using ReferenceStore = vespalib::datastore::UniqueStore<Reference>; using ReferenceStoreIndices = vespalib::RcuVectorBase<EntryRef>; - using IndicesCopyVector = vespalib::Array<EntryRef>; + using IndicesCopyVector = std::vector<EntryRef, vespalib::allocator_large<EntryRef>>; // Class used to map from target lid to source lids using ReverseMapping = vespalib::btree::BTreeStore<uint32_t, vespalib::btree::BTreeNoLeafData, vespalib::btree::NoAggregated, @@ -45,14 +46,14 @@ private: std::shared_ptr<IGidToLidMapperFactory> _gidToLidMapperFactory; ReferenceMappings _referenceMappings; - virtual void onAddDocs(DocId docIdLimit) override; - virtual void removeOldGenerations(generation_t firstUsed) override; - virtual void onGenerationChange(generation_t generation) override; - virtual void onCommit() override; - virtual void onUpdateStat() override; - virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; - virtual bool onLoad() override; - virtual uint64_t getUniqueValueCount() const override; + void onAddDocs(DocId docIdLimit) override; + void removeOldGenerations(generation_t firstUsed) override; + void onGenerationChange(generation_t generation) override; + void onCommit() override; + void onUpdateStat() override; + std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; + bool onLoad() override; + uint64_t getUniqueValueCount() const override; bool considerCompact(const CompactionStrategy &compactionStrategy); void compactWorst(); diff --git a/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h b/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h index 00645810d62..e8341901585 100644 --- a/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h +++ b/searchlib/src/vespa/searchlib/diskindex/bitvectorfile.h @@ -6,6 +6,7 @@ #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/common/tunefileinfo.h> #include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/stllike/allocator.h> #include "bitvectoridxfile.h" namespace search::diskindex { @@ -49,7 +50,7 @@ public: class BitVectorCandidate { private: - std::vector<uint32_t> _array; + std::vector<uint32_t, vespalib::allocator_large<uint32_t>> _array; uint64_t _numDocs; uint32_t _bitVectorLimit; BitVector::UP _bv; diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index f6d5c37b61d..7433021b9b6 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -3,7 +3,6 @@ #include "onnx_feature.h" #include <vespa/searchlib/fef/properties.h> #include <vespa/searchlib/fef/featureexecutor.h> -#include <vespa/eval/tensor/dense/onnx_wrapper.h> #include <vespa/eval/tensor/dense/dense_tensor_view.h> #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> #include <vespa/vespalib/util/stringfmt.h> @@ -23,7 +22,7 @@ using vespalib::eval::ValueType; using vespalib::make_string_short::fmt; using vespalib::tensor::DenseTensorView; using vespalib::tensor::MutableDenseTensorView; -using vespalib::tensor::OnnxWrapper; +using vespalib::tensor::Onnx; namespace search::features { @@ -33,37 +32,28 @@ namespace search::features { class OnnxFeatureExecutor : public FeatureExecutor { private: - const OnnxWrapper &_model; - OnnxWrapper::Params _params; - OnnxWrapper::Result _result; - std::vector<MutableDenseTensorView> _views; - + Onnx::EvalContext _eval_context; public: - OnnxFeatureExecutor(const OnnxWrapper &model) - : _model(model), _params(), _result(OnnxWrapper::Result::make_empty()), _views() - { - _views.reserve(_model.outputs().size()); - for (const auto &output: _model.outputs()) { - _views.emplace_back(output.make_compatible_type()); - } - } + OnnxFeatureExecutor(const Onnx &model, const Onnx::WireInfo &wire_info) + : _eval_context(model, wire_info) {} bool isPure() override { return true; } - void execute(uint32_t) override { - _params = OnnxWrapper::Params(); - for (size_t i = 0; i < _model.inputs().size(); ++i) { - _params.bind(i, static_cast<const DenseTensorView&>(inputs().get_object(i).get())); + void handle_bind_outputs(vespalib::ArrayRef<fef::NumberOrObject>) override { + for (size_t i = 0; i < _eval_context.num_results(); ++i) { + outputs().set_object(i, _eval_context.get_result(i)); } - _result = _model.eval(_params); - for (size_t i = 0; i < _model.outputs().size(); ++i) { - _result.get(i, _views[i]); - outputs().set_object(i, _views[i]); + } + void execute(uint32_t) override { + for (size_t i = 0; i < _eval_context.num_params(); ++i) { + _eval_context.bind_param(i, inputs().get_object(i).get()); } + _eval_context.eval(); } }; OnnxBlueprint::OnnxBlueprint() : Blueprint("onnxModel"), - _model(nullptr) + _model(nullptr), + _wire_info() { } @@ -74,24 +64,25 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, const ParameterList ¶ms) { auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) - ? OnnxWrapper::Optimize::DISABLE - : OnnxWrapper::Optimize::ENABLE; + ? Onnx::Optimize::DISABLE + : Onnx::Optimize::ENABLE; // Note: Using the fileref property with the model name as // fallback to get a file name. This needs to be replaced with an // actual file reference obtained through config when available. vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue()); try { - _model = std::make_unique<OnnxWrapper>(file_name, optimize); + _model = std::make_unique<Onnx>(file_name, optimize); } catch (std::exception &ex) { return fail("Model setup failed: %s", ex.what()); } + Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { const auto &model_input = _model->inputs()[i]; if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", model_input.name.c_str()), AcceptInput::OBJECT)) { const FeatureType &feature_input = maybe_input.value(); assert(feature_input.is_object()); - if (!model_input.is_compatible(feature_input.type())) { + if (!planner.bind_input_type(feature_input.type(), model_input)) { return fail("incompatible type for input '%s': %s -> %s", model_input.name.c_str(), feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str()); } @@ -99,13 +90,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, } for (size_t i = 0; i < _model->outputs().size(); ++i) { const auto &model_output = _model->outputs()[i]; - ValueType output_type = model_output.make_compatible_type(); + ValueType output_type = planner.make_output_type(model_output); if (output_type.is_error()) { return fail("unable to make compatible type for output '%s': %s -> error", model_output.name.c_str(), model_output.type_as_string().c_str()); } describeOutput(model_output.name, "output from onnx model", FeatureType::object(output_type)); } + _wire_info = planner.get_wire_info(*_model); return true; } @@ -113,7 +105,7 @@ FeatureExecutor & OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const { assert(_model); - return stash.create<OnnxFeatureExecutor>(*_model); + return stash.create<OnnxFeatureExecutor>(*_model, _wire_info); } } diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h index eb6e368ffbd..19c6338d2ee 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.h +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h @@ -3,8 +3,7 @@ #pragma once #include <vespa/searchlib/fef/blueprint.h> - -namespace vespalib::tensor { class OnnxWrapper; } +#include <vespa/eval/tensor/dense/onnx_wrapper.h> namespace search::features { @@ -13,7 +12,9 @@ namespace search::features { **/ class OnnxBlueprint : public fef::Blueprint { private: - std::unique_ptr<vespalib::tensor::OnnxWrapper> _model; + using Onnx = vespalib::tensor::Onnx; + std::unique_ptr<Onnx> _model; + Onnx::WireInfo _wire_info; public: OnnxBlueprint(); ~OnnxBlueprint() override; diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 35615b255c0..851400e4806 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -5,6 +5,8 @@ vespa_add_library(searchlib_tensor OBJECT dense_tensor_attribute.cpp dense_tensor_attribute_saver.cpp dense_tensor_store.cpp + direct_tensor_attribute.cpp + direct_tensor_store.cpp distance_function_factory.cpp distance_functions.cpp generic_tensor_attribute.cpp @@ -20,6 +22,7 @@ vespa_add_library(searchlib_tensor OBJECT nearest_neighbor_index.cpp nearest_neighbor_index_saver.cpp tensor_attribute.cpp + tensor_deserialize.cpp tensor_store.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h new file mode 100644 index 00000000000..7c34b60e93d --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/blob_sequence_reader.h @@ -0,0 +1,26 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/fastlib/io/bufferedfile.h> +#include <vespa/searchlib/attribute/readerbase.h> +#include <vespa/searchlib/util/fileutil.h> + +namespace search::tensor { + +/** + * Utility for reading an attribute data file where + * the format is a sequence of blobs (size, byte[size]). + **/ +class BlobSequenceReader : public ReaderBase +{ +private: + FileReader<uint32_t> _sizeReader; +public: + BlobSequenceReader(AttributeVector &attr) + : ReaderBase(attr), + _sizeReader(*_datFile) + { } + uint32_t getNextSize() { return _sizeReader.readHostOrder(); } + void readBlob(void *buf, size_t len) { _datFile->ReadBuf(buf, len); } +}; + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp index 76533839de7..37a042d4e7f 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp @@ -30,26 +30,26 @@ namespace { constexpr uint32_t DENSE_TENSOR_ATTRIBUTE_VERSION = 1; const vespalib::string tensorTypeTag("tensortype"); -class TensorReader : public ReaderBase +class BlobSequenceReader : public ReaderBase { private: static constexpr uint8_t tensorIsNotPresent = 0; static constexpr uint8_t tensorIsPresent = 1; public: - TensorReader(AttributeVector &attr); - ~TensorReader(); + BlobSequenceReader(AttributeVector &attr); + ~BlobSequenceReader(); bool is_present(); void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); } }; -TensorReader::TensorReader(AttributeVector &attr) +BlobSequenceReader::BlobSequenceReader(AttributeVector &attr) : ReaderBase(attr) { } -TensorReader::~TensorReader() = default; +BlobSequenceReader::~BlobSequenceReader() = default; bool -TensorReader::is_present() { +BlobSequenceReader::is_present() { unsigned char detect; _datFile->ReadBuf(&detect, sizeof(detect)); if (detect == tensorIsNotPresent) { @@ -190,7 +190,7 @@ DenseTensorAttribute::getTensor(DocId docId, MutableDenseTensorView &tensor) con bool DenseTensorAttribute::onLoad() { - TensorReader tensorReader(*this); + BlobSequenceReader tensorReader(*this); if (!tensorReader.hasData()) { return false; } diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp new file mode 100644 index 00000000000..f53d42442ba --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.cpp @@ -0,0 +1,52 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "direct_tensor_attribute.h" + +#include <vespa/eval/tensor/tensor.h> +#include <vespa/fastlib/io/bufferedfile.h> +#include <vespa/searchlib/attribute/readerbase.h> +#include <vespa/searchlib/util/fileutil.h> +#include <vespa/vespalib/util/array.h> + +#include "blob_sequence_reader.h" +#include "tensor_deserialize.h" + +using vespalib::tensor::Tensor; + +namespace search::tensor { + +constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0; + +bool +DirectTensorAttribute::onLoad() +{ + BlobSequenceReader tensorReader(*this); + if (!tensorReader.hasData()) { + return false; + } + setCreateSerialNum(tensorReader.getCreateSerialNum()); + assert(tensorReader.getVersion() == TENSOR_ATTRIBUTE_VERSION); + uint32_t numDocs = tensorReader.getDocIdLimit(); + vespalib::Array<char> buffer(1024); + for (uint32_t lid = 0; lid < numDocs; ++lid) { + uint32_t tensorSize = tensorReader.getNextSize(); + if (tensorSize != 0) { + if (tensorSize > buffer.size()) { + buffer.resize(tensorSize + 1024); + } + tensorReader.readBlob(&buffer[0], tensorSize); + setTensor(lid, deserialize_tensor(&buffer[0], tensorSize)); + } + } + setNumDocs(numDocs); + setCommittedDocIdLimit(numDocs); + return true; +} + +void +DirectTensorAttribute::setTensor(DocId , std::unique_ptr<Tensor> ) +{ + // XXX missing +} + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h new file mode 100644 index 00000000000..ae3cb222dba --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_attribute.h @@ -0,0 +1,25 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "tensor_attribute.h" + +namespace search::tensor { + +class DirectTensorAttribute : public TensorAttribute +{ + // XXX must have some sort of TensorStore here +public: + DirectTensorAttribute(vespalib::stringref baseFileName, const Config &cfg); + virtual ~DirectTensorAttribute(); + virtual void setTensor(DocId docId, const Tensor &tensor) override; + virtual std::unique_ptr<Tensor> getTensor(DocId docId) const override; + virtual void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; + virtual bool onLoad() override; + virtual std::unique_ptr<AttributeSaver> onInitSave(vespalib::stringref fileName) override; + virtual void compactWorst() override; + + void setTensor(DocId docId, std::unique_ptr<Tensor> tensor); +}; + +} // namespace search::tensor diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp new file mode 100644 index 00000000000..4e79315d4b1 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp @@ -0,0 +1,62 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "direct_tensor_store.h" +#include <vespa/eval/tensor/tensor.h> +#include <vespa/vespalib/datastore/datastore.hpp> + +using vespalib::datastore::EntryRef; + +namespace search::tensor { + +constexpr size_t MIN_BUFFER_ARRAYS = 8192; + +DirectTensorStore::DirectTensorStore() + : TensorStore(_concrete_store), + _concrete_store(MIN_BUFFER_ARRAYS) +{ +} + +const vespalib::tensor::Tensor* +DirectTensorStore::get_tensor(EntryRef ref) const +{ + if (!ref.valid()) { + return nullptr; + } + auto entry = _concrete_store.getEntry(ref); + assert(entry); + return entry.get(); +} + +EntryRef +DirectTensorStore::set_tensor(std::unique_ptr<Tensor> tensor) +{ + assert(tensor); + // TODO: Account for heap allocated memory + return _concrete_store.addEntry(TensorSP(tensor.release())); +} + +void +DirectTensorStore::holdTensor(EntryRef ref) +{ + if (!ref.valid()) { + return; + } + // TODO: Account for heap allocated memory + _concrete_store.holdElem(ref, 1); +} + +EntryRef +DirectTensorStore::move(EntryRef ref) +{ + if (!ref.valid()) { + return EntryRef(); + } + auto old_tensor = _concrete_store.getEntry(ref); + assert(old_tensor); + // TODO: Account for heap allocated memory (regular + hold) + auto new_ref = _concrete_store.addEntry(old_tensor); + _concrete_store.holdElem(ref, 1); + return new_ref; +} + +} diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h new file mode 100644 index 00000000000..1073780a313 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.h @@ -0,0 +1,34 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "tensor_store.h" +#include <memory> + +namespace search::tensor { + +/** + * Class for storing heap allocated tensors, referenced by EntryRefs. + * + * Shared pointers to the tensors are stored in an underlying data store. + */ +class DirectTensorStore : public TensorStore { +private: + // Note: Must use SP (instead of UP) because of fallbackCopy() and initializeReservedElements() in BufferType, + // and implementation of move(). + using TensorSP = std::shared_ptr<Tensor>; + using DataStoreType = vespalib::datastore::DataStore<TensorSP>; + + DataStoreType _concrete_store; + +public: + DirectTensorStore(); + + const Tensor* get_tensor(EntryRef ref) const; + EntryRef set_tensor(std::unique_ptr<Tensor> tensor); + + void holdTensor(EntryRef ref) override; + EntryRef move(EntryRef ref) override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp index aac199ae818..6864fb52120 100644 --- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_attribute.cpp @@ -3,6 +3,7 @@ #include "generic_tensor_attribute.h" #include "generic_tensor_attribute_saver.h" #include "tensor_attribute.hpp" +#include "blob_sequence_reader.h" #include <vespa/eval/tensor/tensor.h> #include <vespa/fastlib/io/bufferedfile.h> #include <vespa/searchlib/attribute/readerbase.h> @@ -18,19 +19,6 @@ namespace { constexpr uint32_t TENSOR_ATTRIBUTE_VERSION = 0; -class TensorReader : public ReaderBase -{ -private: - FileReader<uint32_t> _tensorSizeReader; -public: - TensorReader(AttributeVector &attr) - : ReaderBase(attr), - _tensorSizeReader(*_datFile) - { } - uint32_t getNextTensorSize() { return _tensorSizeReader.readHostOrder(); } - void readTensor(void *buf, size_t len) { _datFile->ReadBuf(buf, len); } -}; - } GenericTensorAttribute::GenericTensorAttribute(stringref name, const Config &cfg) @@ -76,7 +64,7 @@ GenericTensorAttribute::getTensor(DocId, vespalib::tensor::MutableDenseTensorVie bool GenericTensorAttribute::onLoad() { - TensorReader tensorReader(*this); + BlobSequenceReader tensorReader(*this); if (!tensorReader.hasData()) { return false; } @@ -86,10 +74,10 @@ GenericTensorAttribute::onLoad() _refVector.reset(); _refVector.unsafe_reserve(numDocs); for (uint32_t lid = 0; lid < numDocs; ++lid) { - uint32_t tensorSize = tensorReader.getNextTensorSize(); + uint32_t tensorSize = tensorReader.getNextSize(); auto raw = _genericTensorStore.allocRawBuffer(tensorSize); if (tensorSize != 0) { - tensorReader.readTensor(raw.data, tensorSize); + tensorReader.readBlob(raw.data, tensorSize); } _refVector.push_back(raw.ref); } diff --git a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp index f19bef3ff21..8c695c32719 100644 --- a/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/generic_tensor_store.cpp @@ -1,15 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "generic_tensor_store.h" +#include "tensor_deserialize.h" #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/serialization/typed_binary_format.h> -#include <vespa/document/util/serializableexceptions.h> #include <vespa/vespalib/datastore/datastore.hpp> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/macro.h> -using document::DeserializeException; using vespalib::datastore::Handle; using vespalib::tensor::Tensor; using vespalib::tensor::TypedBinaryFormat; @@ -95,14 +94,7 @@ GenericTensorStore::getTensor(EntryRef ref) const if (raw.second == 0u) { return std::unique_ptr<Tensor>(); } - vespalib::nbostream wrapStream(raw.first, raw.second); - auto tensor = TypedBinaryFormat::deserialize(wrapStream); - if (wrapStream.size() != 0) { - throw DeserializeException("Leftover bytes deserializing " - "tensor attribute value.", - VESPA_STRLOC); - } - return tensor; + return deserialize_tensor(raw.first, raw.second); } TensorStore::EntryRef diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp new file mode 100644 index 00000000000..7998fba5941 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.cpp @@ -0,0 +1,24 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/document/util/serializableexceptions.h> +#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/eval/tensor/tensor.h> +#include <vespa/vespalib/objects/nbostream.h> + +using document::DeserializeException; +using vespalib::tensor::Tensor; +using vespalib::tensor::TypedBinaryFormat; + +namespace search::tensor { + +std::unique_ptr<Tensor> deserialize_tensor(const void *data, size_t size) +{ + vespalib::nbostream wrapStream(data, size); + auto tensor = TypedBinaryFormat::deserialize(wrapStream); + if (wrapStream.size() != 0) { + throw DeserializeException("Leftover bytes deserializing tensor attribute value.", VESPA_STRLOC); + } + return tensor; +} + +} // namespace diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h new file mode 100644 index 00000000000..f1dfa1ca173 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/tensor_deserialize.h @@ -0,0 +1,10 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/tensor/tensor.h> + +namespace search::tensor { + +extern std::unique_ptr<vespalib::tensor::Tensor> +deserialize_tensor(const void *data, size_t size); + +} // namespace diff --git a/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h b/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h index cabb108d2e1..2ef03ba97ef 100644 --- a/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h +++ b/searchlib/src/vespa/searchlib/test/weightedchildrenverifiers.h @@ -63,7 +63,7 @@ public: (void) strict; std::vector<DocumentWeightIterator> children; for (size_t i = 0; i < _num_children; ++i) { - auto dict_entry = _helper.dwa().lookup(vespalib::make_string("%zu", i).c_str()); + auto dict_entry = _helper.dwa().lookup(vespalib::make_string("%zu", i).c_str(), _helper.dwa().get_dictionary_snapshot()); _helper.dwa().create(dict_entry.posting_idx, children); } return create(std::move(children)); diff --git a/searchlib/src/vespa/searchlib/transactionlog/common.cpp b/searchlib/src/vespa/searchlib/transactionlog/common.cpp index a5eaa61af12..ee7d265427c 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/common.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/common.cpp @@ -22,7 +22,8 @@ int makeDirectory(const char * dir) return retval; } -int64_t SerialNumRange::cmp(const SerialNumRange & b) const +int64_t +SerialNumRange::cmp(const SerialNumRange & b) const { int64_t diff(0); if ( ! (contains(b) || b.contains(*this)) ) { @@ -71,7 +72,8 @@ nbostream & Packet::Entry::deserialize(nbostream & os) return os; } -nbostream & Packet::Entry::serialize(nbostream & os) const +nbostream & +Packet::Entry::serialize(nbostream & os) const { os << _unique << _type << static_cast<uint32_t>(_data.size()); os.write(_data.c_str(), _data.size()); diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp index 5a64d829183..5e7cfc74199 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp @@ -131,7 +131,7 @@ Domain::begin(const LockGuard & guard) const assert(guard.locks(_lock)); SerialNum s(0); if ( ! _parts.empty() ) { - s = _parts.begin()->second->range().from(); + s = _parts.cbegin()->second->range().from(); } return s; } @@ -149,7 +149,7 @@ Domain::end(const LockGuard & guard) const assert(guard.locks(_lock)); SerialNum s(0); if ( ! _parts.empty() ) { - s = _parts.rbegin()->second->range().to(); + s = _parts.crbegin()->second->range().to(); } return s; } @@ -203,7 +203,8 @@ Domain::triggerSyncNow() } } -DomainPart::SP Domain::findPart(SerialNum s) +DomainPart::SP +Domain::findPart(SerialNum s) { LockGuard guard(_lock); DomainPartList::iterator it(_parts.upper_bound(s)); @@ -220,12 +221,14 @@ DomainPart::SP Domain::findPart(SerialNum s) return DomainPart::SP(); } -uint64_t Domain::size() const +uint64_t +Domain::size() const { return size(LockGuard(_lock)); } -uint64_t Domain::size(const LockGuard & guard) const +uint64_t +Domain::size(const LockGuard & guard) const { (void) guard; assert(guard.locks(_lock)); @@ -236,7 +239,8 @@ uint64_t Domain::size(const LockGuard & guard) const return sz; } -SerialNum Domain::findOldestActiveVisit() const +SerialNum +Domain::findOldestActiveVisit() const { SerialNum oldestActive(std::numeric_limits<SerialNum>::max()); LockGuard guard(_sessionLock); @@ -249,7 +253,8 @@ SerialNum Domain::findOldestActiveVisit() const return oldestActive; } -void Domain::cleanSessions() +void +Domain::cleanSessions() { if ( _sessions.empty()) { return; @@ -269,7 +274,8 @@ void Domain::cleanSessions() namespace { -void waitPendingSync(vespalib::Monitor &syncMonitor, bool &pendingSync) +void +waitPendingSync(vespalib::Monitor &syncMonitor, bool &pendingSync) { MonitorGuard guard(syncMonitor); while (pendingSync) { @@ -302,7 +308,8 @@ void Domain::commit(const Packet & packet) cleanSessions(); } -bool Domain::erase(SerialNum to) +bool +Domain::erase(SerialNum to) { bool retval(true); /// Do not erase the last element @@ -321,8 +328,9 @@ bool Domain::erase(SerialNum to) return retval; } -int Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to, - std::unique_ptr<Session::Destination> dest) +int +Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to, + std::unique_ptr<Session::Destination> dest) { assert(this == domain.get()); cleanSessions(); @@ -334,7 +342,8 @@ int Domain::visit(const Domain::SP & domain, SerialNum from, SerialNum to, return id; } -int Domain::startSession(int sessionId) +int +Domain::startSession(int sessionId) { int retval(-1); LockGuard guard(_sessionLock); @@ -350,7 +359,8 @@ int Domain::startSession(int sessionId) return retval; } -int Domain::closeSession(int sessionId) +int +Domain::closeSession(int sessionId) { _commitExecutor.sync(); int retval(-1); diff --git a/searchlib/src/vespa/searchlib/transactionlog/session.cpp b/searchlib/src/vespa/searchlib/transactionlog/session.cpp index e703c32484f..dda840808ce 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/session.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/session.cpp @@ -31,7 +31,7 @@ Session::VisitTask::run() bool Session::visit(FastOS_FileInterface & file, DomainPart & dp) { - Packet packet; + Packet packet(size_t(-1)); bool more(false); if (dp.isClosed()) { more = dp.visit(file, _range, packet); diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp index a3528c4f615..caef792704a 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp @@ -18,6 +18,9 @@ using vespalib::make_string; using vespalib::stringref; using vespalib::IllegalArgumentException; using search::common::FileHeaderContext; +using std::make_shared; +using std::runtime_error; +using namespace std::chrono_literals; namespace search::transactionlog { @@ -31,10 +34,10 @@ class SyncHandler : public FNET_Task SerialNum _syncTo; public: - SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req,const Domain::SP &domain, + SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const Domain::SP &domain, const TransLogServer::Session::SP &session, SerialNum syncTo); - ~SyncHandler(); + ~SyncHandler() override; void PerformTask() override; }; @@ -157,17 +160,17 @@ bool TransLogServer::onStop() { LOG(info, "Stopping TLS"); - _reqQ.push(NULL); + _reqQ.push(nullptr); return true; } void TransLogServer::run() { - FRT_RPCRequest *req(NULL); + FRT_RPCRequest *req(nullptr); bool hasPacket(false); do { - for (req = NULL; (hasPacket = _reqQ.pop(req, 60000)) && (req != NULL); req = NULL) { + for (req = nullptr; (hasPacket = _reqQ.pop(req, 60000)) && (req != nullptr); req = nullptr) { bool immediate = true; if (strcmp(req->GetMethodName(), "domainSessionClose") == 0) { domainSessionClose(req); @@ -675,7 +678,7 @@ TransLogServer::finiSession(FRT_RPCRequest *req) { FNET_Connection *conn = req->GetConnection(); void *vctx = conn->GetContext()._value.VOIDP; - conn->GetContextPT()->_value.VOIDP = NULL; + conn->GetContextPT()->_value.VOIDP = nullptr; Session::SP *sessionspp = static_cast<Session::SP *>(vctx); delete sessionspp; } @@ -696,7 +699,7 @@ TransLogServer::domainSync(FRT_RPCRequest *req) Domain::SP domain(findDomain(domainName)); Session::SP session(getSession(req)); - if (domain.get() == nullptr) { + if ( ! domain) { FRT_Values &rvals = *req->GetReturn(); rvals.AddInt32(0); rvals.AddInt64(0); diff --git a/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp b/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp index a18c032fdff..b8fc432e95b 100644 --- a/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp +++ b/staging_vespalib/src/vespa/vespalib/net/generic_state_handler.cpp @@ -2,6 +2,7 @@ #include "generic_state_handler.h" #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> namespace vespalib { diff --git a/storage/src/tests/storageserver/statereportertest.cpp b/storage/src/tests/storageserver/statereportertest.cpp index dc8094275d1..a7d18b21516 100644 --- a/storage/src/tests/storageserver/statereportertest.cpp +++ b/storage/src/tests/storageserver/statereportertest.cpp @@ -10,8 +10,8 @@ #include <tests/common/dummystoragelink.h> #include <vespa/config/common/exceptions.h> #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/vespalib/gtest/gtest.h> -#include <vespa/vespalib/util/time.h> #include <thread> #include <vespa/log/log.h> diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java index 4d50905da7b..62c15fcea27 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/FeedClientFactory.java @@ -5,6 +5,7 @@ package com.yahoo.vespa.http.client; import com.yahoo.vespa.http.client.config.SessionParams; import com.yahoo.vespa.http.client.core.api.FeedClientImpl; +import java.time.Clock; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; @@ -24,7 +25,7 @@ public class FeedClientFactory { * @return newly created FeedClient API object. */ public static FeedClient create(SessionParams sessionParams, FeedClient.ResultCallback resultCallback) { - return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor()); + return new FeedClientImpl(sessionParams, resultCallback, createTimeoutExecutor(), Clock.systemUTC()); } static ScheduledThreadPoolExecutor createTimeoutExecutor() { diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java index 473b9494ba4..16374ec07cc 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/Result.java @@ -78,7 +78,6 @@ public class Result { private final Endpoint endpoint; private final Exception exception; private final String traceMessage; - private final long timeStampMillis = System.currentTimeMillis(); public Detail(Endpoint endpoint, ResultType resultType, String traceMessage, Exception e) { this.endpoint = endpoint; @@ -133,7 +132,6 @@ public class Result { b.append(" trace='").append(traceMessage).append("'"); if (endpoint != null) b.append(" endpoint=").append(endpoint); - b.append(" resultTimeLocally=").append(timeStampMillis).append("\n"); return b.toString(); } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java index b03a2541cd0..b7423f75c87 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/SessionFactory.java @@ -5,6 +5,7 @@ import com.yahoo.vespa.http.client.config.Cluster; import com.yahoo.vespa.http.client.config.Endpoint; import com.yahoo.vespa.http.client.config.SessionParams; +import java.time.Clock; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.ThreadFactory; @@ -30,7 +31,7 @@ public final class SessionFactory { @SuppressWarnings("deprecation") static Session createInternal(SessionParams params) { - return new com.yahoo.vespa.http.client.core.api.SessionImpl(params, createTimeoutExecutor()); + return new com.yahoo.vespa.http.client.core.api.SessionImpl(params, createTimeoutExecutor(), Clock.systemUTC()); } static ScheduledThreadPoolExecutor createTimeoutExecutor() { diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java index 1accbd51ac7..2417a4acf71 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java @@ -42,6 +42,7 @@ public final class ConnectionParams { private int maxRetries = 100; private long minTimeBetweenRetriesMs = 700; private boolean dryRun = false; + private boolean runThreads = true; private int traceLevel = 0; private int traceEveryXOperation = 0; private boolean printTraceToStdErr = true; @@ -191,10 +192,8 @@ public final class ConnectionParams { } /** - * Don't send data to gateway, just pretend that everything is fine. - * - * @param dryRun true if enabled. - * @return pointer to builder. + * Set to true to skip making network connections and instead + * let requests complete successfully with no effect. */ public Builder setDryRun(boolean dryRun) { this.dryRun = dryRun; @@ -202,6 +201,15 @@ public final class ConnectionParams { } /** + * Set to false to skip starting io threads, such that any operation must be driven by a calling thread. + * Useful for testing. + */ + public Builder setRunThreads(boolean runThreads) { + this.runThreads = runThreads; + return this; + } + + /** * Set the min time between retries when temporarily failing against a gateway. * * @param minTimeBetweenRetries the min time value @@ -274,6 +282,7 @@ public final class ConnectionParams { maxRetries, minTimeBetweenRetriesMs, dryRun, + runThreads, traceLevel, traceEveryXOperation, printTraceToStdErr, @@ -293,6 +302,8 @@ public final class ConnectionParams { return dryRun; } + public boolean runThreads() { return runThreads; } + public int getMaxRetries() { return maxRetries; } @@ -330,6 +341,7 @@ public final class ConnectionParams { public Path getCertificate() { return certificate; } public Path getCaCertificates() { return caCertificates; } } + private final SSLContext sslContext; private final Path privateKey; private final Path certificate; @@ -344,6 +356,7 @@ public final class ConnectionParams { private final int maxRetries; private final long minTimeBetweenRetriesMs; private final boolean dryRun; + private final boolean runThreads; private final int traceLevel; private final int traceEveryXOperation; private final boolean printTraceToStdErr; @@ -363,6 +376,7 @@ public final class ConnectionParams { int maxRetries, long minTimeBetweenRetriesMs, boolean dryRun, + boolean runThreads, int traceLevel, int traceEveryXOperation, boolean printTraceToStdErr, @@ -384,6 +398,7 @@ public final class ConnectionParams { this.maxRetries = maxRetries; this.minTimeBetweenRetriesMs = minTimeBetweenRetriesMs; this.dryRun = dryRun; + this.runThreads = runThreads; this.traceLevel = traceLevel; this.traceEveryXOperation = traceEveryXOperation; this.printTraceToStdErr = printTraceToStdErr; @@ -435,6 +450,8 @@ public final class ConnectionParams { return dryRun; } + public boolean runThreads() { return runThreads; } + public int getTraceLevel() { return traceLevel; } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java index d623db3834c..200bedb90da 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/FeedParams.java @@ -172,6 +172,10 @@ public final class FeedParams { return this; } + /** + * Sets the number of milliseconds until we respond with a timeout for a document operation + * if we still have not received a response. + */ public Builder setLocalQueueTimeOut(long timeOutMs) { this.localQueueTimeOut = timeOutMs; return this; diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java index 3131206f148..bf07e3ea634 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/SessionParams.java @@ -141,13 +141,12 @@ public final class SessionParams { private final ErrorReporter errorReport; private int throttlerMinSize; - private SessionParams( - Collection<Cluster> clusters, - FeedParams feedParams, - ConnectionParams connectionParams, - int clientQueueSize, - ErrorReporter errorReporter, - int throttlerMinSize) { + private SessionParams(Collection<Cluster> clusters, + FeedParams feedParams, + ConnectionParams connectionParams, + int clientQueueSize, + ErrorReporter errorReporter, + int throttlerMinSize) { this.clusters = Collections.unmodifiableList(new ArrayList<>(clusters)); this.feedParams = feedParams; this.connectionParams = connectionParams; diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java index bc38155d07a..98fd2f9da84 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/Document.java @@ -7,53 +7,53 @@ import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.charset.CharacterCodingException; import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.util.Objects; import java.util.concurrent.ThreadLocalRandom; /** + * A document operation + * * @author Einar M R Rosenvinge */ final public class Document { private final String documentId; private final ByteBuffer data; - private final long createTimeMillis = System.currentTimeMillis(); - // This is initialized lazily to reduce work on calling thread (which is the thread calling the API). + private final Instant createTime; + // This is initialized lazily to reduce work on calling thread (which is the thread calling the API) private String operationId = null; private final Object context; - private long queueInsertTimestampMillis; + private Instant queueInsertTime; - public Document(String documentId, byte[] data, Object context) { - this.documentId = documentId; - this.context = context; - this.data = ByteBuffer.wrap(data); + public Document(String documentId, byte[] data, Object context, Instant createTime) { + this(documentId, null, ByteBuffer.wrap(data), context, createTime); } - public Document(String documentId, String operationId, CharSequence data, Object context) { + public Document(String documentId, String operationId, CharSequence data, Object context, Instant createTime) { + this(documentId, operationId, encode(data, documentId), context, createTime); + } + + private Document(String documentId, String operationId, ByteBuffer data, Object context, Instant createTime) { this.documentId = documentId; this.operationId = operationId; + this.data = data; this.context = context; - try { - this.data = StandardCharsets.UTF_8.newEncoder().encode(CharBuffer.wrap(data)); - } catch (CharacterCodingException e) { - throw new RuntimeException("Error encoding document data into UTF8 " + documentId, e); - } + this.createTime = Objects.requireNonNull(createTime, "createTime cannot be null"); + this.queueInsertTime = createTime; } - public void resetQueueTime() { - queueInsertTimestampMillis = System.currentTimeMillis(); + public void setQueueInsertTime(Instant queueInsertTime) { + this.queueInsertTime = queueInsertTime; } - public long timeInQueueMillis() { - return System.currentTimeMillis() - queueInsertTimestampMillis; - } + public Instant getQueueInsertTime() { return queueInsertTime; } public CharSequence getDataAsString() { return StandardCharsets.UTF_8.decode(data.asReadOnlyBuffer()); } - public Object getContext() { - return context; - } + public Object getContext() { return context; } public static class DocumentException extends IOException { private static final long serialVersionUID = 29832833292L; @@ -63,9 +63,7 @@ final public class Document { } } - public String getDocumentId() { - return documentId; - } + public String getDocumentId() { return documentId; } public ByteBuffer getData() { return data.asReadOnlyBuffer(); @@ -75,9 +73,7 @@ final public class Document { return data.remaining(); } - public long createTimeMillis() { - return createTimeMillis; - } + public Instant createTime() { return createTime; } public String getOperationId() { if (operationId == null) { @@ -89,4 +85,12 @@ final public class Document { @Override public String toString() { return "document '" + documentId + "'"; } + private static ByteBuffer encode(CharSequence data, String documentId) { + try { + return StandardCharsets.UTF_8.newEncoder().encode(CharBuffer.wrap(data)); + } catch (CharacterCodingException e) { + throw new RuntimeException("Error encoding document data into UTF8 " + documentId, e); + } + } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java index 7238a0c4ba7..a950cb545de 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/FeedClientImpl.java @@ -11,6 +11,7 @@ import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; import java.nio.charset.CharsetEncoder; import java.nio.charset.CodingErrorAction; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.time.Instant; import java.util.Optional; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -23,25 +24,28 @@ import java.util.concurrent.TimeUnit; */ public class FeedClientImpl implements FeedClient { + private final Clock clock; private final OperationProcessor operationProcessor; private final long closeTimeoutMs; private final long sleepTimeMs = 500; public FeedClientImpl(SessionParams sessionParams, ResultCallback resultCallback, - ScheduledThreadPoolExecutor timeoutExecutor) { - this.closeTimeoutMs = (10 + 3 * sessionParams.getConnectionParams().getMaxRetries()) * ( - sessionParams.getFeedParams().getServerTimeout(TimeUnit.MILLISECONDS) + - sessionParams.getFeedParams().getClientTimeout(TimeUnit.MILLISECONDS)); + ScheduledThreadPoolExecutor timeoutExecutor, + Clock clock) { + this.clock = clock; + this.closeTimeoutMs = (10 + 3 * sessionParams.getConnectionParams().getMaxRetries()) * + (sessionParams.getFeedParams().getServerTimeout(TimeUnit.MILLISECONDS) + + sessionParams.getFeedParams().getClientTimeout(TimeUnit.MILLISECONDS)); this.operationProcessor = new OperationProcessor( - new IncompleteResultsThrottler( - sessionParams.getThrottlerMinSize(), - sessionParams.getClientQueueSize(), - ()->System.currentTimeMillis(), - new ThrottlePolicy()), + new IncompleteResultsThrottler(sessionParams.getThrottlerMinSize(), + sessionParams.getClientQueueSize(), + clock, + new ThrottlePolicy()), resultCallback, sessionParams, - timeoutExecutor); + timeoutExecutor, + clock); } @Override @@ -50,7 +54,7 @@ public class FeedClientImpl implements FeedClient { charsetEncoder.onMalformedInput(CodingErrorAction.REPORT); charsetEncoder.onUnmappableCharacter(CodingErrorAction.REPORT); - Document document = new Document(documentId, operationId, documentData, context); + Document document = new Document(documentId, operationId, documentData, context, clock.instant()); operationProcessor.sendDocument(document); } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java index bf55a46277d..e09cecf7161 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/MultiClusterSessionOutputStream.java @@ -6,6 +6,7 @@ import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.time.Clock; /** * Class for wiring up the Session API. It is the return value of stream() in the Session API. @@ -17,19 +18,21 @@ class MultiClusterSessionOutputStream extends ByteArrayOutputStream { private final CharSequence documentId; private final OperationProcessor operationProcessor; private final Object context; + private final Clock clock; - public MultiClusterSessionOutputStream( - CharSequence documentId, - OperationProcessor operationProcessor, - Object context) { + public MultiClusterSessionOutputStream(CharSequence documentId, + OperationProcessor operationProcessor, + Object context, + Clock clock) { this.documentId = documentId; this.context = context; this.operationProcessor = operationProcessor; + this.clock = clock; } @Override public void close() throws IOException { - Document document = new Document(documentId.toString(), toByteArray(), context); + Document document = new Document(documentId.toString(), toByteArray(), context, clock.instant()); operationProcessor.sendDocument(document); super.close(); } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java index a5c97351347..a68d7eb7524 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/api/SessionImpl.java @@ -9,6 +9,7 @@ import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThro import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; import java.io.OutputStream; +import java.time.Clock; import java.util.concurrent.BlockingQueue; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -23,14 +24,15 @@ public class SessionImpl implements com.yahoo.vespa.http.client.Session { private final OperationProcessor operationProcessor; private final BlockingQueue<Result> resultQueue = new LinkedBlockingQueue<>(); + private final Clock clock; - - public SessionImpl(SessionParams sessionParams, ScheduledThreadPoolExecutor timeoutExecutor) { + public SessionImpl(SessionParams sessionParams, ScheduledThreadPoolExecutor timeoutExecutor, Clock clock) { + this.clock = clock; this.operationProcessor = new OperationProcessor( new IncompleteResultsThrottler( sessionParams.getThrottlerMinSize(), sessionParams.getClientQueueSize(), - ()->System.currentTimeMillis(), + clock, new ThrottlePolicy()), new FeedClient.ResultCallback() { @Override @@ -39,12 +41,13 @@ public class SessionImpl implements com.yahoo.vespa.http.client.Session { } }, sessionParams, - timeoutExecutor); + timeoutExecutor, + clock); } @Override public OutputStream stream(CharSequence documentId) { - return new MultiClusterSessionOutputStream(documentId, operationProcessor, null); + return new MultiClusterSessionOutputStream(documentId, operationProcessor, null, clock); } @Override diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java index d510ce4b7ea..a46b2e67fe1 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java @@ -20,6 +20,7 @@ import org.apache.http.client.HttpClient; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.InputStreamEntity; +import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.message.BasicHeader; @@ -28,10 +29,10 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; -import java.net.InetAddress; -import java.net.UnknownHostException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Instant; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -49,84 +50,89 @@ import java.util.zip.GZIPOutputStream; */ class ApacheGatewayConnection implements GatewayConnection { - private static Logger log = Logger.getLogger(ApacheGatewayConnection.class.getName()); + private static final Logger log = Logger.getLogger(ApacheGatewayConnection.class.getName()); private static final ObjectMapper mapper = new ObjectMapper(); private static final String PATH = "/reserved-for-internal-use/feedapi?"; - private final List<Integer> SUPPORTED_VERSIONS = new ArrayList<>(); private static final byte[] START_OF_FEED_XML = "<vespafeed>\n".getBytes(StandardCharsets.UTF_8); private static final byte[] END_OF_FEED_XML = "\n</vespafeed>\n".getBytes(StandardCharsets.UTF_8); private static final byte[] START_OF_FEED_JSON = "[".getBytes(StandardCharsets.UTF_8); private static final byte[] END_OF_FEED_JSON = "]".getBytes(StandardCharsets.UTF_8); + + private final List<Integer> supportedVersions = new ArrayList<>(); private final byte[] startOfFeed; private final byte[] endOfFeed; private final Endpoint endpoint; private final FeedParams feedParams; private final String clusterSpecificRoute; private final ConnectionParams connectionParams; - private HttpClient httpClient; + private CloseableHttpClient httpClient; + private Instant connectionTime = null; + private Instant lastPollTime = null; private String sessionId; private final String clientId; private int negotiatedVersion = -1; private final HttpClientFactory httpClientFactory; private final String shardingKey = UUID.randomUUID().toString().substring(0, 5); - - ApacheGatewayConnection( - Endpoint endpoint, - FeedParams feedParams, - String clusterSpecificRoute, - ConnectionParams connectionParams, - HttpClientFactory httpClientFactory, - String clientId) { - SUPPORTED_VERSIONS.add(3); - this.endpoint = validate(endpoint); + private final Clock clock; + + ApacheGatewayConnection(Endpoint endpoint, + FeedParams feedParams, + String clusterSpecificRoute, + ConnectionParams connectionParams, + HttpClientFactory httpClientFactory, + String clientId, + Clock clock) { + supportedVersions.add(3); + this.endpoint = endpoint; this.feedParams = feedParams; this.clusterSpecificRoute = clusterSpecificRoute; this.httpClientFactory = httpClientFactory; this.connectionParams = connectionParams; this.httpClient = null; - boolean isJson = feedParams.getDataFormat() == FeedParams.DataFormat.JSON_UTF8; - if (isJson) { + this.clientId = clientId; + this.clock = clock; + + if (feedParams.getDataFormat() == FeedParams.DataFormat.JSON_UTF8) { startOfFeed = START_OF_FEED_JSON; endOfFeed = END_OF_FEED_JSON; } else { startOfFeed = START_OF_FEED_XML; endOfFeed = END_OF_FEED_XML; } - this.clientId = clientId; - if (this.clientId == null) - throw new IllegalArgumentException("Got no client Id."); } - private static Endpoint validate(Endpoint endpoint) { - try { - InetAddress.getByName(endpoint.getHostname()); - return endpoint; - } - catch (UnknownHostException e) { - throw new IllegalArgumentException("Unknown host: " + endpoint); - } + @Override + public InputStream write(List<Document> docs) throws ServerResponseException, IOException { + return write(docs, false, connectionParams.getUseCompression()); } @Override - public InputStream writeOperations(List<Document> docs) throws ServerResponseException, IOException { - return write(docs, false, connectionParams.getUseCompression()); + public InputStream poll() throws ServerResponseException, IOException { + lastPollTime = clock.instant(); + return write(Collections.<Document>emptyList(), false, false); } @Override + public Instant lastPollTime() { return lastPollTime; } + + @Override public InputStream drain() throws ServerResponseException, IOException { - return write(Collections.<Document>emptyList(), true /* drain */, false /* use compression */); + return write(Collections.<Document>emptyList(), true, false); } @Override public boolean connect() { - log.fine("Attempting to connect to " + endpoint); - if (httpClient != null) { + log.fine(() -> "Attempting to connect to " + endpoint); + if (httpClient != null) log.log(Level.WARNING, "Previous httpClient still exists."); - } httpClient = httpClientFactory.createClient(); + connectionTime = clock.instant(); return httpClient != null; } + @Override + public Instant connectionTime() { return connectionTime; } + // Protected for easier testing only. protected static InputStreamEntity zipAndCreateEntity(final InputStream inputStream) throws IOException { byte[] buffer = new byte[4096]; @@ -184,7 +190,7 @@ class ApacheGatewayConnection implements GatewayConnection { private HttpPost createPost(boolean drain, boolean useCompression, boolean isHandshake) { HttpPost httpPost = new HttpPost(createUri()); - for (int v : SUPPORTED_VERSIONS) { + for (int v : supportedVersions) { httpPost.addHeader(Headers.VERSION, "" + v); } if (sessionId != null) { @@ -194,11 +200,7 @@ class ApacheGatewayConnection implements GatewayConnection { httpPost.setHeader(Headers.CLIENT_ID, clientId); } httpPost.setHeader(Headers.SHARDING_KEY, shardingKey); - if (drain) { - httpPost.setHeader(Headers.DRAIN, "true"); - } else { - httpPost.setHeader(Headers.DRAIN, "false"); - } + httpPost.setHeader(Headers.DRAIN, drain ? "true" : "false"); if (clusterSpecificRoute != null) { httpPost.setHeader(Headers.ROUTE, feedParams.getRoute()); } else { @@ -246,13 +248,9 @@ class ApacheGatewayConnection implements GatewayConnection { private InputStream executePost(HttpPost httpPost) throws ServerResponseException, IOException { HttpResponse response; try { - if (httpClient == null) { + if (httpClient == null) throw new IOException("Trying to executePost while not having a connection/http client"); - } response = httpClient.execute(httpPost); - } catch (IOException e) { - httpPost.abort(); - throw e; } catch (Exception e) { httpPost.abort(); throw e; @@ -270,18 +268,14 @@ class ApacheGatewayConnection implements GatewayConnection { private void verifyServerResponseCode(HttpResponse response) throws ServerResponseException { StatusLine statusLine = response.getStatusLine(); + int statusCode = statusLine.getStatusCode(); + // We use code 261-299 to report errors related to internal transitive errors that the tenants should not care // about to avoid masking more serious errors. - int statusCode = statusLine.getStatusCode(); - if (statusCode > 199 && statusCode < 260) { - return; - } - if (statusCode == 299) { - throw new ServerResponseException(429, "Too many requests."); - } - String message = tryGetDetailedErrorMessage(response) - .orElseGet(statusLine::getReasonPhrase); - throw new ServerResponseException(statusLine.getStatusCode(), message); + if (statusCode > 199 && statusCode < 260) return; + if (statusCode == 299) throw new ServerResponseException(429, "Too many requests."); + throw new ServerResponseException(statusCode, + tryGetDetailedErrorMessage(response).orElseGet(statusLine::getReasonPhrase)); } private static Optional<String> tryGetDetailedErrorMessage(HttpResponse response) { @@ -305,7 +299,7 @@ class ApacheGatewayConnection implements GatewayConnection { if (negotiatedVersion == 3) { if (clientId == null || !clientId.equals(serverHeaderVal)) { String message = "Running using v3. However, server responds with different session " + - "than client has set; " + serverHeaderVal + " vs client code " + clientId; + "than client has set; " + serverHeaderVal + " vs client code " + clientId; log.severe(message); throw new ServerResponseException(message); } @@ -314,14 +308,12 @@ class ApacheGatewayConnection implements GatewayConnection { if (sessionId == null) { //this must be the first request log.finer("Got session ID from server: " + serverHeaderVal); this.sessionId = serverHeaderVal; - return; } else { if (!sessionId.equals(serverHeaderVal)) { - log.info("Request has been routed to a server which does not recognize the client session." - + " Most likely cause is upgrading of cluster, transitive error."); - throw new ServerResponseException( - "Session ID received from server ('" + serverHeaderVal - + "') does not match cached session ID ('" + sessionId + "')"); + log.info("Request has been routed to a server which does not recognize the client session." + + " Most likely cause is upgrading of cluster, transitive error."); + throw new ServerResponseException("Session ID received from server ('" + serverHeaderVal + + "') does not match cached session ID ('" + sessionId + "')"); } } } @@ -336,9 +328,9 @@ class ApacheGatewayConnection implements GatewayConnection { } catch (NumberFormatException nfe) { throw new ServerResponseException("Got bad protocol version from server: " + nfe.getMessage()); } - if (!SUPPORTED_VERSIONS.contains(serverVersion)) { + if (!supportedVersions.contains(serverVersion)) { throw new ServerResponseException("Unsupported version: " + serverVersion - + ". Supported versions: " + SUPPORTED_VERSIONS); + + ". Supported versions: " + supportedVersions); } if (negotiatedVersion == -1) { if (log.isLoggable(Level.FINE)) { @@ -387,6 +379,13 @@ class ApacheGatewayConnection implements GatewayConnection { @Override public void close() { + try { + if (httpClient != null) + httpClient.close(); + } + catch (IOException e) { + log.log(Level.WARNING, "Failed closing HTTP client", e); + } httpClient = null; } @@ -403,7 +402,7 @@ class ApacheGatewayConnection implements GatewayConnection { this.useSsl = useSsl; } - public HttpClient createClient() { + public CloseableHttpClient createClient() { HttpClientBuilder clientBuilder; if (connectionParams.useTlsConfigFromEnvironment()) { clientBuilder = VespaHttpClientBuilder.create(); @@ -428,12 +427,9 @@ class ApacheGatewayConnection implements GatewayConnection { } clientBuilder.setMaxConnPerRoute(1); clientBuilder.setMaxConnTotal(1); - clientBuilder.setConnectionTimeToLive(connectionParams.getConnectionTimeToLive().getSeconds(), TimeUnit.SECONDS); clientBuilder.setUserAgent(String.format("vespa-http-client (%s)", Vtag.currentVersion)); clientBuilder.setDefaultHeaders(Collections.singletonList(new BasicHeader(Headers.CLIENT_VERSION, Vtag.currentVersion))); clientBuilder.disableContentCompression(); - // Try to disable the disabling to see if system tests become stable again. - // clientBuilder.disableAutomaticRetries(); RequestConfig.Builder requestConfigBuilder = RequestConfig.custom(); requestConfigBuilder.setSocketTimeout(0); if (connectionParams.getProxyHost() != null) { @@ -441,17 +437,16 @@ class ApacheGatewayConnection implements GatewayConnection { } clientBuilder.setDefaultRequestConfig(requestConfigBuilder.build()); - log.fine("Creating HttpClient: " + " ConnectionTimeout " - + " SocketTimeout 0 secs " - + " proxyhost (can be null) " + connectionParams.getProxyHost() - + ":" + connectionParams.getProxyPort() + log.fine(() -> "Creating HttpClient:" + + " ConnectionTimeout " + connectionParams.getConnectionTimeToLive().getSeconds() + " seconds" + + " proxyhost (can be null) " + connectionParams.getProxyHost() + ":" + connectionParams.getProxyPort() + (useSsl ? " using ssl " : " not using ssl") ); return clientBuilder.build(); } } - // Note: Using deprecated setSslcontext() to allow httpclient 4.4 on classpath (e.g unexpected Maven dependency resolution for test classpath) + // Note: Using deprecated setSslContext() to allow httpclient 4.4 on classpath (e.g unexpected Maven dependency resolution for test classpath) @SuppressWarnings("deprecation") private static void setSslContext(HttpClientBuilder builder, SSLContext sslContext) { builder.setSslcontext(sslContext); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java new file mode 100644 index 00000000000..31ec8aa06a2 --- /dev/null +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionFactory.java @@ -0,0 +1,63 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.http.client.core.communication; + +import com.yahoo.vespa.http.client.config.ConnectionParams; +import com.yahoo.vespa.http.client.config.Endpoint; +import com.yahoo.vespa.http.client.config.FeedParams; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.time.Clock; +import java.util.Objects; + +/** + * @author bratseth + */ +public class ApacheGatewayConnectionFactory implements GatewayConnectionFactory { + + private final Endpoint endpoint; + private final FeedParams feedParams; + private final String clusterSpecificRoute; + private final ConnectionParams connectionParams; + private final ApacheGatewayConnection.HttpClientFactory httpClientFactory; + private final String clientId; + private final Clock clock; + + public ApacheGatewayConnectionFactory(Endpoint endpoint, + FeedParams feedParams, + String clusterSpecificRoute, + ConnectionParams connectionParams, + ApacheGatewayConnection.HttpClientFactory httpClientFactory, + String clientId, + Clock clock) { + this.endpoint = validate(endpoint); + this.feedParams = feedParams; + this.clusterSpecificRoute = clusterSpecificRoute; + this.httpClientFactory = httpClientFactory; + this.connectionParams = connectionParams; + this.clientId = Objects.requireNonNull(clientId, "clientId cannot be null"); + this.clock = clock; + } + + private static Endpoint validate(Endpoint endpoint) { + try { + InetAddress.getByName(endpoint.getHostname()); + return endpoint; + } + catch (UnknownHostException e) { + throw new IllegalArgumentException("Unknown host: " + endpoint); + } + } + + @Override + public GatewayConnection newConnection() { + return new ApacheGatewayConnection(endpoint, + feedParams, + clusterSpecificRoute, + connectionParams, + httpClientFactory, + clientId, + clock); + } + +} diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java index d254cd0bab8..8e55e59b3f4 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ClusterConnection.java @@ -14,7 +14,10 @@ import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; import java.io.IOException; import java.io.StringWriter; +import java.time.Clock; +import java.time.Duration; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.concurrent.TimeUnit; @@ -43,7 +46,8 @@ public class ClusterConnection implements AutoCloseable { Cluster cluster, int clusterId, int clientQueueSizePerCluster, - ScheduledThreadPoolExecutor timeoutExecutor) { + ScheduledThreadPoolExecutor timeoutExecutor, + Clock clock) { if (cluster.getEndpoints().isEmpty()) throw new IllegalArgumentException("At least a single endpoint is required in " + cluster); @@ -53,7 +57,7 @@ public class ClusterConnection implements AutoCloseable { throw new IllegalArgumentException("At least 1 persistent connection per endpoint is required in " + cluster); int maxInFlightPerSession = Math.max(1, feedParams.getMaxInFlightRequests() / totalNumberOfEndpointsInThisCluster); - documentQueue = new DocumentQueue(clientQueueSizePerCluster); + documentQueue = new DocumentQueue(clientQueueSizePerCluster, clock); ioThreadGroup = operationProcessor.getIoThreadGroup(); singleEndpoint = cluster.getEndpoints().size() == 1 ? cluster.getEndpoints().get(0) : null; Double idlePollFrequency = feedParams.getIdlePollFrequency(); @@ -66,28 +70,33 @@ public class ClusterConnection implements AutoCloseable { timeoutExecutor, feedParams.getServerTimeout(TimeUnit.MILLISECONDS) + feedParams.getClientTimeout(TimeUnit.MILLISECONDS)); for (int i = 0; i < connectionParams.getNumPersistentConnectionsPerEndpoint(); i++) { - GatewayConnection gatewayConnection; + GatewayConnectionFactory connectionFactory; if (connectionParams.isDryRun()) { - gatewayConnection = new DryRunGatewayConnection(endpoint); + connectionFactory = new DryRunGatewayConnectionFactory(endpoint, clock); } else { - gatewayConnection = new ApacheGatewayConnection(endpoint, - feedParams, - cluster.getRoute(), - connectionParams, - new ApacheGatewayConnection.HttpClientFactory(connectionParams, endpoint.isUseSsl()), - operationProcessor.getClientId() + connectionFactory = new ApacheGatewayConnectionFactory(endpoint, + feedParams, + cluster.getRoute(), + connectionParams, + new ApacheGatewayConnection.HttpClientFactory(connectionParams, endpoint.isUseSsl()), + operationProcessor.getClientId(), + clock ); } IOThread ioThread = new IOThread(operationProcessor.getIoThreadGroup(), + endpoint, endpointResultQueue, - gatewayConnection, + connectionFactory, clusterId, feedParams.getMaxChunkSizeBytes(), maxInFlightPerSession, - feedParams.getLocalQueueTimeOut(), + Duration.ofMillis(feedParams.getLocalQueueTimeOut()), documentQueue, feedParams.getMaxSleepTimeMs(), - idlePollFrequency); + connectionParams.getConnectionTimeToLive(), + connectionParams.runThreads(), + idlePollFrequency, + clock); ioThreads.add(ioThread); } } @@ -160,6 +169,10 @@ public class ClusterConnection implements AutoCloseable { return stringWriter.toString(); } + public List<IOThread> ioThreads() { + return Collections.unmodifiableList(ioThreads); + } + @Override public boolean equals(Object o) { return (this == o) || (o instanceof ClusterConnection && clusterId == ((ClusterConnection) o).clusterId); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java index 16bf881963f..3536013e043 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DocumentQueue.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.http.client.core.communication; import com.yahoo.vespa.http.client.core.Document; +import java.time.Clock; +import java.time.Duration; import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Deque; @@ -11,8 +13,8 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; /** - * Document queue that only gives you document operations on documents for which there are no - * already in flight operations for. + * Shared document queue that gives clients operations on documents which do not have operations already in flight. + * This is multithread safe. * * @author dybis */ @@ -21,10 +23,12 @@ class DocumentQueue { private final Deque<Document> queue; private final int maxSize; private boolean closed = false; + private final Clock clock; - DocumentQueue(int maxSize) { + DocumentQueue(int maxSize, Clock clock) { this.maxSize = maxSize; this.queue = new ArrayDeque<>(maxSize); + this.clock = clock; } List<Document> removeAllDocuments() { @@ -39,7 +43,7 @@ class DocumentQueue { } void put(Document document, boolean calledFromIoThreadGroup) throws InterruptedException { - document.resetQueueTime(); + document.setQueueInsertTime(clock.instant()); synchronized (queue) { while (!closed && (queue.size() >= maxSize) && !calledFromIoThreadGroup) { queue.wait(); @@ -56,9 +60,9 @@ class DocumentQueue { synchronized (queue) { long remainingToWait = unit.toMillis(timeout); while (queue.isEmpty()) { - long startTime = System.currentTimeMillis(); + long startTime = clock.millis(); queue.wait(remainingToWait); - remainingToWait -= (System.currentTimeMillis() - startTime); + remainingToWait -= (clock.millis() - startTime); if (remainingToWait <= 0) { break; } @@ -106,16 +110,15 @@ class DocumentQueue { return previousState; } - Optional<Document> pollDocumentIfTimedoutInQueue(long localQueueTimeOut) { + Optional<Document> pollDocumentIfTimedoutInQueue(Duration localQueueTimeOut) { synchronized (queue) { - if (queue.isEmpty()) { - return Optional.empty(); - } + if (queue.isEmpty()) return Optional.empty(); + Document document = queue.peek(); - if (document.timeInQueueMillis() > localQueueTimeOut) { - return Optional.of(queue.poll()); - } - return Optional.empty(); + if (document.getQueueInsertTime().plus(localQueueTimeOut).isBefore(clock.instant())) + return Optional.ofNullable(queue.poll()); + else + return Optional.empty(); } } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java index 23ab5e36e14..129fc000271 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnection.java @@ -5,13 +5,14 @@ import com.yahoo.vespa.http.client.config.Endpoint; import com.yahoo.vespa.http.client.core.Document; import com.yahoo.vespa.http.client.core.ErrorCode; import com.yahoo.vespa.http.client.core.OperationStatus; -import com.yahoo.vespa.http.client.core.ServerResponseException; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; /** @@ -22,40 +23,78 @@ import java.util.List; public class DryRunGatewayConnection implements GatewayConnection { private final Endpoint endpoint; + private final Clock clock; + private Instant connectionTime = null; + private Instant lastPollTime = null; - public DryRunGatewayConnection(Endpoint endpoint) { + /** Set to true to hold off responding with a result to any incoming operations until this is set false */ + private boolean hold = false; + private List<Document> held = new ArrayList<>(); + + public DryRunGatewayConnection(Endpoint endpoint, Clock clock) { this.endpoint = endpoint; + this.clock = clock; } @Override - public InputStream writeOperations(List<Document> docs) throws ServerResponseException, IOException { + public InputStream write(List<Document> docs) { StringBuilder result = new StringBuilder(); - for (Document doc : docs) { - OperationStatus operationStatus = new OperationStatus("ok", doc.getOperationId(), ErrorCode.OK, false, ""); - result.append(operationStatus.render()); + if (hold) { + held.addAll(docs); + } + else { + for (Document doc : held) + result.append(okResponse(doc).render()); + held.clear(); + for (Document doc : docs) + result.append(okResponse(doc).render()); } return new ByteArrayInputStream(result.toString().getBytes(StandardCharsets.UTF_8)); } + public void hold(boolean hold) { + this.hold = hold; + } + + @Override + public InputStream poll() { + lastPollTime = clock.instant(); + return write(new ArrayList<>()); + } + @Override - public InputStream drain() throws ServerResponseException, IOException { - return writeOperations(new ArrayList<Document>()); + public Instant lastPollTime() { return lastPollTime; } + + @Override + public InputStream drain() { + return write(new ArrayList<>()); } @Override public boolean connect() { + connectionTime = clock.instant(); return true; } @Override + public Instant connectionTime() { return connectionTime; } + + @Override public Endpoint getEndpoint() { return endpoint; } @Override - public void handshake() throws ServerResponseException, IOException { } + public void handshake() { } @Override public void close() { } + /** Returns the document currently held in this */ + public List<Document> held() { return Collections.unmodifiableList(held); } + + private OperationStatus okResponse(Document document) { + return new OperationStatus("ok", document.getOperationId(), ErrorCode.OK, false, ""); + } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java new file mode 100644 index 00000000000..a234dba6b8e --- /dev/null +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/DryRunGatewayConnectionFactory.java @@ -0,0 +1,26 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.http.client.core.communication; + +import com.yahoo.vespa.http.client.config.Endpoint; + +import java.time.Clock; + +/** + * @author bratseth + */ +public class DryRunGatewayConnectionFactory implements GatewayConnectionFactory { + + private final Endpoint endpoint; + private final Clock clock; + + public DryRunGatewayConnectionFactory(Endpoint endpoint, Clock clock) { + this.endpoint = endpoint; + this.clock = clock; + } + + @Override + public GatewayConnection newConnection() { + return new DryRunGatewayConnection(endpoint, clock); + } + +} diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java index cd146cf0e87..1dd8b3bf3ec 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueue.java @@ -15,24 +15,29 @@ import java.util.concurrent.TimeUnit; import java.util.logging.Logger; /** + * The shared queue of operation results. + * This is multithread safe. + * * @author Einar M R Rosenvinge */ class EndpointResultQueue { - private static Logger log = Logger.getLogger(EndpointResultQueue.class.getName()); + private static final Logger log = Logger.getLogger(EndpointResultQueue.class.getName()); private final OperationProcessor operationProcessor; + + /** The currently in flight operations */ private final Map<String, TimerFuture> futureByOperation = new HashMap<>(); + private final Endpoint endpoint; private final int clusterId; private final ScheduledThreadPoolExecutor timer; private final long totalTimeoutMs; - EndpointResultQueue( - OperationProcessor operationProcessor, - Endpoint endpoint, - int clusterId, - ScheduledThreadPoolExecutor timer, - long totalTimeoutMs) { + EndpointResultQueue(OperationProcessor operationProcessor, + Endpoint endpoint, + int clusterId, + ScheduledThreadPoolExecutor timer, + long totalTimeoutMs) { this.operationProcessor = operationProcessor; this.endpoint = endpoint; this.clusterId = clusterId; @@ -64,25 +69,23 @@ class EndpointResultQueue { TimerFuture timerFuture = futureByOperation.remove(result.getOperationId()); if (timerFuture == null) { if (duplicateGivesWarning) { - log.warning( - "Result for ID '" + result.getOperationId() + "' received from '" + endpoint - + "', but we have no record of a sent operation. Either something is wrong on the server side " - + "(bad VIP usage?), or we have somehow received duplicate results, " - + "or operation was received _after_ client-side timeout."); + log.warning("Result for ID '" + result.getOperationId() + "' received from '" + endpoint + + "', but we have no record of a sent operation. Either something is wrong on the server side " + + "(bad VIP usage?), or we have somehow received duplicate results, " + + "or operation was received _after_ client-side timeout."); } return; } timerFuture.getFuture().cancel(false); } - //Called only from ScheduledThreadPoolExecutor thread in DocumentTimerTask.run(), see below + /** Called only from ScheduledThreadPoolExecutor thread in DocumentTimerTask.run(), see below */ private synchronized void timeout(String operationId) { TimerFuture timerFuture = futureByOperation.remove(operationId); if (timerFuture == null) { - log.finer( - "Timeout of operation '" + operationId + "', but operation " - + "not found in map. Result was probably received just-in-time from server, while timeout " - + "task could not be cancelled."); + log.finer("Timeout of operation '" + operationId + "', but operation " + + "not found in map. Result was probably received just-in-time from server, while timeout " + + "task could not be cancelled."); return; } EndpointResult endpointResult = EndPointResultFactory.createTransientError( @@ -108,6 +111,7 @@ class EndpointResultQueue { } private class DocumentTimerTask implements Runnable { + private final String operationId; private DocumentTimerTask(String operationId) { @@ -118,17 +122,21 @@ class EndpointResultQueue { public void run() { timeout(operationId); } + } - private class TimerFuture { + private static class TimerFuture { + private final ScheduledFuture<?> future; public TimerFuture(ScheduledFuture<?> future) { this.future = future; } + private ScheduledFuture<?> getFuture() { return future; } + } } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java index 3e5bdfe3056..ce1edb83fa2 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnection.java @@ -6,12 +6,23 @@ import com.yahoo.vespa.http.client.core.Document; import com.yahoo.vespa.http.client.core.ServerResponseException; import java.io.IOException; import java.io.InputStream; +import java.time.Instant; import java.util.List; public interface GatewayConnection { - InputStream writeOperations(List<Document> docs) throws ServerResponseException, IOException; + /** Returns the time this connected over the network, or null if not connected yet */ + Instant connectionTime(); + /** Returns the last time poll was called on this, or null if never */ + Instant lastPollTime(); + + InputStream write(List<Document> docs) throws ServerResponseException, IOException; + + /** Returns any operation results that are ready now */ + InputStream poll() throws ServerResponseException, IOException; + + /** Attempt to drain all outstanding operations, even if this leads to blocking */ InputStream drain() throws ServerResponseException, IOException; boolean connect(); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java new file mode 100644 index 00000000000..d27aa850995 --- /dev/null +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/GatewayConnectionFactory.java @@ -0,0 +1,13 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.http.client.core.communication; + +/** + * Creates gateway connections on request + * + * @author bratseth + */ +public interface GatewayConnectionFactory { + + GatewayConnection newConnection(); + +} diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java index 0d916002964..2417208fba3 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/IOThread.java @@ -13,8 +13,13 @@ import com.yahoo.vespa.http.client.core.ServerResponseException; import java.io.IOException; import java.io.InputStream; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.Random; @@ -31,23 +36,40 @@ import java.util.logging.Logger; */ class IOThread implements Runnable, AutoCloseable { - private static Logger log = Logger.getLogger(IOThread.class.getName()); + private static final Logger log = Logger.getLogger(IOThread.class.getName()); + private final Endpoint endpoint; - private final GatewayConnection client; + private final GatewayConnectionFactory connectionFactory; private final DocumentQueue documentQueue; private final EndpointResultQueue resultQueue; + + /** The thread running this, or null if it does not run a thread (meaning tick() must be called from the outside) */ private final Thread thread; private final int clusterId; private final CountDownLatch running = new CountDownLatch(1); private final CountDownLatch stopSignal = new CountDownLatch(1); private final int maxChunkSizeBytes; private final int maxInFlightRequests; - private final long localQueueTimeOut; + private final Duration localQueueTimeOut; + private final Duration maxOldConnectionPollInterval; private final GatewayThrottler gatewayThrottler; + private final Duration connectionTimeToLive; private final long pollIntervalUS; + private final Clock clock; private final Random random = new Random(); - private enum ThreadState { DISCONNECTED, CONNECTED, SESSION_SYNCED }; + private GatewayConnection currentConnection; + private ConnectionState connectionState = ConnectionState.DISCONNECTED; + + /** + * Previous connections on which we have sent operations and are still waiting for the result + * (so all connections in this are in state SESSION_SYNCED). + * We need to drain results on the connection where they were sent to make sure we request results on + * the node which received the operation also when going through a VIP. + */ + private final List<GatewayConnection> oldConnections = new ArrayList<>(); + + private enum ConnectionState { DISCONNECTED, CONNECTED, SESSION_SYNCED }; private final AtomicInteger wrongSessionDetectedCounter = new AtomicInteger(0); private final AtomicInteger wrongVersionDetectedCounter = new AtomicInteger(0); private final AtomicInteger problemStatusCodeFromServerCounter = new AtomicInteger(0); @@ -59,70 +81,49 @@ class IOThread implements Runnable, AutoCloseable { private final AtomicInteger lastGatewayProcessTimeMillis = new AtomicInteger(0); IOThread(ThreadGroup ioThreadGroup, + Endpoint endpoint, EndpointResultQueue endpointResultQueue, - GatewayConnection client, + GatewayConnectionFactory connectionFactory, int clusterId, int maxChunkSizeBytes, int maxInFlightRequests, - long localQueueTimeOut, + Duration localQueueTimeOut, DocumentQueue documentQueue, long maxSleepTimeMs, - double idlePollFrequency) { + Duration connectionTimeToLive, + boolean runThreads, + double idlePollFrequency, + Clock clock) { + this.endpoint = endpoint; this.documentQueue = documentQueue; - this.endpoint = client.getEndpoint(); - this.client = client; + this.connectionFactory = connectionFactory; + this.currentConnection = connectionFactory.newConnection(); this.resultQueue = endpointResultQueue; this.clusterId = clusterId; this.maxChunkSizeBytes = maxChunkSizeBytes; this.maxInFlightRequests = maxInFlightRequests; + this.connectionTimeToLive = connectionTimeToLive; this.gatewayThrottler = new GatewayThrottler(maxSleepTimeMs); - //Ensure that pollInterval is in the range [1us, 10s] - this.pollIntervalUS = Math.max(1, (long)(1000000.0/Math.max(0.1, idlePollFrequency))); - this.thread = new Thread(ioThreadGroup, this, "IOThread " + endpoint); - thread.setDaemon(true); + this.pollIntervalUS = Math.max(1, (long)(1000000.0/Math.max(0.1, idlePollFrequency))); // ensure range [1us, 10s] + this.clock = clock; this.localQueueTimeOut = localQueueTimeOut; - thread.start(); + this.maxOldConnectionPollInterval = localQueueTimeOut.dividedBy(10).toMillis() > pollIntervalUS / 1000 + ? localQueueTimeOut.dividedBy(10) + : Duration.ofMillis(pollIntervalUS / 1000); + if (runThreads) { + this.thread = new Thread(ioThreadGroup, this, "IOThread " + endpoint); + thread.setDaemon(true); + thread.start(); + } + else { + this.thread = null; + } } public Endpoint getEndpoint() { return endpoint; } - public static class ConnectionStats { - - // NOTE: These fields are accessed by reflection in JSON serialization - - public final int wrongSessionDetectedCounter; - public final int wrongVersionDetectedCounter; - public final int problemStatusCodeFromServerCounter; - public final int executeProblemsCounter; - public final int docsReceivedCounter; - public final int statusReceivedCounter; - public final int pendingDocumentStatusCount; - public final int successfullHandshakes; - public final int lastGatewayProcessTimeMillis; - - ConnectionStats(int wrongSessionDetectedCounter, - int wrongVersionDetectedCounter, - int problemStatusCodeFromServerCounter, - int executeProblemsCounter, - int docsReceivedCounter, - int statusReceivedCounter, - int pendingDocumentStatusCount, - int successfullHandshakes, - int lastGatewayProcessTimeMillis) { - this.wrongSessionDetectedCounter = wrongSessionDetectedCounter; - this.wrongVersionDetectedCounter = wrongVersionDetectedCounter; - this.problemStatusCodeFromServerCounter = problemStatusCodeFromServerCounter; - this.executeProblemsCounter = executeProblemsCounter; - this.docsReceivedCounter = docsReceivedCounter; - this.statusReceivedCounter = statusReceivedCounter; - this.pendingDocumentStatusCount = pendingDocumentStatusCount; - this.successfullHandshakes = successfullHandshakes; - this.lastGatewayProcessTimeMillis = lastGatewayProcessTimeMillis; - } - } - /** * Returns a snapshot of counters. Threadsafe. */ @@ -152,18 +153,21 @@ class IOThread implements Runnable, AutoCloseable { if (size > 0) { log.info("We have outstanding operations (" + size + ") , trying to fetch responses."); try { - processResponse(client.drain()); + for (GatewayConnection oldConnection : oldConnections) + processResponse(oldConnection.drain()); + processResponse(currentConnection.drain()); } catch (Throwable e) { log.log(Level.SEVERE, "Some failures while trying to get latest responses from vespa.", e); } } try { - client.close(); + for (GatewayConnection oldConnection : oldConnections) + oldConnection.close(); + currentConnection.close(); } finally { // If there is still documents in the queue, fail them. - drainDocumentQueueWhenFailingPermanently(new Exception( - "Closed call, did not manage to process everything so failing this document.")); + drainDocumentQueueWhenFailingPermanently(new Exception("Closed call, did not manage to process everything so failing this document.")); } log.fine("Session to " + endpoint + " closed."); @@ -184,7 +188,7 @@ class IOThread implements Runnable, AutoCloseable { int chunkSizeBytes = 0; try { drainFirstDocumentsInQueueIfOld(); - Document doc = documentQueue.poll(maxWaitUnits, timeUnit); + Document doc = thread != null ? documentQueue.poll(maxWaitUnits, timeUnit) : documentQueue.poll(); if (doc != null) { docsForSendChunk.add(doc); chunkSizeBytes = doc.size(); @@ -236,12 +240,12 @@ class IOThread implements Runnable, AutoCloseable { private InputStream sendAndReceive(List<Document> docs) throws IOException, ServerResponseException { try { // Post the new docs and get async responses for other posts. - return client.writeOperations(docs); + return currentConnection.write(docs); } catch (ServerResponseException ser) { markDocumentAsFailed(docs, ser); throw ser; } catch (Exception e) { - markDocumentAsFailed(docs, new ServerResponseException(e.getMessage())); + markDocumentAsFailed(docs, new ServerResponseException(Exceptions.toMessageString(e))); throw e; } } @@ -274,11 +278,11 @@ class IOThread implements Runnable, AutoCloseable { private ProcessResponse feedDocumentAndProcessResults(List<Document> docs) throws ServerResponseException, IOException { addDocumentsToResultQueue(docs); - long startTime = System.currentTimeMillis(); + long startTime = clock.millis(); InputStream serverResponse = sendAndReceive(docs); ProcessResponse processResponse = processResponse(serverResponse); - lastGatewayProcessTimeMillis.set((int) (System.currentTimeMillis() - startTime)); + lastGatewayProcessTimeMillis.set((int) (clock.millis() - startTime)); return processResponse; } @@ -309,28 +313,30 @@ class IOThread implements Runnable, AutoCloseable { return processResponse; } - /** Given a current thread state, take the appropriate action and return the resulting new thread state */ - private ThreadState cycle(ThreadState threadState) { - switch(threadState) { + /** Given a current connection state, take the appropriate action and return the resulting new connection state */ + private ConnectionState cycle(ConnectionState connectionState) { + switch(connectionState) { case DISCONNECTED: try { - if (! client.connect()) { + if (! currentConnection.connect()) { log.log(Level.WARNING, "Could not connect to endpoint: '" + endpoint + "'. Will re-try."); drainFirstDocumentsInQueueIfOld(); - return ThreadState.DISCONNECTED; + return ConnectionState.DISCONNECTED; } - return ThreadState.CONNECTED; + return ConnectionState.CONNECTED; } catch (Throwable throwable1) { drainFirstDocumentsInQueueIfOld(); log.log(Level.INFO, "Failed connecting to endpoint: '" + endpoint + "'. Will re-try connecting. Failed with '" + Exceptions.toMessageString(throwable1) + "'",throwable1); executeProblemsCounter.incrementAndGet(); - return ThreadState.DISCONNECTED; + return ConnectionState.DISCONNECTED; } case CONNECTED: try { - client.handshake(); + if (isStale(currentConnection)) + return refreshConnection(connectionState); + currentConnection.handshake(); successfulHandshakes.getAndIncrement(); } catch (ServerResponseException ser) { @@ -340,46 +346,49 @@ class IOThread implements Runnable, AutoCloseable { drainFirstDocumentsInQueueIfOld(); resultQueue.onEndpointError(new FeedProtocolException(ser.getResponseCode(), ser.getResponseString(), ser, endpoint)); - return ThreadState.CONNECTED; + return ConnectionState.CONNECTED; } catch (Throwable throwable) { // This cover IOException as well executeProblemsCounter.incrementAndGet(); resultQueue.onEndpointError(new FeedConnectException(throwable, endpoint)); log.log(Level.INFO, "Failed talking to endpoint. Handshake with server endpoint '" + endpoint + "' failed. Will re-try handshake. Failed with '" + Exceptions.toMessageString(throwable) + "'",throwable); drainFirstDocumentsInQueueIfOld(); - client.close(); - return ThreadState.DISCONNECTED; + currentConnection.close(); + return ConnectionState.DISCONNECTED; } - return ThreadState.SESSION_SYNCED; + return ConnectionState.SESSION_SYNCED; case SESSION_SYNCED: try { + if (isStale(currentConnection)) + return refreshConnection(connectionState); ProcessResponse processResponse = pullAndProcessData(pollIntervalUS); gatewayThrottler.handleCall(processResponse.transitiveErrorCount); } catch (ServerResponseException ser) { - log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint - + "'. Will re-try. Endpoint responded with an unexpected HTTP response code. '" - + Exceptions.toMessageString(ser) + "'",ser); - return ThreadState.CONNECTED; + log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint + + "'. Will re-try. Endpoint responded with an unexpected HTTP response code. '" + + Exceptions.toMessageString(ser) + "'",ser); + return ConnectionState.CONNECTED; } - catch (Throwable e) { // Covers IOException as well - log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint - + "'. Will re-try. Connection level error. Failed with '" + Exceptions.toMessageString(e) + "'", e); - client.close(); - return ThreadState.DISCONNECTED; + catch (Throwable e) { + log.log(Level.INFO, "Problems while handing data over to endpoint '" + endpoint + + "'. Will re-try. Connection level error. Failed with '" + + Exceptions.toMessageString(e) + "'", e); + currentConnection.close(); + return ConnectionState.DISCONNECTED; } - return ThreadState.SESSION_SYNCED; + return ConnectionState.SESSION_SYNCED; default: { log.severe("Should never get here."); - client.close(); - return ThreadState.DISCONNECTED; + currentConnection.close(); + return ConnectionState.DISCONNECTED; } } } - private void sleepIfProblemsGettingSyncedConnection(ThreadState newState, ThreadState oldState) { - if (newState == ThreadState.SESSION_SYNCED) return; - if (newState == ThreadState.CONNECTED && oldState == ThreadState.DISCONNECTED) return; + private void sleepIfProblemsGettingSyncedConnection(ConnectionState newState, ConnectionState oldState) { + if (newState == ConnectionState.SESSION_SYNCED) return; + if (newState == ConnectionState.CONNECTED && oldState == ConnectionState.DISCONNECTED) return; try { // Take it easy we have problems getting a connection up. if (stopSignal.getCount() > 0 || !documentQueue.isEmpty()) { @@ -391,16 +400,19 @@ class IOThread implements Runnable, AutoCloseable { @Override public void run() { - ThreadState threadState = ThreadState.DISCONNECTED; - while (stopSignal.getCount() > 0 || !documentQueue.isEmpty()) { - ThreadState oldState = threadState; - threadState = cycle(threadState); - sleepIfProblemsGettingSyncedConnection(threadState, oldState); - - } + while (stopSignal.getCount() > 0 || !documentQueue.isEmpty()) + tick(); log.finer(toString() + " exiting, documentQueue.size()=" + documentQueue.size()); running.countDown(); + } + /** Do one iteration of work. Should be called from the single worker thread of this. */ + public void tick() { + ConnectionState oldState = connectionState; + connectionState = cycle(connectionState); + checkOldConnections(); + if (thread != null) + sleepIfProblemsGettingSyncedConnection(connectionState, oldState); } private void drainFirstDocumentsInQueueIfOld() { @@ -410,14 +422,14 @@ class IOThread implements Runnable, AutoCloseable { EndpointResult endpointResult = EndPointResultFactory.createTransientError( endpoint, document.get().getOperationId(), - new Exception("Not sending document operation, timed out in queue after " - + document.get().timeInQueueMillis() + " ms.")); + new Exception("Not sending document operation, timed out in queue after " + + (clock.millis() - document.get().getQueueInsertTime().toEpochMilli()) + " ms.")); resultQueue.failOperation(endpointResult, clusterId); } } private void drainDocumentQueueWhenFailingPermanently(Exception exception) { - //first, clear sentOperations: + // first, clear sentOperations: resultQueue.failPending(exception); for (Document document : documentQueue.removeAllDocuments()) { @@ -427,4 +439,92 @@ class IOThread implements Runnable, AutoCloseable { } } + private boolean isStale(GatewayConnection connection) { + return connection.connectionTime() != null + && connection.connectionTime().plus(connectionTimeToLive).isBefore(clock.instant()); + } + + private ConnectionState refreshConnection(ConnectionState currentConnectionState) { + if (currentConnectionState == ConnectionState.SESSION_SYNCED) + oldConnections.add(currentConnection); + currentConnection = connectionFactory.newConnection(); + return ConnectionState.DISCONNECTED; + } + + private void checkOldConnections() { + for (Iterator<GatewayConnection> i = oldConnections.iterator(); i.hasNext(); ) { + GatewayConnection connection = i.next(); + if (closingTime(connection).isBefore(clock.instant())) { + connection.close(); + i.remove(); + } + else if (timeToPoll(connection)) { + try { + processResponse(connection.poll()); + } + catch (Exception e) { + // Old connection; best effort + } + } + } + } + + private Instant closingTime(GatewayConnection connection) { + return connection.connectionTime().plus(connectionTimeToLive).plus(localQueueTimeOut); + } + + private boolean timeToPoll(GatewayConnection connection) { + if (connection.lastPollTime() == null) return true; + + // Poll less the closer the connection comes to closing time + double newness = ( closingTime(connection).toEpochMilli() - clock.millis() ) / + (double)localQueueTimeOut.toMillis(); + if (newness < 0) return true; // connection retired prematurely + if (newness > 1) return false; // closing time reached + Duration pollInterval = Duration.ofMillis(pollIntervalUS / 1000 + + (long)((1 - newness) * ( maxOldConnectionPollInterval.toMillis() - pollIntervalUS / 1000))); + return connection.lastPollTime().plus(pollInterval).isBefore(clock.instant()); + } + + public static class ConnectionStats { + + // NOTE: These fields are accessed by reflection in JSON serialization + + public final int wrongSessionDetectedCounter; + public final int wrongVersionDetectedCounter; + public final int problemStatusCodeFromServerCounter; + public final int executeProblemsCounter; + public final int docsReceivedCounter; + public final int statusReceivedCounter; + public final int pendingDocumentStatusCount; + public final int successfullHandshakes; + public final int lastGatewayProcessTimeMillis; + + ConnectionStats(int wrongSessionDetectedCounter, + int wrongVersionDetectedCounter, + int problemStatusCodeFromServerCounter, + int executeProblemsCounter, + int docsReceivedCounter, + int statusReceivedCounter, + int pendingDocumentStatusCount, + int successfullHandshakes, + int lastGatewayProcessTimeMillis) { + this.wrongSessionDetectedCounter = wrongSessionDetectedCounter; + this.wrongVersionDetectedCounter = wrongVersionDetectedCounter; + this.problemStatusCodeFromServerCounter = problemStatusCodeFromServerCounter; + this.executeProblemsCounter = executeProblemsCounter; + this.docsReceivedCounter = docsReceivedCounter; + this.statusReceivedCounter = statusReceivedCounter; + this.pendingDocumentStatusCount = pendingDocumentStatusCount; + this.successfullHandshakes = successfullHandshakes; + this.lastGatewayProcessTimeMillis = lastGatewayProcessTimeMillis; + } + } + + /** For testing. Returns the current connection of this. Not thread safe. */ + public GatewayConnection currentConnection() { return currentConnection; } + + /** For testing. Returns a snapshot of the old connections of this. Not thread safe. */ + public List<GatewayConnection> oldConnections() { return new ArrayList<>(oldConnections); } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java index 883cea7e6f0..27ad88c123e 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/DocumentSendInfo.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.http.client.core.operationProcessor; import com.yahoo.vespa.http.client.Result; import com.yahoo.vespa.http.client.core.Document; +import java.time.Clock; import java.util.HashMap; import java.util.Map; @@ -18,24 +19,25 @@ class DocumentSendInfo { // This is lazily populated as normal cases does not require retries. private Map<Integer, Integer> attemptedRetriesByClusterId = null; private final StringBuilder localTrace; + private final Clock clock; - DocumentSendInfo(Document document, boolean traceThisDoc) { + DocumentSendInfo(Document document, boolean traceThisDoc, Clock clock) { this.document = document; - localTrace = traceThisDoc - ? new StringBuilder("\n" + document.createTimeMillis() + " Trace starting " + "\n") - : null; + localTrace = traceThisDoc ? new StringBuilder("\n" + document.createTime() + " Trace starting " + "\n") + : null; + this.clock = clock; } boolean addIfNotAlreadyThere(Result.Detail detail, int clusterId) { if (detailByClusterId.containsKey(clusterId)) { if (localTrace != null) { - localTrace.append(System.currentTimeMillis() + " Got duplicate detail, ignoring this: " - + detail.toString() + "\n"); + localTrace.append(clock.millis() + " Got duplicate detail, ignoring this: " + + detail.toString() + "\n"); } return false; } if (localTrace != null) { - localTrace.append(System.currentTimeMillis() + " Got detail: " + detail.toString() + "\n"); + localTrace.append(clock.millis() + " Got detail: " + detail.toString() + "\n"); } detailByClusterId.put(clusterId, detail); return true; @@ -60,7 +62,7 @@ class DocumentSendInfo { retries++; attemptedRetriesByClusterId.put(clusterId, retries); if (localTrace != null) { - localTrace.append(System.currentTimeMillis() + " Asked about retrying for cluster ID " + localTrace.append(clock.millis() + " Asked about retrying for cluster ID " + clusterId + ", number of retries is " + retries + " Detail:\n" + detail.toString()); } return retries; diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java index 205153a7a00..3d662eca3e7 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/EndPointResultFactory.java @@ -21,12 +21,11 @@ import java.util.logging.Logger; */ public final class EndPointResultFactory { - private static Logger log = Logger.getLogger(EndPointResultFactory.class.getName()); - + private static final Logger log = Logger.getLogger(EndPointResultFactory.class.getName()); private static final String EMPTY_MESSAGE = "-"; - public static Collection<EndpointResult> createResult( - Endpoint endpoint, InputStream inputStream) throws IOException { + public static Collection<EndpointResult> createResult(Endpoint endpoint, + InputStream inputStream) throws IOException { List<EndpointResult> results = new ArrayList<>(); try (BufferedReader reader = new BufferedReader( new InputStreamReader(inputStream, StandardCharsets.US_ASCII))) { @@ -82,9 +81,9 @@ public final class EndPointResultFactory { return new EndpointResult( reply.operationId, new Result.Detail(endpoint, - replyToResultType(reply), - reply.traceMessage, - exception)); + replyToResultType(reply), + reply.traceMessage, + exception)); } catch (Throwable t) { throw new IllegalArgumentException("Bad result line from server: '" + line + "'", t); } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java index 7cf4e32a880..ebeee802303 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottler.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.http.client.core.operationProcessor; import com.yahoo.vespa.http.client.core.ThrottlePolicy; +import java.time.Clock; import java.util.concurrent.ThreadLocalRandom; /** @@ -57,6 +58,7 @@ public class IncompleteResultsThrottler { /** * Creates the throttler. + * * @param minInFlightValue the throttler will never throttle beyond this limit. * @param maxInFlightValue the throttler will never throttle above this limit. If zero, no limit. * @param clock use to calculate window size. Can be null if minWindowSize and maxInFlightValue are equal. @@ -68,7 +70,7 @@ public class IncompleteResultsThrottler { this.policy = policy; this.clock = clock; if (minInFlightValue != maxInFlightValue) { - this.sampleStartTimeMs = clock.getTimeMillis(); + this.sampleStartTimeMs = clock.millis(); } setNewSemaphoreSize(INITIAL_MAX_IN_FLIGHT_VALUE); } @@ -96,10 +98,6 @@ public class IncompleteResultsThrottler { } } - public interface Clock { - long getTimeMillis(); - } - public void resultReady(boolean success) { blocker.operationDone(); if (!success) { @@ -147,9 +145,8 @@ public class IncompleteResultsThrottler { } private void adjustThrottling() { - if (clock.getTimeMillis() < sampleStartTimeMs + phaseSizeMs) { - return; - } + if (clock.millis() < sampleStartTimeMs + phaseSizeMs) return; + sampleStartTimeMs += phaseSizeMs; if (stabilizingPhasesLeft-- == 0) { diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java index 692d90abe50..90d07104fef 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessor.java @@ -15,7 +15,9 @@ import com.yahoo.vespa.http.client.core.communication.ClusterConnection; import java.math.BigInteger; import java.security.SecureRandom; +import java.time.Clock; import java.util.ArrayList; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -55,16 +57,19 @@ public class OperationProcessor { private final boolean traceToStderr; private final ThreadGroup ioThreadGroup; private final String clientId = new BigInteger(130, random).toString(32); + private final Clock clock; public OperationProcessor(IncompleteResultsThrottler incompleteResultsThrottler, FeedClient.ResultCallback resultCallback, SessionParams sessionParams, - ScheduledThreadPoolExecutor timeoutExecutor) { + ScheduledThreadPoolExecutor timeoutExecutor, + Clock clock) { this.numDestinations = sessionParams.getClusters().size(); this.resultCallback = resultCallback; this.incompleteResultsThrottler = incompleteResultsThrottler; this.timeoutExecutor = timeoutExecutor; this.ioThreadGroup = new ThreadGroup("operationprocessor"); + this.clock = clock; if (sessionParams.getClusters().isEmpty()) throw new IllegalArgumentException("Cannot feed to 0 clusters."); @@ -82,7 +87,8 @@ public class OperationProcessor { cluster, i, sessionParams.getClientQueueSize() / sessionParams.getClusters().size(), - timeoutExecutor)); + timeoutExecutor, + clock)); } operationStats = new OperationStats(sessionParams, clusters, incompleteResultsThrottler); maxRetries = sessionParams.getConnectionParams().getMaxRetries(); @@ -181,7 +187,7 @@ public class OperationProcessor { } } if (blockedDocumentToSend != null) { - sendToClusters(blockedDocumentToSend); + sendToClusters(blockedDocumentToSend, clock); } return result; } @@ -225,13 +231,13 @@ public class OperationProcessor { inflightDocumentIds.add(document.getDocumentId()); } - sendToClusters(document); + sendToClusters(document, clock); } - private void sendToClusters(Document document) { + private void sendToClusters(Document document, Clock clock) { synchronized (monitor) { boolean traceThisDoc = traceEveryXOperation > 0 && traceCounter++ % traceEveryXOperation == 0; - docSendInfoByOperationId.put(document.getOperationId(), new DocumentSendInfo(document, traceThisDoc)); + docSendInfoByOperationId.put(document.getOperationId(), new DocumentSendInfo(document, traceThisDoc, clock)); } for (ClusterConnection clusterConnection : clusters) { @@ -250,6 +256,8 @@ public class OperationProcessor { } } + public List<ClusterConnection> clusters() { return Collections.unmodifiableList(clusters); } + public String getStatsAsJson() { return operationStats.getStatsAsJson(); } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java index 7c034cab75f..926b4cf8c79 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/runner/Runner.java @@ -10,6 +10,7 @@ import com.yahoo.vespa.http.client.core.XmlFeedReader; import java.io.File; import java.io.IOException; import java.io.InputStream; +import java.time.Clock; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; @@ -34,11 +35,11 @@ public class Runner { boolean isJson, AtomicInteger numSent, boolean verbose) { - + Clock clock = Clock.systemUTC(); if (verbose) System.err.println("Now sending data."); - long sendStartTime = System.currentTimeMillis(); + long sendStartTime = clock.millis(); if (isJson) { JsonReader.read(inputStream, feedClient, numSent); } else { @@ -49,7 +50,7 @@ public class Runner { } } - long sendTotalTime = System.currentTimeMillis() - sendStartTime; + long sendTotalTime = clock.millis() - sendStartTime; if (verbose) System.err.println("Waiting for all results, sent " + numSent.get() + " docs."); diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java index aa47128f436..b70fbaf3096 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/FeedClientTest.java @@ -11,6 +11,7 @@ import org.junit.Test; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; @@ -41,7 +42,7 @@ public class FeedClientTest { resultsReceived.incrementAndGet(); }; - FeedClient feedClient = new FeedClientImpl(sessionParams, resultCallback, FeedClientFactory.createTimeoutExecutor()); + FeedClient feedClient = new FeedClientImpl(sessionParams, resultCallback, FeedClientFactory.createTimeoutExecutor(), Clock.systemUTC()); @Test public void testStreamAndClose() { diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java new file mode 100644 index 00000000000..b32d1eaa859 --- /dev/null +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/ManualClock.java @@ -0,0 +1,55 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.http.client; + +import java.time.Clock; +import java.time.Instant; +import java.time.LocalDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.temporal.TemporalAmount; + +/** + * A clock which initially has the time of its creation but can only be advanced by calling advance + * + * @author bratseth + */ +public class ManualClock extends Clock { + + private Instant currentTime = Instant.now(); + + public ManualClock() {} + + public ManualClock(String utcIsoTime) { + this(at(utcIsoTime)); + } + + public ManualClock(Instant currentTime) { + this.currentTime = currentTime; + } + + public void advance(TemporalAmount temporal) { + currentTime = currentTime.plus(temporal); + } + + public void setInstant(Instant time) { + currentTime = time; + } + + @Override + public Instant instant() { return currentTime; } + + @Override + public ZoneId getZone() { return null; } + + @Override + public Clock withZone(ZoneId zone) { return null; } + + @Override + public long millis() { return currentTime.toEpochMilli(); } + + public static Instant at(String utcIsoTime) { + return LocalDateTime.parse(utcIsoTime, DateTimeFormatter.ISO_DATE_TIME).atZone(ZoneOffset.UTC).toInstant(); + } + +} diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java index 1f875e0dd72..0813cb36078 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/QueueBoundsTest.java @@ -12,6 +12,7 @@ import org.junit.Test; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -24,6 +25,7 @@ import static com.yahoo.vespa.http.client.TestUtils.writeDocument; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; @@ -78,7 +80,8 @@ public class QueueBoundsTest { .build()) .setClientQueueSize(2) .build(), - SessionFactory.createTimeoutExecutor())) { + SessionFactory.createTimeoutExecutor(), + Clock.systemUTC())) { FeederThread feeder = new FeederThread(session); try { feeder.start(); @@ -122,7 +125,8 @@ public class QueueBoundsTest { .setNumPersistentConnectionsPerEndpoint(1) .build()) .setClientQueueSize(6) //3 per cluster - .build(), SessionFactory.createTimeoutExecutor())) { + .build(), SessionFactory.createTimeoutExecutor(), + Clock.systemUTC())) { FeederThread feeder = new FeederThread(session); try { @@ -210,22 +214,23 @@ public class QueueBoundsTest { .build()) .setClientQueueSize(1) .build(), - SessionFactory.createTimeoutExecutor())) { + SessionFactory.createTimeoutExecutor(), + Clock.systemUTC())) { FeederThread feeder = new FeederThread(session); feeder.start(); try { { System.out.println("We start with failed connection, post a document."); assertFeedNotBlocking(feeder, 0); - assertThat(session.results().size(), is(0)); + assertEquals(0, session.results().size()); CountDownLatch lastPostFeed = assertFeedBlocking(feeder, 1); System.out.println("No result so far."); - assertThat(session.results().size(), is(0)); + assertEquals(0, session.results().size()); System.out.println("Make connection ok."); mockXmlParsingRequestHandler.setScenario(V3MockParsingRequestHandler.Scenario.ALL_OK); assert(lastPostFeed.await(120, TimeUnit.SECONDS)); - assertThat(lastPostFeed.getCount(), equalTo(0L)); + assertEquals(0L, lastPostFeed.getCount()); assertResultQueueSize(session, 2, 120, TimeUnit.SECONDS); } @@ -235,7 +240,7 @@ public class QueueBoundsTest { { assertFeedNotBlocking(feeder, 2); System.out.println("Fed one document, fit in queue."); - assertThat(session.results().size(), is(2)); + assertEquals(2, session.results().size()); System.out.println("Fed one document more, wait for failure."); assertFeedNotBlocking(feeder, 3); @@ -249,12 +254,12 @@ public class QueueBoundsTest { } int errors = 0; for (Result result : session.results()) { - assertThat(result.getDetails().size(), is(1)); + assertEquals(1, result.getDetails().size()); if (! result.isSuccess()) { errors++; } } - assertThat(errors, is(1)); + assertEquals(1, errors); } finally { feeder.stop(); } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java index 0821fa55e06..79a91d0b5f3 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/Server.java @@ -5,8 +5,7 @@ import org.eclipse.jetty.server.ServerConnector; import org.eclipse.jetty.server.handler.AbstractHandler; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.20 + * @author Einar M R Rosenvinge */ public final class Server implements AutoCloseable { diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java index 1d70ce953e4..780de3e695c 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java @@ -18,10 +18,10 @@ import java.util.Map; import java.util.concurrent.TimeUnit; import static com.yahoo.vespa.http.client.TestUtils.getResults; -import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.CoreMatchers.nullValue; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * @@ -79,34 +79,33 @@ public class V3HttpAPITest { writeDocument(session); Map<String, Result> results = getResults(session, 1); - assertThat(results.size(), is(1)); + assertEquals(1, results.size()); TestDocument document = documents.get(0); Result r = results.remove(document.getDocumentId()); - assertThat(r, not(nullValue())); - if (conditionNotMet) { - assertThat(r.getDetails().iterator().next().getResultType(), is(Result.ResultType.CONDITION_NOT_MET)); - } - assertThat(r.getDetails().toString(), r.isSuccess(), is(false)); - assertThat(results.isEmpty(), is(true)); + assertNotNull(r); + if (conditionNotMet) + assertEquals(Result.ResultType.CONDITION_NOT_MET, r.getDetails().iterator().next().getResultType()); + assertFalse(r.getDetails().toString(), r.isSuccess()); + assertTrue(results.isEmpty()); } } @Test - public void requireThatSingleDestinationWorks() throws Exception { + public void testSingleDestination() throws Exception { try (Server server = new Server(new V3MockParsingRequestHandler(), 0); - Session session = SessionFactory.create(Endpoint.create("localhost", server.getPort(), false))) { + Session session = SessionFactory.create(Endpoint.create("localhost", server.getPort(), false))) { writeDocuments(session); Map<String, Result> results = getResults(session, documents.size()); - assertThat(results.size(), is(documents.size())); + assertEquals(documents.size(), results.size()); for (TestDocument document : documents) { Result r = results.remove(document.getDocumentId()); - assertThat(r, not(nullValue())); - assertThat(r.getDetails().toString(), r.isSuccess(), is(true)); + assertNotNull(r); + assertTrue(r.getDetails().toString(), r.isSuccess()); } - assertThat(results.isEmpty(), is(true)); + assertTrue(results.isEmpty()); } } @@ -169,15 +168,15 @@ public class V3HttpAPITest { writeDocuments(session); Map<String, Result> results = getResults(session, documents.size()); - assertThat(results.size(), is(documents.size())); + assertEquals(documents.size(), results.size()); for (TestDocument document : documents) { Result r = results.remove(document.getDocumentId()); - assertThat(r, not(nullValue())); - assertThat(r.getDetails().toString(), r.isSuccess(), is(false)); - assertThat(r.getDetails().iterator().next().getResultType(), is(Result.ResultType.TRANSITIVE_ERROR)); + assertNotNull(r); + assertFalse(r.getDetails().toString(), r.isSuccess()); + assertEquals(Result.ResultType.TRANSITIVE_ERROR, r.getDetails().iterator().next().getResultType()); } - assertThat(results.isEmpty(), is(true)); + assertTrue(results.isEmpty()); } } @@ -197,4 +196,5 @@ public class V3HttpAPITest { testServerWithMock(new V3MockParsingRequestHandler( 200, V3MockParsingRequestHandler.Scenario.CONDITON_NOT_MET), false, true); } + } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java index b5c03eade51..ee2f021df6a 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/DocumentTest.java @@ -5,9 +5,9 @@ import org.junit.Test; import java.nio.ByteBuffer; import java.nio.ReadOnlyBufferException; +import java.time.Clock; -import static org.hamcrest.core.Is.is; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertEquals; public class DocumentTest { @@ -15,25 +15,25 @@ public class DocumentTest { public void simpleCaseOk() { String docId = "doc id"; String docContent = "foo"; - Document document = new Document(docId, docContent.getBytes(), null); - assertThat(document.getDocumentId(), is(docId)); - assertThat(document.getData(), is(ByteBuffer.wrap(docContent.getBytes()))); - assertThat(document.getDataAsString().toString(), is(docContent)); + Document document = new Document(docId, docContent.getBytes(), null, Clock.systemUTC().instant()); + assertEquals(docId, document.getDocumentId()); + assertEquals(ByteBuffer.wrap(docContent.getBytes()), document.getData()); + assertEquals(docContent, document.getDataAsString().toString()); // Make sure that data is not modified on retrieval. - assertThat(document.getDataAsString().toString(), is(docContent)); - assertThat(document.getData(), is(ByteBuffer.wrap(docContent.getBytes()))); - assertThat(document.getDocumentId(), is(docId)); + assertEquals(docContent, document.getDataAsString().toString()); + assertEquals(ByteBuffer.wrap(docContent.getBytes()), document.getData()); + assertEquals(docId, document.getDocumentId()); } @Test(expected = ReadOnlyBufferException.class) public void notMutablePutTest() { - Document document = new Document("id", null, "data", null /* context */); + Document document = new Document("id", null, "data", null, Clock.systemUTC().instant()); document.getData().put("a".getBytes()); } @Test(expected = ReadOnlyBufferException.class) public void notMutableCompactTest() { - Document document = new Document("id", null, "data", null /* context */); + Document document = new Document("id", null, "data", null, Clock.systemUTC().instant()); document.getData().compact(); } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java index 59a8b613e67..511e40c1c88 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java @@ -14,9 +14,10 @@ import org.apache.http.HttpEntity; import org.apache.http.HttpResponse; import org.apache.http.ParseException; import org.apache.http.StatusLine; -import org.apache.http.client.HttpClient; +import org.apache.http.client.methods.CloseableHttpResponse; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.InputStreamEntity; +import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.message.BasicHeader; import org.junit.Rule; import org.junit.Test; @@ -27,13 +28,12 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.time.Clock; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.core.Is.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.mockito.Mockito.any; @@ -42,7 +42,6 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; - public class ApacheGatewayConnectionTest { @Rule @@ -50,20 +49,18 @@ public class ApacheGatewayConnectionTest { @Test public void testProtocolV3() throws Exception { - final Endpoint endpoint = Endpoint.create("localhost", 666, false); - final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); - final String clusterSpecificRoute = ""; - final ConnectionParams connectionParams = new ConnectionParams.Builder() - .build(); - final List<Document> documents = new ArrayList<>(); + Endpoint endpoint = Endpoint.create("localhost", 666, false); + FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); + String clusterSpecificRoute = ""; + ConnectionParams connectionParams = new ConnectionParams.Builder().build(); + List<Document> documents = new ArrayList<>(); - final String vespaDocContent = "Hello, I a JSON doc."; - final String docId = "42"; + String vespaDocContent = "Hello, I a JSON doc."; + String docId = "42"; - final AtomicInteger requestsReceived = new AtomicInteger(0); // This is the fake server, takes header client ID and uses this as session Id. ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { - final Header clientIdHeader = post.getFirstHeader(Headers.CLIENT_ID); + Header clientIdHeader = post.getFirstHeader(Headers.CLIENT_ID); return httpResponse(clientIdHeader.getValue(), "3"); }); @@ -74,21 +71,21 @@ public class ApacheGatewayConnectionTest { clusterSpecificRoute, connectionParams, mockFactory, - "clientId"); + "clientId", + Clock.systemUTC()); apacheGatewayConnection.connect(); apacheGatewayConnection.handshake(); documents.add(createDoc(docId, vespaDocContent, true)); - apacheGatewayConnection.writeOperations(documents); + apacheGatewayConnection.write(documents); } @Test(expected=IllegalArgumentException.class) public void testServerReturnsBadSessionInV3() throws Exception { - final Endpoint endpoint = Endpoint.create("localhost", 666, false); - final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); - final String clusterSpecificRoute = ""; - final ConnectionParams connectionParams = new ConnectionParams.Builder() - .build(); + Endpoint endpoint = Endpoint.create("localhost", 666, false); + FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); + String clusterSpecificRoute = ""; + ConnectionParams connectionParams = new ConnectionParams.Builder().build(); // This is the fake server, returns wrong session Id. ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> httpResponse("Wrong Id from server", "3")); @@ -100,57 +97,36 @@ public class ApacheGatewayConnectionTest { clusterSpecificRoute, connectionParams, mockFactory, - "clientId"); + "clientId", + Clock.systemUTC()); apacheGatewayConnection.connect(); - final List<Document> documents = new ArrayList<>(); - apacheGatewayConnection.writeOperations(documents); - } - - @Test(expected=RuntimeException.class) - public void testBadConfigParameters() throws Exception { - final Endpoint endpoint = Endpoint.create("localhost", 666, false); - final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); - final String clusterSpecificRoute = ""; - final ConnectionParams connectionParams = new ConnectionParams.Builder() - .build(); - - final ApacheGatewayConnection.HttpClientFactory mockFactory = - mock(ApacheGatewayConnection.HttpClientFactory.class); - - new ApacheGatewayConnection( - endpoint, - feedParams, - clusterSpecificRoute, - connectionParams, - mockFactory, - null); + List<Document> documents = new ArrayList<>(); + apacheGatewayConnection.write(documents); } @Test public void testJsonDocumentHeader() throws Exception { - final Endpoint endpoint = Endpoint.create("localhost", 666, false); - final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); - final String clusterSpecificRoute = ""; - final ConnectionParams connectionParams = new ConnectionParams.Builder() - .setUseCompression(true) - .build(); - final List<Document> documents = new ArrayList<>(); + Endpoint endpoint = Endpoint.create("localhost", 666, false); + FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.JSON_UTF8).build(); + String clusterSpecificRoute = ""; + ConnectionParams connectionParams = new ConnectionParams.Builder().setUseCompression(true).build(); + List<Document> documents = new ArrayList<>(); - final String vespaDocContent ="Hello, I a JSON doc."; - final String docId = "42"; + String vespaDocContent ="Hello, I a JSON doc."; + String docId = "42"; - final AtomicInteger requestsReceived = new AtomicInteger(0); + AtomicInteger requestsReceived = new AtomicInteger(0); // This is the fake server, checks that DATA_FORMAT header is set properly. ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { - final Header header = post.getFirstHeader(Headers.DATA_FORMAT); + Header header = post.getFirstHeader(Headers.DATA_FORMAT); if (requestsReceived.incrementAndGet() == 1) { // This is handshake, it is not json. assert (header == null); return httpResponse("clientId", "3"); } assertNotNull(header); - assertThat(header.getValue(), is(FeedParams.DataFormat.JSON_UTF8.name())); + assertEquals(FeedParams.DataFormat.JSON_UTF8.name(), header.getValue()); // Test is done. return httpResponse("clientId", "3"); }); @@ -162,24 +138,25 @@ public class ApacheGatewayConnectionTest { clusterSpecificRoute, connectionParams, mockFactory, - "clientId"); + "clientId", + Clock.systemUTC()); apacheGatewayConnection.connect(); apacheGatewayConnection.handshake(); documents.add(createDoc(docId, vespaDocContent, true)); - apacheGatewayConnection.writeOperations(documents); + apacheGatewayConnection.write(documents); } @Test public void testZipAndCreateEntity() throws IOException { - final String testString = "Hello world"; + String testString = "Hello world"; InputStream stream = new ByteArrayInputStream(testString.getBytes(StandardCharsets.UTF_8)); // Send in test data to method. InputStreamEntity inputStreamEntity = ApacheGatewayConnection.zipAndCreateEntity(stream); // Verify zipped data by comparing unzipped data with test data. - final String rawContent = TestUtils.zipStreamToString(inputStreamEntity.getContent()); - assert(testString.equals(rawContent)); + String rawContent = TestUtils.zipStreamToString(inputStreamEntity.getContent()); + assertEquals(testString, rawContent); } /** @@ -187,32 +164,28 @@ public class ApacheGatewayConnectionTest { */ @Test public void testCompressedWriteOperations() throws Exception { - final Endpoint endpoint = Endpoint.create("localhost", 666, false); - final FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.XML_UTF8).build(); - final String clusterSpecificRoute = ""; - final ConnectionParams connectionParams = new ConnectionParams.Builder() - .setUseCompression(true) - .build(); - final List<Document> documents = new ArrayList<>(); + Endpoint endpoint = Endpoint.create("localhost", 666, false); + FeedParams feedParams = new FeedParams.Builder().setDataFormat(FeedParams.DataFormat.XML_UTF8).build(); + String clusterSpecificRoute = ""; + ConnectionParams connectionParams = new ConnectionParams.Builder().setUseCompression(true).build(); + List<Document> documents = new ArrayList<>(); - final String vespaDocContent ="Hello, I am the document data."; - final String docId = "42"; + String vespaDocContent ="Hello, I am the document data."; + String docId = "42"; - final Document doc = createDoc(docId, vespaDocContent, false); + Document doc = createDoc(docId, vespaDocContent, false); // When sending data on http client, check if it is compressed. If compressed, unzip, check result, // and count down latch. ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { - final Header header = post.getFirstHeader("Content-Encoding"); + Header header = post.getFirstHeader("Content-Encoding"); if (header != null && header.getValue().equals("gzip")) { final String rawContent = TestUtils.zipStreamToString(post.getEntity().getContent()); final String vespaHeaderText = "<vespafeed>\n"; final String vespaFooterText = "</vespafeed>\n"; - assertThat(rawContent, is( - doc.getOperationId() + " 38\n" + vespaHeaderText + vespaDocContent + "\n" - + vespaFooterText)); - + assertEquals(doc.getOperationId() + " 38\n" + vespaHeaderText + vespaDocContent + "\n" + vespaFooterText, + rawContent); } return httpResponse("clientId", "3"); }); @@ -227,13 +200,14 @@ public class ApacheGatewayConnectionTest { clusterSpecificRoute, connectionParams, mockFactory, - "clientId"); + "clientId", + Clock.systemUTC()); apacheGatewayConnection.connect(); apacheGatewayConnection.handshake(); documents.add(doc); - apacheGatewayConnection.writeOperations(documents); + apacheGatewayConnection.write(documents); } @Test @@ -265,14 +239,15 @@ public class ApacheGatewayConnectionTest { "", connectionParams, mockFactory, - "clientId"); + "clientId", + Clock.systemUTC()); apacheGatewayConnection.connect(); apacheGatewayConnection.handshake(); List<Document> documents = new ArrayList<>(); documents.add(createDoc("42", "content", true)); - apacheGatewayConnection.writeOperations(documents); - apacheGatewayConnection.writeOperations(documents); + apacheGatewayConnection.write(documents); + apacheGatewayConnection.write(documents); verify(headerProvider, times(3)).getHeaderValue(); // 1x connect(), 2x writeOperations() } @@ -293,17 +268,18 @@ public class ApacheGatewayConnectionTest { "", new ConnectionParams.Builder().build(), mockFactory, - "clientId"); + "clientId", + Clock.systemUTC()); apacheGatewayConnection.connect(); apacheGatewayConnection.handshake(); - apacheGatewayConnection.writeOperations(Collections.singletonList(createDoc("42", "content", true))); + apacheGatewayConnection.write(Collections.singletonList(createDoc("42", "content", true))); } private static ApacheGatewayConnection.HttpClientFactory mockHttpClientFactory(HttpExecuteMock httpExecuteMock) throws IOException { ApacheGatewayConnection.HttpClientFactory mockFactory = mock(ApacheGatewayConnection.HttpClientFactory.class); - HttpClient httpClientMock = mock(HttpClient.class); + CloseableHttpClient httpClientMock = mock(CloseableHttpClient.class); when(mockFactory.createClient()).thenReturn(httpClientMock); when(httpClientMock.execute(any())).thenAnswer((Answer) invocation -> { Object[] args = invocation.getArguments(); @@ -317,16 +293,12 @@ public class ApacheGatewayConnectionTest { HttpResponse execute(HttpPost httpPost) throws IOException; } - private Document createDoc(final String docId, final String content, boolean useJson) throws IOException { - return new Document(docId, content.getBytes(), null /* context */); + private Document createDoc(String docId, String content, boolean useJson) { + return new Document(docId, content.getBytes(), null, Clock.systemUTC().instant()); } - private void addMockedHeader( - final HttpResponse httpResponseMock, - final String name, - final String value, - HeaderElement[] elements) { - final Header header = new Header() { + private void addMockedHeader(HttpResponse httpResponseMock, String name, String value, HeaderElement[] elements) { + Header header = new Header() { @Override public String getName() { return name; @@ -344,7 +316,7 @@ public class ApacheGatewayConnectionTest { } private HttpResponse httpResponse(String sessionIdInResult, String version) throws IOException { - final HttpResponse httpResponseMock = mock(HttpResponse.class); + CloseableHttpResponse httpResponseMock = mock(CloseableHttpResponse.class); StatusLine statusLineMock = mock(StatusLine.class); when(httpResponseMock.getStatusLine()).thenReturn(statusLineMock); @@ -365,7 +337,7 @@ public class ApacheGatewayConnectionTest { } private static HttpResponse createErrorHttpResponse(int statusCode, String reasonPhrase, String message) throws IOException { - HttpResponse response = mock(HttpResponse.class); + CloseableHttpResponse response = mock(CloseableHttpResponse.class); StatusLine statusLine = mock(StatusLine.class); when(statusLine.getStatusCode()).thenReturn(statusCode); diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java index 35a06258f86..af354b8feea 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/CloseableQTestCase.java @@ -4,21 +4,24 @@ package com.yahoo.vespa.http.client.core.communication; import com.yahoo.vespa.http.client.core.Document; import org.junit.Test; +import java.time.Clock; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class CloseableQTestCase { + @Test public void requestThatPutIsInterruptedOnClose() throws InterruptedException { - final DocumentQueue q = new DocumentQueue(1); - q.put(new Document("id", null, "data", null), false); + Clock clock = Clock.systemUTC(); + DocumentQueue q = new DocumentQueue(1, clock); + q.put(new Document("id", null, "data", null, clock.instant()), false); Thread t = new Thread(new Runnable() { @Override public void run() { try { Thread.sleep(3000); } catch (InterruptedException e) { - } q.close(); q.clear(); @@ -26,7 +29,7 @@ public class CloseableQTestCase { }); t.start(); try { - q.put(new Document("id2", null, "data2", null), false); + q.put(new Document("id2", null, "data2", null, Clock.systemUTC().instant()), false); fail("This shouldn't have worked."); } catch (IllegalStateException ise) { // ok! @@ -39,10 +42,11 @@ public class CloseableQTestCase { @Test public void requireThatSelfIsUnbounded() throws InterruptedException { - DocumentQueue q = new DocumentQueue(1); - q.put(new Document("1", null, "data", null), true); - q.put(new Document("2", null, "data", null), true); - q.put(new Document("3", null, "data", null), true); + DocumentQueue q = new DocumentQueue(1, Clock.systemUTC()); + q.put(new Document("1", null, "data", null, Clock.systemUTC().instant()), true); + q.put(new Document("2", null, "data", null, Clock.systemUTC().instant()), true); + q.put(new Document("3", null, "data", null, Clock.systemUTC().instant()), true); assertEquals(3, q.size()); } + } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java index 0005bddeb73..da82079e992 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/EndpointResultQueueTest.java @@ -20,8 +20,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> - * @since 5.1.22 + * @author Einar M R Rosenvinge */ public class EndpointResultQueueTest { diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java index e81638ded1c..59fb968906f 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/IOThreadTest.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.http.client.core.communication; import com.yahoo.vespa.http.client.FeedConnectException; import com.yahoo.vespa.http.client.FeedEndpointException; import com.yahoo.vespa.http.client.FeedProtocolException; +import com.yahoo.vespa.http.client.ManualClock; import com.yahoo.vespa.http.client.Result; import com.yahoo.vespa.http.client.V3HttpAPITest; import com.yahoo.vespa.http.client.config.Endpoint; @@ -16,6 +17,8 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.time.Clock; +import java.time.Duration; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; @@ -35,21 +38,27 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +// DO NOT ADD TESTS HERE, add to NewIOThreadTest public class IOThreadTest { private static final Endpoint ENDPOINT = Endpoint.create("myhost"); + final Clock clock = Clock.systemUTC(); final EndpointResultQueue endpointResultQueue = mock(EndpointResultQueue.class); final ApacheGatewayConnection apacheGatewayConnection = mock(ApacheGatewayConnection.class); final String exceptionMessage = "SOME EXCEPTION FOO"; CountDownLatch latch = new CountDownLatch(1); String docId1 = V3HttpAPITest.documents.get(0).getDocumentId(); Document doc1 = new Document(V3HttpAPITest.documents.get(0).getDocumentId(), - V3HttpAPITest.documents.get(0).getContents(), null /* context */); + V3HttpAPITest.documents.get(0).getContents(), + null, + clock.instant()); String docId2 = V3HttpAPITest.documents.get(1).getDocumentId(); Document doc2 = new Document(V3HttpAPITest.documents.get(1).getDocumentId(), - V3HttpAPITest.documents.get(1).getContents(), null /* context */); - DocumentQueue documentQueue = new DocumentQueue(4); + V3HttpAPITest.documents.get(1).getContents(), + null, + clock.instant()); + DocumentQueue documentQueue = new DocumentQueue(4, clock); public IOThreadTest() { when(apacheGatewayConnection.getEndpoint()).thenReturn(ENDPOINT); @@ -57,20 +66,18 @@ public class IOThreadTest { /** * Set up mock so that it can handle both failDocument() and resultReceived(). + * * @param expectedDocIdFail on failure, this has to be the doc id, or the mock will fail. * @param expectedDocIdOk on ok, this has to be the doc id, or the mock will fail. * @param isTransient checked on failure, if different, the mock will fail. * @param expectedException checked on failure, if exception toString is different, the mock will fail. */ - void setupEndpointResultQueueMock(String expectedDocIdFail, String expectedDocIdOk,boolean isTransient, String expectedException) { - + void setupEndpointResultQueueMock(String expectedDocIdFail, String expectedDocIdOk, boolean isTransient, String expectedException) { doAnswer(invocation -> { EndpointResult endpointResult = (EndpointResult) invocation.getArguments()[0]; assertThat(endpointResult.getOperationId(), is(expectedDocIdFail)); - assertThat(endpointResult.getDetail().getException().toString(), - containsString(expectedException)); - assertThat(endpointResult.getDetail().getResultType(), is( - isTransient ? Result.ResultType.TRANSITIVE_ERROR : Result.ResultType.FATAL_ERROR)); + assertThat(endpointResult.getDetail().getException().toString(), containsString(expectedException)); + assertThat(endpointResult.getDetail().getResultType(), is(isTransient ? Result.ResultType.TRANSITIVE_ERROR : Result.ResultType.FATAL_ERROR)); latch.countDown(); return null; @@ -86,7 +93,20 @@ public class IOThreadTest { } private IOThread createIOThread(int maxInFlightRequests, long localQueueTimeOut) { - return new IOThread(null, endpointResultQueue, apacheGatewayConnection, 0, 0, maxInFlightRequests, localQueueTimeOut, documentQueue, 0, 10); + return new IOThread(null, + ENDPOINT, + endpointResultQueue, + new SingletonGatewayConnectionFactory(apacheGatewayConnection), + 0, + 0, + maxInFlightRequests, + Duration.ofMillis(localQueueTimeOut), + documentQueue, + 0, + Duration.ofSeconds(15), + true, + 10, + clock); } @Test @@ -94,7 +114,7 @@ public class IOThreadTest { when(apacheGatewayConnection.connect()).thenReturn(true); InputStream serverResponse = new ByteArrayInputStream( (docId1 + " OK Doc{20}fed").getBytes(StandardCharsets.UTF_8)); - when(apacheGatewayConnection.writeOperations(any())).thenReturn(serverResponse); + when(apacheGatewayConnection.write(any())).thenReturn(serverResponse); setupEndpointResultQueueMock( "nope", docId1, true, exceptionMessage); try (IOThread ioThread = createIOThread(10000, 10000)) { ioThread.post(doc1); @@ -103,9 +123,9 @@ public class IOThreadTest { } @Test - public void requireThatSingleDocumentWriteErrorIsHandledProperly() throws Exception { + public void testDocumentWriteError() throws Exception { when(apacheGatewayConnection.connect()).thenReturn(true); - when(apacheGatewayConnection.writeOperations(any())).thenThrow(new IOException(exceptionMessage)); + when(apacheGatewayConnection.write(any())).thenThrow(new IOException(exceptionMessage)); setupEndpointResultQueueMock(doc1.getOperationId(), "nope", true, exceptionMessage); try (IOThread ioThread = createIOThread(10000, 10000)) { ioThread.post(doc1); @@ -114,11 +134,11 @@ public class IOThreadTest { } @Test - public void requireThatTwoDocumentsFirstWriteErrorSecondOkIsHandledProperly() throws Exception { + public void testTwoDocumentsFirstWriteErrorSecondOk() throws Exception { when(apacheGatewayConnection.connect()).thenReturn(true); InputStream serverResponse = new ByteArrayInputStream( (docId2 + " OK Doc{20}fed").getBytes(StandardCharsets.UTF_8)); - when(apacheGatewayConnection.writeOperations(any())) + when(apacheGatewayConnection.write(any())) .thenThrow(new IOException(exceptionMessage)) .thenReturn(serverResponse); latch = new CountDownLatch(2); @@ -134,10 +154,8 @@ public class IOThreadTest { @Test public void testQueueTimeOutNoNoConnectionToServer() throws Exception { when(apacheGatewayConnection.connect()).thenReturn(false); - InputStream serverResponse = new ByteArrayInputStream( - ("").getBytes(StandardCharsets.UTF_8)); - when(apacheGatewayConnection.writeOperations(any())) - .thenReturn(serverResponse); + InputStream serverResponse = new ByteArrayInputStream(("").getBytes(StandardCharsets.UTF_8)); + when(apacheGatewayConnection.write(any())).thenReturn(serverResponse); setupEndpointResultQueueMock(doc1.getOperationId(), "nope", true, "java.lang.Exception: Not sending document operation, timed out in queue after"); try (IOThread ioThread = createIOThread(10, 10)) { @@ -147,7 +165,7 @@ public class IOThreadTest { } @Test - public void requireThatEndpointProtocolExceptionsArePropagated() + public void testEndpointProtocolExceptionPropagation() throws IOException, ServerResponseException, InterruptedException, TimeoutException, ExecutionException { when(apacheGatewayConnection.connect()).thenReturn(true); int errorCode = 403; @@ -168,7 +186,7 @@ public class IOThreadTest { } @Test - public void requireThatEndpointConnectExceptionsArePropagated() + public void testEndpointConnectExceptionsPropagation() throws IOException, ServerResponseException, InterruptedException, TimeoutException, ExecutionException { when(apacheGatewayConnection.connect()).thenReturn(true); String errorMessage = "generic error message"; @@ -198,4 +216,17 @@ public class IOThreadTest { return futureResult; } + private static final class SingletonGatewayConnectionFactory implements GatewayConnectionFactory { + + private final GatewayConnection singletonConnection; + + SingletonGatewayConnectionFactory(GatewayConnection singletonConnection) { + this.singletonConnection = singletonConnection; + } + + @Override + public GatewayConnection newConnection() { return singletonConnection; } + + } + } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java new file mode 100644 index 00000000000..615fa22a6cf --- /dev/null +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/NewIOThreadTest.java @@ -0,0 +1,192 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.http.client.core.communication; + +import com.yahoo.vespa.http.client.FeedClient; +import com.yahoo.vespa.http.client.FeedEndpointException; +import com.yahoo.vespa.http.client.ManualClock; +import com.yahoo.vespa.http.client.Result; +import com.yahoo.vespa.http.client.config.Cluster; +import com.yahoo.vespa.http.client.config.ConnectionParams; +import com.yahoo.vespa.http.client.config.Endpoint; +import com.yahoo.vespa.http.client.config.SessionParams; +import com.yahoo.vespa.http.client.core.Document; +import com.yahoo.vespa.http.client.core.EndpointResult; +import com.yahoo.vespa.http.client.core.ThrottlePolicy; +import com.yahoo.vespa.http.client.core.operationProcessor.IncompleteResultsThrottler; +import com.yahoo.vespa.http.client.core.operationProcessor.OperationProcessor; +import org.junit.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.concurrent.ScheduledThreadPoolExecutor; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +/** + * TODO: Migrate IOThreadTests here. + * + * @author bratseth + */ +public class NewIOThreadTest { + + @Test + public void testBasics() { + OperationProcessorTester tester = new OperationProcessorTester(); + assertEquals(0, tester.inflight()); + assertEquals(0, tester.success()); + assertEquals(0, tester.failures()); + tester.send("doc1"); + tester.send("doc2"); + tester.send("doc3"); + assertEquals(3, tester.inflight()); + assertEquals(0, tester.success()); + assertEquals(0, tester.failures()); + tester.success("doc1"); + tester.success("doc2"); + tester.success("doc3"); + assertEquals(0, tester.inflight()); + assertEquals(3, tester.success()); + assertEquals(0, tester.failures()); + } + + @Test + public void testPollingOldConnections() { + OperationProcessorTester tester = new OperationProcessorTester(); + tester.tick(3); + + assertEquals(1, tester.clusterConnections().size()); + assertEquals(1, tester.clusterConnections().get(0).ioThreads().size()); + IOThread ioThread = tester.clusterConnections().get(0).ioThreads().get(0); + DryRunGatewayConnection firstConnection = (DryRunGatewayConnection)ioThread.currentConnection(); + assertEquals(0, ioThread.oldConnections().size()); + + firstConnection.hold(true); + tester.send("doc1"); + tester.tick(1); + + tester.clock().advance(Duration.ofSeconds(20)); // Default connection ttl is 15 + tester.tick(3); + + assertEquals(1, ioThread.oldConnections().size()); + assertEquals(firstConnection, ioThread.oldConnections().get(0)); + assertNotSame(firstConnection, ioThread.currentConnection()); + assertEquals(20, firstConnection.lastPollTime().toEpochMilli() / 1000); + + // Check old connection poll pattern (linear backoff) + assertLastPollTimeWhenAdvancing(21, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(22, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(23, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(24, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(24, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(26, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(26, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(28, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(28, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(30, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(30, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(32, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(32, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(34, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(34, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(34, 1, firstConnection, tester); + assertLastPollTimeWhenAdvancing(37, 1, firstConnection, tester); + + tester.clock().advance(Duration.ofSeconds(200)); + tester.tick(1); + assertEquals("Old connection is eventually removed", 0, ioThread.oldConnections().size()); + } + + private void assertLastPollTimeWhenAdvancing(int lastPollTimeSeconds, + int advanceSeconds, + DryRunGatewayConnection connection, + OperationProcessorTester tester) { + tester.clock().advance(Duration.ofSeconds(advanceSeconds)); + tester.tick(1); + assertEquals(lastPollTimeSeconds, connection.lastPollTime().toEpochMilli() / 1000); + } + + private static class OperationProcessorTester { + + private final Endpoint endpoint; + private final int clusterId = 0; + private final ManualClock clock; + private final TestResultCallback resultCallback; + private final OperationProcessor operationProcessor; + + public OperationProcessorTester() { + endpoint = Endpoint.create("test-endpoint"); + SessionParams.Builder params = new SessionParams.Builder(); + Cluster.Builder clusterParams = new Cluster.Builder(); + clusterParams.addEndpoint(endpoint); + params.addCluster(clusterParams.build()); + ConnectionParams.Builder connectionParams = new ConnectionParams.Builder(); + connectionParams.setDryRun(true); + connectionParams.setRunThreads(false); + params.setConnectionParams(connectionParams.build()); + + clock = new ManualClock(Instant.ofEpochMilli(0)); + resultCallback = new TestResultCallback(); + operationProcessor = new OperationProcessor(new IncompleteResultsThrottler(1, 100, clock, new ThrottlePolicy()), + resultCallback, + params.build(), + new ScheduledThreadPoolExecutor(1), + clock); + } + + public ManualClock clock() { return clock; } + + /** Do n iteration of work in all io threads of this */ + public void tick(int n) { + for (int i = 0; i < n; i++) + for (ClusterConnection cluster : operationProcessor.clusters()) + for (IOThread thread : cluster.ioThreads()) + thread.tick(); + } + + public void send(String documentId) { + operationProcessor.sendDocument(new Document(documentId, documentId, "data of " + documentId, null, clock.instant())); + } + + public void success(String documentId) { + operationProcessor.resultReceived(new EndpointResult(documentId, new Result.Detail(endpoint)), clusterId); + } + + public int inflight() { + return operationProcessor.getIncompleteResultQueueSize(); + } + + public int success() { + return resultCallback.successes; + } + + public List<ClusterConnection> clusterConnections() { + return operationProcessor.clusters(); + } + + public int failures() { + return resultCallback.failures; + } + + } + + private static class TestResultCallback implements FeedClient.ResultCallback { + + private int successes = 0; + private int failures = 0; + + @Override + public void onCompletion(String docId, Result documentResult) { + successes++; + } + + @Override + public void onEndpointException(FeedEndpointException exception) { + failures++; + } + + + } + +} diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java index baf6e2f2df3..ec929d68efb 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/IncompleteResultsThrottlerTest.java @@ -1,9 +1,12 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.http.client.core.operationProcessor; +import com.yahoo.vespa.http.client.ManualClock; import com.yahoo.vespa.http.client.core.ThrottlePolicy; import org.junit.Test; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.LinkedList; @@ -12,6 +15,7 @@ import java.util.Random; import java.util.concurrent.atomic.AtomicLong; import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.anyDouble; @@ -42,14 +46,14 @@ public class IncompleteResultsThrottlerTest { * @return median queue length. */ int getAverageQueue(int clientCount, int breakPoint, int simulationTimeMs) { - final AtomicLong timeMs = new AtomicLong(0); + ManualClock clock = new ManualClock(Instant.ofEpochMilli(0)); ArrayList<IncompleteResultsThrottler> incompleteResultsThrottlers = new ArrayList<>(); MockServer mockServer = new MockServer(breakPoint); for (int x = 0; x < clientCount; x++) { IncompleteResultsThrottler incompleteResultsThrottler = - new IncompleteResultsThrottler(10, 50000, () -> timeMs.get(), new ThrottlePolicy()); + new IncompleteResultsThrottler(10, 50000, clock, new ThrottlePolicy()); incompleteResultsThrottlers.add(incompleteResultsThrottler); } long sum = 0; @@ -68,8 +72,8 @@ public class IncompleteResultsThrottlerTest { if (fastForward) { time = mockServer.nextRequestFinished(); } - timeMs.set(time); - mockServer.moveTime(timeMs.get()); + clock.setInstant(Instant.ofEpochMilli(time)); + mockServer.moveTime(clock.instant().toEpochMilli()); for (int y = 0; y < clientCount; y++) { // Fill up, but don't block as that would stop the simulation. while (incompleteResultsThrottlers.get(y).availableCapacity() > 0) { @@ -140,45 +144,46 @@ public class IncompleteResultsThrottlerTest { } } - private void moveToNextCycle(final IncompleteResultsThrottler throttler, AtomicLong timeMs) + private void moveToNextCycle(final IncompleteResultsThrottler throttler, ManualClock clock) throws InterruptedException { waitForThreads(); // Enter an adaption phase, we don't care about this phase. - timeMs.addAndGet(throttler.phaseSizeMs); + clock.advance(Duration.ofMillis(throttler.phaseSizeMs)); throttler.operationStart(); throttler.resultReady(false); // Now enter the real next phase. - timeMs.addAndGet(throttler.phaseSizeMs); + clock.advance(Duration.ofMillis(throttler.phaseSizeMs)); throttler.operationStart(); throttler.resultReady(false); } @Test public void testInteractionWithPolicyByMockingPolicy() throws InterruptedException { + ManualClock clock = new ManualClock(Instant.ofEpochMilli(0)); final int MAX_SIZE = 1000; final int MORE_THAN_MAX_SIZE = MAX_SIZE + 20; final int SIZE_AFTER_CYCLE_FIRST = 30; final int SIZE_AFTER_CYCLE_SECOND = 5000; ThrottlePolicy policy = mock(ThrottlePolicy.class); - final AtomicLong timeMs = new AtomicLong(0); IncompleteResultsThrottler incompleteResultsThrottler = - new IncompleteResultsThrottler(2, MAX_SIZE, ()->timeMs.get(), policy); + new IncompleteResultsThrottler(2, MAX_SIZE, clock, policy); long bucketSizeMs = incompleteResultsThrottler.phaseSizeMs; // Cycle 1 - Algorithm has fixed value for max-in-flight: INITIAL_MAX_IN_FLIGHT_VALUE. // We post a few operations, not all finishing in this cycle. We explicitly do not fill the window // size to test the argument about any requests blocked. - assertThat(incompleteResultsThrottler.availableCapacity(), - is(IncompleteResultsThrottler.INITIAL_MAX_IN_FLIGHT_VALUE)); + assertEquals(IncompleteResultsThrottler.INITIAL_MAX_IN_FLIGHT_VALUE, + incompleteResultsThrottler.availableCapacity()); postOperations(20, incompleteResultsThrottler); postSuccesses(15, incompleteResultsThrottler); - moveToNextCycle(incompleteResultsThrottler, timeMs); + moveToNextCycle(incompleteResultsThrottler, clock); // Cycle 2 - Algorithm has fixed value also for second iteration: SECOND_MAX_IN_FLIGHT_VALUE. // Test verifies that this value is used, and insert a value to be used for next phase SIZE_AFTER_CYCLE_FIRST. - assertThat(incompleteResultsThrottler.availableCapacity(), - is(IncompleteResultsThrottler.SECOND_MAX_IN_FLIGHT_VALUE - 5)); // 5 slots already taken earlier + assertEquals("5 slots already taken earlier", + IncompleteResultsThrottler.SECOND_MAX_IN_FLIGHT_VALUE - 5, + incompleteResultsThrottler.availableCapacity()); postSuccesses(5, incompleteResultsThrottler); when(policy.calcNewMaxInFlight( anyDouble(), // Max performance change @@ -188,12 +193,11 @@ public class IncompleteResultsThrottlerTest { eq(IncompleteResultsThrottler.SECOND_MAX_IN_FLIGHT_VALUE), // current size eq(false))) // is any request blocked, should be false since we only posted 20 docs. .thenReturn(SIZE_AFTER_CYCLE_FIRST); - moveToNextCycle(incompleteResultsThrottler, timeMs); + moveToNextCycle(incompleteResultsThrottler, clock); // Cycle 3 - Test that value set in previous phase is used. Now return a very large number. // However, this number should be cropped by the system (tested in next cycle). - assertThat(incompleteResultsThrottler.availableCapacity(), - is(SIZE_AFTER_CYCLE_FIRST)); + assertEquals(SIZE_AFTER_CYCLE_FIRST, incompleteResultsThrottler.availableCapacity()); postOperations(MORE_THAN_MAX_SIZE, incompleteResultsThrottler); postSuccesses(MORE_THAN_MAX_SIZE, incompleteResultsThrottler); when(policy.calcNewMaxInFlight( @@ -204,11 +208,10 @@ public class IncompleteResultsThrottlerTest { eq(SIZE_AFTER_CYCLE_FIRST),// current size eq(true))) // is any request blocked, should be true since we posted MORE_THAN_MAX_SIZE docs. .thenReturn(SIZE_AFTER_CYCLE_SECOND); - moveToNextCycle(incompleteResultsThrottler, timeMs); + moveToNextCycle(incompleteResultsThrottler, clock); // Cycle 4 - Test that the large number from previous cycle is cropped and that max value is used instead. - assertThat(incompleteResultsThrottler.availableCapacity(), - is(MAX_SIZE)); + assertEquals(MAX_SIZE, incompleteResultsThrottler.availableCapacity()); } private long inversesU(int size, int sweetSpot) { diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java index 9753a180618..e4ae138054d 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/operationProcessor/OperationProcessorTest.java @@ -10,6 +10,7 @@ import com.yahoo.vespa.http.client.core.Document; import com.yahoo.vespa.http.client.core.EndpointResult; import org.junit.Test; +import java.time.Clock; import java.util.ArrayDeque; import java.util.Queue; import java.util.concurrent.CountDownLatch; @@ -32,10 +33,10 @@ import static org.mockito.Mockito.when; public class OperationProcessorTest { final Queue<Result> queue = new ArrayDeque<>(); - final Document doc1 = new Document("id:a:type::b", null, "data doc 1", null); - final Document doc1b = new Document("id:a:type::b", null, "data doc 1b", null); - final Document doc2 = new Document("id:a:type::b2", null, "data doc 2", null); - final Document doc3 = new Document("id:a:type::b3", null, "data doc 3", null); + final Document doc1 = new Document("id:a:type::b", null, "data doc 1", null, Clock.systemUTC().instant()); + final Document doc1b = new Document("id:a:type::b", null, "data doc 1b", null, Clock.systemUTC().instant()); + final Document doc2 = new Document("id:a:type::b2", null, "data doc 2", null, Clock.systemUTC().instant()); + final Document doc3 = new Document("id:a:type::b3", null, "data doc 3", null, Clock.systemUTC().instant()); @Test public void testBasic() { @@ -49,7 +50,7 @@ public class OperationProcessorTest { OperationProcessor q = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); q.resultReceived(new EndpointResult("foo", new Result.Detail(null)), 0); @@ -127,7 +128,7 @@ public class OperationProcessorTest { OperationProcessor operationProcessor = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); operationProcessor.sendDocument(doc1); operationProcessor.sendDocument(doc1b); @@ -165,7 +166,7 @@ public class OperationProcessorTest { OperationProcessor operationProcessor = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); operationProcessor.sendDocument(doc1); operationProcessor.sendDocument(doc1b); @@ -198,11 +199,11 @@ public class OperationProcessorTest { OperationProcessor operationProcessor = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); Queue<Document> documentQueue = new ArrayDeque<>(); for (int x = 0; x < 100; x++) { - Document document = new Document("id:a:type::b", null, String.valueOf(x), null); + Document document = new Document("id:a:type::b", null, String.valueOf(x), null, Clock.systemUTC().instant()); operationProcessor.sendDocument(document); documentQueue.add(document); } @@ -233,7 +234,7 @@ public class OperationProcessorTest { OperationProcessor operationProcessor = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); operationProcessor.sendDocument(doc1); operationProcessor.sendDocument(doc1b); // Blocked @@ -273,7 +274,7 @@ public class OperationProcessorTest { OperationProcessor q = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); q.sendDocument(doc1); assertEquals(0, queue.size()); @@ -299,7 +300,7 @@ public class OperationProcessorTest { OperationProcessor q = new OperationProcessor( new IncompleteResultsThrottler(1000, 1000, null, null), (docId, documentResult) -> queue.add(documentResult), - sessionParams, null); + sessionParams, null, Clock.systemUTC()); q.sendDocument(doc1); assertEquals(0, queue.size()); @@ -358,7 +359,7 @@ public class OperationProcessorTest { OperationProcessor operationProcessor = new OperationProcessor( new IncompleteResultsThrottler(1, 1, null, null), (docId, documentResult) -> {}, - sessionParams, null); + sessionParams, null, Clock.systemUTC()); operationProcessor.sendDocument(doc1); @@ -397,7 +398,7 @@ public class OperationProcessorTest { (docId, documentResult) -> { countDownLatch.countDown(); }, - sessionParams, executor); + sessionParams, executor, Clock.systemUTC()); // Will fail due to bogus host name, but will be retried. operationProcessor.sendDocument(doc1); @@ -425,7 +426,7 @@ public class OperationProcessorTest { (docId, documentResult) -> { countDownLatch.countDown(); }, - sessionParams, executor); + sessionParams, executor, Clock.systemUTC()); fail("Expected exception"); } diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java index db1c2471752..5548f8fbc1f 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/ClientFeederV3.java @@ -156,45 +156,38 @@ class ClientFeederV3 { } private int getOverloadReturnCode(HttpRequest request) { - if (request.getHeader(Headers.SILENTUPGRADE) != null ) { - return 299; - } + if (request.getHeader(Headers.SILENTUPGRADE) != null ) return 299; return 429; } - private Optional<DocumentOperationMessageV3> pullMessageFromRequest( - FeederSettings settings, InputStream requestInputStream, BlockingQueue<OperationStatus> repliesFromOldMessages) { + private Optional<DocumentOperationMessageV3> pullMessageFromRequest(FeederSettings settings, + InputStream requestInputStream, + BlockingQueue<OperationStatus> repliesFromOldMessages) { while (true) { Optional<String> operationId; try { operationId = streamReaderV3.getNextOperationId(requestInputStream); + if (operationId.isEmpty()) return Optional.empty(); } catch (IOException ioe) { - if (log.isLoggable(Level.FINE)) { - log.log(Level.FINE, Exceptions.toMessageString(ioe), ioe); - } - return Optional.empty(); - } - if (! operationId.isPresent()) { + log.log(Level.FINE, () -> Exceptions.toMessageString(ioe)); return Optional.empty(); } - DocumentOperationMessageV3 message; try { - message = getNextMessage(operationId.get(), requestInputStream, settings); + DocumentOperationMessageV3 message = getNextMessage(operationId.get(), requestInputStream, settings); + if (message != null) + setRoute(message, settings); + return Optional.ofNullable(message); } catch (Exception e) { - if (log.isLoggable(Level.WARNING)) { - log.log(Level.WARNING, Exceptions.toMessageString(e)); - } + log.log(Level.WARNING, () -> Exceptions.toMessageString(e)); metric.add(MetricNames.PARSE_ERROR, 1, null); - repliesFromOldMessages.add(new OperationStatus( - Exceptions.toMessageString(e), operationId.get(), ErrorCode.ERROR, false, "")); - - continue; + repliesFromOldMessages.add(new OperationStatus(Exceptions.toMessageString(e), + operationId.get(), + ErrorCode.ERROR, + false, + "")); } - if (message != null) - setRoute(message, settings); - return Optional.ofNullable(message); } } @@ -223,47 +216,45 @@ class ClientFeederV3 { BlockingQueue<OperationStatus> repliesFromOldMessages, AtomicInteger threadsAvailableForFeeding) throws InterruptedException { while (true) { - Optional<DocumentOperationMessageV3> msg = pullMessageFromRequest(settings, requestInputStream, repliesFromOldMessages); + Optional<DocumentOperationMessageV3> message = pullMessageFromRequest(settings, + requestInputStream, + repliesFromOldMessages); - if (! msg.isPresent()) { - break; - } - setMessageParameters(msg.get(), settings); + if (message.isEmpty()) break; + setMessageParameters(message.get(), settings); Result result; try { - result = sendMessage(settings, msg.get(), threadsAvailableForFeeding); + result = sendMessage(settings, message.get(), threadsAvailableForFeeding); } catch (RuntimeException e) { - repliesFromOldMessages.add(createOperationStatus(msg.get().getOperationId(), + repliesFromOldMessages.add(createOperationStatus(message.get().getOperationId(), Exceptions.toMessageString(e), ErrorCode.ERROR, false, - msg.get().getMessage())); + message.get().getMessage())); continue; } if (result.isAccepted()) { outstandingOperations.incrementAndGet(); updateOpsPerSec(); - log(Level.FINE, "Sent message successfully, document id: ", msg.get().getOperationId()); + log(Level.FINE, "Sent message successfully, document id: ", message.get().getOperationId()); } else if (!result.getError().isFatal()) { - repliesFromOldMessages.add(createOperationStatus(msg.get().getOperationId(), + repliesFromOldMessages.add(createOperationStatus(message.get().getOperationId(), result.getError().getMessage(), ErrorCode.TRANSIENT_ERROR, false, - msg.get().getMessage())); - continue; + message.get().getMessage())); } else { // should probably not happen, but everybody knows stuff that // shouldn't happen, happens all the time boolean isConditionNotMet = result.getError().getCode() == DocumentProtocol.ERROR_TEST_AND_SET_CONDITION_FAILED; - repliesFromOldMessages.add(createOperationStatus(msg.get().getOperationId(), + repliesFromOldMessages.add(createOperationStatus(message.get().getOperationId(), result.getError().getMessage(), ErrorCode.ERROR, isConditionNotMet, - msg.get().getMessage())); - continue; + message.get().getMessage())); } } } @@ -326,17 +317,11 @@ class ClientFeederV3 { } protected final void log(Level level, Object... msgParts) { - StringBuilder s; + if (!log.isLoggable(level)) return; - if (!log.isLoggable(level)) { - return; - } - - s = new StringBuilder(); - for (Object part : msgParts) { + StringBuilder s = new StringBuilder(); + for (Object part : msgParts) s.append(part.toString()); - } - log.log(level, s.toString()); } diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java index e6d8a88d10b..909c643a006 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeederSettings.java @@ -6,6 +6,8 @@ import com.yahoo.messagebus.routing.Route; import com.yahoo.vespa.http.client.config.FeedParams.DataFormat; import com.yahoo.vespa.http.client.core.Headers; +import java.util.Optional; + /** * Wrapper for the feed feederSettings read from HTTP request. * @@ -14,7 +16,7 @@ import com.yahoo.vespa.http.client.core.Headers; public class FeederSettings { private static final Route DEFAULT_ROUTE = Route.parse("default"); - public final boolean drain; + public final boolean drain; // TODO: Implement drain=true public final Route route; public final boolean denyIfBusy; public final DataFormat dataFormat; @@ -22,55 +24,13 @@ public class FeederSettings { public final Integer traceLevel; public FeederSettings(HttpRequest request) { - { - String tmpDrain = request.getHeader(Headers.DRAIN); - if (tmpDrain != null) { - drain = Boolean.parseBoolean(tmpDrain); - } else { - drain = false; - } - } - { - String tmpRoute = request.getHeader(Headers.ROUTE); - if (tmpRoute != null) { - route = Route.parse(tmpRoute); - } else { - route = DEFAULT_ROUTE; - } - } - { - String tmpDenyIfBusy = request.getHeader(Headers.DENY_IF_BUSY); - if (tmpDenyIfBusy != null) { - denyIfBusy = Boolean.parseBoolean(tmpDenyIfBusy); - } else { - denyIfBusy = false; - } - } - { - // TODO: Change default to JSON on Vespa 8 - String tmpDataFormat = request.getHeader(Headers.DATA_FORMAT); - if (tmpDataFormat != null) { - dataFormat = DataFormat.valueOf(tmpDataFormat); - } else { - dataFormat = DataFormat.XML_UTF8; - } - } - { - String tmpDataFormat = request.getHeader(Headers.PRIORITY); - if (tmpDataFormat != null) { - priority = tmpDataFormat; - } else { - priority = null; - } - } - { - String tmpDataFormat = request.getHeader(Headers.TRACE_LEVEL); - if (tmpDataFormat != null) { - traceLevel = Integer.valueOf(tmpDataFormat); - } else { - traceLevel = null; - } - } + this.drain = Optional.ofNullable(request.getHeader(Headers.DRAIN)).map(Boolean::parseBoolean).orElse(false); + this.route = Optional.ofNullable(request.getHeader(Headers.ROUTE)).map(Route::parse).orElse(DEFAULT_ROUTE); + this.denyIfBusy = Optional.ofNullable(request.getHeader(Headers.DENY_IF_BUSY)).map(Boolean::parseBoolean).orElse(false); + // TODO: Change default to JSON on Vespa 8: + this.dataFormat = Optional.ofNullable(request.getHeader(Headers.DATA_FORMAT)).map(DataFormat::valueOf).orElse(DataFormat.XML_UTF8); + this.priority = request.getHeader(Headers.PRIORITY); + this.traceLevel = Optional.ofNullable(request.getHeader(Headers.TRACE_LEVEL)).map(Integer::valueOf).orElse(null); } } diff --git a/vespalib/src/tests/slime/json_slime_benchmark.cpp b/vespalib/src/tests/slime/json_slime_benchmark.cpp index 3c006bb89f7..36987843492 100644 --- a/vespalib/src/tests/slime/json_slime_benchmark.cpp +++ b/vespalib/src/tests/slime/json_slime_benchmark.cpp @@ -1,9 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/vespalib/testkit/test_kit.h> #include <iostream> #include <fstream> -#include <sstream> using namespace vespalib::slime::convenience; diff --git a/vespalib/src/tests/slime/slime_binary_format_test.cpp b/vespalib/src/tests/slime/slime_binary_format_test.cpp index e6661cbf554..37ce6d5dfdf 100644 --- a/vespalib/src/tests/slime/slime_binary_format_test.cpp +++ b/vespalib/src/tests/slime/slime_binary_format_test.cpp @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/data/simple_buffer.h> #include "type_traits.h" #include <vespa/vespalib/util/stringfmt.h> diff --git a/vespalib/src/tests/slime/slime_json_format_test.cpp b/vespalib/src/tests/slime/slime_json_format_test.cpp index d1f77f09af1..df2f8b2e30b 100644 --- a/vespalib/src/tests/slime/slime_json_format_test.cpp +++ b/vespalib/src/tests/slime/slime_json_format_test.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/data/input.h> #include <vespa/vespalib/data/memory_input.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <iostream> #include <fstream> diff --git a/vespalib/src/tests/slime/slime_test.cpp b/vespalib/src/tests/slime/slime_test.cpp index 7e70dc3538e..e58b1599b8f 100644 --- a/vespalib/src/tests/slime/slime_test.cpp +++ b/vespalib/src/tests/slime/slime_test.cpp @@ -1,11 +1,14 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/log/log.h> -LOG_SETUP("slime_test"); + #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/data/slime/strfmt.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <type_traits> +#include <vespa/log/log.h> +LOG_SETUP("slime_test"); + using namespace vespalib::slime::convenience; TEST("print sizes") { diff --git a/vespalib/src/tests/trace/trace_serialization.cpp b/vespalib/src/tests/trace/trace_serialization.cpp index 7658fe7f163..3182e46061a 100644 --- a/vespalib/src/tests/trace/trace_serialization.cpp +++ b/vespalib/src/tests/trace/trace_serialization.cpp @@ -3,6 +3,7 @@ #include <vespa/vespalib/trace/tracenode.h> #include <vespa/vespalib/trace/slime_trace_serializer.h> #include <vespa/vespalib/trace/slime_trace_deserializer.h> +#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/log/log.h> LOG_SETUP("trace_test"); diff --git a/vespalib/src/vespa/vespalib/data/memory.h b/vespalib/src/vespa/vespalib/data/memory.h index 07767180b57..eee0a1a3e4f 100644 --- a/vespalib/src/vespa/vespalib/data/memory.h +++ b/vespalib/src/vespa/vespalib/data/memory.h @@ -15,14 +15,14 @@ struct Memory const char *data; size_t size; - Memory() : data(nullptr), size(0) {} - Memory(const char *d, size_t s) : data(d), size(s) {} - Memory(const char *str) : data(str), size(strlen(str)) {} - Memory(const std::string &str) + Memory() noexcept : data(nullptr), size(0) {} + Memory(const char *d, size_t s) noexcept : data(d), size(s) {} + Memory(const char *str) noexcept : data(str), size(strlen(str)) {} + Memory(const std::string &str) noexcept : data(str.data()), size(str.size()) {} - Memory(const vespalib::string &str) + Memory(const vespalib::string &str) noexcept : data(str.data()), size(str.size()) {} - Memory(vespalib::stringref str_ref) + Memory(vespalib::stringref str_ref) noexcept : data(str_ref.data()), size(str_ref.size()) {} vespalib::string make_string() const; vespalib::stringref make_stringref() const { return stringref(data, size); } diff --git a/vespalib/src/vespa/vespalib/data/simple_buffer.cpp b/vespalib/src/vespa/vespalib/data/simple_buffer.cpp index 09ac4a4b830..7e3c5022fc5 100644 --- a/vespalib/src/vespa/vespalib/data/simple_buffer.cpp +++ b/vespalib/src/vespa/vespalib/data/simple_buffer.cpp @@ -11,7 +11,7 @@ SimpleBuffer::SimpleBuffer() { } -SimpleBuffer::~SimpleBuffer() { } +SimpleBuffer::~SimpleBuffer() = default; Memory SimpleBuffer::obtain() diff --git a/vespalib/src/vespa/vespalib/data/simple_buffer.h b/vespalib/src/vespa/vespalib/data/simple_buffer.h index f7d9543440f..3bcb43a3856 100644 --- a/vespalib/src/vespa/vespalib/data/simple_buffer.h +++ b/vespalib/src/vespa/vespalib/data/simple_buffer.h @@ -4,6 +4,7 @@ #include "input.h" #include "output.h" +#include <vespa/vespalib/stllike/allocator.h> #include <iosfwd> #include <vector> @@ -20,7 +21,7 @@ class SimpleBuffer : public Input, public Output { private: - std::vector<char> _data; + std::vector<char, allocator_large<char>> _data; size_t _used; public: diff --git a/vespalib/src/vespa/vespalib/data/slime/slime.h b/vespalib/src/vespa/vespalib/data/slime/slime.h index aa44b38b353..6523cd1dac0 100644 --- a/vespalib/src/vespa/vespalib/data/slime/slime.h +++ b/vespalib/src/vespa/vespalib/data/slime/slime.h @@ -31,7 +31,6 @@ #include "external_data_value_factory.h" #include <vespa/vespalib/data/input_reader.h> #include <vespa/vespalib/data/output_writer.h> -#include <vespa/vespalib/data/simple_buffer.h> #include <vespa/vespalib/data/output.h> namespace vespalib { diff --git a/vespalib/src/vespa/vespalib/datastore/datastore.h b/vespalib/src/vespa/vespalib/datastore/datastore.h index 193869a5591..4b908452b32 100644 --- a/vespalib/src/vespa/vespalib/datastore/datastore.h +++ b/vespalib/src/vespa/vespalib/datastore/datastore.h @@ -106,6 +106,7 @@ public: DataStore(const DataStore &rhs) = delete; DataStore &operator=(const DataStore &rhs) = delete; DataStore(); + DataStore(uint32_t min_arrays); ~DataStore(); EntryRef addEntry(const EntryType &e); diff --git a/vespalib/src/vespa/vespalib/datastore/datastore.hpp b/vespalib/src/vespa/vespalib/datastore/datastore.hpp index 6549425b022..ad35f3c7383 100644 --- a/vespalib/src/vespa/vespalib/datastore/datastore.hpp +++ b/vespalib/src/vespa/vespalib/datastore/datastore.hpp @@ -133,8 +133,14 @@ DataStoreT<RefT>::freeListRawAllocator(uint32_t typeId) template <typename EntryType, typename RefT> DataStore<EntryType, RefT>::DataStore() + : DataStore(RefType::offsetSize()) +{ +} + +template <typename EntryType, typename RefT> +DataStore<EntryType, RefT>::DataStore(uint32_t min_arrays) : ParentType(), - _type(1, RefType::offsetSize(), RefType::offsetSize()) + _type(1, min_arrays, RefType::offsetSize()) { addType(&_type); initActiveBuffers(); diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h index db545451a30..78597a53dc8 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_enumerator.h @@ -3,6 +3,7 @@ #pragma once #include "i_unique_store_dictionary.h" +#include <vespa/vespalib/stllike/allocator.h> namespace vespalib::datastore { @@ -18,9 +19,10 @@ template <typename RefT> class UniqueStoreEnumerator { public: using RefType = RefT; - using EnumValues = std::vector<std::vector<uint32_t>>; private: + using UInt32Vector = std::vector<uint32_t, vespalib::allocator_large<uint32_t>>; + using EnumValues = std::vector<UInt32Vector>; IUniqueStoreDictionary::ReadSnapshot::UP _dict_snapshot; const DataStoreBase &_store; EnumValues _enumValues; diff --git a/vespalib/src/vespa/vespalib/objects/nbostream.h b/vespalib/src/vespa/vespalib/objects/nbostream.h index daaea981b5a..c5b26d786b3 100644 --- a/vespalib/src/vespa/vespalib/objects/nbostream.h +++ b/vespalib/src/vespa/vespalib/objects/nbostream.h @@ -20,7 +20,7 @@ class nbostream public: using Buffer = Array<char>; using Alloc = alloc::Alloc; - enum State { ok=0, eof=0x01}; + enum State { ok=0, eof=0x01, oob=0x02}; nbostream(size_t initialSize=1024); protected: nbostream(const void * buf, size_t sz, bool longLivedBuffer); @@ -145,6 +145,7 @@ public: const char * peek() const { return &_rbuf[_rp]; } size_t rp() const { return _rp; } nbostream & rp(size_t pos) { if (pos > _wp) fail(eof); _rp = pos; return *this; } + nbostream & wp(size_t pos) { if (pos > _wbuf.size()) fail(oob); _wp = pos; return *this; } size_t wp() const { return _wp; } State state() const { return _state; } bool good() const { return _state == ok; } diff --git a/vespalib/src/vespa/vespalib/util/arrayref.h b/vespalib/src/vespa/vespalib/util/arrayref.h index 749395ff574..03634a7a094 100644 --- a/vespalib/src/vespa/vespalib/util/arrayref.h +++ b/vespalib/src/vespa/vespalib/util/arrayref.h @@ -13,11 +13,11 @@ namespace vespalib { template <typename T> class ArrayRef { public: - ArrayRef() : _v(nullptr), _sz(0) { } - ArrayRef(T * v, size_t sz) : _v(v), _sz(sz) { } + ArrayRef() noexcept : _v(nullptr), _sz(0) { } + ArrayRef(T * v, size_t sz) noexcept : _v(v), _sz(sz) { } template<typename A=std::allocator<T>> - ArrayRef(std::vector<T, A> & v) : _v(&v[0]), _sz(v.size()) { } - ArrayRef(Array<T> &v) : _v(&v[0]), _sz(v.size()) { } + ArrayRef(std::vector<T, A> & v) noexcept : _v(&v[0]), _sz(v.size()) { } + ArrayRef(Array<T> &v) noexcept : _v(&v[0]), _sz(v.size()) { } T & operator [] (size_t i) { return _v[i]; } const T & operator [] (size_t i) const { return _v[i]; } size_t size() const { return _sz; } @@ -32,12 +32,12 @@ private: template <typename T> class ConstArrayRef { public: - ConstArrayRef(const T *v, size_t sz) : _v(v), _sz(sz) { } + ConstArrayRef(const T *v, size_t sz) noexcept : _v(v), _sz(sz) { } template<typename A=std::allocator<T>> - ConstArrayRef(const std::vector<T, A> & v) : _v(&v[0]), _sz(v.size()) { } - ConstArrayRef(const ArrayRef<T> & v) : _v(&v[0]), _sz(v.size()) { } - ConstArrayRef(const Array<T> &v) : _v(&v[0]), _sz(v.size()) { } - ConstArrayRef() : _v(nullptr), _sz(0) {} + ConstArrayRef(const std::vector<T, A> & v) noexcept : _v(&v[0]), _sz(v.size()) { } + ConstArrayRef(const ArrayRef<T> & v) noexcept : _v(&v[0]), _sz(v.size()) { } + ConstArrayRef(const Array<T> &v) noexcept : _v(&v[0]), _sz(v.size()) { } + ConstArrayRef() noexcept : _v(nullptr), _sz(0) {} const T & operator [] (size_t i) const { return _v[i]; } size_t size() const { return _sz; } bool empty() const { return _sz == 0; } |