summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java
blob: 398538d187fe4919e8597b44c890f4c1d840d437 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation.change;

import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.vespa.model.Host;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;

import static com.yahoo.vespa.model.application.validation.JvmHeapSizeValidator.gbLimit;
import static com.yahoo.vespa.model.application.validation.JvmHeapSizeValidator.percentLimit;
import static java.util.logging.Level.FINE;
import static com.yahoo.config.model.api.OnnxModelCost.ModelInfo;
import static java.util.logging.Level.INFO;

/**
 * If Onnx models change in a way that requires restart of containers in
 * a container cluster this validator will make sure that restartOnDeploy is set for
 * configs for this cluster.
 *
 * @author hmusum
 */
public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValidator {

    private static final Logger log = Logger.getLogger(RestartOnDeployForOnnxModelChangesValidator.class.getName());

    @Override
    public List<ConfigChangeAction> validate(VespaModel currentModel, VespaModel nextModel, DeployState deployState) {
        if ( ! deployState.featureFlags().restartOnDeployWhenOnnxModelChanges()) return List.of();
        List<ConfigChangeAction> actions = new ArrayList<>();

        // Compare onnx models used by each cluster and set restart on deploy for cluster if estimated cost,
        // model hash or model options have changed
        for (var cluster : nextModel.getContainerClusters().values()) {
            var clusterInCurrentModel = currentModel.getContainerClusters().get(cluster.getName());
            if (clusterInCurrentModel == null) continue;

            var currentModels = clusterInCurrentModel.onnxModelCostCalculator().models();
            var nextModels = cluster.onnxModelCostCalculator().models();

            if (enoughMemoryToAvoidRestart(clusterInCurrentModel, cluster, deployState.getDeployLogger()))
                continue;

            log.log(FINE, "Validating " + cluster + ", current models=" + currentModels + ", next models=" + nextModels);
            actions.addAll(validateModelChanges(cluster, currentModels, nextModels));
            actions.addAll(validateSetOfModels(cluster, currentModels, nextModels));
        }
        return actions;
    }

    private List<ConfigChangeAction> validateModelChanges(ApplicationContainerCluster cluster,
                                                          Map<String, ModelInfo> currentModels,
                                                          Map<String, ModelInfo> nextModels) {
        List<ConfigChangeAction> actions = new ArrayList<>();
        for (var nextModelInfo : nextModels.values()) {
            if (! currentModels.containsKey(nextModelInfo.modelId())) continue;

            modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> {
                String message = "Onnx model '%s' has changed (%s), need to restart services in %s"
                        .formatted(nextModelInfo.modelId(), change, cluster);
                setRestartOnDeployAndAddRestartAction(actions, cluster, message);
            });
        }
        return actions;
    }

    private List<ConfigChangeAction> validateSetOfModels(ApplicationContainerCluster cluster,
                                                         Map<String, ModelInfo> currentModels,
                                                         Map<String, ModelInfo> nextModels) {
        List<ConfigChangeAction> actions = new ArrayList<>();
        Set<String> currentModelIds = currentModels.keySet();
        Set<String> nextModelIds = nextModels.keySet();
        log.log(FINE, "Checking if model set has changed (%s) -> (%s)".formatted(currentModelIds, nextModelIds));
        if (! currentModelIds.equals(nextModelIds)) {
            String message = "Onnx model set has changed from %s to %s, need to restart services in %s"
                    .formatted(currentModelIds, nextModelIds, cluster);
            setRestartOnDeployAndAddRestartAction(actions, cluster, message);
        }
        return actions;
    }

    private Optional<String> modelChanged(OnnxModelCost.ModelInfo a, OnnxModelCost.ModelInfo b) {
        log.log(FINE, "Checking if model has changed (%s) -> (%s)".formatted(a, b));
        if (a.estimatedCost() != b.estimatedCost()) return Optional.of("estimated cost");
        if (a.hash() != b.hash()) return Optional.of("model hash");
        if (! a.onnxModelOptions().equals(b.onnxModelOptions())) return Optional.of("model option(s)");
        return Optional.empty();
    }

    private static void setRestartOnDeployAndAddRestartAction(List<ConfigChangeAction> actions, ApplicationContainerCluster cluster, String message) {
        log.log(INFO, message);
        cluster.onnxModelCostCalculator().setRestartOnDeploy();
        cluster.onnxModelCostCalculator().store();
        actions.add(new VespaRestartAction(cluster.id(), message));
    }

    private static boolean enoughMemoryToAvoidRestart(ApplicationContainerCluster clusterInCurrentModel,
                                                      ApplicationContainerCluster cluster,
                                                      DeployLogger deployLogger) {
        double currentModelCostInGb = onnxModelCostInGb(clusterInCurrentModel);
        double nextModelCostInGb = onnxModelCostInGb(cluster);

        double totalMemory = cluster.getContainers().get(0).getHostResource().realResources().memoryGb();
        double memoryUsedByModels = currentModelCostInGb + nextModelCostInGb;
        double availableMemory = Math.max(0, totalMemory - Host.memoryOverheadGb - memoryUsedByModels);

        var availableMemoryPercentage = cluster.availableMemoryPercentage();
        int memoryPercentage = (int) (availableMemory / totalMemory * availableMemoryPercentage);

        if (memoryPercentage < percentLimit) {
            deployLogger.log(INFO, "Validating %s, percentage of available memory too low (%d < %d) to avoid restart, consider a flavor with more memory to avoid this"
                    .formatted(cluster, memoryPercentage, percentLimit));
            return false;
        }

        if (availableMemory < gbLimit) {
            deployLogger.log(INFO, "Validating %s, available memory too low (%.2f Gb < %.2f Gb) to avoid restart, consider a flavor with more memory to avoid this"
                    .formatted(cluster, availableMemory, gbLimit));
            return false;
        }

        log.log(FINE, "Validating %s, enough available memory (%.2f Gb) to avoid restart (models use %.2f Gb)"
                .formatted(cluster, availableMemory, memoryUsedByModels));
        return true;
    }

    private static double onnxModelCostInGb(ApplicationContainerCluster clusterInCurrentModel) {
        return (double) clusterInCurrentModel.onnxModelCostCalculator().aggregatedModelCostInBytes() / 1024 / 1024 / 1024;
    }

}