aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-09-01 09:16:44 +0200
committerGitHub <noreply@github.com>2023-09-01 09:16:44 +0200
commit6ac1607dddfff252b99816a8cafc078aedcc4d03 (patch)
tree75a9e5bfb6b88cf4acfeb9042c88721a076a72ff /config-model
parentaf0bf5ff10dfa2093df5ed5f5b5fd0c965d0c8bf (diff)
parentcefce7443669e8dcd03c1d4905e875bef2fe83cb (diff)
Merge pull request #28241 from vespa-engine/lesters/add-cluster-specific-settings-for-model-evaluation
Add cluster specific settings for model evaluation
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/DistributableResource.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java23
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java14
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java8
-rw-r--r--config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx16
-rwxr-xr-xconfig-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py26
-rw-r--r--config-model/src/test/cfg/application/onnx_cluster_specific/services.xml34
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java24
9 files changed, 151 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java
index e7bdb68a03d..8594b40a367 100644
--- a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java
+++ b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java
@@ -8,7 +8,7 @@ import com.yahoo.path.Path;
import java.nio.ByteBuffer;
import java.util.Objects;
-public class DistributableResource implements Comparable <DistributableResource> {
+public class DistributableResource implements Comparable <DistributableResource>, Cloneable {
public enum PathType { FILE, URI, BLOB }
@@ -35,6 +35,11 @@ public class DistributableResource implements Comparable <DistributableResource>
this.pathType = type;
}
+ @Override
+ public DistributableResource clone() throws CloneNotSupportedException {
+ return (DistributableResource) super.clone();
+ }
+
// TODO: Remove and make path/pathType final
public void setFileName(String fileName) {
Objects.requireNonNull(fileName, "Filename cannot be null");
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 90a27d1f036..3295b2e93aa 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -18,13 +18,15 @@ import java.util.Set;
*
* @author lesters
*/
-public class OnnxModel extends DistributableResource {
+public class OnnxModel extends DistributableResource implements Cloneable {
+ // Model information
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
private final Map<String, String> outputMap = new HashMap<>();
private final Set<String> initializers = new HashSet<>();
+ // Runtime options
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
@@ -40,6 +42,15 @@ public class OnnxModel extends DistributableResource {
}
@Override
+ public OnnxModel clone() {
+ try {
+ return (OnnxModel) super.clone(); // Shallow clone is sufficient here
+ } catch (CloneNotSupportedException e) {
+ throw new RuntimeException("Clone not supported", e);
+ }
+ }
+
+ @Override
public void setUri(String uri) {
throw new IllegalArgumentException("URI for ONNX models are not currently supported");
}
@@ -148,26 +159,24 @@ public class OnnxModel extends DistributableResource {
}
}
+ public Optional<Integer> getStatelessIntraOpThreads() {
+ return Optional.ofNullable(statelessIntraOpThreads);
+ }
+
public void setGpuDevice(int deviceNumber, boolean required) {
if (deviceNumber >= 0) {
this.gpuDevice = new GpuDevice(deviceNumber, required);
}
}
- public Optional<Integer> getStatelessIntraOpThreads() {
- return Optional.ofNullable(statelessIntraOpThreads);
- }
-
public Optional<GpuDevice> getGpuDevice() {
return Optional.ofNullable(gpuDevice);
}
public record GpuDevice(int deviceNumber, boolean required) {
-
public GpuDevice {
if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
}
-
}
}
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
index e3c697e3262..c3fa6aedf31 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
@@ -35,6 +35,16 @@ public class FileDistributedOnnxModels extends Derived implements OnnxModelsConf
this.models = Collections.unmodifiableMap(distributableModels);
}
+ private FileDistributedOnnxModels(Collection<OnnxModel> models) {
+ Map<String, OnnxModel> distributableModels = models.stream()
+ .collect(LinkedHashMap::new, (m, v) -> m.put(v.getName(), v.clone()), LinkedHashMap::putAll);
+ this.models = Collections.unmodifiableMap(distributableModels);
+ }
+
+ public FileDistributedOnnxModels clone() {
+ return new FileDistributedOnnxModels(models.values());
+ }
+
public Map<String, OnnxModel> asMap() { return models; }
public void getConfig(OnnxModelsConfig.Builder builder) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index c227700733e..906ef739ef1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.model.container;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.osgi.provider.model.ComponentModel;
+import com.yahoo.schema.derived.FileDistributedOnnxModels;
import com.yahoo.schema.derived.RankProfileList;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
@@ -42,9 +43,16 @@ public class ContainerModelEvaluation implements
/** Global rank profiles, aka models */
private final RankProfileList rankProfileList;
+ private final FileDistributedOnnxModels onnxModels; // For cluster specific ONNX model settings
public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) {
+ this(cluster, rankProfileList, null);
+ }
+
+ public ContainerModelEvaluation(ApplicationContainerCluster cluster,
+ RankProfileList rankProfileList, FileDistributedOnnxModels onnxModels) {
this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null");
+ this.onnxModels = onnxModels;
cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME);
cluster.addComponent(ContainerModelEvaluation.getHandler());
}
@@ -61,7 +69,11 @@ public class ContainerModelEvaluation implements
@Override
public void getConfig(OnnxModelsConfig.Builder builder) {
- rankProfileList.getConfig(builder);
+ if (onnxModels != null) {
+ onnxModels.getConfig(builder);
+ } else {
+ rankProfileList.getConfig(builder);
+ }
}
public void getConfig(RankingExpressionsConfig.Builder builder) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 0e72cff1688..b603f9f0ba1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -44,6 +44,7 @@ import com.yahoo.jdisc.http.server.jetty.VoidRequestLog;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.path.Path;
import com.yahoo.schema.OnnxModel;
+import com.yahoo.schema.derived.FileDistributedOnnxModels;
import com.yahoo.schema.derived.RankProfileList;
import com.yahoo.search.rendering.RendererRegistry;
import com.yahoo.security.X509CertificateUtils;
@@ -751,10 +752,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
RankProfileList profiles =
context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty;
+ // Create a copy of models so each cluster can have its own specific settings
+ FileDistributedOnnxModels models = profiles.getOnnxModels().clone();
+
Element onnxElement = XML.getChild(modelEvaluationElement, "onnx");
Element modelsElement = XML.getChild(onnxElement, "models");
for (Element modelElement : XML.getChildren(modelsElement, "model") ) {
- OnnxModel onnxModel = profiles.getOnnxModels().asMap().get(modelElement.getAttribute("name"));
+ OnnxModel onnxModel = models.asMap().get(modelElement.getAttribute("name"));
if (onnxModel == null) {
String availableModels = String.join(", ", profiles.getOnnxModels().asMap().keySet());
context.getDeployState().getDeployLogger().logApplicationPackage(WARNING,
@@ -774,7 +778,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
}
}
- cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
+ cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models));
}
private String getStringValue(Element element, String name, String defaultValue) {
diff --git a/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx
new file mode 100644
index 00000000000..087e2c3427f
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.onnx
@@ -0,0 +1,16 @@
+mul.py:f
+
+input1
+input2output"MulmulZ
+input1
+
+
+Z
+input2
+
+
+b
+output
+
+
+B \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py
new file mode 100755
index 00000000000..9fcb8612af9
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_cluster_specific/models/mul.py
@@ -0,0 +1,26 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Mul',
+ ['input1', 'input2'],
+ ['output'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'mul',
+ [
+ INPUT_1,
+ INPUT_2
+ ],
+ [OUTPUT],
+)
+model_def = helper.make_model(graph_def, producer_name='mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'mul.onnx')
diff --git a/config-model/src/test/cfg/application/onnx_cluster_specific/services.xml b/config-model/src/test/cfg/application/onnx_cluster_specific/services.xml
new file mode 100644
index 00000000000..06b9a8c3a55
--- /dev/null
+++ b/config-model/src/test/cfg/application/onnx_cluster_specific/services.xml
@@ -0,0 +1,34 @@
+<?xml version="1.0" encoding="utf-8" ?>
+<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
+<services version="1.0">
+
+ <container id="c1" version="1.0">
+ <model-evaluation>
+ <onnx>
+ <models>
+ <model name="mul">
+ <intraop-threads>2</intraop-threads>
+ <gpu-device>0</gpu-device>
+ </model>
+ </models>
+ </onnx>
+ </model-evaluation>
+ </container>
+
+ <container id="c2" version="1.0">
+ <http>
+ <server id="c1Server" port="8081" />
+ </http>
+ <model-evaluation>
+ <onnx>
+ <models>
+ <model name="mul">
+ <intraop-threads>4</intraop-threads>
+ <gpu-device>1</gpu-device>
+ </model>
+ </models>
+ </onnx>
+ </model-evaluation>
+ </container>
+
+</services>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index fc70a65b394..137907cb003 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -84,6 +84,30 @@ public class ModelEvaluationTest {
}
}
+ @Test
+ void testContainerSpecificModelSettings() {
+ Path appDir = Path.fromString("src/test/cfg/application/onnx_cluster_specific");
+ try {
+ ImportedModelTester tester = new ImportedModelTester("mul", appDir);
+ VespaModel model = tester.createVespaModel();
+ OnnxModelsConfig.Model c1Model = getOnnxModelsConfig(model.getContainerClusters().get("c1"));
+ OnnxModelsConfig.Model c2Model = getOnnxModelsConfig(model.getContainerClusters().get("c2"));
+ assertEquals(2, c1Model.stateless_intraop_threads());
+ assertEquals(4, c2Model.stateless_intraop_threads());
+ assertEquals(0, c1Model.gpu_device());
+ assertEquals(1, c2Model.gpu_device());
+ } finally {
+ IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+ }
+
+ private OnnxModelsConfig.Model getOnnxModelsConfig(ApplicationContainerCluster cluster) {
+ OnnxModelsConfig.Builder ob = new OnnxModelsConfig.Builder();
+ cluster.getConfig(ob);
+ return new OnnxModelsConfig(ob).model(0);
+ }
+
private void assertHasMlModels(VespaModel model, Path appDir) {
ApplicationContainerCluster cluster = model.getContainerClusters().get("container");
assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));