From b41221bfb9894d0182eef237a0e1c0d4f8130c00 Mon Sep 17 00:00:00 2001 From: Harald Musum Date: Tue, 12 Dec 2023 07:43:39 +0100 Subject: Restart if the set of onnx models has changed for a container cluster --- ...estartOnDeployForOnnxModelChangesValidator.java | 69 +++++++++++++++++----- 1 file changed, 54 insertions(+), 15 deletions(-) (limited to 'config-model/src/main/java') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java index c72be979a16..b1189dbf923 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java @@ -4,14 +4,19 @@ package com.yahoo.vespa.model.application.validation.change; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.vespa.model.Host; import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; import static java.util.logging.Level.FINE; +import static com.yahoo.config.model.api.OnnxModelCost.ModelInfo; /** * If Onnx models change in a way that requires restart of containers in @@ -32,27 +37,49 @@ public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValida // Compare onnx models used by each cluster and set restart on deploy for cluster if estimated cost, // model hash or model options have changed // TODO: Skip if container has enough memory to handle reload of onnx model (2 models in memory at the same time) - for (var cluster : nextModel.getContainerClusters().values()) { var clusterInCurrentModel = currentModel.getContainerClusters().get(cluster.getName()); if (clusterInCurrentModel == null) continue; - log.log(FINE, "Validating cluster '" + cluster.name() + "'"); var currentModels = clusterInCurrentModel.onnxModelCostCalculator().models(); var nextModels = cluster.onnxModelCostCalculator().models(); - log.log(FINE, "current models=" + currentModels + ", next models=" + nextModels); - - for (var nextModelInfo : nextModels.values()) { - if (!currentModels.containsKey(nextModelInfo.modelId())) continue; - - log.log(FINE, "Checking if " + nextModelInfo + " has changed"); - modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> { - String message = "Onnx model '%s' has changed (%s), need to restart services in container cluster '%s'" - .formatted(nextModelInfo.modelId(), change, cluster.name()); - cluster.onnxModelCostCalculator().setRestartOnDeploy(); - actions.add(new VespaRestartAction(cluster.id(), message)); - }); - } + + log.log(FINE, "Validating " + cluster + ", current models=" + currentModels + ", next models=" + nextModels); + actions.addAll(validateModelChanges(cluster, currentModels, nextModels)); + actions.addAll(validateSetOfModels(cluster, currentModels, nextModels)); + } + return actions; + } + + private List validateModelChanges(ApplicationContainerCluster cluster, + Map currentModels, + Map nextModels) { + List actions = new ArrayList<>(); + for (var nextModelInfo : nextModels.values()) { + if (! currentModels.containsKey(nextModelInfo.modelId())) continue; + + log.log(FINE, "Checking if " + nextModelInfo + " has changed"); + modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> { + String message = "Onnx model '%s' has changed (%s), need to restart services in %s" + .formatted(nextModelInfo.modelId(), change, cluster); + cluster.onnxModelCostCalculator().setRestartOnDeploy(); + addRestartAction(actions, cluster, message); + }); + } + return actions; + } + + private List validateSetOfModels(ApplicationContainerCluster cluster, + Map currentModels, + Map nextModels) { + List actions = new ArrayList<>(); + Set currentModelIds = currentModels.keySet(); + Set nextModelIds = nextModels.keySet(); + log.log(FINE, "Checking if model set has changed (%s) -> (%s)".formatted(currentModelIds, nextModelIds)); + if (! currentModelIds.equals(nextModelIds)) { + String message = "Onnx model set has changed from %s to %s, need to restart services in %s" + .formatted(currentModelIds, nextModelIds, cluster); + addRestartAction(actions, cluster, message); } return actions; } @@ -67,4 +94,16 @@ public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValida return Optional.empty(); } + private static void addRestartAction(List actions, ApplicationContainerCluster cluster, String message) { + actions.add(new VespaRestartAction(cluster.id(), message)); + } + + private static boolean enoughMemoryToAvoidRestart(ApplicationContainerCluster cluster) { + // Node memory is known so convert available memory percentage to node memory percentage + double totalMemory = cluster.getContainers().get(0).getHostResource().realResources().memoryGb(); + double availableMemory = Math.max(0, totalMemory - Host.memoryOverheadGb); + double costInGb = (double) cluster.onnxModelCostCalculator().aggregatedModelCostInBytes() / 1024 / 1024 / 1024; + return ( 2 * costInGb < availableMemory); + } + } -- cgit v1.2.3