diff options
author | Harald Musum <musum@verizonmedia.com> | 2020-06-11 15:01:38 +0200 |
---|---|---|
committer | Harald Musum <musum@verizonmedia.com> | 2020-06-11 15:01:38 +0200 |
commit | 62728231905e839bf3b6199a3165407cd0cfa7e1 (patch) | |
tree | 80b185e0b8de0bcdc45628d9c4856c2fce31b04b | |
parent | 05d3317e9860aff8fa3005ac214e2c4890457e29 (diff) | |
parent | ad47760e1a20586d07842c3f1996d07a2da45283 (diff) |
Merge branch 'master' into hmusum/configserver-refactoring-9
44 files changed, 1138 insertions, 423 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JvmOptionsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JvmOptionsTest.java index 1bc9b65bdad..294df42bd77 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JvmOptionsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/JvmOptionsTest.java @@ -133,6 +133,7 @@ public class JvmOptionsTest extends ContainerModelBuilderTestBase { public void requireThatJvmGCOptionsIsHonoured() throws IOException, SAXException { verifyJvmGCOptions(false, null,null, ContainerCluster.G1GC); verifyJvmGCOptions(true, null,null, ContainerCluster.CMS); + verifyJvmGCOptions(true, "",null, ContainerCluster.CMS); verifyJvmGCOptions(false, "-XX:+UseConcMarkSweepGC",null, "-XX:+UseConcMarkSweepGC"); verifyJvmGCOptions(true, "-XX:+UseConcMarkSweepGC",null, "-XX:+UseConcMarkSweepGC"); verifyJvmGCOptions(false, null,"-XX:+UseG1GC", "-XX:+UseG1GC"); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java index 9ed1e8ee88c..54d06deeaeb 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java @@ -420,6 +420,10 @@ public class TenantApplications implements RequestHandler, ReloadHandler, HostVa reloadListener.verifyHostsAreAvailable(tenant, newHosts); } + public HostValidator<ApplicationId> getHostValidator() { + return this; + } + public HostRegistry<ApplicationId> getApplicationHostRegistry() { return hostRegistry; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java index 3ae508f9fc0..b01ddb04d48 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/LocalSession.java @@ -4,7 +4,6 @@ package com.yahoo.vespa.config.server.session; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationMetaData; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; @@ -13,7 +12,6 @@ import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.config.server.TimeoutBudget; import com.yahoo.vespa.config.server.application.TenantApplications; -import com.yahoo.vespa.config.server.host.HostValidator; import java.io.File; @@ -31,7 +29,6 @@ public class LocalSession extends Session { protected final ApplicationPackage applicationPackage; private final TenantApplications applicationRepo; private final File serverDBSessionDir; - private final HostValidator<ApplicationId> hostValidator; /** * Creates a session. This involves loading the application, validating it and distributing it. @@ -40,12 +37,11 @@ public class LocalSession extends Session { */ public LocalSession(TenantName tenant, long sessionId, ApplicationPackage applicationPackage, SessionZooKeeperClient sessionZooKeeperClient, File serverDBSessionDir, - TenantApplications applicationRepo, HostValidator<ApplicationId> hostValidator) { + TenantApplications applicationRepo) { super(tenant, sessionId, sessionZooKeeperClient); this.serverDBSessionDir = serverDBSessionDir; this.applicationPackage = applicationPackage; this.applicationRepo = applicationRepo; - this.hostValidator = hostValidator; } public ApplicationFile getApplicationFile(Path relativePath, Mode mode) { @@ -100,8 +96,6 @@ public class LocalSession extends Session { public ApplicationPackage getApplicationPackage() { return applicationPackage; } - public HostValidator<ApplicationId> getHostValidator() { return hostValidator; } - // The rest of this class should be moved elsewhere ... private static class FileTransaction extends AbstractTransaction { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 150100077b4..2a61c621224 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -22,7 +22,6 @@ import com.yahoo.vespa.config.server.application.TenantApplications; import com.yahoo.vespa.config.server.configchange.ConfigChangeActions; import com.yahoo.vespa.config.server.deploy.TenantFileSystemDirs; import com.yahoo.vespa.config.server.filedistribution.FileDirectory; -import com.yahoo.vespa.config.server.host.HostValidator; import com.yahoo.vespa.config.server.monitoring.MetricUpdater; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.tenant.TenantRepository; @@ -83,7 +82,6 @@ public class SessionRepository { private final MetricUpdater metrics; private final Curator.DirectoryCache directoryCache; private final TenantApplications applicationRepo; - private final HostValidator<ApplicationId> hostRegistry; private final SessionPreparer sessionPreparer; private final Path sessionsPath; private final TenantName tenantName; @@ -94,7 +92,6 @@ public class SessionRepository { TenantApplications applicationRepo, ReloadHandler reloadHandler, FlagSource flagSource, - HostValidator<ApplicationId> hostRegistry, SessionPreparer sessionPreparer) { this.tenantName = tenantName; this.componentRegistry = componentRegistry; @@ -105,7 +102,6 @@ public class SessionRepository { this.zkWatcherExecutor = command -> componentRegistry.getZkWatcherExecutor().execute(tenantName, command); this.tenantFileSystemDirs = new TenantFileSystemDirs(componentRegistry.getConfigServerDB(), tenantName); this.applicationRepo = applicationRepo; - this.hostRegistry = hostRegistry; this.sessionPreparer = sessionPreparer; this.distributeApplicationPackage = Flags.CONFIGSERVER_DISTRIBUTE_APPLICATION_PACKAGE.bindTo(flagSource); this.reloadHandler = reloadHandler; @@ -162,7 +158,7 @@ public class SessionRepository { long sessionId = session.getSessionId(); SessionZooKeeperClient sessionZooKeeperClient = createSessionZooKeeperClient(sessionId); Curator.CompletionWaiter waiter = sessionZooKeeperClient.createPrepareWaiter(); - ConfigChangeActions actions = sessionPreparer.prepare(session.getHostValidator(), logger, params, + ConfigChangeActions actions = sessionPreparer.prepare(applicationRepo.getHostValidator(), logger, params, currentActiveApplicationSet, tenantPath, now, getSessionAppDir(sessionId), session.getApplicationPackage(), sessionZooKeeperClient); @@ -429,7 +425,7 @@ public class SessionRepository { sessionZKClient.createNewSession(clock.instant()); Curator.CompletionWaiter waiter = sessionZKClient.getUploadWaiter(); LocalSession session = new LocalSession(tenantName, sessionId, applicationPackage, sessionZKClient, - getSessionAppDir(sessionId), applicationRepo, hostRegistry); + getSessionAppDir(sessionId), applicationRepo); waiter.awaitCompletion(timeoutBudget.timeLeft()); return session; } @@ -488,7 +484,7 @@ public class SessionRepository { sessionId, currentlyActiveSessionId, false); SessionZooKeeperClient sessionZooKeeperClient = createSessionZooKeeperClient(sessionId); return new LocalSession(tenantName, sessionId, applicationPackage, sessionZooKeeperClient, - getSessionAppDir(sessionId), applicationRepo, hostRegistry); + getSessionAppDir(sessionId), applicationRepo); } catch (Exception e) { throw new RuntimeException("Error creating session " + sessionId, e); } @@ -517,7 +513,7 @@ public class SessionRepository { ApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(sessionDir); SessionZooKeeperClient sessionZKClient = createSessionZooKeeperClient(sessionId); return new LocalSession(tenantName, sessionId, applicationPackage, sessionZKClient, - getSessionAppDir(sessionId), applicationRepo, hostRegistry); + getSessionAppDir(sessionId), applicationRepo); } /** diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java index 33001d2996c..807629a2148 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java @@ -14,7 +14,6 @@ import com.yahoo.config.provision.DockerImage; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.path.Path; import com.yahoo.text.Utf8; -import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.config.server.UserConfigDefinitionRepo; import com.yahoo.vespa.config.server.deploy.ZooKeeperClient; @@ -129,18 +128,6 @@ public class SessionZooKeeperClient { return curator.getCompletionWaiter(path, getNumberOfMembers(), serverId); } - public void delete(NestedTransaction transaction ) { - try { - log.log(Level.FINE, "Deleting " + sessionPath.getAbsolute()); - CuratorTransaction curatorTransaction = new CuratorTransaction(curator); - CuratorOperations.deleteAll(sessionPath.getAbsolute(), curator).forEach(curatorTransaction::add); - transaction.add(curatorTransaction); - transaction.commit(); - } catch (RuntimeException e) { - log.log(Level.INFO, "Error deleting session (" + sessionPath.getAbsolute() + ") from zookeeper", e); - } - } - /** Returns a transaction deleting this session on commit */ public CuratorTransaction deleteTransaction() { return CuratorTransaction.from(CuratorOperations.deleteAll(sessionPath.getAbsolute(), curator), curator); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java index 32e9f694027..304fbb6786a 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java @@ -224,7 +224,6 @@ public class TenantRepository { SessionRepository sessionRepository = new SessionRepository(tenantName, componentRegistry, applicationRepo, reloadHandler, componentRegistry.getFlagSource(), - applicationRepo, componentRegistry.getSessionPreparer()); log.log(Level.INFO, "Creating tenant '" + tenantName + "'"); Tenant tenant = new Tenant(tenantName, sessionRepository, requestHandler, diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionHandlerTest.java index 90d4ecaccf9..be368692aba 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SessionHandlerTest.java @@ -16,7 +16,6 @@ import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.path.Path; import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; -import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.session.DummyTransaction; import com.yahoo.vespa.config.server.session.LocalSession; import com.yahoo.vespa.config.server.session.MockSessionZKClient; @@ -87,7 +86,7 @@ public class SessionHandlerTest { private ApplicationId applicationId; public MockLocalSession(long sessionId, ApplicationPackage app) { - super(TenantName.defaultName(), sessionId, app, new MockSessionZKClient(app), null, null, new HostRegistry<>()); + super(TenantName.defaultName(), sessionId, app, new MockSessionZKClient(app), null, null); } public MockLocalSession(long sessionId, ApplicationPackage app, ApplicationId applicationId) { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java index e32071d6d16..9d84d5a7f28 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/LocalSessionTest.java @@ -17,7 +17,6 @@ import com.yahoo.vespa.config.server.application.TenantApplications; import com.yahoo.vespa.config.server.deploy.DeployHandlerLogger; import com.yahoo.vespa.config.server.deploy.TenantFileSystemDirs; import com.yahoo.vespa.config.server.deploy.ZooKeeperClient; -import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.tenant.TenantRepository; import com.yahoo.vespa.config.server.zookeeper.ConfigCurator; import com.yahoo.vespa.curator.Curator; @@ -139,7 +138,7 @@ public class LocalSessionTest { new TestComponentRegistry.Builder().curator(curator).build(), tenant); applications.createApplication(zkc.readApplicationId()); return new LocalSession(tenant, sessionId, FilesApplicationPackage.fromFile(testApp), - zkc, sessionDir, applications, new HostRegistry<>()); + zkc, sessionDir, applications); } private void doPrepare(LocalSession session) { diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java index f73fb053649..1528c82188e 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java @@ -76,7 +76,7 @@ public class SessionRepositoryTest { TenantApplications applicationRepo = TenantApplications.create(globalComponentRegistry, tenantName); sessionRepository = new SessionRepository(tenantName, globalComponentRegistry, applicationRepo, applicationRepo, new InMemoryFlagSource(), - applicationRepo, globalComponentRegistry.getSessionPreparer()); + globalComponentRegistry.getSessionPreparer()); } @Test diff --git a/container-core/src/main/java/com/yahoo/container/protect/ProcessTerminator.java b/container-core/src/main/java/com/yahoo/container/protect/ProcessTerminator.java index 38f5b72336b..16cf741813c 100644 --- a/container-core/src/main/java/com/yahoo/container/protect/ProcessTerminator.java +++ b/container-core/src/main/java/com/yahoo/container/protect/ProcessTerminator.java @@ -5,7 +5,7 @@ import com.yahoo.protect.Process; /** * An injectable terminator of the Java vm. - * Components that encounters conditions where the vm should be terminator should + * Components that encounters conditions where the vm should be terminated should * request an instance of this injected. That makes termination testable * as tests can create subclasses of this which register the termination request * rather than terminating. diff --git a/documentapi/abi-spec.json b/documentapi/abi-spec.json index 8f49b51fa1c..bb4deed2914 100644 --- a/documentapi/abi-spec.json +++ b/documentapi/abi-spec.json @@ -957,6 +957,28 @@ ], "fields": [] }, + "com.yahoo.documentapi.local.LocalVisitorSession": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.documentapi.VisitorSession" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.documentapi.local.LocalDocumentAccess, com.yahoo.documentapi.VisitorParameters)", + "public boolean isDone()", + "public com.yahoo.documentapi.ProgressToken getProgress()", + "public com.yahoo.messagebus.Trace getTrace()", + "public boolean waitUntilDone(long)", + "public void ack(com.yahoo.documentapi.AckToken)", + "public void abort()", + "public com.yahoo.documentapi.VisitorResponse getNext()", + "public com.yahoo.documentapi.VisitorResponse getNext(int)", + "public void destroy()" + ], + "fields": [] + }, "com.yahoo.documentapi.messagebus.MessageBusAsyncSession": { "superClass": "java.lang.Object", "interfaces": [ diff --git a/documentapi/src/main/java/com/yahoo/documentapi/local/LocalDocumentAccess.java b/documentapi/src/main/java/com/yahoo/documentapi/local/LocalDocumentAccess.java index 202929130c7..c69a8fb48de 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/local/LocalDocumentAccess.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/local/LocalDocumentAccess.java @@ -3,6 +3,7 @@ package com.yahoo.documentapi.local; import com.yahoo.document.Document; import com.yahoo.document.DocumentId; +import com.yahoo.document.select.parser.ParseException; import com.yahoo.documentapi.AsyncParameters; import com.yahoo.documentapi.AsyncSession; import com.yahoo.documentapi.DocumentAccess; @@ -43,8 +44,8 @@ public class LocalDocumentAccess extends DocumentAccess { } @Override - public VisitorSession createVisitorSession(VisitorParameters parameters) { - throw new UnsupportedOperationException("Not supported yet"); + public VisitorSession createVisitorSession(VisitorParameters parameters) throws ParseException { + return new LocalVisitorSession(this, parameters); } @Override diff --git a/documentapi/src/main/java/com/yahoo/documentapi/local/LocalVisitorSession.java b/documentapi/src/main/java/com/yahoo/documentapi/local/LocalVisitorSession.java new file mode 100644 index 00000000000..e107be94008 --- /dev/null +++ b/documentapi/src/main/java/com/yahoo/documentapi/local/LocalVisitorSession.java @@ -0,0 +1,165 @@ +package com.yahoo.documentapi.local; + +import com.yahoo.document.Document; +import com.yahoo.document.DocumentGet; +import com.yahoo.document.DocumentId; +import com.yahoo.document.DocumentPut; +import com.yahoo.document.Field; +import com.yahoo.document.fieldset.FieldCollection; +import com.yahoo.document.fieldset.FieldSet; +import com.yahoo.document.fieldset.FieldSetRepo; +import com.yahoo.document.select.DocumentSelector; +import com.yahoo.document.select.Result; +import com.yahoo.document.select.parser.ParseException; +import com.yahoo.documentapi.AckToken; +import com.yahoo.documentapi.ProgressToken; +import com.yahoo.documentapi.VisitorControlHandler; +import com.yahoo.documentapi.VisitorDataHandler; +import com.yahoo.documentapi.VisitorDataQueue; +import com.yahoo.documentapi.VisitorParameters; +import com.yahoo.documentapi.VisitorResponse; +import com.yahoo.documentapi.VisitorSession; +import com.yahoo.documentapi.messagebus.protocol.PutDocumentMessage; +import com.yahoo.messagebus.Trace; +import com.yahoo.yolean.Exceptions; + +import java.util.Comparator; +import java.util.Map; +import java.util.concurrent.ConcurrentSkipListMap; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Local visitor session that copies and iterates through all items in the local document access. + * Each document must be ack'ed for the session to be done visiting. + * Only document puts are sent by this session, and this is done from a separate thread. + * + * @author jonmv + */ +public class LocalVisitorSession implements VisitorSession { + + private enum State { RUNNING, FAILURE, ABORTED, SUCCESS } + + private final VisitorDataHandler data; + private final VisitorControlHandler control; + private final Map<DocumentId, Document> outstanding; + private final DocumentSelector selector; + private final FieldSet fieldSet; + private final AtomicReference<State> state; + + public LocalVisitorSession(LocalDocumentAccess access, VisitorParameters parameters) throws ParseException { + if (parameters.getResumeToken() != null) + throw new UnsupportedOperationException("Continuation via progress tokens is not supported"); + + if (parameters.getRemoteDataHandler() != null) + throw new UnsupportedOperationException("Remote data handlers are not supported"); + + this.selector = new DocumentSelector(parameters.getDocumentSelection()); + this.fieldSet = new FieldSetRepo().parse(access.getDocumentTypeManager(), parameters.fieldSet()); + + this.data = parameters.getLocalDataHandler() == null ? new VisitorDataQueue() : parameters.getLocalDataHandler(); + this.data.reset(); + this.data.setSession(this); + + this.control = parameters.getControlHandler() == null ? new VisitorControlHandler() : parameters.getControlHandler(); + this.control.reset(); + this.control.setSession(this); + + this.outstanding = new ConcurrentSkipListMap<>(Comparator.comparing(DocumentId::toString)); + this.outstanding.putAll(access.documents); + this.state = new AtomicReference<>(State.RUNNING); + + start(); + } + + void start() { + new Thread(() -> { + try { + // Iterate through all documents and pass on to data handler + outstanding.forEach((id, document) -> { + if (state.get() != State.RUNNING) + return; + + if (selector.accepts(new DocumentPut(document)) != Result.TRUE) + return; + + Document copy = new Document(document.getDataType(), document.getId()); + new FieldSetRepo().copyFields(document, copy, fieldSet); + + data.onMessage(new PutDocumentMessage(new DocumentPut(copy)), + new AckToken(id)); + }); + // Transition to a terminal state when done + state.updateAndGet(current -> { + switch (current) { + case RUNNING: + control.onDone(VisitorControlHandler.CompletionCode.SUCCESS, "Success"); + return State.SUCCESS; + case ABORTED: + control.onDone(VisitorControlHandler.CompletionCode.ABORTED, "Aborted by user"); + return State.ABORTED; + default: + control.onDone(VisitorControlHandler.CompletionCode.FAILURE, "Unexpected state '" + current + "'");; + return State.FAILURE; + } + }); + } + // Transition to failure terminal state on error + catch (Exception e) { + state.set(State.FAILURE); + outstanding.clear(); + control.onDone(VisitorControlHandler.CompletionCode.FAILURE, Exceptions.toMessageString(e)); + } + finally { + data.onDone(); + } + }).start(); + } + + @Override + public boolean isDone() { + return outstanding.isEmpty() // All documents ack'ed + && control.isDone(); // Control handler has been notified + } + + @Override + public ProgressToken getProgress() { + throw new UnsupportedOperationException("Progress tokens are not supported"); + } + + @Override + public Trace getTrace() { + throw new UnsupportedOperationException("Traces are not supported"); + } + + @Override + public boolean waitUntilDone(long timeoutMs) throws InterruptedException { + return control.waitUntilDone(timeoutMs); + } + + @Override + public void ack(AckToken token) { + outstanding.remove((DocumentId) token.ackObject); + } + + @Override + public void abort() { + state.updateAndGet(current -> current == State.RUNNING ? State.ABORTED : current); + outstanding.clear(); + } + + @Override + public VisitorResponse getNext() { + return data.getNext(); + } + + @Override + public VisitorResponse getNext(int timeoutMilliseconds) throws InterruptedException { + return data.getNext(timeoutMilliseconds); + } + + @Override + public void destroy() { + abort(); + } + +} diff --git a/documentapi/src/test/java/com/yahoo/documentapi/local/LocalDocumentApiTestCase.java b/documentapi/src/test/java/com/yahoo/documentapi/local/LocalDocumentApiTestCase.java new file mode 100644 index 00000000000..d1361e50973 --- /dev/null +++ b/documentapi/src/test/java/com/yahoo/documentapi/local/LocalDocumentApiTestCase.java @@ -0,0 +1,242 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.documentapi.local; + +import com.yahoo.document.Document; +import com.yahoo.document.DocumentId; +import com.yahoo.document.DocumentPut; +import com.yahoo.document.DocumentRemove; +import com.yahoo.document.DocumentType; +import com.yahoo.document.DocumentUpdate; +import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.document.select.parser.ParseException; +import com.yahoo.document.update.FieldUpdate; +import com.yahoo.documentapi.AsyncParameters; +import com.yahoo.documentapi.AsyncSession; +import com.yahoo.documentapi.DocumentAccess; +import com.yahoo.documentapi.DocumentAccessParams; +import com.yahoo.documentapi.DocumentResponse; +import com.yahoo.documentapi.DumpVisitorDataHandler; +import com.yahoo.documentapi.Response; +import com.yahoo.documentapi.Result; +import com.yahoo.documentapi.SyncParameters; +import com.yahoo.documentapi.SyncSession; +import com.yahoo.documentapi.VisitorControlHandler; +import com.yahoo.documentapi.VisitorParameters; +import com.yahoo.documentapi.VisitorSession; +import com.yahoo.documentapi.test.AbstractDocumentApiTestCase; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.CountDownLatch; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertTrue; + +/** + * Runs the superclass tests on this implementation + * + * @author bratseth + */ +public class LocalDocumentApiTestCase extends AbstractDocumentApiTestCase { + + protected LocalDocumentAccess access; + + @Override + protected DocumentAccess access() { + return access; + } + + @Before + public void setUp() { + DocumentAccessParams params = new DocumentAccessParams(); + params.setDocumentManagerConfigId("file:src/test/cfg/documentmanager.cfg"); + access = new LocalDocumentAccess(params); + } + + @After + public void shutdownAccess() { + access.shutdown(); + } + + @Test + public void testNoExceptionFromAsync() { + AsyncSession session = access.createAsyncSession(new AsyncParameters()); + + DocumentType type = access.getDocumentTypeManager().getDocumentType("music"); + DocumentUpdate docUp = new DocumentUpdate(type, new DocumentId("id:ns:music::2")); + + Result result = session.update(docUp); + assertTrue(result.isSuccess()); + Response response = session.getNext(); + assertEquals(result.getRequestId(), response.getRequestId()); + assertFalse(response.isSuccess()); + session.destroy(); + } + + @Test + public void testAsyncFetch() { + AsyncSession session = access.createAsyncSession(new AsyncParameters()); + List<DocumentId> ids = new ArrayList<>(); + ids.add(new DocumentId("id:music:music::1")); + ids.add(new DocumentId("id:music:music::2")); + ids.add(new DocumentId("id:music:music::3")); + for (DocumentId id : ids) + session.put(new Document(access.getDocumentTypeManager().getDocumentType("music"), id)); + int timeout = 100; + + long startTime = System.currentTimeMillis(); + Set<Long> outstandingRequests = new HashSet<>(); + for (DocumentId id : ids) { + Result result = session.get(id); + if ( ! result.isSuccess()) + throw new IllegalStateException("Failed requesting document " + id, result.getError().getCause()); + outstandingRequests.add(result.getRequestId()); + } + + List<Document> documents = new ArrayList<>(); + try { + while ( ! outstandingRequests.isEmpty()) { + int timeSinceStart = (int)(System.currentTimeMillis() - startTime); + Response response = session.getNext(timeout - timeSinceStart); + if (response == null) + throw new RuntimeException("Timed out waiting for documents"); // or return what you have + if ( ! outstandingRequests.contains(response.getRequestId())) continue; // Stale: Ignore + + if (response.isSuccess()) + documents.add(((DocumentResponse)response).getDocument()); + outstandingRequests.remove(response.getRequestId()); + } + } + catch (InterruptedException e) { + throw new RuntimeException("Interrupted while waiting for documents", e); + } + + assertEquals(3, documents.size()); + for (Document document : documents) + assertNotNull(document); + } + + @Test + public void testFeedingAndVisiting() throws InterruptedException, ParseException { + DocumentType musicType = access().getDocumentTypeManager().getDocumentType("music"); + Document doc1 = new Document(musicType, "id:ns:music::1"); doc1.setFieldValue("artist", "one"); + Document doc2 = new Document(musicType, "id:ns:music::2"); doc2.setFieldValue("artist", "two"); + Document doc3 = new Document(musicType, "id:ns:music::3"); + + // Select all music documents where the "artist" field is set + VisitorParameters parameters = new VisitorParameters("music.artist"); + parameters.setFieldSet("music:artist"); + VisitorControlHandler control = new VisitorControlHandler(); + parameters.setControlHandler(control); + Set<Document> received = new ConcurrentSkipListSet<>(); + parameters.setLocalDataHandler(new DumpVisitorDataHandler() { + @Override public void onDocument(Document doc, long timeStamp) { + received.add(doc); + } + @Override public void onRemove(DocumentId id) { + throw new IllegalStateException("Not supposed to get here"); + } + }); + + // Visit when there are no documents completes immediately + access.createVisitorSession(parameters).waitUntilDone(0); + assertSame(VisitorControlHandler.CompletionCode.SUCCESS, + control.getResult().getCode()); + assertEquals(Set.of(), + received); + + // Sync-put some documents + SyncSession out = access.createSyncSession(new SyncParameters.Builder().build()); + out.put(new DocumentPut(doc1)); + out.put(new DocumentPut(doc2)); + out.put(new DocumentPut(doc3)); + assertEquals(Map.of(doc1.getId(), doc1, + doc2.getId(), doc2, + doc3.getId(), doc3), + access.documents); + + // Expect a subset of documents to be returned, based on the selection + access.createVisitorSession(parameters).waitUntilDone(0); + assertSame(VisitorControlHandler.CompletionCode.SUCCESS, + control.getResult().getCode()); + assertEquals(Set.of(doc1, doc2), + received); + + // Remove doc2 and set artist for doc3, to see changes are reflected in subsequent visits + out.remove(new DocumentRemove(doc2.getId())); + out.update(new DocumentUpdate(musicType, doc3.getId()).addFieldUpdate(FieldUpdate.createAssign(musicType.getField("artist"), + new StringFieldValue("three")))); + assertEquals(Map.of(doc1.getId(), doc1, + doc3.getId(), doc3), + access.documents); + assertEquals("three", + ((StringFieldValue) doc3.getFieldValue("artist")).getString()); + + // Visit the documents again, retrieving none of the document fields + parameters.setFieldSet("[id]"); + received.clear(); + access.createVisitorSession(parameters).waitUntilDone(0); + assertSame(VisitorControlHandler.CompletionCode.SUCCESS, + control.getResult().getCode()); + assertEquals(Set.of(new Document(musicType, doc1.getId()), new Document(musicType, doc3.getId())), + received); + + // Visit the documents again, throwing an exception in the data handler on doc3 + received.clear(); + parameters.setLocalDataHandler(new DumpVisitorDataHandler() { + @Override public void onDocument(Document doc, long timeStamp) { + if (doc3.getId().equals(doc.getId())) + throw new RuntimeException("SEGFAULT"); + received.add(doc); + } + @Override public void onRemove(DocumentId id) { + throw new IllegalStateException("Not supposed to get here"); + } + }); + access.createVisitorSession(parameters).waitUntilDone(0); + assertSame(VisitorControlHandler.CompletionCode.FAILURE, + control.getResult().getCode()); + assertEquals("SEGFAULT", + control.getResult().getMessage()); + assertEquals(Set.of(new Document(musicType, doc1.getId())), + received); + + // Visit the documents again, aborting after the first document + received.clear(); + CountDownLatch visitLatch = new CountDownLatch(1); + CountDownLatch abortLatch = new CountDownLatch(1); + parameters.setLocalDataHandler(new DumpVisitorDataHandler() { + @Override public void onDocument(Document doc, long timeStamp) { + received.add(doc); + abortLatch.countDown(); + try { + visitLatch.await(); + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + @Override public void onRemove(DocumentId id) { throw new IllegalStateException("Not supposed to get here"); } + }); + VisitorSession visit = access.createVisitorSession(parameters); + abortLatch.await(); + control.abort(); + visitLatch.countDown(); + visit.waitUntilDone(0); + assertSame(VisitorControlHandler.CompletionCode.ABORTED, + control.getResult().getCode()); + assertEquals(Set.of(new Document(musicType, doc1.getId())), + received); + } + +} diff --git a/documentapi/src/test/java/com/yahoo/documentapi/local/test/LocalDocumentApiTestCase.java b/documentapi/src/test/java/com/yahoo/documentapi/local/test/LocalDocumentApiTestCase.java deleted file mode 100644 index 252bf739951..00000000000 --- a/documentapi/src/test/java/com/yahoo/documentapi/local/test/LocalDocumentApiTestCase.java +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.documentapi.local.test; - -import com.yahoo.document.*; -import com.yahoo.documentapi.*; -import com.yahoo.documentapi.local.LocalDocumentAccess; -import com.yahoo.documentapi.test.AbstractDocumentApiTestCase; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import static org.junit.Assert.*; - -/** - * Runs the superclass tests on this implementation - * - * @author bratseth - */ -public class LocalDocumentApiTestCase extends AbstractDocumentApiTestCase { - - protected DocumentAccess access; - - @Override - protected DocumentAccess access() { - return access; - } - - @Before - public void setUp() { - DocumentAccessParams params = new DocumentAccessParams(); - params.setDocumentManagerConfigId("file:src/test/cfg/documentmanager.cfg"); - access = new LocalDocumentAccess(params); - } - - @After - public void shutdownAccess() { - access.shutdown(); - } - - @Test - public void testNoExceptionFromAsync() { - AsyncSession session = access.createAsyncSession(new AsyncParameters()); - - DocumentType type = access.getDocumentTypeManager().getDocumentType("music"); - DocumentUpdate docUp = new DocumentUpdate(type, new DocumentId("id:ns:music::2")); - - Result result = session.update(docUp); - assertTrue(result.isSuccess()); - Response response = session.getNext(); - assertEquals(result.getRequestId(), response.getRequestId()); - assertFalse(response.isSuccess()); - session.destroy(); - } - - @Test - public void testAsyncFetch() { - AsyncSession session = access.createAsyncSession(new AsyncParameters()); - List<DocumentId> ids = new ArrayList<>(); - ids.add(new DocumentId("id:music:music::1")); - ids.add(new DocumentId("id:music:music::2")); - ids.add(new DocumentId("id:music:music::3")); - for (DocumentId id : ids) - session.put(new Document(access.getDocumentTypeManager().getDocumentType("music"), id)); - int timeout = 100; - - long startTime = System.currentTimeMillis(); - Set<Long> outstandingRequests = new HashSet<>(); - for (DocumentId id : ids) { - Result result = session.get(id); - if ( ! result.isSuccess()) - throw new IllegalStateException("Failed requesting document " + id, result.getError().getCause()); - outstandingRequests.add(result.getRequestId()); - } - - List<Document> documents = new ArrayList<>(); - try { - while ( ! outstandingRequests.isEmpty()) { - int timeSinceStart = (int)(System.currentTimeMillis() - startTime); - Response response = session.getNext(timeout - timeSinceStart); - if (response == null) - throw new RuntimeException("Timed out waiting for documents"); // or return what you have - if ( ! outstandingRequests.contains(response.getRequestId())) continue; // Stale: Ignore - - if (response.isSuccess()) - documents.add(((DocumentResponse)response).getDocument()); - outstandingRequests.remove(response.getRequestId()); - } - } - catch (InterruptedException e) { - throw new RuntimeException("Interrupted while waiting for documents", e); - } - - assertEquals(3, documents.size()); - for (Document document : documents) - assertNotNull(document); - } - -} diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index b68440795d4..67f9fa19dc0 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -17,6 +17,7 @@ vespa_define_module( src/tests/eval/function src/tests/eval/function_speed src/tests/eval/gbdt + src/tests/eval/inline_operation src/tests/eval/interpreted_function src/tests/eval/node_tools src/tests/eval/node_types diff --git a/eval/src/tests/eval/inline_operation/CMakeLists.txt b/eval/src/tests/eval/inline_operation/CMakeLists.txt new file mode 100644 index 00000000000..04cdbca3abf --- /dev/null +++ b/eval/src/tests/eval/inline_operation/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_inline_operation_test_app TEST + SOURCES + inline_operation_test.cpp + DEPENDS + vespaeval + gtest +) +vespa_add_test(NAME eval_inline_operation_test_app COMMAND eval_inline_operation_test_app) diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp new file mode 100644 index 00000000000..4520176e276 --- /dev/null +++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp @@ -0,0 +1,156 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> +#include <vespa/eval/eval/function.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::eval::operation; + +template <typename T> struct IsInlined { constexpr static bool value = true; }; +template <> struct IsInlined<CallOp1> { constexpr static bool value = false; }; +template <> struct IsInlined<CallOp2> { constexpr static bool value = false; }; + +template <typename T> double test_op1(op1_t ref, double a, bool inlined) { + T op(ref); + EXPECT_EQ(IsInlined<T>::value, inlined); + EXPECT_EQ(op(a), ref(a)); + return op(a); +}; + +template <typename T> double test_op2(op2_t ref, double a, double b, bool inlined) { + T op(ref); + EXPECT_EQ(IsInlined<T>::value, inlined); + EXPECT_EQ(op(a,b), ref(a,b)); + return op(a,b); +}; + +op1_t as_op1(const vespalib::string &str) { + auto fun = Function::parse({"a"}, str); + auto res = lookup_op1(*fun); + EXPECT_TRUE(res.has_value()); + return res.value(); +} + +op2_t as_op2(const vespalib::string &str) { + auto fun = Function::parse({"a", "b"}, str); + auto res = lookup_op2(*fun); + EXPECT_TRUE(res.has_value()); + return res.value(); +} + +TEST(InlineOperationTest, op1_lambdas_are_recognized) { + EXPECT_EQ(as_op1("-a"), Neg::f); + EXPECT_EQ(as_op1("!a"), Not::f); + EXPECT_EQ(as_op1("cos(a)"), Cos::f); + EXPECT_EQ(as_op1("sin(a)"), Sin::f); + EXPECT_EQ(as_op1("tan(a)"), Tan::f); + EXPECT_EQ(as_op1("cosh(a)"), Cosh::f); + EXPECT_EQ(as_op1("sinh(a)"), Sinh::f); + EXPECT_EQ(as_op1("tanh(a)"), Tanh::f); + EXPECT_EQ(as_op1("acos(a)"), Acos::f); + EXPECT_EQ(as_op1("asin(a)"), Asin::f); + EXPECT_EQ(as_op1("atan(a)"), Atan::f); + EXPECT_EQ(as_op1("exp(a)"), Exp::f); + EXPECT_EQ(as_op1("log10(a)"), Log10::f); + EXPECT_EQ(as_op1("log(a)"), Log::f); + EXPECT_EQ(as_op1("sqrt(a)"), Sqrt::f); + EXPECT_EQ(as_op1("ceil(a)"), Ceil::f); + EXPECT_EQ(as_op1("fabs(a)"), Fabs::f); + EXPECT_EQ(as_op1("floor(a)"), Floor::f); + EXPECT_EQ(as_op1("isNan(a)"), IsNan::f); + EXPECT_EQ(as_op1("relu(a)"), Relu::f); + EXPECT_EQ(as_op1("sigmoid(a)"), Sigmoid::f); + EXPECT_EQ(as_op1("elu(a)"), Elu::f); +} + +TEST(InlineOperationTest, op1_lambdas_are_recognized_with_different_parameter_names) { + EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "-x")).value(), Neg::f); + EXPECT_EQ(lookup_op1(*Function::parse({"x"}, "!x")).value(), Not::f); +} + +TEST(InlineOperationTest, non_op1_lambdas_are_not_recognized) { + EXPECT_FALSE(lookup_op1(*Function::parse({"a"}, "a*a")).has_value()); + EXPECT_FALSE(lookup_op1(*Function::parse({"a", "b"}, "a+b")).has_value()); +} + +TEST(InlineOperationTest, op2_lambdas_are_recognized) { + EXPECT_EQ(as_op2("a+b"), Add::f); + EXPECT_EQ(as_op2("a-b"), Sub::f); + EXPECT_EQ(as_op2("a*b"), Mul::f); + EXPECT_EQ(as_op2("a/b"), Div::f); + EXPECT_EQ(as_op2("a%b"), Mod::f); + EXPECT_EQ(as_op2("a^b"), Pow::f); + EXPECT_EQ(as_op2("a==b"), Equal::f); + EXPECT_EQ(as_op2("a!=b"), NotEqual::f); + EXPECT_EQ(as_op2("a~=b"), Approx::f); + EXPECT_EQ(as_op2("a<b"), Less::f); + EXPECT_EQ(as_op2("a<=b"), LessEqual::f); + EXPECT_EQ(as_op2("a>b"), Greater::f); + EXPECT_EQ(as_op2("a>=b"), GreaterEqual::f); + EXPECT_EQ(as_op2("a&&b"), And::f); + EXPECT_EQ(as_op2("a||b"), Or::f); + EXPECT_EQ(as_op2("atan2(a,b)"), Atan2::f); + EXPECT_EQ(as_op2("ldexp(a,b)"), Ldexp::f); + EXPECT_EQ(as_op2("pow(a,b)"), Pow::f); + EXPECT_EQ(as_op2("fmod(a,b)"), Mod::f); + EXPECT_EQ(as_op2("min(a,b)"), Min::f); + EXPECT_EQ(as_op2("max(a,b)"), Max::f); +} + +TEST(InlineOperationTest, op2_lambdas_are_recognized_with_different_parameter_names) { + EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x+y")).value(), Add::f); + EXPECT_EQ(lookup_op2(*Function::parse({"x", "y"}, "x-y")).value(), Sub::f); +} + +TEST(InlineOperationTest, non_op2_lambdas_are_not_recognized) { + EXPECT_FALSE(lookup_op2(*Function::parse({"a"}, "-a")).has_value()); + EXPECT_FALSE(lookup_op2(*Function::parse({"a", "b"}, "b+a")).has_value()); +} + +TEST(InlineOperationTest, generic_op1_wrapper_works) { + CallOp1 op(Neg::f); + EXPECT_EQ(op(3), -3); + EXPECT_EQ(op(-5), 5); +} + +TEST(InlineOperationTest, generic_op2_wrapper_works) { + CallOp2 op(Add::f); + EXPECT_EQ(op(2,3), 5); + EXPECT_EQ(op(3,7), 10); +} + +TEST(InlineOperationTest, inline_op2_example_works) { + op2_t ignored = nullptr; + InlineOp2<Add> op(ignored); + EXPECT_EQ(op(2,3), 5); + EXPECT_EQ(op(3,7), 10); +} + +TEST(InlineOperationTest, parameter_swap_wrapper_works) { + CallOp2 op(Sub::f); + SwapArgs2<CallOp2> swap_op(Sub::f); + EXPECT_EQ(op(2,3), -1); + EXPECT_EQ(swap_op(2,3), 1); + EXPECT_EQ(op(3,7), -4); + EXPECT_EQ(swap_op(3,7), 4); +} + +TEST(InlineOperationTest, resolved_op1_works) { + auto a = TypifyOp1::resolve(Neg::f, [](auto t){ return test_op1<typename decltype(t)::type>(Neg::f, 2.0, false); }); + // putting the lambda inside the EXPECT does not work + EXPECT_EQ(a, -2.0); +} + +TEST(InlineOperationTest, resolved_op2_works) { + auto a = TypifyOp2::resolve(Add::f, [](auto t){ return test_op2<typename decltype(t)::type>(Add::f, 2.0, 5.0, true); }); + auto b = TypifyOp2::resolve(Mul::f, [](auto t){ return test_op2<typename decltype(t)::type>(Mul::f, 5.0, 3.0, true); }); + auto c = TypifyOp2::resolve(Sub::f, [](auto t){ return test_op2<typename decltype(t)::type>(Sub::f, 8.0, 5.0, false); }); + // putting the lambda inside the EXPECT does not work + EXPECT_EQ(a, 7.0); + EXPECT_EQ(b, 15.0); + EXPECT_EQ(c, 3.0); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp index a571837b8e9..92fdbfade46 100644 --- a/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp +++ b/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp @@ -67,7 +67,6 @@ TEST("require that matmul can be optimized") { TEST("require that matmul with lambda can be optimized") { TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true)); - TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, true, true)); } TEST("require that expressions similar to matmul are not optimized") { @@ -75,6 +74,7 @@ TEST("require that expressions similar to matmul are not optimized") { TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum,b)")); TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,prod,d)")); TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x+y)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*y)),sum,d)")); diff --git a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp index c0823248538..f9c563c9bf8 100644 --- a/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp +++ b/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp @@ -78,7 +78,6 @@ TEST("require that single multi matmul can be optimized") { TEST("require that multi matmul with lambda can be optimized") { TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, 6, true, true)); - TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, 6, true, true)); } TEST("require that expressions similar to multi matmul are not optimized") { @@ -86,6 +85,7 @@ TEST("require that expressions similar to multi matmul are not optimized") { TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,b)")); TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,prod,d)")); TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x+y)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*x)),sum,d)")); TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*y)),sum,d)")); diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp index 0b924451907..3ecc3f66cda 100644 --- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp +++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp @@ -130,13 +130,13 @@ TEST("require that xw product gives same results as reference join/reduce") { TEST("require that various variants of xw product can be optimized") { TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true)); - TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)", 3, 2, true)); } TEST("require that expressions similar to xw product are not optimized") { TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)")); TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)")); TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)")); + TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)")); TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)")); diff --git a/eval/src/vespa/eval/eval/inline_operation.h b/eval/src/vespa/eval/eval/inline_operation.h new file mode 100644 index 00000000000..493de9ea56c --- /dev/null +++ b/eval/src/vespa/eval/eval/inline_operation.h @@ -0,0 +1,67 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "operation.h" +#include <vespa/vespalib/util/typify.h> + +namespace vespalib::eval::operation { + +//----------------------------------------------------------------------------- + +struct CallOp1 { + op1_t my_op1; + CallOp1(op1_t op1) : my_op1(op1) {} + double operator()(double a) const { return my_op1(a); } +}; + +struct TypifyOp1 { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(op1_t value, F &&f) { + (void) value; + return f(Result<CallOp1>()); + } +}; + +//----------------------------------------------------------------------------- + +struct CallOp2 { + op2_t my_op2; + CallOp2(op2_t op2) : my_op2(op2) {} + op2_t get() const { return my_op2; } + double operator()(double a, double b) const { return my_op2(a, b); } +}; + +template <typename Op2> +struct SwapArgs2 { + Op2 op2; + SwapArgs2(op2_t op2_in) : op2(op2_in) {} + template <typename A, typename B> constexpr auto operator()(A a, B b) const { return op2(b, a); } +}; + +template <typename T> struct InlineOp2; +template <> struct InlineOp2<Add> { + InlineOp2(op2_t) {} + template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a+b); } +}; +template <> struct InlineOp2<Mul> { + InlineOp2(op2_t) {} + template <typename A, typename B> constexpr auto operator()(A a, B b) const { return (a*b); } +}; + +struct TypifyOp2 { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(op2_t value, F &&f) { + if (value == Add::f) { + return f(Result<InlineOp2<Add>>()); + } else if (value == Mul::f) { + return f(Result<InlineOp2<Mul>>()); + } else { + return f(Result<CallOp2>()); + } + } +}; + +//----------------------------------------------------------------------------- + +} diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 3a73a3b8784..02d16caae6b 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -15,25 +15,6 @@ namespace vespalib::eval { namespace { using namespace nodes; -using map_fun_t = double (*)(double); -using join_fun_t = double (*)(double, double); - -//----------------------------------------------------------------------------- - -// TODO(havardpe): generic function pointer resolving for all single -// operation lambdas. - -template <typename OP2> -bool is_op2(const Function &lambda) { - if (lambda.num_params() == 2) { - if (auto op2 = as<OP2>(lambda.root())) { - auto sym1 = as<Symbol>(op2->lhs()); - auto sym2 = as<Symbol>(op2->rhs()); - return (sym1 && sym2 && (sym1->id() != sym2->id())); - } - } - return false; -} //----------------------------------------------------------------------------- @@ -63,13 +44,13 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::reduce(a, aggr, dimensions, stash); } - void make_map(const Node &, map_fun_t function) { + void make_map(const Node &, operation::op1_t function) { assert(stack.size() >= 1); const auto &a = stack.back().get(); stack.back() = tensor_function::map(a, function, stash); } - void make_join(const Node &, join_fun_t function) { + void make_join(const Node &, operation::op2_t function) { assert(stack.size() >= 2); const auto &b = stack.back().get(); stack.pop_back(); @@ -77,7 +58,7 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::join(a, b, function, stash); } - void make_merge(const Node &, join_fun_t function) { + void make_merge(const Node &, operation::op2_t function) { assert(stack.size() >= 2); const auto &b = stack.back().get(); stack.pop_back(); @@ -203,14 +184,16 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { abort(); } void visit(const TensorMap &node) override { - const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); - make_map(node, token.get()->get().get_function<1>()); + if (auto op1 = operation::lookup_op1(node.lambda())) { + make_map(node, op1.value()); + } else { + const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); + make_map(node, token.get()->get().get_function<1>()); + } } void visit(const TensorJoin &node) override { - if (is_op2<Mul>(node.lambda())) { - make_join(node, operation::Mul::f); - } else if (is_op2<Add>(node.lambda())) { - make_join(node, operation::Add::f); + if (auto op2 = operation::lookup_op2(node.lambda())) { + make_join(node, op2.value()); } else { const auto &token = stash.create<CompileCache::Token::UP>(CompileCache::compile(node.lambda(), PassParams::SEPARATE)); make_join(node, token.get()->get().get_function<2>()); diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index fa0a99de461..581f65c0e31 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "operation.h" +#include "function.h" +#include "key_gen.h" #include <vespa/vespalib/util/approx.h> #include <algorithm> @@ -48,4 +50,97 @@ double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } +namespace { + +template <typename T> +void add_op(std::map<vespalib::string,T> &map, const Function &fun, T op) { + assert(!fun.has_error()); + auto key = gen_key(fun, PassParams::SEPARATE); + auto res = map.emplace(key, op); + assert(res.second); +} + +template <typename T> +std::optional<T> lookup_op(const std::map<vespalib::string,T> &map, const Function &fun) { + auto key = gen_key(fun, PassParams::SEPARATE); + auto pos = map.find(key); + if (pos != map.end()) { + return pos->second; + } + return std::nullopt; +} + +void add_op1(std::map<vespalib::string,op1_t> &map, const vespalib::string &expr, op1_t op) { + add_op(map, *Function::parse({"a"}, expr), op); +} + +void add_op2(std::map<vespalib::string,op2_t> &map, const vespalib::string &expr, op2_t op) { + add_op(map, *Function::parse({"a", "b"}, expr), op); +} + +std::map<vespalib::string,op1_t> make_op1_map() { + std::map<vespalib::string,op1_t> map; + add_op1(map, "-a", Neg::f); + add_op1(map, "!a", Not::f); + add_op1(map, "cos(a)", Cos::f); + add_op1(map, "sin(a)", Sin::f); + add_op1(map, "tan(a)", Tan::f); + add_op1(map, "cosh(a)", Cosh::f); + add_op1(map, "sinh(a)", Sinh::f); + add_op1(map, "tanh(a)", Tanh::f); + add_op1(map, "acos(a)", Acos::f); + add_op1(map, "asin(a)", Asin::f); + add_op1(map, "atan(a)", Atan::f); + add_op1(map, "exp(a)", Exp::f); + add_op1(map, "log10(a)", Log10::f); + add_op1(map, "log(a)", Log::f); + add_op1(map, "sqrt(a)", Sqrt::f); + add_op1(map, "ceil(a)", Ceil::f); + add_op1(map, "fabs(a)", Fabs::f); + add_op1(map, "floor(a)", Floor::f); + add_op1(map, "isNan(a)", IsNan::f); + add_op1(map, "relu(a)", Relu::f); + add_op1(map, "sigmoid(a)", Sigmoid::f); + add_op1(map, "elu(a)", Elu::f); + return map; +} + +std::map<vespalib::string,op2_t> make_op2_map() { + std::map<vespalib::string,op2_t> map; + add_op2(map, "a+b", Add::f); + add_op2(map, "a-b", Sub::f); + add_op2(map, "a*b", Mul::f); + add_op2(map, "a/b", Div::f); + add_op2(map, "a%b", Mod::f); + add_op2(map, "a^b", Pow::f); + add_op2(map, "a==b", Equal::f); + add_op2(map, "a!=b", NotEqual::f); + add_op2(map, "a~=b", Approx::f); + add_op2(map, "a<b", Less::f); + add_op2(map, "a<=b", LessEqual::f); + add_op2(map, "a>b", Greater::f); + add_op2(map, "a>=b", GreaterEqual::f); + add_op2(map, "a&&b", And::f); + add_op2(map, "a||b", Or::f); + add_op2(map, "atan2(a,b)", Atan2::f); + add_op2(map, "ldexp(a,b)", Ldexp::f); + add_op2(map, "pow(a,b)", Pow::f); + add_op2(map, "fmod(a,b)", Mod::f); + add_op2(map, "min(a,b)", Min::f); + add_op2(map, "max(a,b)", Max::f); + return map; +} + +} // namespace <unnamed> + +std::optional<op1_t> lookup_op1(const Function &fun) { + static const std::map<vespalib::string,op1_t> map = make_op1_map(); + return lookup_op(map, fun); +} + +std::optional<op2_t> lookup_op2(const Function &fun) { + static const std::map<vespalib::string,op2_t> map = make_op2_map(); + return lookup_op(map, fun); +} + } diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index fa99f51a308..a80193e704d 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -1,6 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <optional> + +namespace vespalib::eval { class Function; } namespace vespalib::eval::operation { @@ -46,4 +49,10 @@ struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; struct Elu { static double f(double a); }; +using op1_t = double (*)(double); +using op2_t = double (*)(double, double); + +std::optional<op1_t> lookup_op1(const Function &fun); +std::optional<op2_t> lookup_op2(const Function &fun); + } diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 3e91240048b..a8ae9c44bb0 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -2,6 +2,7 @@ #pragma once +#include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/stllike/string.h> #include <vector> @@ -104,4 +105,15 @@ template <typename CT> inline ValueType::CellType get_cell_type(); template <> inline ValueType::CellType get_cell_type<double>() { return ValueType::CellType::DOUBLE; } template <> inline ValueType::CellType get_cell_type<float>() { return ValueType::CellType::FLOAT; } +struct TypifyCellType { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(ValueType::CellType value, F &&f) { + switch(value) { + case ValueType::CellType::DOUBLE: return f(Result<double>()); + case ValueType::CellType::FLOAT: return f(Result<float>()); + } + abort(); + } +}; + } // namespace diff --git a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp index 6b0d65c0743..c358c9d618d 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_simple_join_function.cpp @@ -5,6 +5,8 @@ #include <vespa/vespalib/objects/objectvisitor.h> #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/operation.h> +#include <vespa/eval/eval/inline_operation.h> +#include <vespa/vespalib/util/typify.h> #include <optional> #include <algorithm> @@ -16,6 +18,7 @@ using eval::Value; using eval::ValueType; using eval::TensorFunction; using eval::TensorEngine; +using eval::TypifyCellType; using eval::as; using namespace eval::operation; @@ -30,6 +33,18 @@ using State = eval::InterpretedFunction::State; namespace { +struct TypifyOverlap { + template <Overlap VALUE> using Result = TypifyResultValue<Overlap, VALUE>; + template <typename F> static decltype(auto) resolve(Overlap value, F &&f) { + switch (value) { + case Overlap::INNER: return f(Result<Overlap::INNER>()); + case Overlap::OUTER: return f(Result<Overlap::OUTER>()); + case Overlap::FULL: return f(Result<Overlap::FULL>()); + } + abort(); + } +}; + struct JoinParams { const ValueType &result_type; size_t factor; @@ -38,44 +53,17 @@ struct JoinParams { : result_type(result_type_in), factor(factor_in), function(function_in) {} }; -struct CallFun { - join_fun_t function; - CallFun(const JoinParams ¶ms) : function(params.function) {} - double eval(double a, double b) const { return function(a, b); } -}; - -struct AddFun { - AddFun(const JoinParams &) {} - template <typename A, typename B> - auto eval(A a, B b) const { return (a + b); } -}; - -struct MulFun { - MulFun(const JoinParams &) {} - template <typename A, typename B> - auto eval(A a, B b) const { return (a * b); } -}; - -// needed for asymmetric operations like Sub and Div -template <typename Fun> -struct SwapFun { - Fun fun; - SwapFun(const JoinParams ¶ms) : fun(params) {} - template <typename A, typename B> - auto eval(A a, B b) const { return fun.eval(b, a); } -}; - template <typename OCT, typename PCT, typename SCT, typename Fun> void apply_fun_1_to_n(OCT *dst, const PCT *pri, SCT sec, size_t n, const Fun &fun) { for (size_t i = 0; i < n; ++i) { - dst[i] = fun.eval(pri[i], sec); + dst[i] = fun(pri[i], sec); } } template <typename OCT, typename PCT, typename SCT, typename Fun> void apply_fun_n_to_n(OCT *dst, const PCT *pri, const SCT *sec, size_t n, const Fun &fun) { for (size_t i = 0; i < n; ++i) { - dst[i] = fun.eval(pri[i], sec[i]); + dst[i] = fun(pri[i], sec[i]); } } @@ -93,9 +81,9 @@ void my_simple_join_op(State &state, uint64_t param) { using PCT = typename std::conditional<swap,RCT,LCT>::type; using SCT = typename std::conditional<swap,LCT,RCT>::type; using OCT = typename eval::UnifyCellTypes<PCT,SCT>::type; - using OP = typename std::conditional<swap,SwapFun<Fun>,Fun>::type; + using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type; const JoinParams ¶ms = *(JoinParams*)param; - OP my_op(params); + OP my_op(params.function); auto pri_cells = DenseTensorView::typify_cells<PCT>(state.peek(swap ? 0 : 1)); auto sec_cells = DenseTensorView::typify_cells<SCT>(state.peek(swap ? 1 : 0)); auto dst_cells = make_dst_cells<OCT, pri_mut>(pri_cells, state.stash); @@ -122,67 +110,13 @@ void my_simple_join_op(State &state, uint64_t param) { //----------------------------------------------------------------------------- -template <typename Fun, bool swap, Overlap overlap, bool pri_mut> -struct MySimpleJoinOp { - template <typename LCT, typename RCT> - static auto get_fun() { return my_simple_join_op<LCT,RCT,Fun,swap,overlap,pri_mut>; } -}; - -template <bool swap, Overlap overlap, bool pri_mut> -op_function my_select_4(ValueType::CellType lct, - ValueType::CellType rct, - join_fun_t fun_hint) -{ - if (fun_hint == Add::f) { - return select_2<MySimpleJoinOp<AddFun,swap,overlap,pri_mut>>(lct, rct); - } else if (fun_hint == Mul::f) { - return select_2<MySimpleJoinOp<MulFun,swap,overlap,pri_mut>>(lct, rct); - } else { - return select_2<MySimpleJoinOp<CallFun,swap,overlap,pri_mut>>(lct, rct); - } -} - -template <bool swap, Overlap overlap> -op_function my_select_3(ValueType::CellType lct, - ValueType::CellType rct, - bool pri_mut, - join_fun_t fun_hint) -{ - if (pri_mut) { - return my_select_4<swap, overlap, true>(lct, rct, fun_hint); - } else { - return my_select_4<swap, overlap, false>(lct, rct, fun_hint); +struct MyGetFun { + template <typename R1, typename R2, typename R3, typename R4, typename R5, typename R6> static auto invoke() { + return my_simple_join_op<R1, R2, R3, R4::value, R5::value, R6::value>; } -} - -template <bool swap> -op_function my_select_2(ValueType::CellType lct, - ValueType::CellType rct, - Overlap overlap, - bool pri_mut, - join_fun_t fun_hint) -{ - switch (overlap) { - case Overlap::INNER: return my_select_3<swap, Overlap::INNER>(lct, rct, pri_mut, fun_hint); - case Overlap::OUTER: return my_select_3<swap, Overlap::OUTER>(lct, rct, pri_mut, fun_hint); - case Overlap::FULL: return my_select_3<swap, Overlap::FULL>(lct, rct, pri_mut, fun_hint); - } - abort(); -} +}; -op_function my_select(ValueType::CellType lct, - ValueType::CellType rct, - Primary primary, - Overlap overlap, - bool pri_mut, - join_fun_t fun_hint) -{ - switch (primary) { - case Primary::LHS: return my_select_2<false>(lct, rct, overlap, pri_mut, fun_hint); - case Primary::RHS: return my_select_2<true>(lct, rct, overlap, pri_mut, fun_hint); - } - abort(); -} +using MyTypify = TypifyValue<TypifyCellType,TypifyOp2,TypifyBool,TypifyOverlap>; //----------------------------------------------------------------------------- @@ -280,11 +214,10 @@ Instruction DenseSimpleJoinFunction::compile_self(const TensorEngine &, Stash &stash) const { const JoinParams ¶ms = stash.create<JoinParams>(result_type(), factor(), function()); - auto op = my_select(lhs().result_type().cell_type(), - rhs().result_type().cell_type(), - _primary, _overlap, - primary_is_mutable(), - function()); + auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(), + rhs().result_type().cell_type(), + function(), (_primary == Primary::RHS), + _overlap, primary_is_mutable()); static_assert(sizeof(uint64_t) == sizeof(¶ms)); return Instruction(op, (uint64_t)(¶ms)); } diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 74522f4c517..30383393e97 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -287,13 +287,6 @@ public class Flags { CONSOLE_USER_EMAIL ); - public static final UnboundBooleanFlag CONFIGSERVER_PROVISION_LB = defineFeatureFlag( - "configserver-provision-lb", false, - "Provision load balancer for config server cluster", - "Takes effect when zone-config-servers application is redeployed", - ZONE_ID - ); - /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, String description, String modificationEffect, FetchVector.Dimension... dimensions) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java index edf2932ad6e..09723d83e3e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerService.java @@ -3,7 +3,9 @@ package com.yahoo.vespa.hosted.provision.lb; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; -import com.yahoo.config.provision.NodeType; +import com.yahoo.vespa.hosted.provision.NodeRepository; + +import java.util.Set; /** * A managed load balance service. @@ -13,14 +15,17 @@ import com.yahoo.config.provision.NodeType; public interface LoadBalancerService { /** - * Create a load balancer from the given specification. Implementations are expected to be idempotent + * Create a load balancer for given application cluster. Implementations are expected to be idempotent * - * @param spec Load balancer specification + * @param application Application owning the LB + * @param cluster Target cluster of the LB + * @param reals Reals that should be configured on the LB * @param force Whether reconfiguration should be forced (e.g. allow configuring an empty set of reals on a * pre-existing load balancer). * @return The provisioned load balancer instance */ - LoadBalancerInstance create(LoadBalancerSpec spec, boolean force); + LoadBalancerInstance create(ApplicationId application, ClusterSpec.Id cluster, Set<Real> reals, boolean force, + NodeRepository nodeRepository); /** Permanently remove load balancer for given application cluster */ void remove(ApplicationId application, ClusterSpec.Id cluster); @@ -28,12 +33,6 @@ public interface LoadBalancerService { /** Returns the protocol supported by this load balancer service */ Protocol protocol(); - /** Returns whether load balancers created by this service can forward traffic to given node and cluster type */ - default boolean canForwardTo(NodeType nodeType, ClusterSpec.Type clusterType) { - return (nodeType == NodeType.tenant && clusterType.isContainer()) || - (nodeType == NodeType.config && clusterType == ClusterSpec.Type.admin); - } - /** Load balancer protocols */ enum Protocol { ipv4, diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java index f4d689056c3..9bd1189420a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java @@ -5,11 +5,13 @@ import com.google.common.collect.ImmutableSet; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.HostName; +import com.yahoo.vespa.hosted.provision.NodeRepository; import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.Set; /** * @author mpolden @@ -28,18 +30,19 @@ public class LoadBalancerServiceMock implements LoadBalancerService { } @Override - public LoadBalancerInstance create(LoadBalancerSpec spec, boolean force) { - var id = new LoadBalancerId(spec.application(), spec.cluster()); + public LoadBalancerInstance create(ApplicationId application, ClusterSpec.Id cluster, Set<Real> reals, boolean force, + NodeRepository nodeRepository) { + var id = new LoadBalancerId(application, cluster); var oldInstance = instances.get(id); - if (!force && oldInstance != null && !oldInstance.reals().isEmpty() && spec.reals().isEmpty()) { + if (!force && oldInstance != null && !oldInstance.reals().isEmpty() && reals.isEmpty()) { throw new IllegalArgumentException("Refusing to remove all reals from load balancer " + id); } var instance = new LoadBalancerInstance( - HostName.from("lb-" + spec.application().toShortString() + "-" + spec.cluster().value()), + HostName.from("lb-" + application.toShortString() + "-" + cluster.value()), Optional.of(new DnsZone("zone-id-1")), Collections.singleton(4443), ImmutableSet.of("10.2.3.0/24", "10.4.5.0/24"), - spec.reals()); + reals); instances.put(id, instance); return instance; } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerSpec.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerSpec.java deleted file mode 100644 index 198f7f347ef..00000000000 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerSpec.java +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.provision.lb; - -import com.yahoo.config.provision.ApplicationId; -import com.yahoo.config.provision.ClusterSpec; - -import java.util.Objects; -import java.util.Set; - -/** - * A specification for a load balancer. - * - * @author mpolden - */ -public class LoadBalancerSpec { - - private final ApplicationId application; - private final ClusterSpec.Id cluster; - private final Set<Real> reals; - - public LoadBalancerSpec(ApplicationId application, ClusterSpec.Id cluster, Set<Real> reals) { - this.application = Objects.requireNonNull(application); - this.cluster = Objects.requireNonNull(cluster); - this.reals = Set.copyOf(Objects.requireNonNull(reals)); - } - - /** Owner of the load balancer */ - public ApplicationId application() { - return application; - } - - /** The target cluster of this load balancer */ - public ClusterSpec.Id cluster() { - return cluster; - } - - /** Real servers to attach to this load balancer */ - public Set<Real> reals() { - return reals; - } - -} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerService.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerService.java index 7667672e470..07074bc45af 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerService.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerService.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.provision.lb; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.vespa.hosted.provision.NodeRepository; import java.util.Comparator; import java.util.Optional; @@ -17,10 +18,11 @@ import java.util.Set; public class PassthroughLoadBalancerService implements LoadBalancerService { @Override - public LoadBalancerInstance create(LoadBalancerSpec spec, boolean force) { - var real = spec.reals().stream() - .min(Comparator.naturalOrder()) - .orElseThrow(() -> new IllegalArgumentException("No reals given")); + public LoadBalancerInstance create(ApplicationId application, ClusterSpec.Id cluster, Set<Real> reals, boolean force, + NodeRepository nodeRepository) { + var real = reals.stream() + .min(Comparator.naturalOrder()) + .orElseThrow(() -> new IllegalArgumentException("No reals given")); return new LoadBalancerInstance(real.hostname(), Optional.empty(), Set.of(real.port()), Set.of(real.ipAddress() + "/32"), Set.of()); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java index bc4381573c6..a8faafc0bad 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.lb; +import com.google.inject.Inject; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.HostName; @@ -26,14 +27,13 @@ public class SharedLoadBalancerService implements LoadBalancerService { private static final Comparator<Node> hostnameComparator = Comparator.comparing(Node::hostname); - private final NodeRepository nodeRepository; - - public SharedLoadBalancerService(NodeRepository nodeRepository) { - this.nodeRepository = Objects.requireNonNull(nodeRepository); + @Inject + public SharedLoadBalancerService() { } @Override - public LoadBalancerInstance create(LoadBalancerSpec spec, boolean force) { + public LoadBalancerInstance create(ApplicationId application, ClusterSpec.Id cluster, Set<Real> reals, boolean force, + NodeRepository nodeRepository) { var proxyNodes = new ArrayList<>(nodeRepository.getNodes(NodeType.proxy)); proxyNodes.sort(hostnameComparator); @@ -52,7 +52,7 @@ public class SharedLoadBalancerService implements LoadBalancerService { Optional.empty(), Set.of(4080, 4443), networkNames, - spec.reals() + reals ); } @@ -66,12 +66,6 @@ public class SharedLoadBalancerService implements LoadBalancerService { return Protocol.dualstack; } - @Override - public boolean canForwardTo(NodeType nodeType, ClusterSpec.Type clusterType) { - // Shared routing layer only supports routing to tenant nodes - return nodeType == NodeType.tenant && clusterType.isContainer(); - } - private static String withPrefixLength(String address) { if (IP.isV6(address)) { return address + "/128"; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/LoadBalancerExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/LoadBalancerExpirer.java index 6edd57de1c1..483b4dc8f84 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/LoadBalancerExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/LoadBalancerExpirer.java @@ -1,13 +1,13 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.maintenance; +import java.util.logging.Level; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; import com.yahoo.vespa.hosted.provision.lb.LoadBalancer.State; import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; import com.yahoo.vespa.hosted.provision.lb.LoadBalancerService; -import com.yahoo.vespa.hosted.provision.lb.LoadBalancerSpec; import com.yahoo.vespa.hosted.provision.persistence.CuratorDatabaseClient; import java.time.Duration; @@ -17,7 +17,6 @@ import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import java.util.logging.Level; import java.util.stream.Collectors; /** @@ -100,7 +99,7 @@ public class LoadBalancerExpirer extends NodeRepositoryMaintainer { // Remove any real no longer allocated to this application reals.removeIf(real -> !allocatedNodes.contains(real.hostname().value())); try { - service.create(new LoadBalancerSpec(lb.id().application(), lb.id().cluster(), reals), true); + service.create(lb.id().application(), lb.id().cluster(), reals, true, nodeRepository()); db.writeLoadBalancer(lb.with(lb.instance().withReals(reals))); } catch (Exception e) { failed.add(lb.id()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java index a61032af276..ebe9327967e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java @@ -111,10 +111,10 @@ class Activator { /** Activate load balancers */ private void activateLoadBalancers(ApplicationId application, Collection<HostSpec> hosts, NestedTransaction transaction, @SuppressWarnings("unused") Mutex applicationLock) { - loadBalancerProvisioner.ifPresent(provisioner -> provisioner.activate(application, allClustersOf(hosts), applicationLock, transaction)); + loadBalancerProvisioner.ifPresent(provisioner -> provisioner.activate(application, clustersOf(hosts), applicationLock, transaction)); } - private static Set<ClusterSpec> allClustersOf(Collection<HostSpec> hosts) { + private static Set<ClusterSpec> clustersOf(Collection<HostSpec> hosts) { return hosts.stream() .map(HostSpec::membership) .flatMap(Optional::stream) diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java index 460e1e71e65..c6945e1779b 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java @@ -8,9 +8,6 @@ import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.exception.LoadBalancerServiceException; import com.yahoo.transaction.Mutex; import com.yahoo.transaction.NestedTransaction; -import com.yahoo.vespa.flags.BooleanFlag; -import com.yahoo.vespa.flags.FlagSource; -import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; @@ -18,7 +15,6 @@ import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; import com.yahoo.vespa.hosted.provision.lb.LoadBalancerId; import com.yahoo.vespa.hosted.provision.lb.LoadBalancerInstance; import com.yahoo.vespa.hosted.provision.lb.LoadBalancerService; -import com.yahoo.vespa.hosted.provision.lb.LoadBalancerSpec; import com.yahoo.vespa.hosted.provision.lb.Real; import com.yahoo.vespa.hosted.provision.node.IP; import com.yahoo.vespa.hosted.provision.persistence.CuratorDatabaseClient; @@ -27,7 +23,6 @@ import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.function.Function; import java.util.logging.Level; @@ -50,13 +45,11 @@ public class LoadBalancerProvisioner { private final NodeRepository nodeRepository; private final CuratorDatabaseClient db; private final LoadBalancerService service; - private final BooleanFlag provisionConfigServerLoadBalancer; - public LoadBalancerProvisioner(NodeRepository nodeRepository, LoadBalancerService service, FlagSource flagSource) { + public LoadBalancerProvisioner(NodeRepository nodeRepository, LoadBalancerService service) { this.nodeRepository = nodeRepository; this.db = nodeRepository.database(); this.service = service; - this.provisionConfigServerLoadBalancer = Flags.CONFIGSERVER_PROVISION_LB.bindTo(flagSource); // Read and write all load balancers to make sure they are stored in the latest version of the serialization format for (var id : db.readLoadBalancerIds()) { try (var lock = db.lock(id.application())) { @@ -77,12 +70,11 @@ public class LoadBalancerProvisioner { * Calling this for irrelevant node or cluster types is a no-op. */ public void prepare(ApplicationId application, ClusterSpec cluster, NodeSpec requestedNodes) { - if (!canForwardTo(requestedNodes.type(), cluster)) return; // Nothing to provision for this node and cluster type + if (requestedNodes.type() != NodeType.tenant) return; // Nothing to provision for this node type + if (!cluster.type().isContainer()) return; // Nothing to provision for this cluster type if (application.instance().isTester()) return; // Do not provision for tester instances try (var lock = db.lock(application)) { - ClusterSpec.Id clusterId = effectiveId(cluster); - List<Node> nodes = nodesOf(clusterId, application); - provision(application, clusterId, nodes,false, lock); + provision(application, effectiveId(cluster), false, lock); } } @@ -99,14 +91,13 @@ public class LoadBalancerProvisioner { public void activate(ApplicationId application, Set<ClusterSpec> clusters, @SuppressWarnings("unused") Mutex applicationLock, NestedTransaction transaction) { try (var lock = db.lock(application)) { - for (var cluster : loadBalancedClustersOf(application).entrySet()) { + var containerClusters = containerClustersOf(clusters); + for (var clusterId : containerClusters) { // Provision again to ensure that load balancer instance is re-configured with correct nodes - provision(application, cluster.getKey(), cluster.getValue(), true, lock); + provision(application, clusterId, true, lock); } // Deactivate any surplus load balancers, i.e. load balancers for clusters that have been removed - var surplusLoadBalancers = surplusLoadBalancersOf(application, clusters.stream() - .map(LoadBalancerProvisioner::effectiveId) - .collect(Collectors.toSet())); + var surplusLoadBalancers = surplusLoadBalancersOf(application, containerClusters); deactivate(surplusLoadBalancers, transaction); } } @@ -147,17 +138,9 @@ public class LoadBalancerProvisioner { db.writeLoadBalancers(deactivatedLoadBalancers, transaction); } - // TODO(mpolden): Inline when feature flag is removed - private boolean canForwardTo(NodeType type, ClusterSpec cluster) { - boolean canForwardTo = service.canForwardTo(type, cluster.type()); - if (canForwardTo && type == NodeType.config) { - return provisionConfigServerLoadBalancer.value(); - } - return canForwardTo; - } /** Idempotently provision a load balancer for given application and cluster */ - private void provision(ApplicationId application, ClusterSpec.Id clusterId, List<Node> nodes, boolean activate, + private void provision(ApplicationId application, ClusterSpec.Id clusterId, boolean activate, @SuppressWarnings("unused") Mutex loadBalancersLock) { var id = new LoadBalancerId(application, clusterId); var now = nodeRepository.clock().instant(); @@ -165,7 +148,7 @@ public class LoadBalancerProvisioner { if (loadBalancer.isEmpty() && activate) return; // Nothing to activate as this load balancer was never prepared var force = loadBalancer.isPresent() && loadBalancer.get().state() != LoadBalancer.State.active; - var instance = provisionInstance(application, clusterId, nodes, force); + var instance = create(application, clusterId, allocatedContainers(application, clusterId), force); LoadBalancer newLoadBalancer; if (loadBalancer.isEmpty()) { newLoadBalancer = new LoadBalancer(id, instance, LoadBalancer.State.reserved, now); @@ -176,8 +159,7 @@ public class LoadBalancerProvisioner { db.writeLoadBalancer(newLoadBalancer); } - private LoadBalancerInstance provisionInstance(ApplicationId application, ClusterSpec.Id cluster, List<Node> nodes, - boolean force) { + private LoadBalancerInstance create(ApplicationId application, ClusterSpec.Id cluster, List<Node> nodes, boolean force) { var reals = new LinkedHashSet<Real>(); for (var node : nodes) { for (var ip : reachableIpAddresses(node)) { @@ -187,7 +169,7 @@ public class LoadBalancerProvisioner { log.log(Level.FINE, "Creating load balancer for " + cluster + " in " + application.toShortString() + ", targeting: " + reals); try { - return service.create(new LoadBalancerSpec(application, cluster, reals), force); + return service.create(application, cluster, reals, force, nodeRepository); } catch (Exception e) { throw new LoadBalancerServiceException("Failed to (re)configure load balancer for " + cluster + " in " + application + ", targeting: " + reals + ". The operation will be " + @@ -195,21 +177,14 @@ public class LoadBalancerProvisioner { } } - /** Returns the nodes allocated to the given load balanced cluster */ - private List<Node> nodesOf(ClusterSpec.Id loadBalancedCluster, ApplicationId application) { - return loadBalancedClustersOf(application).getOrDefault(loadBalancedCluster, List.of()); - } - - /** Returns the load balanced clusters of given application and their nodes */ - private Map<ClusterSpec.Id, List<Node>> loadBalancedClustersOf(ApplicationId application) { - NodeList nodes = NodeList.copyOf(nodeRepository.getNodes(Node.State.reserved, Node.State.active)) - .owner(application); - if (nodes.stream().anyMatch(node -> node.type() == NodeType.config)) { - nodes = nodes.nodeType(NodeType.config).type(ClusterSpec.Type.admin); - } else { - nodes = nodes.nodeType(NodeType.tenant).container(); - } - return nodes.stream().collect(Collectors.groupingBy(node -> effectiveId(node.allocation().get().membership().cluster()))); + /** Returns a list of active and reserved nodes of type container in given cluster */ + private List<Node> allocatedContainers(ApplicationId application, ClusterSpec.Id clusterId) { + return NodeList.copyOf(nodeRepository.getNodes(NodeType.tenant, Node.State.reserved, Node.State.active)) + .owner(application) + .matching(node -> node.state().isAllocated()) + .container() + .matching(node -> effectiveId(node.allocation().get().membership().cluster()).equals(clusterId)) + .asList(); } /** Find IP addresses reachable by the load balancer service */ @@ -227,6 +202,14 @@ public class LoadBalancerProvisioner { return reachable; } + /** Returns the container cluster IDs of the given clusters */ + private static Set<ClusterSpec.Id> containerClustersOf(Set<ClusterSpec> clusters) { + return clusters.stream() + .filter(c -> c.type().isContainer()) + .map(LoadBalancerProvisioner::effectiveId) + .collect(Collectors.toUnmodifiableSet()); + } + private static ClusterSpec.Id effectiveId(ClusterSpec cluster) { return cluster.combinedId().orElse(cluster.id()); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java index fe46070da1d..59fca955a68 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeRepositoryProvisioner.java @@ -7,6 +7,7 @@ import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.ClusterResources; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.HostFilter; import com.yahoo.config.provision.HostSpec; import com.yahoo.config.provision.NodeResources; @@ -14,6 +15,7 @@ import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.ProvisionLogger; import com.yahoo.config.provision.Provisioner; import com.yahoo.config.provision.Zone; +import java.util.logging.Level; import com.yahoo.transaction.Mutex; import com.yahoo.transaction.NestedTransaction; import com.yahoo.vespa.flags.FlagSource; @@ -34,7 +36,6 @@ import java.util.Collection; import java.util.Comparator; import java.util.List; import java.util.Optional; -import java.util.logging.Level; import java.util.logging.Logger; /** @@ -65,7 +66,7 @@ public class NodeRepositoryProvisioner implements Provisioner { this.allocationOptimizer = new AllocationOptimizer(nodeRepository); this.capacityPolicies = new CapacityPolicies(nodeRepository); this.zone = zone; - this.loadBalancerProvisioner = provisionServiceProvider.getLoadBalancerService().map(lbService -> new LoadBalancerProvisioner(nodeRepository, lbService, flagSource)); + this.loadBalancerProvisioner = provisionServiceProvider.getLoadBalancerService().map(lbService -> new LoadBalancerProvisioner(nodeRepository, lbService)); this.nodeResourceLimits = new NodeResourceLimits(nodeRepository); this.preparer = new Preparer(nodeRepository, zone.environment() == Environment.prod ? SPARE_CAPACITY_PROD : SPARE_CAPACITY_NONPROD, diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerServiceTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerServiceTest.java index 997aec8a156..e70fc184b87 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerServiceTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/PassthroughLoadBalancerServiceTest.java @@ -20,8 +20,8 @@ public class PassthroughLoadBalancerServiceTest { var lbService = new PassthroughLoadBalancerService(); var real = new Real(HostName.from("host1.example.com"), "192.0.2.10"); var reals = Set.of(real, new Real(HostName.from("host2.example.com"), "192.0.2.11")); - var instance = lbService.create(new LoadBalancerSpec(ApplicationId.from("tenant1", "app1", "default"), - ClusterSpec.Id.from("c1"), reals), false); + var instance = lbService.create(ApplicationId.from("tenant1", "app1", "default"), + ClusterSpec.Id.from("c1"), reals, false, null); assertEquals(real.hostname(), instance.hostname()); assertEquals(Set.of(real.port()), instance.ports()); assertEquals(Set.of(real.ipAddress() + "/32"), instance.networks()); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerServiceTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerServiceTest.java index 06f18d94c5f..64d189b9111 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerServiceTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerServiceTest.java @@ -19,7 +19,7 @@ import static org.junit.Assert.assertEquals; public class SharedLoadBalancerServiceTest { private final ProvisioningTester tester = new ProvisioningTester.Builder().build(); - private final SharedLoadBalancerService loadBalancerService = new SharedLoadBalancerService(tester.nodeRepository()); + private final SharedLoadBalancerService loadBalancerService = new SharedLoadBalancerService(); private final ApplicationId applicationId = ApplicationId.from("tenant1", "application1", "default"); private final ClusterSpec.Id clusterId = ClusterSpec.Id.from("qrs1"); private final Set<Real> reals = Set.of( @@ -30,7 +30,7 @@ public class SharedLoadBalancerServiceTest { @Test public void test_create_lb() { tester.makeReadyNodes(2, "default", NodeType.proxy); - var lb = loadBalancerService.create(new LoadBalancerSpec(applicationId, clusterId, reals), false); + var lb = loadBalancerService.create(applicationId, clusterId, reals, false, tester.nodeRepository()); assertEquals(HostName.from("host-1.yahoo.com"), lb.hostname()); assertEquals(Optional.empty(), lb.dnsZone()); @@ -40,7 +40,7 @@ public class SharedLoadBalancerServiceTest { @Test(expected = IllegalStateException.class) public void test_exception_on_missing_proxies() { - loadBalancerService.create(new LoadBalancerSpec(applicationId, clusterId, reals), false); + loadBalancerService.create(applicationId, clusterId, reals, false, tester.nodeRepository()); } @Test diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java index f48127f650d..26039c29ae8 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisionerTest.java @@ -11,8 +11,6 @@ import com.yahoo.config.provision.HostSpec; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.transaction.NestedTransaction; -import com.yahoo.vespa.flags.Flags; -import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.lb.LoadBalancer; import com.yahoo.vespa.hosted.provision.lb.LoadBalancerInstance; @@ -45,8 +43,7 @@ public class LoadBalancerProvisionerTest { private final ApplicationId app2 = ApplicationId.from("tenant2", "application2", "default"); private final ApplicationId infraApp1 = ApplicationId.from("vespa", "tenant-host", "default"); - private final InMemoryFlagSource flagSource = new InMemoryFlagSource(); - private final ProvisioningTester tester = new ProvisioningTester.Builder().flagSource(flagSource).build(); + private final ProvisioningTester tester = new ProvisioningTester.Builder().build(); @Test public void provision_load_balancer() { @@ -215,21 +212,6 @@ public class LoadBalancerProvisionerTest { assertEquals(combinedId, lbs.get().get(0).id().cluster()); } - @Test - public void provision_load_balancer_config_server_cluster() { - flagSource.withBooleanFlag(Flags.CONFIGSERVER_PROVISION_LB.id(), true); - ApplicationId configServerApp = ApplicationId.from("hosted-vespa", "zone-config-servers", "default"); - Supplier<List<LoadBalancer>> lbs = () -> tester.nodeRepository().loadBalancers(configServerApp).asList(); - var cluster = ClusterSpec.Id.from("zone-config-servers"); - var nodes = prepare(configServerApp, Capacity.fromRequiredNodeType(NodeType.config), false, - clusterRequest(ClusterSpec.Type.admin, cluster)); - assertEquals(1, lbs.get().size()); - assertEquals("Prepare provisions load balancer with reserved nodes", 2, lbs.get().get(0).instance().reals().size()); - tester.activate(configServerApp, nodes); - assertSame(LoadBalancer.State.active, lbs.get().get(0).state()); - assertEquals(cluster, lbs.get().get(0).id().cluster()); - } - private void dirtyNodesOf(ApplicationId application) { tester.nodeRepository().setDirty(tester.nodeRepository().getNodes(application), Agent.system, this.getClass().getSimpleName()); } diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index 2675bc16bf2..1ca9816a921 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -24,8 +24,8 @@ vespa_define_module( src/tests/assert src/tests/barrier src/tests/benchmark_timer - src/tests/btree src/tests/box + src/tests/btree src/tests/closure src/tests/component src/tests/compress @@ -128,6 +128,7 @@ vespa_define_module( src/tests/tutorial/minimal src/tests/tutorial/simple src/tests/tutorial/threads + src/tests/typify src/tests/util/generationhandler src/tests/util/generationhandler_stress src/tests/util/md5 diff --git a/vespalib/src/tests/typify/CMakeLists.txt b/vespalib/src/tests/typify/CMakeLists.txt new file mode 100644 index 00000000000..29e95af1988 --- /dev/null +++ b/vespalib/src/tests/typify/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_typify_test_app TEST + SOURCES + typify_test.cpp + DEPENDS + vespalib + gtest +) +vespa_add_test(NAME vespalib_typify_test_app COMMAND vespalib_typify_test_app) diff --git a/vespalib/src/tests/typify/typify_test.cpp b/vespalib/src/tests/typify/typify_test.cpp new file mode 100644 index 00000000000..4c3f1c512ca --- /dev/null +++ b/vespalib/src/tests/typify/typify_test.cpp @@ -0,0 +1,124 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib; + +struct A { static constexpr int value_from_type = 1; }; +struct B { static constexpr int value_from_type = 2; }; + +struct MyIntA { int value; }; +struct MyIntB { int value; }; +struct MyIntC { int value; }; // no typifier for this type + +// MyIntA -> A or B +struct TypifyMyIntA { + template <typename T> using Result = TypifyResultType<T>; + template <typename F> static decltype(auto) resolve(MyIntA value, F &&f) { + if (value.value == 1) { + return f(Result<A>()); + } else if (value.value == 2) { + return f(Result<B>()); + } + abort(); + } +}; + +// MyIntB -> TypifyResultValue<int,1> or TypifyResultValue<int,2> +struct TypifyMyIntB { + template <int VALUE> using Result = TypifyResultValue<int,VALUE>; + template <typename F> static decltype(auto) resolve(MyIntB value, F &&f) { + if (value.value == 1) { + return f(Result<1>()); + } else if (value.value == 2) { + return f(Result<2>()); + } + abort(); + } +}; + +using TX = TypifyValue<TypifyBool, TypifyMyIntA, TypifyMyIntB>; + +//----------------------------------------------------------------------------- + +struct GetFromType { + template <typename T> static int invoke() { return T::value_from_type; } +}; + +TEST(TypifyTest, simple_type_typification_works) { + auto res1 = typify_invoke<1,TX,GetFromType>(MyIntA{1}); + auto res2 = typify_invoke<1,TX,GetFromType>(MyIntA{2}); + EXPECT_EQ(res1, 1); + EXPECT_EQ(res2, 2); +} + +struct GetFromValue { + template <typename R> static int invoke() { return R::value; } +}; + +TEST(TypifyTest, simple_value_typification_works) { + auto res1 = typify_invoke<1,TX,GetFromValue>(MyIntB{1}); + auto res2 = typify_invoke<1,TX,GetFromValue>(MyIntB{2}); + EXPECT_EQ(res1, 1); + EXPECT_EQ(res2, 2); +} + +struct MaybeSum { + template <typename F1, typename V1, typename F2, typename V2> static int invoke(MyIntC v3) { + int res = 0; + if (F1::value) { + res += V1::value_from_type; + } + if (F2::value) { + res += V2::value; + } + res += v3.value; + return res; + } +}; + +TEST(TypifyTest, complex_typification_works) { + auto res1 = typify_invoke<4,TX,MaybeSum>(false, MyIntA{2}, false, MyIntB{1}, MyIntC{4}); + auto res2 = typify_invoke<4,TX,MaybeSum>(false, MyIntA{2}, true, MyIntB{1}, MyIntC{4}); + auto res3 = typify_invoke<4,TX,MaybeSum>(true, MyIntA{2}, false, MyIntB{1}, MyIntC{4}); + auto res4 = typify_invoke<4,TX,MaybeSum>(true, MyIntA{2}, true, MyIntB{1}, MyIntC{4}); + EXPECT_EQ(res1, 4); + EXPECT_EQ(res2, 5); + EXPECT_EQ(res3, 6); + EXPECT_EQ(res4, 7); +} + +struct Singleton { + virtual int get() const = 0; + virtual ~Singleton() {} +}; + +template <int A, int B> +struct MySingleton : Singleton { + MySingleton() = default; + MySingleton(const MySingleton &) = delete; + MySingleton &operator=(const MySingleton &) = delete; + int get() const override { return A + B; } +}; + +struct GetSingleton { + template <typename A, typename B> + static const Singleton &invoke() { + static MySingleton<A::value, B::value> obj; + return obj; + } +}; + +TEST(TypifyTest, typify_invoke_can_return_object_reference) { + const Singleton &s1 = typify_invoke<2,TX,GetSingleton>(MyIntB{1}, MyIntB{1}); + const Singleton &s2 = typify_invoke<2,TX,GetSingleton>(MyIntB{2}, MyIntB{2}); + const Singleton &s3 = typify_invoke<2,TX,GetSingleton>(MyIntB{2}, MyIntB{2}); + EXPECT_EQ(s1.get(), 2); + EXPECT_EQ(s2.get(), 4); + EXPECT_EQ(s3.get(), 4); + EXPECT_NE(&s1, &s2); + EXPECT_EQ(&s2, &s3); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/util/typify.h b/vespalib/src/vespa/vespalib/util/typify.h new file mode 100644 index 00000000000..a2a24baca41 --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/typify.h @@ -0,0 +1,96 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <stddef.h> +#include <utility> + +namespace vespalib { + +//----------------------------------------------------------------------------- + +/** + * Typification result for values resolving into actual types. + **/ +template <typename T> struct TypifyResultType { + static constexpr bool is_type = true; + using type = T; +}; + +/** + * Typification result for values resolving into compile-time values + * which are also types as long as they are kept inside their result + * wrappers. + **/ +template <typename T, T VALUE> struct TypifyResultValue { + static constexpr bool is_type = false; + static constexpr T value = VALUE; +}; + +/** + * A Typifier is able to take a run-time value and resolve it into a + * type. The resolve result is passed to the specified function in the + * form of a thin result wrapper. + **/ +struct TypifyBool { + template <bool VALUE> using Result = TypifyResultValue<bool, VALUE>; + template <typename F> static decltype(auto) resolve(bool value, F &&f) { + if (value) { + return f(Result<true>()); + } else { + return f(Result<false>()); + } + } +}; + +//----------------------------------------------------------------------------- + +/** + * Template used to combine individual typifiers into a typifier able + * to resolve multiple types. + **/ +template <typename ...Ts> struct TypifyValue : Ts... { using Ts::resolve...; }; + +//----------------------------------------------------------------------------- + +template <size_t N, typename Typifier, typename Target, typename ...Rs> struct TypifyInvokeImpl { + static decltype(auto) select() { + static_assert(sizeof...(Rs) == N); + return Target::template invoke<Rs...>(); + } + template <typename T, typename ...Args> static decltype(auto) select(T &&value, Args &&...args) { + if constexpr (N == sizeof...(Rs)) { + return Target::template invoke<Rs...>(std::forward<T>(value), std::forward<Args>(args)...); + } else { + return Typifier::resolve(value, [&](auto t)->decltype(auto) + { + using X = decltype(t); + if constexpr (X::is_type) { + return TypifyInvokeImpl<N, Typifier, Target, Rs..., typename X::type>::select(std::forward<Args>(args)...); + } else { + return TypifyInvokeImpl<N, Typifier, Target, Rs..., X>::select(std::forward<Args>(args)...); + } + }); + } + } +}; + +/** + * Typify the N first parameters using 'Typifier' (typically an + * instantiation of the TypifyValue template) and forward the + * remaining parameters to the Target::invoke template function with + * the typification results from the N first parameters as template + * parameters. Note that typification results that are types are + * unwrapped before being used as template parameters while + * typification results that are compile-time values are kept in their + * wrappers when passed as template parameters. Please refer to the + * unit test for examples. + **/ +template <size_t N, typename Typifier, typename Target, typename ...Args> decltype(auto) typify_invoke(Args && ...args) { + static_assert(N > 0); + return TypifyInvokeImpl<N,Typifier,Target>::select(std::forward<Args>(args)...); +} + +//----------------------------------------------------------------------------- + +} |