From 96cba97a152b9c67da0a5860e920ab5a39887a94 Mon Sep 17 00:00:00 2001 From: Harald Musum Date: Fri, 8 Dec 2023 12:07:59 +0100 Subject: Add validator that checks if restart is needed due to Onnx model changes Validates changes and creates a restart action if needed and makes sure configs for cluster are marked as restartOnDeploy --- config-model-api/abi-spec.json | 33 ++++- .../com/yahoo/config/model/api/ModelContext.java | 1 + .../com/yahoo/config/model/api/OnnxModelCost.java | 11 +- .../yahoo/config/model/deploy/TestProperties.java | 7 + .../validation/JvmHeapSizeValidator.java | 2 +- .../model/application/validation/Validation.java | 4 +- ...estartOnDeployForOnnxModelChangesValidator.java | 67 +++++++++ .../container/ApplicationContainerCluster.java | 14 +- .../vespa/model/container/component/Model.java | 4 +- .../model/container/search/ContainerSearch.java | 2 +- .../model/container/xml/ContainerModelBuilder.java | 2 +- .../validation/JvmHeapSizeValidatorTest.java | 4 + ...rtOnDeployForOnnxModelChangesValidatorTest.java | 158 +++++++++++++++++++++ .../config/server/deploy/ModelContextImpl.java | 3 + 14 files changed, 298 insertions(+), 14 deletions(-) create mode 100644 config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java create mode 100644 config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index 8f5d0d37c21..ba483fb0421 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1289,7 +1289,8 @@ "public boolean usePerDocumentThrottledDeleteBucket()", "public boolean alwaysMarkPhraseExpensive()", "public boolean createPostinglistWhenNonStrict()", - "public boolean useEstimateForFetchPostings()" + "public boolean useEstimateForFetchPostings()", + "public boolean restartOnDeployWhenOnnxModelChanges()" ], "fields" : [ ] }, @@ -1457,7 +1458,10 @@ "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)", "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", "public abstract void registerModel(java.net.URI)", - "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" + "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)", + "public abstract java.util.Map models()", + "public abstract void setRestartOnDeploy()", + "public abstract boolean restartOnDeploy()" ], "fields" : [ ] }, @@ -1477,7 +1481,30 @@ "public void registerModel(com.yahoo.config.application.api.ApplicationFile)", "public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", "public void registerModel(java.net.URI)", - "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" + "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)", + "public java.util.Map models()", + "public void setRestartOnDeploy()", + "public boolean restartOnDeploy()" + ], + "fields" : [ ] + }, + "com.yahoo.config.model.api.OnnxModelCost$ModelInfo" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void (java.lang.String, long, long, java.util.Optional)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public java.lang.String modelId()", + "public long estimatedCost()", + "public long hash()", + "public java.util.Optional onnxModelOptions()" ], "fields" : [ ] }, diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index f34f63a0cfc..e5cc13719c1 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -118,6 +118,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"baldersheim"}) default boolean alwaysMarkPhraseExpensive() { return false; } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean createPostinglistWhenNonStrict() { return true; } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean useEstimateForFetchPostings() { return false; } + @ModelFeatureFlag(owners = {"hmusum"}) default boolean restartOnDeployWhenOnnxModelChanges() { return false; } } /** Warning: As elsewhere in this package, do not make backwards incompatible changes that will break old config models! */ diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index b98667457e4..c13ce4def09 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -6,11 +6,12 @@ import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.provision.ApplicationId; import java.net.URI; +import java.util.Map; +import java.util.Optional; /** * @author bjorncs */ -// TODO: Rename public interface OnnxModelCost { Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId); @@ -21,8 +22,13 @@ public interface OnnxModelCost { void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions); void registerModel(URI uri); void registerModel(URI uri, OnnxModelOptions onnxModelOptions); + Map models(); + void setRestartOnDeploy(); + boolean restartOnDeploy(); } + record ModelInfo(String modelId, long estimatedCost, long hash, Optional onnxModelOptions) {} + static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); } class DisabledOnnxModelCost implements OnnxModelCost, Calculator { @@ -32,6 +38,9 @@ public interface OnnxModelCost { @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) {} @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {} + @Override public Map models() { return Map.of(); } + @Override public void setRestartOnDeploy() {} + @Override public boolean restartOnDeploy() { return false; } } } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index cd4998bb912..2e1c661e09a 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -86,6 +86,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private boolean dynamicHeapSize = false; private long mergingMaxMemoryUsagePerNode = -1; private boolean usePerDocumentThrottledDeleteBucket = false; + private boolean restartOnDeployWhenOnnxModelChanges = false; @Override public ModelContext.FeatureFlags featureFlags() { return this; } @Override public boolean multitenant() { return multitenant; } @@ -146,6 +147,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public boolean dynamicHeapSize() { return dynamicHeapSize; } @Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; } @Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; } + @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; } public TestProperties sharedStringRepoNoReclaim(boolean sharedStringRepoNoReclaim) { this.sharedStringRepoNoReclaim = sharedStringRepoNoReclaim; @@ -388,6 +390,11 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea return this; } + public TestProperties setRestartOnDeployForOnnxModelChanges(boolean enable) { + this.restartOnDeployWhenOnnxModelChanges = enable; + return this; + } + public static class Spec implements ConfigServerSpec { private final String hostName; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java index 425a662bb2d..60f325cbe43 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java @@ -27,7 +27,7 @@ public class JvmHeapSizeValidator extends Validator { ds.getDeployLogger().log(Level.FINE, "Host resources unknown or percentage overridden with 'allocated-memory'"); return; } - long jvmModelCost = appCluster.onnxModelCost().aggregatedModelCostInBytes(); + long jvmModelCost = appCluster.onnxModelCostCalculator().aggregatedModelCostInBytes(); if (jvmModelCost > 0) { int percentLimit = 15; double gbLimit = 0.6; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java index d7699bb3180..56277345515 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java @@ -21,6 +21,7 @@ import com.yahoo.vespa.model.application.validation.change.IndexingModeChangeVal import com.yahoo.vespa.model.application.validation.change.NodeResourceChangeValidator; import com.yahoo.vespa.model.application.validation.change.RedundancyIncreaseValidator; import com.yahoo.vespa.model.application.validation.change.ResourcesReductionValidator; +import com.yahoo.vespa.model.application.validation.change.RestartOnDeployForOnnxModelChangesValidator; import com.yahoo.vespa.model.application.validation.change.StartupCommandChangeValidator; import com.yahoo.vespa.model.application.validation.change.StreamingSearchClusterChangeValidator; import com.yahoo.vespa.model.application.validation.first.RedundancyValidator; @@ -122,7 +123,8 @@ public class Validation { new NodeResourceChangeValidator(), new RedundancyIncreaseValidator(), new CertificateRemovalChangeValidator(), - new RedundancyValidator() + new RedundancyValidator(), + new RestartOnDeployForOnnxModelChangesValidator(), }; List actions = Arrays.stream(validators) .flatMap(v -> v.validate(currentModel, nextModel, deployState).stream()) diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java new file mode 100644 index 00000000000..64ada801be2 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidator.java @@ -0,0 +1,67 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.application.validation.change; + +import com.yahoo.config.model.api.ConfigChangeAction; +import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.vespa.model.VespaModel; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.logging.Logger; + +import static java.util.logging.Level.FINE; + +/** + * If Onnx models change in a way that requires restart of containers in + * a container cluster this validator will make sure that restartOnDeploy is set for + * configs for this cluster. + * + * @author hmusum + */ +public class RestartOnDeployForOnnxModelChangesValidator implements ChangeValidator { + + private static final Logger log = Logger.getLogger(RestartOnDeployForOnnxModelChangesValidator.class.getSimpleName()); + + @Override + public List validate(VespaModel currentModel, VespaModel nextModel, DeployState deployState) { + if ( ! deployState.featureFlags().restartOnDeployWhenOnnxModelChanges()) return List.of(); + List actions = new ArrayList<>(); + + // Compare onnx models used by each cluster and set restart on deploy for cluster if estimated cost, + // model hash or model options have changed + // TODO: Skip if container has enough memory to handle reload of onnx model (2 models in memory at the same time) + + for (var cluster : nextModel.getContainerClusters().values()) { + var clusterInCurrentModel = currentModel.getContainerClusters().get(cluster.getName()); + if (clusterInCurrentModel == null) continue; + + log.log(FINE, "Validating cluster '" + cluster.name() + "'"); + var currentModels = clusterInCurrentModel.onnxModelCostCalculator().models(); + var nextModels = cluster.onnxModelCostCalculator().models(); + log.log(FINE, "current models=" + currentModels + ", next models=" + nextModels); + + for (var nextModelInfo : nextModels.values()) { + if (!currentModels.containsKey(nextModelInfo.modelId())) continue; + + log.log(FINE, "Checking if " + nextModelInfo + " has changed"); + modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> { + String message = "Onnx model '%s' has changed (%s), need to restart services in container cluster '%s'" + .formatted(nextModelInfo.modelId(), change, cluster.name()); + cluster.onnxModelCostCalculator().setRestartOnDeploy(); + actions.add(new VespaRestartAction(cluster.id(), message)); + }); + } + } + return actions; + } + + private Optional modelChanged(OnnxModelCost.ModelInfo a, OnnxModelCost.ModelInfo b) { + if (a.estimatedCost() != b.estimatedCost()) return Optional.of("estimated cost"); + if (a.hash() != b.hash()) return Optional.of("model hash"); + if (! a.onnxModelOptions().equals(b.onnxModelOptions())) return Optional.of("model option(s)"); + return Optional.empty(); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index e04711a1c56..20b5c687257 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -87,7 +87,8 @@ public final class ApplicationContainerCluster extends ContainerCluster applicationBundles = new LinkedHashSet<>(); private final Set previousHosts; - private final OnnxModelCost.Calculator onnxModelCost; + private final OnnxModelCost onnxModelCost; + private final OnnxModelCost.Calculator onnxModelCostCalculator; private final DeployLogger logger; private ContainerModelEvaluation modelEvaluation; @@ -136,7 +137,8 @@ public final class ApplicationContainerCluster extends ContainerCluster 0 ? Math.min(99, deployState.featureFlags().heapSizePercentage()) : defaultHeapSizePercentageOfAvailableMemory; - onnxModelCost = deployState.onnxModelCost().newCalculator( + onnxModelCost = deployState.onnxModelCost(); + onnxModelCostCalculator = deployState.onnxModelCost().newCalculator( deployState.getApplicationPackage(), deployState.getProperties().applicationId()); logger = deployState.getDeployLogger(); } @@ -150,6 +152,8 @@ public final class ApplicationContainerCluster extends ContainerCluster c.getHostResource().realResources().memoryGb()).min().orElseThrow() : getContainers().get(0).getHostResource().realResources().memoryGb(); - double jvmHeapDeductionGb = dynamicHeapSize ? onnxModelCost.aggregatedModelCostInBytes() / (1024D * 1024 * 1024) : 0; + double jvmHeapDeductionGb = dynamicHeapSize ? onnxModelCostCalculator.aggregatedModelCostInBytes() / (1024D * 1024 * 1024) : 0; double availableMemory = Math.max(0, totalMemory - Host.memoryOverheadGb - jvmHeapDeductionGb); int memoryPercentage = (int) (availableMemory / totalMemory * availableMemoryPercentage); logger.log(FINE, () -> "cluster id '%s': memoryPercentage=%d, availableMemory=%f, totalMemory=%f, availableMemoryPercentage=%d, jvmHeapDeductionGb=%f" @@ -381,7 +385,9 @@ public final class ApplicationContainerCluster extends ContainerCluster if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) { var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels(); onnxModels.asMap().forEach( - (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions())); + (__, model) -> owningCluster.onnxModelCostCalculator().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions())); owningCluster.addComponent(factory); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 104d19d8953..e4038a5bca6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -800,7 +800,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder { !container.getHostResource().realResources().gpuResources().isZero()); onnxModel.setGpuDevice(gpuDevice, hasGpu); } - cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); + cluster.onnxModelCostCalculator().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 9cadf5cffd8..213cf4bdfcf 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -22,6 +22,7 @@ import org.xml.sax.SAXException; import java.io.IOException; import java.net.URI; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.concurrent.atomic.AtomicLong; @@ -120,6 +121,9 @@ class JvmHeapSizeValidatorTest { ModelCostDummy(long modelCost) { this.modelCost = modelCost; } @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } + @Override public Map models() { return Map.of(); } + @Override public void setRestartOnDeploy() {} + @Override public boolean restartOnDeploy() { return false;} @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java new file mode 100644 index 00000000000..1845bcf0b52 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/RestartOnDeployForOnnxModelChangesValidatorTest.java @@ -0,0 +1,158 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.application.validation.change; + +import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.api.ConfigChangeAction; +import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.api.OnnxModelOptions; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * @author hmusum + */ +public class RestartOnDeployForOnnxModelChangesValidatorTest { + + @Test + void validate_no_changes() { + VespaModel current = createModel(); + VespaModel next = createModel(); + List result = validateModel(current, next); + assertEquals(0, result.size()); + } + + @Test + void validate_changed_estimated_cost() { + VespaModel current = createModel(); + VespaModel next = createModel(onnxModelCost(123, 0)); + List result = validateModel(current, next); + assertEquals(1, result.size()); + assertTrue(result.get(0).validationId().isEmpty()); + assertEquals("Onnx model 'https://my/url/model.onnx' has changed (estimated cost), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + } + + @Test + void validate_changed_hash() { + VespaModel current = createModel(); + VespaModel next = createModel(onnxModelCost(0, 123)); + List result = validateModel(current, next); + assertEquals(1, result.size()); + assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model hash), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + } + + @Test + void validate_changed_option() { + VespaModel current = createModel(); + VespaModel next = createModel(onnxModelCost(0, 0), "sequential"); + List result = validateModel(current, next); + assertEquals(1, result.size()); + assertEquals("Onnx model 'https://my/url/model.onnx' has changed (model option(s)), need to restart services in container cluster 'cluster1'", result.get(0).getMessage()); + } + + private static List validateModel(VespaModel current, VespaModel next) { + return new RestartOnDeployForOnnxModelChangesValidator().validate(current, next, deployStateBuilder().build()); + } + + private static OnnxModelCost onnxModelCost() { + return onnxModelCost(0, 0); + } + + private static OnnxModelCost onnxModelCost(long estimatedCost, long hash) { + return (appPkg, applicationId) -> new OnnxModelCost.Calculator() { + + private final Map models = new HashMap<>(); + private boolean restartOnDeploy = false; + + @Override + public long aggregatedModelCostInBytes() { return estimatedCost; } + + @Override + public void registerModel(ApplicationFile path) {} + + @Override + public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} + + @Override + public void registerModel(URI uri) {} + + @Override + public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) { + models.put(uri.toString(), new OnnxModelCost.ModelInfo(uri.toString(), estimatedCost, hash, Optional.ofNullable(onnxModelOptions))); + } + + @Override + public Map models() { return models; } + + @Override + public void setRestartOnDeploy() { restartOnDeploy = true; } + + @Override + public boolean restartOnDeploy() { return restartOnDeploy; } + }; + } + + private static VespaModel createModel() { + return createModel(onnxModelCost()); + } + + private static VespaModel createModel(OnnxModelCost onnxModelCost) { + return createModel(onnxModelCost, "parallel"); + } + + private static VespaModel createModel(OnnxModelCost onnxModelCost, String executionMode) { + DeployState.Builder builder = deployStateBuilder(); + builder.onnxModelCost(onnxModelCost); + return createModel(builder, executionMode); + } + + private static VespaModel createModel(DeployState.Builder builder, String executionMode) { + String xml = """ + + + + + + + + + 1024 + my_input_ids + my_attention_mask + my_token_type_ids + my_output + true + %s + 10 + 8 + mean + + + + + + + + + """.formatted(executionMode); + + return new VespaModelCreatorWithMockPkg(null, xml).create(builder); + } + + private static DeployState.Builder deployStateBuilder() { + return new DeployState.Builder() + .properties((new TestProperties()).setRestartOnDeployForOnnxModelChanges(true)); + } + +} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 2f126cd84d3..e5ef36e07d8 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -210,6 +210,7 @@ public class ModelContextImpl implements ModelContext { private final int searchHandlerThreadpool; private final long mergingMaxMemoryUsagePerNode; private final boolean usePerDocumentThrottledDeleteBucket; + private final boolean restartOnDeployWhenOnnxModelChanges; public FeatureFlags(FlagSource source, ApplicationId appId, Version version) { this.defaultTermwiseLimit = flagValue(source, appId, version, Flags.DEFAULT_TERM_WISE_LIMIT); @@ -256,6 +257,7 @@ public class ModelContextImpl implements ModelContext { this.alwaysMarkPhraseExpensive = flagValue(source, appId, version, Flags.ALWAYS_MARK_PHRASE_EXPENSIVE); this.createPostinglistWhenNonStrict = flagValue(source, appId, version, Flags.CREATE_POSTINGLIST_WHEN_NON_STRICT); this.useEstimateForFetchPostings = flagValue(source, appId, version, Flags.USE_ESTIMATE_FOR_FETCH_POSTINGS); + this.restartOnDeployWhenOnnxModelChanges = flagValue(source, appId, version, Flags.RESTART_ON_DEPLOY_WHEN_ONNX_MODEL_CHANGES); } @Override public int heapSizePercentage() { return heapPercentage; } @@ -310,6 +312,7 @@ public class ModelContextImpl implements ModelContext { @Override public int searchHandlerThreadpool() { return searchHandlerThreadpool; } @Override public long mergingMaxMemoryUsagePerNode() { return mergingMaxMemoryUsagePerNode; } @Override public boolean usePerDocumentThrottledDeleteBucket() { return usePerDocumentThrottledDeleteBucket; } + @Override public boolean restartOnDeployWhenOnnxModelChanges() { return restartOnDeployWhenOnnxModelChanges; } private static V flagValue(FlagSource source, ApplicationId appId, Version vespaVersion, UnboundFlag flag) { return flag.bindTo(source) -- cgit v1.2.3