diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-20 15:17:28 +0100 |
---|---|---|
committer | Harald Musum <musum@yahooinc.com> | 2023-11-20 15:17:28 +0100 |
commit | 9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 (patch) | |
tree | add8988cb32377e8395c36d2fb46cc9259b5dd63 /config-model/src | |
parent | e8c0a04b67b632ea3f98327d8f39cd0293ad8581 (diff) |
Register model with onnx model options
Diffstat (limited to 'config-model/src')
9 files changed, 25 insertions, 57 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 867ffdb3960..9456baafd57 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,9 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; -import com.yahoo.vespa.model.container.component.OnnxModelOptions; import com.yahoo.vespa.model.ml.OnnxModelInfo; import java.util.Collections; @@ -171,4 +171,6 @@ public class OnnxModel extends DistributableResource implements Cloneable { return onnxModelOptions.gpuDevice(); } + public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index ea3caadc23a..67fb720b8c0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -47,7 +48,7 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index cbae50b400c..d22e6afc3d1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -55,7 +56,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index d1bd0dce000..d98c72ab3a4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -48,7 +49,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java index c5daf23d6f8..0d350242fd0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.builder.xml.XmlHelper; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.path.Path; @@ -54,10 +55,10 @@ class Model { return new Model(ds, model.getTagName(), modelId, url, path); } - void registerOnnxModelCost(ApplicationContainerCluster c) { + void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) { var resolvedUrl = resolvedUrl().orElse(null); - if (file != null) c.onnxModelCost().registerModel(file); - else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl); + if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions); + else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions); } String name() { return paramName; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java deleted file mode 100644 index 6347f0dc427..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.component; - -import java.util.Optional; - -/** - * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the - * values in this class change, reload is needed. - * - * @author hmusum - */ -public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads, - Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) { - - public static OnnxModelOptions empty() { - return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - } - - public OnnxModelOptions withExecutionMode(String executionMode) { - return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice); - } - - public OnnxModelOptions withInterOpThreads(Integer interOpThreads) { - return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice); - } - - public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) { - return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice); - } - - public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { - return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice)); - } - - public record GpuDevice(int deviceNumber, boolean required) { - public GpuDevice { - if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); - } - - public GpuDevice(int deviceNumber) { - this(deviceNumber, false); - } - } - -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index d86d117f1d2..31468c05b99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -52,11 +52,11 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> private final List<SearchCluster> searchClusters = new LinkedList<>(); private final Collection<String> schemasWithGlobalPhase; private final boolean globalPhase; + private final ApplicationPackage app; private QueryProfiles queryProfiles; private SemanticRules semanticRules; private PageTemplates pageTemplates; - private ApplicationPackage app; public ContainerSearch(DeployState deployState, ApplicationContainerCluster cluster, SearchChains chains) { super(chains); @@ -102,7 +102,7 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) { var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels(); onnxModels.asMap().forEach( - (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()))); + (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions())); owningCluster.addComponent(factory); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 18020f5df5d..5ffd34c6557 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -800,7 +800,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { !container.getHostResource().realResources().gpuResources().isZero()); onnxModel.setGpuDevice(gpuDevice, hasGpu); } - cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath())); + cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 8531aff3b1a..9cadf5cffd8 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -1,14 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.model.application.validation; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.provision.InMemoryProvisioner; @@ -123,12 +122,20 @@ class JvmHeapSizeValidatorTest { @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) { assertEquals("https://my/url/model.onnx", uri.toString()); totalCost.addAndGet(modelCost); } + + @Override + public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) { + assertEquals("https://my/url/model.onnx", uri.toString()); + totalCost.addAndGet(modelCost); + } + } } |