summaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-12-12 07:43:39 +0100
committerHarald Musum <musum@yahooinc.com>2023-12-12 07:43:39 +0100
commitb41221bfb9894d0182eef237a0e1c0d4f8130c00 (patch)
treef1f904899d0f8e54a87c273f43270cdec6b6f715 /config-model/src
parent945279cb802887375234bed3a18413ac47271f12 (diff)
Restart if the set of onnx models has changed for a container cluster
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java69
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java41
2 files changed, 79 insertions, 31 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java
index c72be979a16..b1189dbf923 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java
@@ -4,14 +4,19 @@ package com.yahoo.vespa.model.application.validation.change;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.vespa.model.Host;
import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.logging.Logger;
import static java.util.logging.Level.FINE;
+import static com.yahoo.config.model.api.OnnxModelCost.ModelInfo;
/**
* If Onnx models change in a way that requires restart of containers in
@@ -32,27 +37,49 @@ public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValida
// Compare onnx models used by each cluster and set restart on deploy for cluster if estimated cost,
// model hash or model options have changed
// TODO: Skip if container has enough memory to handle reload of onnx model (2 models in memory at the same time)
-
for (var cluster : nextModel.getContainerClusters().values()) {
var clusterInCurrentModel = currentModel.getContainerClusters().get(cluster.getName());
if (clusterInCurrentModel == null) continue;
- log.log(FINE, "Validating cluster '" + cluster.name() + "'");
var currentModels = clusterInCurrentModel.onnxModelCostCalculator().models();
var nextModels = cluster.onnxModelCostCalculator().models();
- log.log(FINE, "current models=" + currentModels + ", next models=" + nextModels);
-
- for (var nextModelInfo : nextModels.values()) {
- if (!currentModels.containsKey(nextModelInfo.modelId())) continue;
-
- log.log(FINE, "Checking if " + nextModelInfo + " has changed");
- modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> {
- String message = "Onnx model '%s' has changed (%s), need to restart services in container cluster '%s'"
- .formatted(nextModelInfo.modelId(), change, cluster.name());
- cluster.onnxModelCostCalculator().setRestartOnDeploy();
- actions.add(new VespaRestartAction(cluster.id(), message));
- });
- }
+
+ log.log(FINE, "Validating " + cluster + ", current models=" + currentModels + ", next models=" + nextModels);
+ actions.addAll(validateModelChanges(cluster, currentModels, nextModels));
+ actions.addAll(validateSetOfModels(cluster, currentModels, nextModels));
+ }
+ return actions;
+ }
+
+ private List<ConfigChangeAction> validateModelChanges(ApplicationContainerCluster cluster,
+ Map<String, ModelInfo> currentModels,
+ Map<String, ModelInfo> nextModels) {
+ List<ConfigChangeAction> actions = new ArrayList<>();
+ for (var nextModelInfo : nextModels.values()) {
+ if (! currentModels.containsKey(nextModelInfo.modelId())) continue;
+
+ log.log(FINE, "Checking if " + nextModelInfo + " has changed");
+ modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> {
+ String message = "Onnx model '%s' has changed (%s), need to restart services in %s"
+ .formatted(nextModelInfo.modelId(), change, cluster);
+ cluster.onnxModelCostCalculator().setRestartOnDeploy();
+ addRestartAction(actions, cluster, message);
+ });
+ }
+ return actions;
+ }
+
+ private List<ConfigChangeAction> validateSetOfModels(ApplicationContainerCluster cluster,
+ Map<String, ModelInfo> currentModels,
+ Map<String, ModelInfo> nextModels) {
+ List<ConfigChangeAction> actions = new ArrayList<>();
+ Set<String> currentModelIds = currentModels.keySet();
+ Set<String> nextModelIds = nextModels.keySet();
+ log.log(FINE, "Checking if model set has changed (%s) -> (%s)".formatted(currentModelIds, nextModelIds));
+ if (! currentModelIds.equals(nextModelIds)) {
+ String message = "Onnx model set has changed from %s to %s, need to restart services in %s"
+ .formatted(currentModelIds, nextModelIds, cluster);
+ addRestartAction(actions, cluster, message);
}
return actions;
}
@@ -67,4 +94,16 @@ public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValida
return Optional.empty();
}
+ private static void addRestartAction(List<ConfigChangeAction> actions, ApplicationContainerCluster cluster, String message) {
+ actions.add(new VespaRestartAction(cluster.id(), message));
+ }
+
+ private static boolean enoughMemoryToAvoidRestart(ApplicationContainerCluster cluster) {
+ // Node memory is known so convert available memory percentage to node memory percentage
+ double totalMemory = cluster.getContainers().get(0).getHostResource().realResources().memoryGb();
+ double availableMemory = Math.max(0, totalMemory - Host.memoryOverheadGb);
+ double costInGb = (double) cluster.onnxModelCostCalculator().aggregatedModelCostInBytes() / 1024 / 1024 / 1024;
+ return ( 2 * costInGb < availableMemory);
+ }
+
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
index 1845bcf0b52..5873d15bd9a 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
@@ -40,7 +40,8 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
List<ConfigChangeAction> result = validateModel(current, next);
assertEquals(1, result.size());
assertTrue(result.get(0).validationId().isEmpty());
- assertEquals("Onnx model 'https://my/url/model.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertEquals("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (estimated cost)", result);
}
@Test
@@ -49,7 +50,7 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
VespaModel next = createModel(onnxModelCost(0, 123));
List<ConfigChangeAction> result = validateModel(current, next);
assertEquals(1, result.size());
- assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model hash), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (model hash)", result);
}
@Test
@@ -58,7 +59,16 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
VespaModel next = createModel(onnxModelCost(0, 0), "sequential");
List<ConfigChangeAction> result = validateModel(current, next);
assertEquals(1, result.size());
- assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model option(s)), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (model option(s))", result);
+ }
+
+ @Test
+ void validate_changed_model_set() {
+ VespaModel current = createModel();
+ VespaModel next = createModel(onnxModelCost(0, 0), "parallel", "e5-small-v2");
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(1, result.size());
+ assertStartsWith("Onnx model set has changed from [https://my/url/e5-base-v2.onnx] to [https://my/url/e5-small-v2.onnx", result);
}
private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) {
@@ -112,12 +122,16 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
}
private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode) {
+ return createModel(onnxModelCost, executionMode, "e5-base-v2");
+ }
+
+ private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode, String modelId) {
DeployState.Builder builder = deployStateBuilder();
builder.onnxModelCost(onnxModelCost);
- return createModel(builder, executionMode);
+ return createModel(builder, executionMode, modelId);
}
- private static VespaModel createModel(DeployState.Builder builder, String executionMode) {
+ private static VespaModel createModel(DeployState.Builder builder, String executionMode, String modelId) {
String xml = """
<services version='1.0'>
<container id='cluster1' version='1.0'>
@@ -125,18 +139,9 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
<server id='server1' port='8080'/>
</http>
<component id="hf-embedder" type="hugging-face-embedder">
- <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <transformer-model model-id="%s" url="https://my/url/%s.onnx"/>
<tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
- <max-tokens>1024</max-tokens>
- <transformer-input-ids>my_input_ids</transformer-input-ids>
- <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
- <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
- <transformer-output>my_output</transformer-output>
- <normalize>true</normalize>
<onnx-execution-mode>%s</onnx-execution-mode>
- <onnx-intraop-threads>10</onnx-intraop-threads>
- <onnx-interop-threads>8</onnx-interop-threads>
- <pooling-strategy>mean</pooling-strategy>
</component>
</container>
<container id='cluster2' version='1.0'>
@@ -145,7 +150,7 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
</http>
</container>
</services>
- """.formatted(executionMode);
+ """.formatted(modelId, modelId, executionMode);
return new VespaModelCreatorWithMockPkg(null, xml).create(builder);
}
@@ -155,4 +160,8 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
.properties((new TestProperties()).setRestartOnDeployForOnnxModelChanges(true));
}
+ private static void assertStartsWith(String expected, List<ConfigChangeAction> result) {
+ assertTrue(result.get(0).getMessage().startsWith(expected));
+ }
+
}