From f3a6f2dedc0f6c43b7a51d348b9096a676b168c5 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 2 May 2024 14:45:56 +0200 Subject: Restart on deploy for local LLMs --- .../RestartOnDeployForLocalLLMValidatorTest.java | 69 ++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java (limited to 'config-model/src/test/java') diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java new file mode 100644 index 00000000000..30915ad02fc --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java @@ -0,0 +1,69 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.application.validation.change; + +import com.yahoo.config.model.api.ConfigChangeAction; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.application.validation.ValidationTester; +import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author lesters + */ +public class RestartOnDeployForLocalLLMValidatorTest { + + @Test + void validate_no_restart_on_deploy() { + VespaModel current = createModelWithComponent("ai.vespa.llm.clients.OpenAI"); + VespaModel next = createModelWithComponent("ai.vespa.llm.clients.OpenAI"); + List result = validateModel(current, next); + assertEquals(0, result.size()); + } + + @Test + void validate_restart_on_deploy() { + VespaModel current = createModelWithComponent("ai.vespa.llm.clients.LocalLLM"); + VespaModel next = createModelWithComponent("ai.vespa.llm.clients.LocalLLM"); + List result = validateModel(current, next); + assertEquals(1, result.size()); + assertTrue(result.get(0).validationId().isEmpty()); + assertEquals("Restarting services in container cluster 'cluster1' because of local LLM use", result.get(0).getMessage()); + } + + private static List validateModel(VespaModel current, VespaModel next) { + return ValidationTester.validateChanges(new RestartOnDeployForLocalLLMValidator(), + next, + deployStateBuilder().previousModel(current).build()); + } + + private static VespaModel createModelWithComponent(String componentClass) { + var xml = """ + + + + + + + + + """.formatted(componentClass); + DeployState.Builder builder = deployStateBuilder(); + return new VespaModelCreatorWithMockPkg(null, xml).create(builder); + } + + private static DeployState.Builder deployStateBuilder() { + return new DeployState.Builder().properties(new TestProperties()); + } + + private static void assertStartsWith(String expected, List result) { + assertTrue(result.get(0).getMessage().startsWith(expected)); + } + +} -- cgit v1.2.3 From 1f22a18d27c99abbf81ce208d5584f58ea5b34ac Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 3 May 2024 12:22:52 +0200 Subject: Only restart if local LLM is found in both previous and next generation --- .../RestartOnDeployForLocalLLMValidator.java | 41 +++++++++++++++------- .../RestartOnDeployForLocalLLMValidatorTest.java | 13 ++++--- 2 files changed, 37 insertions(+), 17 deletions(-) (limited to 'config-model/src/test/java') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java index 88cfcfaf67c..c9b67ca4240 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java @@ -1,11 +1,17 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.application.validation.Validation.ChangeContext; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import java.util.HashSet; +import java.util.Set; import java.util.logging.Logger; import static java.util.logging.Level.INFO; +import static java.util.stream.Collectors.toUnmodifiableSet; /** * If using local LLMs, this validator will make sure that restartOnDeploy is set for @@ -15,24 +21,35 @@ import static java.util.logging.Level.INFO; */ public class RestartOnDeployForLocalLLMValidator implements ChangeValidator { + private static final String LOCAL_LLM_COMPONENT = "ai.vespa.llm.clients.LocalLLM"; + private static final Logger log = Logger.getLogger(RestartOnDeployForLocalLLMValidator.class.getName()); @Override public void validate(ChangeContext context) { + var previousClustersWithLocalLLM = findClustersWithLocalLLMs(context.previousModel()); + var nextClustersWithLocalLLM = findClustersWithLocalLLMs(context.model()); + + // Only restart services if we use a local LLM in both the next and previous generation + for (var clusterId : intersect(previousClustersWithLocalLLM, nextClustersWithLocalLLM)) { + String message = "Need to restart services in %s due to use of local LLM".formatted(clusterId); + context.require(new VespaRestartAction(clusterId, message)); + log.log(INFO, message); + } + } - for (var cluster : context.model().getContainerClusters().values()) { - - // For now, if a local LLM is used, force a restart of the services - // Later, be more sophisticated and only restart if redeploy does not fit in (GPU) memory - cluster.getAllComponents().forEach(component -> { - if (component.getClassId().getName().equals("ai.vespa.llm.clients.LocalLLM")) { - String message = "Restarting services in %s because of local LLM use".formatted(cluster); - log.log(INFO, message); - context.require(new VespaRestartAction(cluster.id(), message)); - } - }); + private Set findClustersWithLocalLLMs(VespaModel model) { + return model.getContainerClusters().values().stream() + .filter(cluster -> cluster.getAllComponents().stream() + .anyMatch(component -> component.getClassId().getName().equals(LOCAL_LLM_COMPONENT))) + .map(ApplicationContainerCluster::id) + .collect(toUnmodifiableSet()); + } - } + private Set intersect(Set a, Set b) { + Set result = new HashSet<>(a); + result.retainAll(b); + return result; } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java index 30915ad02fc..311d4f39fcd 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java @@ -19,22 +19,25 @@ import static org.junit.jupiter.api.Assertions.assertTrue; */ public class RestartOnDeployForLocalLLMValidatorTest { + public static final String OPENAI_LLM_COMPONENT = "ai.vespa.llm.clients.OpenAI"; + public static final String LOCAL_LLM_COMPONENT = "ai.vespa.llm.clients.LocalLLM"; + @Test void validate_no_restart_on_deploy() { - VespaModel current = createModelWithComponent("ai.vespa.llm.clients.OpenAI"); - VespaModel next = createModelWithComponent("ai.vespa.llm.clients.OpenAI"); + VespaModel current = createModelWithComponent(OPENAI_LLM_COMPONENT); + VespaModel next = createModelWithComponent(LOCAL_LLM_COMPONENT); List result = validateModel(current, next); assertEquals(0, result.size()); } @Test void validate_restart_on_deploy() { - VespaModel current = createModelWithComponent("ai.vespa.llm.clients.LocalLLM"); - VespaModel next = createModelWithComponent("ai.vespa.llm.clients.LocalLLM"); + VespaModel current = createModelWithComponent(LOCAL_LLM_COMPONENT); + VespaModel next = createModelWithComponent(LOCAL_LLM_COMPONENT); List result = validateModel(current, next); assertEquals(1, result.size()); assertTrue(result.get(0).validationId().isEmpty()); - assertEquals("Restarting services in container cluster 'cluster1' because of local LLM use", result.get(0).getMessage()); + assertEquals("Need to restart services in cluster 'cluster1' due to use of local LLM", result.get(0).getMessage()); } private static List validateModel(VespaModel current, VespaModel next) { -- cgit v1.2.3 From 0c0868e895c2ad0c1b82c1f57992e68378d1f3b0 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 3 May 2024 13:04:05 +0200 Subject: Use class.getName() instead of string --- .../RestartOnDeployForLocalLLMValidator.java | 2 +- .../RestartOnDeployForLocalLLMValidatorTest.java | 25 ++++++++++++++-------- 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'config-model/src/test/java') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java index c9b67ca4240..ccfc611c3dc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java @@ -21,7 +21,7 @@ import static java.util.stream.Collectors.toUnmodifiableSet; */ public class RestartOnDeployForLocalLLMValidator implements ChangeValidator { - private static final String LOCAL_LLM_COMPONENT = "ai.vespa.llm.clients.LocalLLM"; + public static final String LOCAL_LLM_COMPONENT = ai.vespa.llm.clients.LocalLLM.class.getName(); private static final Logger log = Logger.getLogger(RestartOnDeployForLocalLLMValidator.class.getName()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java index 311d4f39fcd..13e91f60712 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java @@ -19,21 +19,20 @@ import static org.junit.jupiter.api.Assertions.assertTrue; */ public class RestartOnDeployForLocalLLMValidatorTest { - public static final String OPENAI_LLM_COMPONENT = "ai.vespa.llm.clients.OpenAI"; - public static final String LOCAL_LLM_COMPONENT = "ai.vespa.llm.clients.LocalLLM"; + private static final String LOCAL_LLM_COMPONENT = RestartOnDeployForLocalLLMValidator.LOCAL_LLM_COMPONENT; @Test void validate_no_restart_on_deploy() { - VespaModel current = createModelWithComponent(OPENAI_LLM_COMPONENT); - VespaModel next = createModelWithComponent(LOCAL_LLM_COMPONENT); + VespaModel current = createModel(); + VespaModel next = createModel(withComponent(LOCAL_LLM_COMPONENT)); List result = validateModel(current, next); assertEquals(0, result.size()); } @Test void validate_restart_on_deploy() { - VespaModel current = createModelWithComponent(LOCAL_LLM_COMPONENT); - VespaModel next = createModelWithComponent(LOCAL_LLM_COMPONENT); + VespaModel current = createModel(withComponent(LOCAL_LLM_COMPONENT)); + VespaModel next = createModel(withComponent(LOCAL_LLM_COMPONENT)); List result = validateModel(current, next); assertEquals(1, result.size()); assertTrue(result.get(0).validationId().isEmpty()); @@ -46,21 +45,29 @@ public class RestartOnDeployForLocalLLMValidatorTest { deployStateBuilder().previousModel(current).build()); } - private static VespaModel createModelWithComponent(String componentClass) { + private static VespaModel createModel(String component) { var xml = """ - + %s - """.formatted(componentClass); + """.formatted(component); DeployState.Builder builder = deployStateBuilder(); return new VespaModelCreatorWithMockPkg(null, xml).create(builder); } + private static VespaModel createModel() { + return createModel(""); + } + + private static String withComponent(String componentClass) { + return "".formatted(componentClass); + } + private static DeployState.Builder deployStateBuilder() { return new DeployState.Builder().properties(new TestProperties()); } -- cgit v1.2.3