summaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-12-12 07:43:39 +0100
committerHarald Musum <musum@yahooinc.com>2023-12-12 07:43:39 +0100
commitb41221bfb9894d0182eef237a0e1c0d4f8130c00 (patch)
treef1f904899d0f8e54a87c273f43270cdec6b6f715 /config-model/src/test/java
parent945279cb802887375234bed3a18413ac47271f12 (diff)
Restart if the set of onnx models has changed for a container cluster
Diffstat (limited to 'config-model/src/test/java')
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java41
1 files changed, 25 insertions, 16 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
index 1845bcf0b52..5873d15bd9a 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
@@ -40,7 +40,8 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
List<ConfigChangeAction> result = validateModel(current, next);
assertEquals(1, result.size());
assertTrue(result.get(0).validationId().isEmpty());
- assertEquals("Onnx model 'https://my/url/model.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertEquals("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (estimated cost)", result);
}
@Test
@@ -49,7 +50,7 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
VespaModel next = createModel(onnxModelCost(0, 123));
List<ConfigChangeAction> result = validateModel(current, next);
assertEquals(1, result.size());
- assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model hash), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (model hash)", result);
}
@Test
@@ -58,7 +59,16 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
VespaModel next = createModel(onnxModelCost(0, 0), "sequential");
List<ConfigChangeAction> result = validateModel(current, next);
assertEquals(1, result.size());
- assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model option(s)), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ assertStartsWith("Onnx model 'https://my/url/e5-base-v2.onnx' has changed (model option(s))", result);
+ }
+
+ @Test
+ void validate_changed_model_set() {
+ VespaModel current = createModel();
+ VespaModel next = createModel(onnxModelCost(0, 0), "parallel", "e5-small-v2");
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(1, result.size());
+ assertStartsWith("Onnx model set has changed from [https://my/url/e5-base-v2.onnx] to [https://my/url/e5-small-v2.onnx", result);
}
private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) {
@@ -112,12 +122,16 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
}
private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode) {
+ return createModel(onnxModelCost, executionMode, "e5-base-v2");
+ }
+
+ private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode, String modelId) {
DeployState.Builder builder = deployStateBuilder();
builder.onnxModelCost(onnxModelCost);
- return createModel(builder, executionMode);
+ return createModel(builder, executionMode, modelId);
}
- private static VespaModel createModel(DeployState.Builder builder, String executionMode) {
+ private static VespaModel createModel(DeployState.Builder builder, String executionMode, String modelId) {
String xml = """
<services version='1.0'>
<container id='cluster1' version='1.0'>
@@ -125,18 +139,9 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
<server id='server1' port='8080'/>
</http>
<component id="hf-embedder" type="hugging-face-embedder">
- <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <transformer-model model-id="%s" url="https://my/url/%s.onnx"/>
<tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
- <max-tokens>1024</max-tokens>
- <transformer-input-ids>my_input_ids</transformer-input-ids>
- <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
- <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
- <transformer-output>my_output</transformer-output>
- <normalize>true</normalize>
<onnx-execution-mode>%s</onnx-execution-mode>
- <onnx-intraop-threads>10</onnx-intraop-threads>
- <onnx-interop-threads>8</onnx-interop-threads>
- <pooling-strategy>mean</pooling-strategy>
</component>
</container>
<container id='cluster2' version='1.0'>
@@ -145,7 +150,7 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
</http>
</container>
</services>
- """.formatted(executionMode);
+ """.formatted(modelId, modelId, executionMode);
return new VespaModelCreatorWithMockPkg(null, xml).create(builder);
}
@@ -155,4 +160,8 @@ public class RestartOnDeployForOnnxModelChangesValidatorTest {
.properties((new TestProperties()).setRestartOnDeployForOnnxModelChanges(true));
}
+ private static void assertStartsWith(String expected, List<ConfigChangeAction> result) {
+ assertTrue(result.get(0).getMessage().startsWith(expected));
+ }
+
}