diff options
author | Lester Solbakken <lester.solbakken@gmail.com> | 2024-05-02 14:45:56 +0200 |
---|---|---|
committer | Lester Solbakken <lester.solbakken@gmail.com> | 2024-05-02 14:45:56 +0200 |
commit | f3a6f2dedc0f6c43b7a51d348b9096a676b168c5 (patch) | |
tree | bcab6d4b746dffb7edee3bd86a200cf1688d5695 /config-model | |
parent | 2a383e9597d400c3e57105434d17f3b0ed434398 (diff) |
Restart on deploy for local LLMs
Diffstat (limited to 'config-model')
3 files changed, 109 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java index ed0804f7420..7f624032627 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java @@ -19,6 +19,7 @@ import com.yahoo.vespa.model.application.validation.change.IndexingModeChangeVal import com.yahoo.vespa.model.application.validation.change.NodeResourceChangeValidator; import com.yahoo.vespa.model.application.validation.change.RedundancyIncreaseValidator; import com.yahoo.vespa.model.application.validation.change.ResourcesReductionValidator; +import com.yahoo.vespa.model.application.validation.change.RestartOnDeployForLocalLLMValidator; import com.yahoo.vespa.model.application.validation.change.RestartOnDeployForOnnxModelChangesValidator; import com.yahoo.vespa.model.application.validation.change.StartupCommandChangeValidator; import com.yahoo.vespa.model.application.validation.change.StreamingSearchClusterChangeValidator; @@ -129,6 +130,7 @@ public class Validation { new CertificateRemovalChangeValidator().validate(execution); new RedundancyValidator().validate(execution); new RestartOnDeployForOnnxModelChangesValidator().validate(execution); + new RestartOnDeployForLocalLLMValidator().validate(execution); } public interface Context { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java new file mode 100644 index 00000000000..88cfcfaf67c --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java @@ -0,0 +1,38 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.application.validation.change; + +import com.yahoo.vespa.model.application.validation.Validation.ChangeContext; + +import java.util.logging.Logger; + +import static java.util.logging.Level.INFO; + +/** + * If using local LLMs, this validator will make sure that restartOnDeploy is set for + * configs for this cluster. + * + * @author lesters + */ +public class RestartOnDeployForLocalLLMValidator implements ChangeValidator { + + private static final Logger log = Logger.getLogger(RestartOnDeployForLocalLLMValidator.class.getName()); + + @Override + public void validate(ChangeContext context) { + + for (var cluster : context.model().getContainerClusters().values()) { + + // For now, if a local LLM is used, force a restart of the services + // Later, be more sophisticated and only restart if redeploy does not fit in (GPU) memory + cluster.getAllComponents().forEach(component -> { + if (component.getClassId().getName().equals("ai.vespa.llm.clients.LocalLLM")) { + String message = "Restarting services in %s because of local LLM use".formatted(cluster); + log.log(INFO, message); + context.require(new VespaRestartAction(cluster.id(), message)); + } + }); + + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java new file mode 100644 index 00000000000..30915ad02fc --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java @@ -0,0 +1,69 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.application.validation.change; + +import com.yahoo.config.model.api.ConfigChangeAction; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.application.validation.ValidationTester; +import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author lesters + */ +public class RestartOnDeployForLocalLLMValidatorTest { + + @Test + void validate_no_restart_on_deploy() { + VespaModel current = createModelWithComponent("ai.vespa.llm.clients.OpenAI"); + VespaModel next = createModelWithComponent("ai.vespa.llm.clients.OpenAI"); + List<ConfigChangeAction> result = validateModel(current, next); + assertEquals(0, result.size()); + } + + @Test + void validate_restart_on_deploy() { + VespaModel current = createModelWithComponent("ai.vespa.llm.clients.LocalLLM"); + VespaModel next = createModelWithComponent("ai.vespa.llm.clients.LocalLLM"); + List<ConfigChangeAction> result = validateModel(current, next); + assertEquals(1, result.size()); + assertTrue(result.get(0).validationId().isEmpty()); + assertEquals("Restarting services in container cluster 'cluster1' because of local LLM use", result.get(0).getMessage()); + } + + private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) { + return ValidationTester.validateChanges(new RestartOnDeployForLocalLLMValidator(), + next, + deployStateBuilder().previousModel(current).build()); + } + + private static VespaModel createModelWithComponent(String componentClass) { + var xml = """ + <services version='1.0'> + <container id='cluster1' version='1.0'> + <http> + <server id='server1' port='8080'/> + </http> + <component id="llm" class="%s" /> + </container> + </services> + """.formatted(componentClass); + DeployState.Builder builder = deployStateBuilder(); + return new VespaModelCreatorWithMockPkg(null, xml).create(builder); + } + + private static DeployState.Builder deployStateBuilder() { + return new DeployState.Builder().properties(new TestProperties()); + } + + private static void assertStartsWith(String expected, List<ConfigChangeAction> result) { + assertTrue(result.get(0).getMessage().startsWith(expected)); + } + +} |