aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2024-05-06 09:18:47 +0200
committerGitHub <noreply@github.com>2024-05-06 09:18:47 +0200
commit2f4511677d4da29e615f3543fd167d4bbce8588e (patch)
treee908baaa3a534d5036d56d6370a17d41e4cbba1c /config-model
parent2721b3e11d361a01e36cc4030792d3daea01c740 (diff)
parent0c0868e895c2ad0c1b82c1f57992e68378d1f3b0 (diff)
Merge pull request #31097 from vespa-engine/lesters/restart-on-deploy-for-local-llm
Restart on deploy for local LLMs
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java55
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java79
3 files changed, 136 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java
index ed0804f7420..7f624032627 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java
@@ -19,6 +19,7 @@ import com.yahoo.vespa.model.application.validation.change.IndexingModeChangeVal
import com.yahoo.vespa.model.application.validation.change.NodeResourceChangeValidator;
import com.yahoo.vespa.model.application.validation.change.RedundancyIncreaseValidator;
import com.yahoo.vespa.model.application.validation.change.ResourcesReductionValidator;
+import com.yahoo.vespa.model.application.validation.change.RestartOnDeployForLocalLLMValidator;
import com.yahoo.vespa.model.application.validation.change.RestartOnDeployForOnnxModelChangesValidator;
import com.yahoo.vespa.model.application.validation.change.StartupCommandChangeValidator;
import com.yahoo.vespa.model.application.validation.change.StreamingSearchClusterChangeValidator;
@@ -129,6 +130,7 @@ public class Validation {
new CertificateRemovalChangeValidator().validate(execution);
new RedundancyValidator().validate(execution);
new RestartOnDeployForOnnxModelChangesValidator().validate(execution);
+ new RestartOnDeployForLocalLLMValidator().validate(execution);
}
public interface Context {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java
new file mode 100644
index 00000000000..ccfc611c3dc
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java
@@ -0,0 +1,55 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.application.validation.change;
+
+import com.yahoo.config.provision.ClusterSpec;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.application.validation.Validation.ChangeContext;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
+
+import java.util.HashSet;
+import java.util.Set;
+import java.util.logging.Logger;
+
+import static java.util.logging.Level.INFO;
+import static java.util.stream.Collectors.toUnmodifiableSet;
+
+/**
+ * If using local LLMs, this validator will make sure that restartOnDeploy is set for
+ * configs for this cluster.
+ *
+ * @author lesters
+ */
+public class RestartOnDeployForLocalLLMValidator implements ChangeValidator {
+
+ public static final String LOCAL_LLM_COMPONENT = ai.vespa.llm.clients.LocalLLM.class.getName();
+
+ private static final Logger log = Logger.getLogger(RestartOnDeployForLocalLLMValidator.class.getName());
+
+ @Override
+ public void validate(ChangeContext context) {
+ var previousClustersWithLocalLLM = findClustersWithLocalLLMs(context.previousModel());
+ var nextClustersWithLocalLLM = findClustersWithLocalLLMs(context.model());
+
+ // Only restart services if we use a local LLM in both the next and previous generation
+ for (var clusterId : intersect(previousClustersWithLocalLLM, nextClustersWithLocalLLM)) {
+ String message = "Need to restart services in %s due to use of local LLM".formatted(clusterId);
+ context.require(new VespaRestartAction(clusterId, message));
+ log.log(INFO, message);
+ }
+ }
+
+ private Set<ClusterSpec.Id> findClustersWithLocalLLMs(VespaModel model) {
+ return model.getContainerClusters().values().stream()
+ .filter(cluster -> cluster.getAllComponents().stream()
+ .anyMatch(component -> component.getClassId().getName().equals(LOCAL_LLM_COMPONENT)))
+ .map(ApplicationContainerCluster::id)
+ .collect(toUnmodifiableSet());
+ }
+
+ private Set<ClusterSpec.Id> intersect(Set<ClusterSpec.Id> a, Set<ClusterSpec.Id> b) {
+ Set<ClusterSpec.Id> result = new HashSet<>(a);
+ result.retainAll(b);
+ return result;
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java
new file mode 100644
index 00000000000..13e91f60712
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidatorTest.java
@@ -0,0 +1,79 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.application.validation.change;
+
+import com.yahoo.config.model.api.ConfigChangeAction;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.application.validation.ValidationTester;
+import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class RestartOnDeployForLocalLLMValidatorTest {
+
+ private static final String LOCAL_LLM_COMPONENT = RestartOnDeployForLocalLLMValidator.LOCAL_LLM_COMPONENT;
+
+ @Test
+ void validate_no_restart_on_deploy() {
+ VespaModel current = createModel();
+ VespaModel next = createModel(withComponent(LOCAL_LLM_COMPONENT));
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(0, result.size());
+ }
+
+ @Test
+ void validate_restart_on_deploy() {
+ VespaModel current = createModel(withComponent(LOCAL_LLM_COMPONENT));
+ VespaModel next = createModel(withComponent(LOCAL_LLM_COMPONENT));
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(1, result.size());
+ assertTrue(result.get(0).validationId().isEmpty());
+ assertEquals("Need to restart services in cluster 'cluster1' due to use of local LLM", result.get(0).getMessage());
+ }
+
+ private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) {
+ return ValidationTester.validateChanges(new RestartOnDeployForLocalLLMValidator(),
+ next,
+ deployStateBuilder().previousModel(current).build());
+ }
+
+ private static VespaModel createModel(String component) {
+ var xml = """
+ <services version='1.0'>
+ <container id='cluster1' version='1.0'>
+ <http>
+ <server id='server1' port='8080'/>
+ </http>
+ %s
+ </container>
+ </services>
+ """.formatted(component);
+ DeployState.Builder builder = deployStateBuilder();
+ return new VespaModelCreatorWithMockPkg(null, xml).create(builder);
+ }
+
+ private static VespaModel createModel() {
+ return createModel("");
+ }
+
+ private static String withComponent(String componentClass) {
+ return "<component id='llm' class='%s' />".formatted(componentClass);
+ }
+
+ private static DeployState.Builder deployStateBuilder() {
+ return new DeployState.Builder().properties(new TestProperties());
+ }
+
+ private static void assertStartsWith(String expected, List<ConfigChangeAction> result) {
+ assertTrue(result.get(0).getMessage().startsWith(expected));
+ }
+
+}