From 9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 Mon Sep 17 00:00:00 2001 From: Harald Musum Date: Mon, 20 Nov 2023 15:17:28 +0100 Subject: Register model with onnx model options --- .../src/main/java/com/yahoo/schema/OnnxModel.java | 4 +- .../model/container/component/BertEmbedder.java | 3 +- .../model/container/component/ColBertEmbedder.java | 3 +- .../container/component/HuggingFaceEmbedder.java | 3 +- .../vespa/model/container/component/Model.java | 7 ++-- .../container/component/OnnxModelOptions.java | 45 ---------------------- .../model/container/search/ContainerSearch.java | 4 +- .../model/container/xml/ContainerModelBuilder.java | 2 +- 8 files changed, 16 insertions(+), 55 deletions(-) delete mode 100644 config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java (limited to 'config-model/src/main') diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 867ffdb3960..9456baafd57 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,9 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; -import com.yahoo.vespa.model.container.component.OnnxModelOptions; import com.yahoo.vespa.model.ml.OnnxModelInfo; import java.util.Collections; @@ -171,4 +171,6 @@ public class OnnxModel extends DistributableResource implements Cloneable { return onnxModelOptions.gpuDevice(); } + public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index ea3caadc23a..67fb720b8c0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -47,7 +48,7 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index cbae50b400c..d22e6afc3d1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -55,7 +56,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index d1bd0dce000..d98c72ab3a4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -48,7 +49,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java index c5daf23d6f8..0d350242fd0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.builder.xml.XmlHelper; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.path.Path; @@ -54,10 +55,10 @@ class Model { return new Model(ds, model.getTagName(), modelId, url, path); } - void registerOnnxModelCost(ApplicationContainerCluster c) { + void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) { var resolvedUrl = resolvedUrl().orElse(null); - if (file != null) c.onnxModelCost().registerModel(file); - else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl); + if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions); + else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions); } String name() { return paramName; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java deleted file mode 100644 index 6347f0dc427..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.component; - -import java.util.Optional; - -/** - * Onnx model options that are relevant when deciding if an Onnx model needs to be reloaded. If any of the - * values in this class change, reload is needed. - * - * @author hmusum - */ -public record OnnxModelOptions(Optional executionMode, Optional interOpThreads, - Optional intraOpThreads, Optional gpuDevice) { - - public static OnnxModelOptions empty() { - return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); - } - - public OnnxModelOptions withExecutionMode(String executionMode) { - return new OnnxModelOptions(Optional.ofNullable(executionMode), interOpThreads, intraOpThreads, gpuDevice); - } - - public OnnxModelOptions withInterOpThreads(Integer interOpThreads) { - return new OnnxModelOptions(executionMode, Optional.ofNullable(interOpThreads), intraOpThreads, gpuDevice); - } - - public OnnxModelOptions withIntraOpThreads(Integer intraOpThreads) { - return new OnnxModelOptions(executionMode, interOpThreads, Optional.ofNullable(intraOpThreads), gpuDevice); - } - - public OnnxModelOptions withGpuDevice(GpuDevice gpuDevice) { - return new OnnxModelOptions(executionMode, interOpThreads, intraOpThreads, Optional.ofNullable(gpuDevice)); - } - - public record GpuDevice(int deviceNumber, boolean required) { - public GpuDevice { - if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); - } - - public GpuDevice(int deviceNumber) { - this(deviceNumber, false); - } - } - -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index d86d117f1d2..31468c05b99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -52,11 +52,11 @@ public class ContainerSearch extends ContainerSubsystem private final List searchClusters = new LinkedList<>(); private final Collection schemasWithGlobalPhase; private final boolean globalPhase; + private final ApplicationPackage app; private QueryProfiles queryProfiles; private SemanticRules semanticRules; private PageTemplates pageTemplates; - private ApplicationPackage app; public ContainerSearch(DeployState deployState, ApplicationContainerCluster cluster, SearchChains chains) { super(chains); @@ -102,7 +102,7 @@ public class ContainerSearch extends ContainerSubsystem if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) { var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels(); onnxModels.asMap().forEach( - (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()))); + (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions())); owningCluster.addComponent(factory); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 18020f5df5d..5ffd34c6557 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -800,7 +800,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder { !container.getHostResource().realResources().gpuResources().isZero()); onnxModel.setGpuDevice(gpuDevice, hasGpu); } - cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath())); + cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); -- cgit v1.2.3