diff options
author | Lester Solbakken <lesters@oath.com> | 2023-08-29 14:11:13 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2023-08-29 14:11:13 +0200 |
commit | 48a3511b28731a4b45089bdd4808167c58824b00 (patch) | |
tree | ee8cb9898fe636cb786ebfe526c43c194fda51b6 /config-model | |
parent | 639deae84372d6f38944cf9a6b663ce8924a3bd7 (diff) |
Add cluster specific settings for model evaluation
Diffstat (limited to 'config-model')
9 files changed, 149 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java index e7bdb68a03d..8594b40a367 100644 --- a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java +++ b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java @@ -8,7 +8,7 @@ import com.yahoo.path.Path; import java.nio.ByteBuffer; import java.util.Objects; -public class DistributableResource implements Comparable <DistributableResource> { +public class DistributableResource implements Comparable <DistributableResource>, Cloneable { public enum PathType { FILE, URI, BLOB } @@ -35,6 +35,11 @@ public class DistributableResource implements Comparable <DistributableResource> this.pathType = type; } + @Override + public DistributableResource clone() throws CloneNotSupportedException { + return (DistributableResource) super.clone(); + } + // TODO: Remove and make path/pathType final public void setFileName(String fileName) { Objects.requireNonNull(fileName, "Filename cannot be null"); diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 90a27d1f036..fbec97c797d 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -18,7 +18,7 @@ import java.util.Set; * * @author lesters */ -public class OnnxModel extends DistributableResource { +public class OnnxModel extends DistributableResource implements Cloneable { private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); @@ -40,6 +40,19 @@ public class OnnxModel extends DistributableResource { } @Override + public OnnxModel clone() { + try { + OnnxModel clone = (OnnxModel) super.clone(); + clone.inputMap.putAll(inputMap); + clone.outputMap.putAll(outputMap); + clone.initializers.addAll(initializers); + return clone; + } catch (CloneNotSupportedException e) { + throw new RuntimeException("Clone not supported", e); + } + } + + @Override public void setUri(String uri) { throw new IllegalArgumentException("URI for ONNX models are not currently supported"); } diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java index e3c697e3262..c3fa6aedf31 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java @@ -35,6 +35,16 @@ public class FileDistributedOnnxModels extends Derived implements OnnxModelsConf this.models = Collections.unmodifiableMap(distributableModels); } + private FileDistributedOnnxModels(Collection<OnnxModel> models) { + Map<String, OnnxModel> distributableModels = models.stream() + .collect(LinkedHashMap::new, (m, v) -> m.put(v.getName(), v.clone()), LinkedHashMap::putAll); + this.models = Collections.unmodifiableMap(distributableModels); + } + + public FileDistributedOnnxModels clone() { + return new FileDistributedOnnxModels(models.values()); + } + public Map<String, OnnxModel> asMap() { return models; } public void getConfig(OnnxModelsConfig.Builder builder) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index c227700733e..906ef739ef1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.model.container; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.osgi.provider.model.ComponentModel; +import com.yahoo.schema.derived.FileDistributedOnnxModels; import com.yahoo.schema.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; @@ -42,9 +43,16 @@ public class ContainerModelEvaluation implements /** Global rank profiles, aka models */ private final RankProfileList rankProfileList; + private final FileDistributedOnnxModels onnxModels; // For cluster specific ONNX model settings public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) { + this(cluster, rankProfileList, null); + } + + public ContainerModelEvaluation(ApplicationContainerCluster cluster, + RankProfileList rankProfileList, FileDistributedOnnxModels onnxModels) { this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); + this.onnxModels = onnxModels; cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME); cluster.addComponent(ContainerModelEvaluation.getHandler()); } @@ -61,7 +69,11 @@ public class ContainerModelEvaluation implements @Override public void getConfig(OnnxModelsConfig.Builder builder) { - rankProfileList.getConfig(builder); + if (onnxModels != null) { + onnxModels.getConfig(builder); + } else { + rankProfileList.getConfig(builder); + } } public void getConfig(RankingExpressionsConfig.Builder builder) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index e7692aeee7b..fb4dce1cc38 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -44,6 +44,7 @@ import com.yahoo.jdisc.http.server.jetty.VoidRequestLog; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.path.Path; import com.yahoo.schema.OnnxModel; +import com.yahoo.schema.derived.FileDistributedOnnxModels; import com.yahoo.schema.derived.RankProfileList; import com.yahoo.search.rendering.RendererRegistry; import com.yahoo.security.X509CertificateUtils; @@ -751,10 +752,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { RankProfileList profiles = context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty; + // Create a copy of models so each cluster can have its own specific settings + FileDistributedOnnxModels models = profiles.getOnnxModels().clone(); + Element onnxElement = XML.getChild(modelEvaluationElement, "onnx"); Element modelsElement = XML.getChild(onnxElement, "models"); for (Element modelElement : XML.getChildren(modelsElement, "model") ) { - OnnxModel onnxModel = profiles.getOnnxModels().asMap().get(modelElement.getAttribute("name")); + OnnxModel onnxModel = models.asMap().get(modelElement.getAttribute("name")); if (onnxModel == null) { String availableModels = String.join(", ", profiles.getOnnxModels().asMap().keySet()); context.getDeployState().getDeployLogger().logApplicationPackage(WARNING, @@ -774,7 +778,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } } - cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); + cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); } private String getStringValue(Element element, String name, String defaultValue) { diff --git a/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx new file mode 100644 index 00000000000..087e2c3427f --- /dev/null +++ b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx @@ -0,0 +1,16 @@ +mul.py:f + +input1 +input2output"MulmulZ +input1 + + +Z +input2 + + +b +output + + +B
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py new file mode 100755 index 00000000000..9fcb8612af9 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py @@ -0,0 +1,26 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1]) +INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1]) +OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) + +nodes = [ + helper.make_node( + 'Mul', + ['input1', 'input2'], + ['output'], + ), +] +graph_def = helper.make_graph( + nodes, + 'mul', + [ + INPUT_1, + INPUT_2 + ], + [OUTPUT], +) +model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'mul.onnx') diff --git a/config-model/src/test/cfg/application/onnx_cluster_specific/services.xml b/config-model/src/test/cfg/application/onnx_cluster_specific/services.xml new file mode 100644 index 00000000000..06b9a8c3a55 --- /dev/null +++ b/config-model/src/test/cfg/application/onnx_cluster_specific/services.xml @@ -0,0 +1,34 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<services version="1.0"> + + <container id="c1" version="1.0"> + <model-evaluation> + <onnx> + <models> + <model name="mul"> + <intraop-threads>2</intraop-threads> + <gpu-device>0</gpu-device> + </model> + </models> + </onnx> + </model-evaluation> + </container> + + <container id="c2" version="1.0"> + <http> + <server id="c1Server" port="8081" /> + </http> + <model-evaluation> + <onnx> + <models> + <model name="mul"> + <intraop-threads>4</intraop-threads> + <gpu-device>1</gpu-device> + </model> + </models> + </onnx> + </model-evaluation> + </container> + +</services> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index fc70a65b394..137907cb003 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -84,6 +84,30 @@ public class ModelEvaluationTest { } } + @Test + void testContainerSpecificModelSettings() { + Path appDir = Path.fromString("src/test/cfg/application/onnx_cluster_specific"); + try { + ImportedModelTester tester = new ImportedModelTester("mul", appDir); + VespaModel model = tester.createVespaModel(); + OnnxModelsConfig.Model c1Model = getOnnxModelsConfig(model.getContainerClusters().get("c1")); + OnnxModelsConfig.Model c2Model = getOnnxModelsConfig(model.getContainerClusters().get("c2")); + assertEquals(2, c1Model.stateless_intraop_threads()); + assertEquals(4, c2Model.stateless_intraop_threads()); + assertEquals(0, c1Model.gpu_device()); + assertEquals(1, c2Model.gpu_device()); + } finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + } + + private OnnxModelsConfig.Model getOnnxModelsConfig(ApplicationContainerCluster cluster) { + OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder(); + cluster.getConfig(ob); + return new OnnxModelsConfig(ob).model(0); + } + private void assertHasMlModels(VespaModel model, Path appDir) { ApplicationContainerCluster cluster = model.getContainerClusters().get("container"); assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); |