diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-01-23 14:34:35 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-01-23 14:34:35 +0100 |
commit | c9ba67dd5b8402125ea84a5c5fd12562ca7ebd15 (patch) | |
tree | 415a4d8ba38f8a9b52cd158f1540d19cf337bccf | |
parent | 4ffc7ef27da2a8d824c9f04a80f89e569e36b322 (diff) |
Support configuration of GPU device to use in ONNX model
7 files changed, 61 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 26a0b3e595d..ae6f1fd96e4 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -24,6 +24,7 @@ public class OnnxModel extends DistributableResource { private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; private Integer statelessIntraOpThreads = null; + private GpuDevice gpuDevice = null; public OnnxModel(String name) { super(name); @@ -113,8 +114,26 @@ public class OnnxModel extends DistributableResource { } } + public void setGpuDevice(int deviceNumber, boolean required) { + if (deviceNumber >= 0) { + this.gpuDevice = new GpuDevice(deviceNumber, required); + } + } + public Optional<Integer> getStatelessIntraOpThreads() { return Optional.ofNullable(statelessIntraOpThreads); } + public Optional<GpuDevice> getGpuDevice() { + return Optional.ofNullable(gpuDevice); + } + + public record GpuDevice(int deviceNumber, boolean required) { + + public GpuDevice { + if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber); + } + + } + } diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java index b5c3909c78c..cff0776fddd 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java @@ -51,6 +51,10 @@ public class FileDistributedOnnxModels { modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); if (model.getStatelessIntraOpThreads().isPresent()) modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + if (model.getGpuDevice().isPresent()) { + modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber()); + modelBuilder.gpu_device_required(model.getGpuDevice().get().required()); + } builder.model(modelBuilder); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index d0e1531ca94..53b95e2d455 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -685,6 +685,12 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null)); onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1)); onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1)); + Element gpuDeviceElement = XML.getChild(modelsElement, "gpu-device"); + if (gpuDeviceElement != null) { + int gpuDevice = Integer.parseInt(gpuDeviceElement.getTextContent()); + boolean required = Boolean.parseBoolean(extractAttribute(gpuDeviceElement, "required")); + onnxModel.setGpuDevice(gpuDevice, required); + } } cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc index 81455084ad2..888851b2db2 100644 --- a/config-model/src/main/resources/schema/containercluster.rnc +++ b/config-model/src/main/resources/schema/containercluster.rnc @@ -103,7 +103,11 @@ ModelEvaluation = element model-evaluation { attribute name { string } & element intraop-threads { xsd:nonNegativeInteger }? & element interop-threads { xsd:nonNegativeInteger }? & - element execution-mode { string "sequential" | string "parallel" }? + element execution-mode { string "sequential" | string "parallel" }? & + element gpu-device { + attribute required { xsd:boolean } & + xsd:nonNegativeInteger + }? }* }? }? diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 83674d6789e..9877dd69e83 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -182,7 +182,7 @@ public class RankProfilesConfigImporter { options.setExecutionMode(onnxModelConfig.stateless_execution_mode()); options.setInterOpThreads(onnxModelConfig.stateless_interop_threads()); options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads()); - + options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required()); return new OnnxModel(name, file, options); } catch (InterruptedException e) { throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name()); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 8467040e5c0..fceb63e6ae6 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -2,7 +2,6 @@ package ai.vespa.modelintegration.evaluator; -import ai.onnxruntime.OrtEnvironment; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; @@ -17,6 +16,8 @@ public class OnnxEvaluatorOptions { private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; + private int gpuDeviceNumber; + private boolean gpuDeviceRequired; public OnnxEvaluatorOptions() { // Defaults: @@ -24,6 +25,8 @@ public class OnnxEvaluatorOptions { executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; interOpThreads = 1; intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); + gpuDeviceNumber = -1; + gpuDeviceRequired = false; } public OrtSession.SessionOptions getOptions() throws OrtException { @@ -32,9 +35,24 @@ public class OnnxEvaluatorOptions { options.setExecutionMode(executionMode); options.setInterOpNumThreads(interOpThreads); options.setIntraOpNumThreads(intraOpThreads); + addCuda(options); return options; } + private void addCuda(OrtSession.SessionOptions options) throws OrtException { + if (gpuDeviceNumber < 0) return; + try { + options.addCUDA(gpuDeviceNumber); + } catch (OrtException e) { + if (e.getCode() != OrtException.OrtErrorCode.ORT_EP_FAIL) { + throw e; + } + if (gpuDeviceRequired) { + throw new IllegalArgumentException("GPU device " + gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e); + } + } + } + public void setExecutionMode(String mode) { if ("parallel".equalsIgnoreCase(mode)) { executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL; @@ -55,4 +73,9 @@ public class OnnxEvaluatorOptions { } } + public void setGpuDevice(int deviceNumber, boolean required) { + this.gpuDeviceNumber = deviceNumber; + this.gpuDeviceRequired = required; + } + } diff --git a/searchcore/src/vespa/searchcore/config/onnx-models.def b/searchcore/src/vespa/searchcore/config/onnx-models.def index c117d6f162c..b8f5d319075 100644 --- a/searchcore/src/vespa/searchcore/config/onnx-models.def +++ b/searchcore/src/vespa/searchcore/config/onnx-models.def @@ -11,3 +11,5 @@ model[].dry_run_on_setup bool default=false model[].stateless_execution_mode string default="" model[].stateless_interop_threads int default=-1 model[].stateless_intraop_threads int default=-1 +model[].gpu_device int default=-1 +model[].gpu_device_required bool default=false |