diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-02-07 16:04:57 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-02-08 09:52:53 +0100 |
commit | c47f27f1c3362b459e276c59ebcd09ab259b710e (patch) | |
tree | f68963b9b91558d9418da0cb88e5a16d67a52034 /config-model | |
parent | d03bf2ed1b239f4998bdfd6580965bcd0a7d62a4 (diff) |
Allow fallback to CPU if nodes are provisioned without GPU
Diffstat (limited to 'config-model')
5 files changed, 27 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 3d96849fa15..ae6f1fd96e4 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -24,7 +24,7 @@ public class OnnxModel extends DistributableResource { private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; private Integer statelessIntraOpThreads = null; - private Integer gpuDevice = null; + private GpuDevice gpuDevice = null; public OnnxModel(String name) { super(name); @@ -114,9 +114,9 @@ public class OnnxModel extends DistributableResource { } } - public void setGpuDevice(int deviceNumber) { + public void setGpuDevice(int deviceNumber, boolean required) { if (deviceNumber >= 0) { - this.gpuDevice = deviceNumber; + this.gpuDevice = new GpuDevice(deviceNumber, required); } } @@ -124,8 +124,16 @@ public class OnnxModel extends DistributableResource { return Optional.ofNullable(statelessIntraOpThreads); } - public Optional<Integer> getGpuDevice() { + public Optional<GpuDevice> getGpuDevice() { return Optional.ofNullable(gpuDevice); } + public record GpuDevice(int deviceNumber, boolean required) { + + public GpuDevice { + if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); + } + + } + } diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java index f63e872836e..4196af18fb6 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java @@ -52,7 +52,8 @@ public class FileDistributedOnnxModels { if (model.getStatelessIntraOpThreads().isPresent()) modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); if (model.getGpuDevice().isPresent()) { - modelBuilder.gpu_device(model.getGpuDevice().get()); + modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber()); + modelBuilder.gpu_device_required(model.getGpuDevice().get().required()); } builder.model(modelBuilder); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 700393e84f3..81626581722 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -207,9 +207,6 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { addConfiguredComponents(deployState, cluster, spec); addSecretStore(cluster, spec, deployState); - addModelEvaluation(spec, cluster, context); - addModelEvaluationBundles(cluster); - addProcessing(deployState, spec, cluster, context); addSearch(deployState, spec, cluster, context); addDocproc(deployState, spec, cluster); @@ -225,6 +222,9 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { addAccessLogs(deployState, cluster, spec); addNodes(cluster, spec, context); + addModelEvaluation(spec, cluster, context); // NOTE: Must be done after addNodes + addModelEvaluationBundles(cluster); + addServerProviders(deployState, spec, cluster); if (!standaloneBuilder) cluster.addAllPlatformBundles(); @@ -685,7 +685,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null)); onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1)); onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1)); - onnxModel.setGpuDevice(getIntValue(modelElement, "gpu-device", -1)); + Element gpuDeviceElement = XML.getChild(modelElement, "gpu-device"); + if (gpuDeviceElement != null) { + int gpuDevice = Integer.parseInt(gpuDeviceElement.getTextContent()); + Capacity capacity = context.getDeployState().provisioned().all().get(cluster.id()); + boolean gpuProvisioned = capacity != null && !capacity.minResources().nodeResources().gpuResources().isZero(); + onnxModel.setGpuDevice(gpuDevice, gpuProvisioned); + } } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml index 088bbcc4921..b17e34e66c2 100644 --- a/config-model/src/test/cfg/application/onnx/services.xml +++ b/config-model/src/test/cfg/application/onnx/services.xml @@ -8,11 +8,11 @@ <models> <model name="mul"> <intraop-threads>2</intraop-threads> + <gpu-device>0</gpu-device> </model> <model name="non-existent-model"> <interop-threads>400</interop-threads> <execution-mode>parallel</execution-mode> - <gpu-device>0</gpu-device> </model> </models> </onnx> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java index b1e28649e9f..8ccbe99f70a 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java @@ -129,7 +129,8 @@ public class StatelessOnnxEvaluationTest { assertEquals(2, mulModel.stateless_intraop_threads()); assertEquals(-1, mulModel.stateless_interop_threads()); assertEquals("", mulModel.stateless_execution_mode()); - assertEquals(-1, mulModel.gpu_device()); + assertFalse(mulModel.gpu_device_required()); + assertEquals(0, mulModel.gpu_device()); } } |