diff options
author | Harald Musum <musum@yahooinc.com> | 2023-12-12 07:43:39 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2023-12-12 07:43:39 +0100 |
commit | b41221bfb9894d0182eef237a0e1c0d4f8130c00 (patch) | |
tree | f1f904899d0f8e54a87c273f43270cdec6b6f715 /config-model | |
parent | 945279cb802887375234bed3a18413ac47271f12 (diff) |
Restart if the set of onnx models has changed for a container cluster
Diffstat (limited to 'config-model')
2 files changed, 79 insertions, 31 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java index c72be979a16..b1189dbf923 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java @@ -4,14 +4,19 @@ package com.yahoo.vespa.model.application.validation.change; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.vespa.model.Host; import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; import static java.util.logging.Level.FINE; +import static com.yahoo.config.model.api.OnnxModelCost.ModelInfo; /** * If Onnx models change in a way that requires restart of containers in @@ -32,27 +37,49 @@ public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValida // Compare onnx models used by each cluster and set restart on deploy for cluster if estimated cost, // model hash or model options have changed // TODO: Skip if container has enough memory to handle reload of onnx model (2 models in memory at the same time) - for (var cluster : nextModel.getContainerClusters().values()) { var clusterInCurrentModel = currentModel.getContainerClusters().get(cluster.getName()); if (clusterInCurrentModel == null) continue; - log.log(FINE, "Validating cluster '" + cluster.name() + "'"); var currentModels = clusterInCurrentModel.onnxModelCostCalculator().models(); var nextModels = cluster.onnxModelCostCalculator().models(); - log.log(FINE, "current models=" + currentModels + ", next models=" + nextModels); - - for (var nextModelInfo : nextModels.values()) { - if (!currentModels.containsKey(nextModelInfo.modelId())) continue; - - log.log(FINE, "Checking if " + nextModelInfo + " has changed"); - modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> { - String message = "Onnx model '%s' has changed (%s), need to restart services in container cluster '%s'" - .formatted(nextModelInfo.modelId(), change, cluster.name()); - cluster.onnxModelCostCalculator().setRestartOnDeploy(); - actions.add(new VespaRestartAction(cluster.id(), message)); - }); - } + + log.log(FINE, "Validating " + cluster + ", current models=" + currentModels + ", next models=" + nextModels); + actions.addAll(validateModelChanges(cluster, currentModels, nextModels)); + actions.addAll(validateSetOfModels(cluster, currentModels, nextModels)); + } + return actions; + } + + private List<ConfigChangeAction> validateModelChanges(ApplicationContainerCluster cluster, + Map<String, ModelInfo> currentModels, + Map<String, ModelInfo> nextModels) { + List<ConfigChangeAction> actions = new ArrayList<>(); + for (var nextModelInfo : nextModels.values()) { + if (! currentModels.containsKey(nextModelInfo.modelId())) continue; + + log.log(FINE, "Checking if " + nextModelInfo + " has changed"); + modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> { + String message = "Onnx model '%s' has changed (%s), need to restart services in %s" + .formatted(nextModelInfo.modelId(), change, cluster); + cluster.onnxModelCostCalculator().setRestartOnDeploy(); + addRestartAction(actions, cluster, message); + }); + } + return actions; + } + + private List<ConfigChangeAction> validateSetOfModels(ApplicationContainerCluster cluster, + Map<String, ModelInfo> currentModels, + Map<String, ModelInfo> nextModels) { + List<ConfigChangeAction> actions = new ArrayList<>(); + Set<String> currentModelIds = currentModels.keySet(); + Set<String> nextModelIds = nextModels.keySet(); + log.log(FINE, "Checking if model set has changed (%s) -> (%s)".formatted(currentModelIds, nextModelIds)); + if (! currentModelIds.equals(nextModelIds)) { + String message = "Onnx model set has changed from %s to %s, need to restart services in %s" + .formatted(currentModelIds, nextModelIds, cluster); + addRestartAction(actions, cluster, message); } return actions; } @@ -67,4 +94,16 @@ public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValida return Optional.empty(); } + private static void addRestartAction(List<ConfigChangeAction> actions, ApplicationContainerCluster cluster, String message) { + actions.add(new VespaRestartAction(cluster.id(), message)); + } + + private static boolean enoughMemoryToAvoidRestart(ApplicationContainerCluster cluster) { + // Node memory is known so convert available memory percentage to node memory percentage + double totalMemory = cluster.getContainers().get(0).getHostResource().realResources().memoryGb(); + double availableMemory = Math.max(0, totalMemory - Host.memoryOverheadGb); + double costInGb = (double) cluster.onnxModelCostCalculator().aggregatedModelCostInBytes() / 1024 / 1024 / 1024; + return ( 2 * costInGb < availableMemory); + } + } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java index 1845bcf0b52..5873d15bd9a 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java @@ -40,7 +40,8 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { List<ConfigChangeAction> result = validateModel(current, next); assertEquals(1, result.size()); assertTrue(result.get(0).validationId().isEmpty()); - assertEquals("Onnx model 'https://my/url/model.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + assertEquals("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (estimated cost)", result); } @Test @@ -49,7 +50,7 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { VespaModel next = createModel(onnxModelCost(0, 123)); List<ConfigChangeAction> result = validateModel(current, next); assertEquals(1, result.size()); - assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model hash), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (model hash)", result); } @Test @@ -58,7 +59,16 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { VespaModel next = createModel(onnxModelCost(0, 0), "sequential"); List<ConfigChangeAction> result = validateModel(current, next); assertEquals(1, result.size()); - assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model option(s)), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (model option(s))", result); + } + + @Test + void validate_changed_model_set() { + VespaModel current = createModel(); + VespaModel next = createModel(onnxModelCost(0, 0), "parallel", "e5-small-v2"); + List<ConfigChangeAction> result = validateModel(current, next); + assertEquals(1, result.size()); + assertStartsWith("Onnx model set has changed from [https://my/url/e5-base-v2.onnx] to [https://my/url/e5-small-v2.onnx", result); } private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) { @@ -112,12 +122,16 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { } private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode) { + return createModel(onnxModelCost, executionMode, "e5-base-v2"); + } + + private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode, String modelId) { DeployState.Builder builder = deployStateBuilder(); builder.onnxModelCost(onnxModelCost); - return createModel(builder, executionMode); + return createModel(builder, executionMode, modelId); } - private static VespaModel createModel(DeployState.Builder builder, String executionMode) { + private static VespaModel createModel(DeployState.Builder builder, String executionMode, String modelId) { String xml = """ <services version='1.0'> <container id='cluster1' version='1.0'> @@ -125,18 +139,9 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { <server id='server1' port='8080'/> </http> <component id="hf-embedder" type="hugging-face-embedder"> - <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <transformer-model model-id="%s" url="https://my/url/%s.onnx"/> <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> - <max-tokens>1024</max-tokens> - <transformer-input-ids>my_input_ids</transformer-input-ids> - <transformer-attention-mask>my_attention_mask</transformer-attention-mask> - <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> - <transformer-output>my_output</transformer-output> - <normalize>true</normalize> <onnx-execution-mode>%s</onnx-execution-mode> - <onnx-intraop-threads>10</onnx-intraop-threads> - <onnx-interop-threads>8</onnx-interop-threads> - <pooling-strategy>mean</pooling-strategy> </component> </container> <container id='cluster2' version='1.0'> @@ -145,7 +150,7 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { </http> </container> </services> - """.formatted(executionMode); + """.formatted(modelId, modelId, executionMode); return new VespaModelCreatorWithMockPkg(null, xml).create(builder); } @@ -155,4 +160,8 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest { .properties((new TestProperties()).setRestartOnDeployForOnnxModelChanges(true)); } + private static void assertStartsWith(String expected, List<ConfigChangeAction> result) { + assertTrue(result.get(0).getMessage().startsWith(expected)); + } + } |