summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-01-23 14:34:35 +0100
committerMartin Polden <mpolden@mpolden.no>2023-01-23 14:34:35 +0100
commitc9ba67dd5b8402125ea84a5c5fd12562ca7ebd15 (patch)
tree415a4d8ba38f8a9b52cd158f1540d19cf337bccf
parent4ffc7ef27da2a8d824c9f04a80f89e569e36b322 (diff)
Support configuration of GPU device to use in ONNX model
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java19
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java6
-rw-r--r--config-model/src/main/resources/schema/containercluster.rnc6
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java25
-rw-r--r--searchcore/src/vespa/searchcore/config/onnx-models.def2
7 files changed, 61 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 26a0b3e595d..ae6f1fd96e4 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -24,6 +24,7 @@ public class OnnxModel extends DistributableResource {
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
+ private GpuDevice gpuDevice = null;
public OnnxModel(String name) {
super(name);
@@ -113,8 +114,26 @@ public class OnnxModel extends DistributableResource {
}
}
+ public void setGpuDevice(int deviceNumber, boolean required) {
+ if (deviceNumber >= 0) {
+ this.gpuDevice = new GpuDevice(deviceNumber, required);
+ }
+ }
+
public Optional<Integer> getStatelessIntraOpThreads() {
return Optional.ofNullable(statelessIntraOpThreads);
}
+ public Optional<GpuDevice> getGpuDevice() {
+ return Optional.ofNullable(gpuDevice);
+ }
+
+ public record GpuDevice(int deviceNumber, boolean required) {
+
+ public GpuDevice {
+ if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
+ }
+
+ }
+
}
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
index b5c3909c78c..cff0776fddd 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
@@ -51,6 +51,10 @@ public class FileDistributedOnnxModels {
modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
if (model.getStatelessIntraOpThreads().isPresent())
modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+ if (model.getGpuDevice().isPresent()) {
+ modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber());
+ modelBuilder.gpu_device_required(model.getGpuDevice().get().required());
+ }
builder.model(modelBuilder);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index d0e1531ca94..53b95e2d455 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -685,6 +685,12 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null));
onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1));
onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1));
+ Element gpuDeviceElement = XML.getChild(modelsElement, "gpu-device");
+ if (gpuDeviceElement != null) {
+ int gpuDevice = Integer.parseInt(gpuDeviceElement.getTextContent());
+ boolean required = Boolean.parseBoolean(extractAttribute(gpuDeviceElement, "required"));
+ onnxModel.setGpuDevice(gpuDevice, required);
+ }
}
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc
index 81455084ad2..888851b2db2 100644
--- a/config-model/src/main/resources/schema/containercluster.rnc
+++ b/config-model/src/main/resources/schema/containercluster.rnc
@@ -103,7 +103,11 @@ ModelEvaluation = element model-evaluation {
attribute name { string } &
element intraop-threads { xsd:nonNegativeInteger }? &
element interop-threads { xsd:nonNegativeInteger }? &
- element execution-mode { string "sequential" | string "parallel" }?
+ element execution-mode { string "sequential" | string "parallel" }? &
+ element gpu-device {
+ attribute required { xsd:boolean } &
+ xsd:nonNegativeInteger
+ }?
}*
}?
}?
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 83674d6789e..9877dd69e83 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -182,7 +182,7 @@ public class RankProfilesConfigImporter {
options.setExecutionMode(onnxModelConfig.stateless_execution_mode());
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
-
+ options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
return new OnnxModel(name, file, options);
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 8467040e5c0..fceb63e6ae6 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -2,7 +2,6 @@
package ai.vespa.modelintegration.evaluator;
-import ai.onnxruntime.OrtEnvironment;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
@@ -17,6 +16,8 @@ public class OnnxEvaluatorOptions {
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
+ private int gpuDeviceNumber;
+ private boolean gpuDeviceRequired;
public OnnxEvaluatorOptions() {
// Defaults:
@@ -24,6 +25,8 @@ public class OnnxEvaluatorOptions {
executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
interOpThreads = 1;
intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
+ gpuDeviceNumber = -1;
+ gpuDeviceRequired = false;
}
public OrtSession.SessionOptions getOptions() throws OrtException {
@@ -32,9 +35,24 @@ public class OnnxEvaluatorOptions {
options.setExecutionMode(executionMode);
options.setInterOpNumThreads(interOpThreads);
options.setIntraOpNumThreads(intraOpThreads);
+ addCuda(options);
return options;
}
+ private void addCuda(OrtSession.SessionOptions options) throws OrtException {
+ if (gpuDeviceNumber < 0) return;
+ try {
+ options.addCUDA(gpuDeviceNumber);
+ } catch (OrtException e) {
+ if (e.getCode() != OrtException.OrtErrorCode.ORT_EP_FAIL) {
+ throw e;
+ }
+ if (gpuDeviceRequired) {
+ throw new IllegalArgumentException("GPU device " + gpuDeviceNumber + " is required, but CUDA backend could not be initialized", e);
+ }
+ }
+ }
+
public void setExecutionMode(String mode) {
if ("parallel".equalsIgnoreCase(mode)) {
executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
@@ -55,4 +73,9 @@ public class OnnxEvaluatorOptions {
}
}
+ public void setGpuDevice(int deviceNumber, boolean required) {
+ this.gpuDeviceNumber = deviceNumber;
+ this.gpuDeviceRequired = required;
+ }
+
}
diff --git a/searchcore/src/vespa/searchcore/config/onnx-models.def b/searchcore/src/vespa/searchcore/config/onnx-models.def
index c117d6f162c..b8f5d319075 100644
--- a/searchcore/src/vespa/searchcore/config/onnx-models.def
+++ b/searchcore/src/vespa/searchcore/config/onnx-models.def
@@ -11,3 +11,5 @@ model[].dry_run_on_setup bool default=false
model[].stateless_execution_mode string default=""
model[].stateless_interop_threads int default=-1
model[].stateless_intraop_threads int default=-1
+model[].gpu_device int default=-1
+model[].gpu_device_required bool default=false