diff options
author | Lester Solbakken <lester.solbakken@gmail.com> | 2024-05-03 12:22:52 +0200 |
---|---|---|
committer | Lester Solbakken <lester.solbakken@gmail.com> | 2024-05-03 12:22:52 +0200 |
commit | 1f22a18d27c99abbf81ce208d5584f58ea5b34ac (patch) | |
tree | ff0c211861b41f1bfd45c0f57a7da1cb32217dbe /config-model/src/main | |
parent | f3a6f2dedc0f6c43b7a51d348b9096a676b168c5 (diff) |
Only restart if local LLM is found in both previous and next generation
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java | 41 |
1 files changed, 29 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java index 88cfcfaf67c..c9b67ca4240 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForLocalLLMValidator.java @@ -1,11 +1,17 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.application.validation.Validation.ChangeContext; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import java.util.HashSet; +import java.util.Set; import java.util.logging.Logger; import static java.util.logging.Level.INFO; +import static java.util.stream.Collectors.toUnmodifiableSet; /** * If using local LLMs, this validator will make sure that restartOnDeploy is set for @@ -15,24 +21,35 @@ import static java.util.logging.Level.INFO; */ public class RestartOnDeployForLocalLLMValidator implements ChangeValidator { + private static final String LOCAL_LLM_COMPONENT = "ai.vespa.llm.clients.LocalLLM"; + private static final Logger log = Logger.getLogger(RestartOnDeployForLocalLLMValidator.class.getName()); @Override public void validate(ChangeContext context) { + var previousClustersWithLocalLLM = findClustersWithLocalLLMs(context.previousModel()); + var nextClustersWithLocalLLM = findClustersWithLocalLLMs(context.model()); + + // Only restart services if we use a local LLM in both the next and previous generation + for (var clusterId : intersect(previousClustersWithLocalLLM, nextClustersWithLocalLLM)) { + String message = "Need to restart services in %s due to use of local LLM".formatted(clusterId); + context.require(new VespaRestartAction(clusterId, message)); + log.log(INFO, message); + } + } - for (var cluster : context.model().getContainerClusters().values()) { - - // For now, if a local LLM is used, force a restart of the services - // Later, be more sophisticated and only restart if redeploy does not fit in (GPU) memory - cluster.getAllComponents().forEach(component -> { - if (component.getClassId().getName().equals("ai.vespa.llm.clients.LocalLLM")) { - String message = "Restarting services in %s because of local LLM use".formatted(cluster); - log.log(INFO, message); - context.require(new VespaRestartAction(cluster.id(), message)); - } - }); + private Set<ClusterSpec.Id> findClustersWithLocalLLMs(VespaModel model) { + return model.getContainerClusters().values().stream() + .filter(cluster -> cluster.getAllComponents().stream() + .anyMatch(component -> component.getClassId().getName().equals(LOCAL_LLM_COMPONENT))) + .map(ApplicationContainerCluster::id) + .collect(toUnmodifiableSet()); + } - } + private Set<ClusterSpec.Id> intersect(Set<ClusterSpec.Id> a, Set<ClusterSpec.Id> b) { + Set<ClusterSpec.Id> result = new HashSet<>(a); + result.retainAll(b); + return result; } } |