diff options
author | Harald Musum <musum@yahooinc.com> | 2023-11-21 09:52:52 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-21 09:52:52 +0100 |
commit | 430c0f8c9e1ea5eaeae2b795cd4b7350091679ae (patch) | |
tree | 4e49701083ff1600df5775f481eb3057edbf88bf | |
parent | d998b2774ce916ce5a92f4879f3f47a23f1346a9 (diff) | |
parent | 9d28a47b003f5498bc59bfd10017dd55fc7ab6e0 (diff) |
Merge pull request #29388 from vespa-engine/hmusum/register-with-onnx-model-options
Register model with onnx model options
11 files changed, 87 insertions, 16 deletions
diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index d9c68c89189..78b32d8af7b 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1453,7 +1453,9 @@ "methods" : [ "public abstract long aggregatedModelCostInBytes()", "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile)", - "public abstract void registerModel(java.net.URI)" + "public abstract void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", + "public abstract void registerModel(java.net.URI)", + "public abstract void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" ], "fields" : [ ] }, @@ -1471,7 +1473,9 @@ "public com.yahoo.config.model.api.OnnxModelCost$Calculator newCalculator(com.yahoo.config.application.api.ApplicationPackage, com.yahoo.config.provision.ApplicationId)", "public long aggregatedModelCostInBytes()", "public void registerModel(com.yahoo.config.application.api.ApplicationFile)", - "public void registerModel(java.net.URI)" + "public void registerModel(com.yahoo.config.application.api.ApplicationFile, com.yahoo.config.model.api.OnnxModelOptions)", + "public void registerModel(java.net.URI)", + "public void registerModel(java.net.URI, com.yahoo.config.model.api.OnnxModelOptions)" ], "fields" : [ ] }, @@ -1489,6 +1493,51 @@ ], "fields" : [ ] }, + "com.yahoo.config.model.api.OnnxModelOptions$GpuDevice" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void <init>(int, boolean)", + "public void <init>(int)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public int deviceNumber()", + "public boolean required()" + ], + "fields" : [ ] + }, + "com.yahoo.config.model.api.OnnxModelOptions" : { + "superClass" : "java.lang.Record", + "interfaces" : [ ], + "attributes" : [ + "public", + "final", + "record" + ], + "methods" : [ + "public void <init>(java.lang.String, int, int, com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)", + "public void <init>(java.util.Optional, java.util.Optional, java.util.Optional, java.util.Optional)", + "public static com.yahoo.config.model.api.OnnxModelOptions empty()", + "public com.yahoo.config.model.api.OnnxModelOptions withExecutionMode(java.lang.String)", + "public com.yahoo.config.model.api.OnnxModelOptions withInterOpThreads(java.lang.Integer)", + "public com.yahoo.config.model.api.OnnxModelOptions withIntraOpThreads(java.lang.Integer)", + "public com.yahoo.config.model.api.OnnxModelOptions withGpuDevice(com.yahoo.config.model.api.OnnxModelOptions$GpuDevice)", + "public final java.lang.String toString()", + "public final int hashCode()", + "public final boolean equals(java.lang.Object)", + "public java.util.Optional executionMode()", + "public java.util.Optional interOpThreads()", + "public java.util.Optional intraOpThreads()", + "public java.util.Optional gpuDevice()" + ], + "fields" : [ ] + }, "com.yahoo.config.model.api.PortInfo" : { "superClass" : "java.lang.Object", "interfaces" : [ ], diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index acb88070482..b98667457e4 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -10,6 +10,7 @@ import java.net.URI; /** * @author bjorncs */ +// TODO: Rename public interface OnnxModelCost { Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId); @@ -17,7 +18,9 @@ public interface OnnxModelCost { interface Calculator { long aggregatedModelCostInBytes(); void registerModel(ApplicationFile path); + void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions); void registerModel(URI uri); + void registerModel(URI uri, OnnxModelOptions onnxModelOptions); } static OnnxModelCost disabled() { return new DisabledOnnxModelCost(); } @@ -26,7 +29,9 @@ public interface OnnxModelCost { @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() {return 0;} @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) {} + @Override public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) {} } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java index 6347f0dc427..92817baae3f 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/OnnxModelOptions.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelOptions.java @@ -1,5 +1,5 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.model.container.component; +package com.yahoo.config.model.api; import java.util.Optional; @@ -12,7 +12,11 @@ import java.util.Optional; public record OnnxModelOptions(Optional<String> executionMode, Optional<Integer> interOpThreads, Optional<Integer> intraOpThreads, Optional<GpuDevice> gpuDevice) { - public static OnnxModelOptions empty() { + public OnnxModelOptions(String executionMode, int interOpThreads, int intraOpThreads, GpuDevice gpuDevice) { + this(Optional.of(executionMode), Optional.of(interOpThreads), Optional.of(intraOpThreads), Optional.of(gpuDevice)); + } + + public static OnnxModelOptions empty() { return new OnnxModelOptions(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); } diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 867ffdb3960..9456baafd57 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,9 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; -import com.yahoo.vespa.model.container.component.OnnxModelOptions; import com.yahoo.vespa.model.ml.OnnxModelInfo; import java.util.Collections; @@ -171,4 +171,6 @@ public class OnnxModel extends DistributableResource implements Cloneable { return onnxModelOptions.gpuDevice(); } + public OnnxModelOptions onnxModelOptions() { return onnxModelOptions; } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index ea3caadc23a..67fb720b8c0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -47,7 +48,7 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index cbae50b400c..d22e6afc3d1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -55,7 +56,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index d1bd0dce000..d98c72ab3a4 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; @@ -48,7 +49,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm transformerOutput = getChildValue(xml, "transformer-output").orElse(null); normalize = getChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null); poolingStrategy = getChildValue(xml, "pooling-strategy").orElse(null); - model.registerOnnxModelCost(cluster); + model.registerOnnxModelCost(cluster, onnxModelOptions); } private static ModelReference resolveDefaultVocab(Model model, DeployState state) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java index c5daf23d6f8..0d350242fd0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.builder.xml.XmlHelper; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.path.Path; @@ -54,10 +55,10 @@ class Model { return new Model(ds, model.getTagName(), modelId, url, path); } - void registerOnnxModelCost(ApplicationContainerCluster c) { + void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) { var resolvedUrl = resolvedUrl().orElse(null); - if (file != null) c.onnxModelCost().registerModel(file); - else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl); + if (file != null) c.onnxModelCost().registerModel(file, onnxModelOptions); + else if (resolvedUrl != null) c.onnxModelCost().registerModel(resolvedUrl, onnxModelOptions); } String name() { return paramName; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java index d86d117f1d2..31468c05b99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/ContainerSearch.java @@ -52,11 +52,11 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> private final List<SearchCluster> searchClusters = new LinkedList<>(); private final Collection<String> schemasWithGlobalPhase; private final boolean globalPhase; + private final ApplicationPackage app; private QueryProfiles queryProfiles; private SemanticRules semanticRules; private PageTemplates pageTemplates; - private ApplicationPackage app; public ContainerSearch(DeployState deployState, ApplicationContainerCluster cluster, SearchChains chains) { super(chains); @@ -102,7 +102,7 @@ public class ContainerSearch extends ContainerSubsystem<SearchChains> if ( ! owningCluster.getComponentsMap().containsKey(factory.getComponentId())) { var onnxModels = documentDb.getDerivedConfiguration().getRankProfileList().getOnnxModels(); onnxModels.asMap().forEach( - (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()))); + (__, model) -> owningCluster.onnxModelCost().registerModel(app.getFile(model.getFilePath()), model.onnxModelOptions())); owningCluster.addComponent(factory); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 18020f5df5d..5ffd34c6557 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -800,7 +800,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { !container.getHostResource().realResources().gpuResources().isZero()); onnxModel.setGpuDevice(gpuDevice, hasGpu); } - cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath())); + cluster.onnxModelCost().registerModel(context.getApplicationPackage().getFile(onnxModel.getFilePath()), onnxModel.onnxModelOptions()); } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 8531aff3b1a..9cadf5cffd8 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -1,14 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.model.application.validation; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.config.model.api.OnnxModelOptions; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.provision.InMemoryProvisioner; @@ -123,12 +122,20 @@ class JvmHeapSizeValidatorTest { @Override public Calculator newCalculator(ApplicationPackage appPkg, ApplicationId applicationId) { return this; } @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} + @Override public void registerModel(ApplicationFile path, OnnxModelOptions onnxModelOptions) {} @Override public void registerModel(URI uri) { assertEquals("https://my/url/model.onnx", uri.toString()); totalCost.addAndGet(modelCost); } + + @Override + public void registerModel(URI uri, OnnxModelOptions onnxModelOptions) { + assertEquals("https://my/url/model.onnx", uri.toString()); + totalCost.addAndGet(modelCost); + } + } } |