summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2023-12-08 14:18:39 +0100
committerGitHub <noreply@github.com>2023-12-08 14:18:39 +0100
commit39032146a75f003bffb14459fee69bf8b9436fac (patch)
tree6da847fe9cf57dee4746c0df806acb20239561e5
parent8d9c697019bc52ce02001048eb263077a7c4b98e (diff)
parent96cba97a152b9c67da0a5860e920ab5a39887a94 (diff)
Merge pull request #29591 from vespa-engine/hmusum/validate-onnx-model-changes
Add validator that checks if restart is needed due to Onnx model changes
-rw-r--r--config-model-api/abi-spec.json33
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java1
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java11
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java67
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java14
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java158
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java3
14 files changed, 298 insertions, 14 deletions
diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json
index 8f5d0d37c21..ba483fb0421 100644
--- a/config-model-api/abi-spec.json
+++ b/config-model-api/abi-spec.json
@@ -1289,7 +1289,8 @@
"public boolean usePerDocumentThrottledDeleteBucket()",
"public boolean alwaysMarkPhraseExpensive()",
"public boolean createPostinglistWhenNonStrict()",
- "public boolean useEstimateForFetchPostings()"
+ "public boolean useEstimateForFetchPostings()",
+ "public boolean restartOnDeployWhenOnnxModelChanges()"
],
"fields" : [ ]
},
@@ -1457,7 +1458,10 @@
"public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)",
"public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
"public abstract void registerModel(java.net.URI)",
- "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
+ "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)",
+ "public abstract java.util.Map models()",
+ "public abstract void setRestartOnDeploy()",
+ "public abstract boolean restartOnDeploy()"
],
"fields" : [ ]
},
@@ -1477,7 +1481,30 @@
"public void registerModel(com.yahoo.config.application.api.ApplicationFile)",
"public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)",
"public void registerModel(java.net.URI)",
- "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)"
+ "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)",
+ "public java.util.Map models()",
+ "public void setRestartOnDeploy()",
+ "public boolean restartOnDeploy()"
+ ],
+ "fields" : [ ]
+ },
+ "com.yahoo.config.model.api.OnnxModelCost$ModelInfo" : {
+ "superClass" : "java.lang.Record",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final",
+ "record"
+ ],
+ "methods" : [
+ "public void <init>(java.lang.String, long, long, java.util.Optional)",
+ "public final java.lang.String toString()",
+ "public final int hashCode()",
+ "public final boolean equals(java.lang.Object)",
+ "public java.lang.String modelId()",
+ "public long estimatedCost()",
+ "public long hash()",
+ "public java.util.Optional onnxModelOptions()"
],
"fields" : [ ]
},
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
index f34f63a0cfc..e5cc13719c1 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
@@ -118,6 +118,7 @@ public interface ModelContext {
@ModelFeatureFlag(owners = {"baldersheim"}) default boolean alwaysMarkPhraseExpensive() { return false; }
@ModelFeatureFlag(owners = {"baldersheim"}) default boolean createPostinglistWhenNonStrict() { return true; }
@ModelFeatureFlag(owners = {"baldersheim"}) default boolean useEstimateForFetchPostings() { return false; }
+ @ModelFeatureFlag(owners = {"hmusum"}) default boolean restartOnDeployWhenOnnxModelChanges() { return false; }
}
/** Warning: As elsewhere in this package, do not make backwards incompatible changes that will break old config models! */
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
index b98667457e4..c13ce4def09 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java
@@ -6,11 +6,12 @@ import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.provision.ApplicationId;
import java.net.URI;
+import java.util.Map;
+import java.util.Optional;
/**
* @author bjorncs
*/
-// TODO: Rename
public interface OnnxModelCost {
Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId);
@@ -21,8 +22,13 @@ public interface OnnxModelCost {
void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions);
void registerModel(URI uri);
void registerModel(URI uri, OnnxModelOptions onnxModelOptions);
+ Map<String, ModelInfo> models();
+ void setRestartOnDeploy();
+ boolean restartOnDeploy();
}
+ record ModelInfo(String modelId, long estimatedCost, long hash, Optional<OnnxModelOptions> onnxModelOptions) {}
+
static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); }
class DisabledOnnxModelCost implements OnnxModelCost, Calculator {
@@ -32,6 +38,9 @@ public interface OnnxModelCost {
@Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
@Override public void registerModel(URI uri) {}
@Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {}
+ @Override public Map<String, ModelInfo> models() { return Map.of(); }
+ @Override public void setRestartOnDeploy() {}
+ @Override public boolean restartOnDeploy() { return false; }
}
}
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java
index cd4998bb912..2e1c661e09a 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java
@@ -86,6 +86,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea
private boolean dynamicHeapSize = false;
private long mergingMaxMemoryUsagePerNode = -1;
private boolean usePerDocumentThrottledDeleteBucket = false;
+ private boolean restartOnDeployWhenOnnxModelChanges = false;
@Override public ModelContext.FeatureFlags featureFlags() { return this; }
@Override public boolean multitenant() { return multitenant; }
@@ -146,6 +147,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea
@Override public boolean dynamicHeapSize() { return dynamicHeapSize; }
@Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; }
@Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; }
+ @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; }
public TestProperties sharedStringRepoNoReclaim(boolean sharedStringRepoNoReclaim) {
this.sharedStringRepoNoReclaim = sharedStringRepoNoReclaim;
@@ -388,6 +390,11 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea
return this;
}
+ public TestProperties setRestartOnDeployForOnnxModelChanges(boolean enable) {
+ this.restartOnDeployWhenOnnxModelChanges = enable;
+ return this;
+ }
+
public static class Spec implements ConfigServerSpec {
private final String hostName;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java
index 425a662bb2d..60f325cbe43 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java
@@ -27,7 +27,7 @@ public class JvmHeapSizeValidator extends Validator {
ds.getDeployLogger().log(Level.FINE, "Host resources unknown or percentage overridden with 'allocated-memory'");
return;
}
- long jvmModelCost = appCluster.onnxModelCost().aggregatedModelCostInBytes();
+ long jvmModelCost = appCluster.onnxModelCostCalculator().aggregatedModelCostInBytes();
if (jvmModelCost > 0) {
int percentLimit = 15;
double gbLimit = 0.6;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java
index d7699bb3180..56277345515 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java
@@ -21,6 +21,7 @@ import com.yahoo.vespa.model.application.validation.change.IndexingModeChangeVal
import com.yahoo.vespa.model.application.validation.change.NodeResourceChangeValidator;
import com.yahoo.vespa.model.application.validation.change.RedundancyIncreaseValidator;
import com.yahoo.vespa.model.application.validation.change.ResourcesReductionValidator;
+import com.yahoo.vespa.model.application.validation.change.RestartOnDeployForOnnxModelChangesValidator;
import com.yahoo.vespa.model.application.validation.change.StartupCommandChangeValidator;
import com.yahoo.vespa.model.application.validation.change.StreamingSearchClusterChangeValidator;
import com.yahoo.vespa.model.application.validation.first.RedundancyValidator;
@@ -122,7 +123,8 @@ public class Validation {
new NodeResourceChangeValidator(),
new RedundancyIncreaseValidator(),
new CertificateRemovalChangeValidator(),
- new RedundancyValidator()
+ new RedundancyValidator(),
+ new RestartOnDeployForOnnxModelChangesValidator(),
};
List<ConfigChangeAction> actions = Arrays.stream(validators)
.flatMap(v -> v.validate(currentModel, nextModel, deployState).stream())
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java
new file mode 100644
index 00000000000..64ada801be2
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java
@@ -0,0 +1,67 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.application.validation.change;
+
+import com.yahoo.config.model.api.ConfigChangeAction;
+import com.yahoo.config.model.api.OnnxModelCost;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.vespa.model.VespaModel;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+import static java.util.logging.Level.FINE;
+
+/**
+ * If Onnx models change in a way that requires restart of containers in
+ * a container cluster this validator will make sure that restartOnDeploy is set for
+ * configs for this cluster.
+ *
+ * @author hmusum
+ */
+public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValidator {
+
+ private static final Logger log = Logger.getLogger(RestartOnDeployForOnnxModelChangesValidator.class.getSimpleName());
+
+ @Override
+ public List<ConfigChangeAction> validate(VespaModel currentModel, VespaModel nextModel, DeployState deployState) {
+ if ( ! deployState.featureFlags().restartOnDeployWhenOnnxModelChanges()) return List.of();
+ List<ConfigChangeAction> actions = new ArrayList<>();
+
+ // Compare onnx models used by each cluster and set restart on deploy for cluster if estimated cost,
+ // model hash or model options have changed
+ // TODO: Skip if container has enough memory to handle reload of onnx model (2 models in memory at the same time)
+
+ for (var cluster : nextModel.getContainerClusters().values()) {
+ var clusterInCurrentModel = currentModel.getContainerClusters().get(cluster.getName());
+ if (clusterInCurrentModel == null) continue;
+
+ log.log(FINE, "Validating cluster '" + cluster.name() + "'");
+ var currentModels = clusterInCurrentModel.onnxModelCostCalculator().models();
+ var nextModels = cluster.onnxModelCostCalculator().models();
+ log.log(FINE, "current models=" + currentModels + ", next models=" + nextModels);
+
+ for (var nextModelInfo : nextModels.values()) {
+ if (!currentModels.containsKey(nextModelInfo.modelId())) continue;
+
+ log.log(FINE, "Checking if " + nextModelInfo + " has changed");
+ modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> {
+ String message = "Onnx model '%s' has changed (%s), need to restart services in container cluster '%s'"
+ .formatted(nextModelInfo.modelId(), change, cluster.name());
+ cluster.onnxModelCostCalculator().setRestartOnDeploy();
+ actions.add(new VespaRestartAction(cluster.id(), message));
+ });
+ }
+ }
+ return actions;
+ }
+
+ private Optional<String> modelChanged(OnnxModelCost.ModelInfo a, OnnxModelCost.ModelInfo b) {
+ if (a.estimatedCost() != b.estimatedCost()) return Optional.of("estimated cost");
+ if (a.hash() != b.hash()) return Optional.of("model hash");
+ if (! a.onnxModelOptions().equals(b.onnxModelOptions())) return Optional.of("model option(s)");
+ return Optional.empty();
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
index e04711a1c56..20b5c687257 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java
@@ -87,7 +87,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
private final Set<FileReference> applicationBundles = new LinkedHashSet<>();
private final Set<String> previousHosts;
- private final OnnxModelCost.Calculator onnxModelCost;
+ private final OnnxModelCost onnxModelCost;
+ private final OnnxModelCost.Calculator onnxModelCostCalculator;
private final DeployLogger logger;
private ContainerModelEvaluation modelEvaluation;
@@ -136,7 +137,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
heapSizePercentageOfAvailableMemory = deployState.featureFlags().heapSizePercentage() > 0
? Math.min(99, deployState.featureFlags().heapSizePercentage())
: defaultHeapSizePercentageOfAvailableMemory;
- onnxModelCost = deployState.onnxModelCost().newCalculator(
+ onnxModelCost = deployState.onnxModelCost();
+ onnxModelCostCalculator = deployState.onnxModelCost().newCalculator(
deployState.getApplicationPackage(), deployState.getProperties().applicationId());
logger = deployState.getDeployLogger();
}
@@ -150,6 +152,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
registerApplicationBundles(deployState);
registerUserConfiguredFiles(deployState);
createEndpoints(deployState);
+ if (onnxModelCostCalculator.restartOnDeploy())
+ setDeferChangesUntilRestart(true);
}
private void registerApplicationBundles(DeployState deployState) {
@@ -215,7 +219,7 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
double totalMemory = dynamicHeapSize
? getContainers().stream().mapToDouble(c -> c.getHostResource().realResources().memoryGb()).min().orElseThrow()
: getContainers().get(0).getHostResource().realResources().memoryGb();
- double jvmHeapDeductionGb = dynamicHeapSize ? onnxModelCost.aggregatedModelCostInBytes() / (1024D * 1024 * 1024) : 0;
+ double jvmHeapDeductionGb = dynamicHeapSize ? onnxModelCostCalculator.aggregatedModelCostInBytes() / (1024D * 1024 * 1024) : 0;
double availableMemory = Math.max(0, totalMemory - Host.memoryOverheadGb - jvmHeapDeductionGb);
int memoryPercentage = (int) (availableMemory / totalMemory * availableMemoryPercentage);
logger.log(FINE, () -> "cluster id '%s': memoryPercentage=%d, availableMemory=%f, totalMemory=%f, availableMemoryPercentage=%d, jvmHeapDeductionGb=%f"
@@ -381,7 +385,9 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat
@Override
public String name() { return getName(); }
- public OnnxModelCost.Calculator onnxModelCost() { return onnxModelCost; }
+ public OnnxModelCost onnxModelCost() { return onnxModelCost; }
+
+ public OnnxModelCost.Calculator onnxModelCostCalculator() { return onnxModelCostCalculator; }
/** Returns whether the deployment in given deploy state should have endpoints */
private static boolean configureEndpoints(DeployState deployState) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
index 0d350242fd0..102ed926fad 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
@@ -57,8 +57,8 @@ class Model {
void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) {
var resolvedUrl = resolvedUrl().orElse(null);
- if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions);
- else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions);
+ if (file != null) c.onnxModelCostCalculator().registerModel(file, onnxModelOptions);
+ else if (resolvedUrl != null) c.onnxModelCostCalculator().registerModel(resolvedUrl, onnxModelOptions);
}
String name() { return paramName; }
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
index 36b6d0fe07a..f58928354db 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java
@@ -107,7 +107,7 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains>
if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) {
var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels();
onnxModels.asMap().forEach(
- (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions()));
+ (__, model) -> owningCluster.onnxModelCostCalculator().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions()));
owningCluster.addComponent(factory);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 104d19d8953..e4038a5bca6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -800,7 +800,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
!container.getHostResource().realResources().gpuResources().isZero());
onnxModel.setGpuDevice(gpuDevice, hasGpu);
}
- cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions());
+ cluster.onnxModelCostCalculator().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions());
}
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models));
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
index 9cadf5cffd8..213cf4bdfcf 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java
@@ -22,6 +22,7 @@ import org.xml.sax.SAXException;
import java.io.IOException;
import java.net.URI;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;
@@ -120,6 +121,9 @@ class JvmHeapSizeValidatorTest {
ModelCostDummy(long modelCost) { this.modelCost = modelCost; }
@Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; }
+ @Override public Map<String, ModelInfo> models() { return Map.of(); }
+ @Override public void setRestartOnDeploy() {}
+ @Override public boolean restartOnDeploy() { return false;}
@Override public long aggregatedModelCostInBytes() { return totalCost.get(); }
@Override public void registerModel(ApplicationFile path) {}
@Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
new file mode 100644
index 00000000000..1845bcf0b52
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java
@@ -0,0 +1,158 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.application.validation.change;
+
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.model.api.ConfigChangeAction;
+import com.yahoo.config.model.api.OnnxModelCost;
+import com.yahoo.config.model.api.OnnxModelOptions;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg;
+import org.junit.jupiter.api.Test;
+
+import java.net.URI;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * @author hmusum
+ */
+public class RestartOnDeployForOnnxModelChangesValidatorTest {
+
+ @Test
+ void validate_no_changes() {
+ VespaModel current = createModel();
+ VespaModel next = createModel();
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(0, result.size());
+ }
+
+ @Test
+ void validate_changed_estimated_cost() {
+ VespaModel current = createModel();
+ VespaModel next = createModel(onnxModelCost(123, 0));
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(1, result.size());
+ assertTrue(result.get(0).validationId().isEmpty());
+ assertEquals("Onnx model 'https://my/url/model.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ }
+
+ @Test
+ void validate_changed_hash() {
+ VespaModel current = createModel();
+ VespaModel next = createModel(onnxModelCost(0, 123));
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(1, result.size());
+ assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model hash), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ }
+
+ @Test
+ void validate_changed_option() {
+ VespaModel current = createModel();
+ VespaModel next = createModel(onnxModelCost(0, 0), "sequential");
+ List<ConfigChangeAction> result = validateModel(current, next);
+ assertEquals(1, result.size());
+ assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model option(s)), need to restart services in container cluster 'cluster1'", result.get(0).getMessage());
+ }
+
+ private static List<ConfigChangeAction> validateModel(VespaModel current, VespaModel next) {
+ return new RestartOnDeployForOnnxModelChangesValidator().validate(current, next, deployStateBuilder().build());
+ }
+
+ private static OnnxModelCost onnxModelCost() {
+ return onnxModelCost(0, 0);
+ }
+
+ private static OnnxModelCost onnxModelCost(long estimatedCost, long hash) {
+ return (appPkg, applicationId) -> new OnnxModelCost.Calculator() {
+
+ private final Map<String, OnnxModelCost.ModelInfo> models = new HashMap<>();
+ private boolean restartOnDeploy = false;
+
+ @Override
+ public long aggregatedModelCostInBytes() { return estimatedCost; }
+
+ @Override
+ public void registerModel(ApplicationFile path) {}
+
+ @Override
+ public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {}
+
+ @Override
+ public void registerModel(URI uri) {}
+
+ @Override
+ public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {
+ models.put(uri.toString(), new OnnxModelCost.ModelInfo(uri.toString(), estimatedCost, hash, Optional.ofNullable(onnxModelOptions)));
+ }
+
+ @Override
+ public Map<String, OnnxModelCost.ModelInfo> models() { return models; }
+
+ @Override
+ public void setRestartOnDeploy() { restartOnDeploy = true; }
+
+ @Override
+ public boolean restartOnDeploy() { return restartOnDeploy; }
+ };
+ }
+
+ private static VespaModel createModel() {
+ return createModel(onnxModelCost());
+ }
+
+ private static VespaModel createModel(OnnxModelCost onnxModelCost) {
+ return createModel(onnxModelCost, "parallel");
+ }
+
+ private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode) {
+ DeployState.Builder builder = deployStateBuilder();
+ builder.onnxModelCost(onnxModelCost);
+ return createModel(builder, executionMode);
+ }
+
+ private static VespaModel createModel(DeployState.Builder builder, String executionMode) {
+ String xml = """
+ <services version='1.0'>
+ <container id='cluster1' version='1.0'>
+ <http>
+ <server id='server1' port='8080'/>
+ </http>
+ <component id="hf-embedder" type="hugging-face-embedder">
+ <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
+ <max-tokens>1024</max-tokens>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <normalize>true</normalize>
+ <onnx-execution-mode>%s</onnx-execution-mode>
+ <onnx-intraop-threads>10</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <pooling-strategy>mean</pooling-strategy>
+ </component>
+ </container>
+ <container id='cluster2' version='1.0'>
+ <http>
+ <server id='server1' port='8081'/>
+ </http>
+ </container>
+ </services>
+ """.formatted(executionMode);
+
+ return new VespaModelCreatorWithMockPkg(null, xml).create(builder);
+ }
+
+ private static DeployState.Builder deployStateBuilder() {
+ return new DeployState.Builder()
+ .properties((new TestProperties()).setRestartOnDeployForOnnxModelChanges(true));
+ }
+
+}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
index 2f126cd84d3..e5ef36e07d8 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
@@ -210,6 +210,7 @@ public class ModelContextImpl implements ModelContext {
private final int searchHandlerThreadpool;
private final long mergingMaxMemoryUsagePerNode;
private final boolean usePerDocumentThrottledDeleteBucket;
+ private final boolean restartOnDeployWhenOnnxModelChanges;
public FeatureFlags(FlagSource source, ApplicationId appId, Version version) {
this.defaultTermwiseLimit = flagValue(source, appId, version, Flags.DEFAULT_TERM_WISE_LIMIT);
@@ -256,6 +257,7 @@ public class ModelContextImpl implements ModelContext {
this.alwaysMarkPhraseExpensive = flagValue(source, appId, version, Flags.ALWAYS_MARK_PHRASE_EXPENSIVE);
this.createPostinglistWhenNonStrict = flagValue(source, appId, version, Flags.CREATE_POSTINGLIST_WHEN_NON_STRICT);
this.useEstimateForFetchPostings = flagValue(source, appId, version, Flags.USE_ESTIMATE_FOR_FETCH_POSTINGS);
+ this.restartOnDeployWhenOnnxModelChanges = flagValue(source, appId, version, Flags.RESTART_ON_DEPLOY_WHEN_ONNX_MODEL_CHANGES);
}
@Override public int heapSizePercentage() { return heapPercentage; }
@@ -310,6 +312,7 @@ public class ModelContextImpl implements ModelContext {
@Override public int searchHandlerThreadpool() { return searchHandlerThreadpool; }
@Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; }
@Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; }
+ @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; }
private static <V> V flagValue(FlagSource source, ApplicationId appId, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) {
return flag.bindTo(source)