summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-02-07 16:04:57 +0100
committerMartin Polden <mpolden@mpolden.no>2023-02-08 09:52:53 +0100
commitc47f27f1c3362b459e276c59ebcd09ab259b710e (patch)
treef68963b9b91558d9418da0cb88e5a16d67a52034 /config-model
parentd03bf2ed1b239f4998bdfd6580965bcd0a7d62a4 (diff)
Allow fallback to CPU if nodes are provisioned without GPU
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java16
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java14
-rw-r--r--config-model/src/test/cfg/application/onnx/services.xml2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java3
5 files changed, 27 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 3d96849fa15..ae6f1fd96e4 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -24,7 +24,7 @@ public class OnnxModel extends DistributableResource {
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
- private Integer gpuDevice = null;
+ private GpuDevice gpuDevice = null;
public OnnxModel(String name) {
super(name);
@@ -114,9 +114,9 @@ public class OnnxModel extends DistributableResource {
}
}
- public void setGpuDevice(int deviceNumber) {
+ public void setGpuDevice(int deviceNumber, boolean required) {
if (deviceNumber >= 0) {
- this.gpuDevice = deviceNumber;
+ this.gpuDevice = new GpuDevice(deviceNumber, required);
}
}
@@ -124,8 +124,16 @@ public class OnnxModel extends DistributableResource {
return Optional.ofNullable(statelessIntraOpThreads);
}
- public Optional<Integer> getGpuDevice() {
+ public Optional<GpuDevice> getGpuDevice() {
return Optional.ofNullable(gpuDevice);
}
+ public record GpuDevice(int deviceNumber, boolean required) {
+
+ public GpuDevice {
+ if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
+ }
+
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
index f63e872836e..4196af18fb6 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
@@ -52,7 +52,8 @@ public class FileDistributedOnnxModels {
if (model.getStatelessIntraOpThreads().isPresent())
modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
if (model.getGpuDevice().isPresent()) {
- modelBuilder.gpu_device(model.getGpuDevice().get());
+ modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber());
+ modelBuilder.gpu_device_required(model.getGpuDevice().get().required());
}
builder.model(modelBuilder);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 700393e84f3..81626581722 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -207,9 +207,6 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
addConfiguredComponents(deployState, cluster, spec);
addSecretStore(cluster, spec, deployState);
- addModelEvaluation(spec, cluster, context);
- addModelEvaluationBundles(cluster);
-
addProcessing(deployState, spec, cluster, context);
addSearch(deployState, spec, cluster, context);
addDocproc(deployState, spec, cluster);
@@ -225,6 +222,9 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
addAccessLogs(deployState, cluster, spec);
addNodes(cluster, spec, context);
+ addModelEvaluation(spec, cluster, context); // NOTE: Must be done after addNodes
+ addModelEvaluationBundles(cluster);
+
addServerProviders(deployState, spec, cluster);
if (!standaloneBuilder) cluster.addAllPlatformBundles();
@@ -685,7 +685,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null));
onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1));
onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1));
- onnxModel.setGpuDevice(getIntValue(modelElement, "gpu-device", -1));
+ Element gpuDeviceElement = XML.getChild(modelElement, "gpu-device");
+ if (gpuDeviceElement != null) {
+ int gpuDevice = Integer.parseInt(gpuDeviceElement.getTextContent());
+ Capacity capacity = context.getDeployState().provisioned().all().get(cluster.id());
+ boolean gpuProvisioned = capacity != null && !capacity.minResources().nodeResources().gpuResources().isZero();
+ onnxModel.setGpuDevice(gpuDevice, gpuProvisioned);
+ }
}
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml
index 088bbcc4921..b17e34e66c2 100644
--- a/config-model/src/test/cfg/application/onnx/services.xml
+++ b/config-model/src/test/cfg/application/onnx/services.xml
@@ -8,11 +8,11 @@
<models>
<model name="mul">
<intraop-threads>2</intraop-threads>
+ <gpu-device>0</gpu-device>
</model>
<model name="non-existent-model">
<interop-threads>400</interop-threads>
<execution-mode>parallel</execution-mode>
- <gpu-device>0</gpu-device>
</model>
</models>
</onnx>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
index b1e28649e9f..8ccbe99f70a 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
@@ -129,7 +129,8 @@ public class StatelessOnnxEvaluationTest {
assertEquals(2, mulModel.stateless_intraop_threads());
assertEquals(-1, mulModel.stateless_interop_threads());
assertEquals("", mulModel.stateless_execution_mode());
- assertEquals(-1, mulModel.gpu_device());
+ assertFalse(mulModel.gpu_device_required());
+ assertEquals(0, mulModel.gpu_device());
}
}