summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java16
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java7
-rw-r--r--config-model/src/main/resources/schema/containercluster.rnc3
-rw-r--r--config-model/src/test/cfg/application/onnx/services.xml2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java3
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java25
-rw-r--r--searchcore/src/vespa/searchcore/config/onnx-models.def1
10 files changed, 25 insertions, 49 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index ae6f1fd96e4..3d96849fa15 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -24,7 +24,7 @@ public class OnnxModel extends DistributableResource {
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
private Integer statelessIntraOpThreads = null;
- private GpuDevice gpuDevice = null;
+ private Integer gpuDevice = null;
public OnnxModel(String name) {
super(name);
@@ -114,9 +114,9 @@ public class OnnxModel extends DistributableResource {
}
}
- public void setGpuDevice(int deviceNumber, boolean required) {
+ public void setGpuDevice(int deviceNumber) {
if (deviceNumber >= 0) {
- this.gpuDevice = new GpuDevice(deviceNumber, required);
+ this.gpuDevice = deviceNumber;
}
}
@@ -124,16 +124,8 @@ public class OnnxModel extends DistributableResource {
return Optional.ofNullable(statelessIntraOpThreads);
}
- public Optional<GpuDevice> getGpuDevice() {
+ public Optional<Integer> getGpuDevice() {
return Optional.ofNullable(gpuDevice);
}
- public record GpuDevice(int deviceNumber, boolean required) {
-
- public GpuDevice {
- if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
- }
-
- }
-
}
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
index cff0776fddd..f63e872836e 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
@@ -52,10 +52,8 @@ public class FileDistributedOnnxModels {
if (model.getStatelessIntraOpThreads().isPresent())
modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
if (model.getGpuDevice().isPresent()) {
- modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber());
- modelBuilder.gpu_device_required(model.getGpuDevice().get().required());
+ modelBuilder.gpu_device(model.getGpuDevice().get());
}
-
builder.model(modelBuilder);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index e07b1a95e20..700393e84f3 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -685,12 +685,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
onnxModel.setStatelessExecutionMode(getStringValue(modelElement, "execution-mode", null));
onnxModel.setStatelessInterOpThreads(getIntValue(modelElement, "interop-threads", -1));
onnxModel.setStatelessIntraOpThreads(getIntValue(modelElement, "intraop-threads", -1));
- Element gpuDeviceElement = XML.getChild(modelElement, "gpu-device");
- if (gpuDeviceElement != null) {
- int gpuDevice = Integer.parseInt(gpuDeviceElement.getTextContent());
- boolean required = Boolean.parseBoolean(extractAttribute(gpuDeviceElement, "required"));
- onnxModel.setGpuDevice(gpuDevice, required);
- }
+ onnxModel.setGpuDevice(getIntValue(modelElement, "gpu-device", -1));
}
cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
diff --git a/config-model/src/main/resources/schema/containercluster.rnc b/config-model/src/main/resources/schema/containercluster.rnc
index 888851b2db2..b8c02b013aa 100644
--- a/config-model/src/main/resources/schema/containercluster.rnc
+++ b/config-model/src/main/resources/schema/containercluster.rnc
@@ -105,7 +105,8 @@ ModelEvaluation = element model-evaluation {
element interop-threads { xsd:nonNegativeInteger }? &
element execution-mode { string "sequential" | string "parallel" }? &
element gpu-device {
- attribute required { xsd:boolean } &
+ # TODO(mpolden): Remove this after 2023-02-01
+ attribute required { xsd:boolean }? &
xsd:nonNegativeInteger
}?
}*
diff --git a/config-model/src/test/cfg/application/onnx/services.xml b/config-model/src/test/cfg/application/onnx/services.xml
index 68c2e8530be..088bbcc4921 100644
--- a/config-model/src/test/cfg/application/onnx/services.xml
+++ b/config-model/src/test/cfg/application/onnx/services.xml
@@ -8,11 +8,11 @@
<models>
<model name="mul">
<intraop-threads>2</intraop-threads>
- <gpu-device required="false">0</gpu-device>
</model>
<model name="non-existent-model">
<interop-threads>400</interop-threads>
<execution-mode>parallel</execution-mode>
+ <gpu-device>0</gpu-device>
</model>
</models>
</onnx>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
index 8ccbe99f70a..b1e28649e9f 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/StatelessOnnxEvaluationTest.java
@@ -129,8 +129,7 @@ public class StatelessOnnxEvaluationTest {
assertEquals(2, mulModel.stateless_intraop_threads());
assertEquals(-1, mulModel.stateless_interop_threads());
assertEquals("", mulModel.stateless_execution_mode());
- assertFalse(mulModel.gpu_device_required());
- assertEquals(0, mulModel.gpu_device());
+ assertEquals(-1, mulModel.gpu_device());
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 9877dd69e83..924eed18633 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -182,7 +182,7 @@ public class RankProfilesConfigImporter {
options.setExecutionMode(onnxModelConfig.stateless_execution_mode());
options.setInterOpThreads(onnxModelConfig.stateless_interop_threads());
options.setIntraOpThreads(onnxModelConfig.stateless_intraop_threads());
- options.setGpuDevice(onnxModelConfig.gpu_device(), onnxModelConfig.gpu_device_required());
+ options.setGpuDevice(onnxModelConfig.gpu_device());
return new OnnxModel(name, file, options);
} catch (InterruptedException e) {
throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index bb40333f9b3..ebed464421b 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -31,7 +31,7 @@ public class OnnxEvaluator {
public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
environment = OrtEnvironment.getEnvironment();
- session = createSession(modelPath, environment, options, true);
+ session = createSession(modelPath, environment, options);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
@@ -86,19 +86,18 @@ public class OnnxEvaluator {
}
}
- private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return environment.createSession(modelPath, options.getOptions(tryCuda && options.hasGpuDevice()));
+ return environment.createSession(modelPath, options.getOptions());
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
throw new IllegalArgumentException("No such file: " + modelPath);
}
- if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
- // Failed in CUDA native code, but GPU device is optional, so we can proceed without it
- return createSession(modelPath, environment, options, false);
+ if (isCudaError(e)) {
+ throw new IllegalArgumentException("GPU device " + options.gpuDevice() + " requested, but CUDA initialization failed", e);
}
throw new RuntimeException("ONNX Runtime exception", e);
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 9d82531df02..f838a3b3f7f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -16,8 +16,7 @@ public class OnnxEvaluatorOptions {
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
- private int gpuDeviceNumber;
- private boolean gpuDeviceRequired;
+ private int gpuDevice;
public OnnxEvaluatorOptions() {
// Defaults:
@@ -25,18 +24,17 @@ public class OnnxEvaluatorOptions {
executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
interOpThreads = 1;
intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
- gpuDeviceNumber = -1;
- gpuDeviceRequired = false;
+ gpuDevice = -1;
}
- public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException {
+ public OrtSession.SessionOptions getOptions() throws OrtException {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
options.setInterOpNumThreads(interOpThreads);
options.setIntraOpNumThreads(intraOpThreads);
- if (loadCuda) {
- options.addCUDA(gpuDeviceNumber);
+ if (gpuDevice > -1) {
+ options.addCUDA(gpuDevice);
}
return options;
}
@@ -61,17 +59,12 @@ public class OnnxEvaluatorOptions {
}
}
- public void setGpuDevice(int deviceNumber, boolean required) {
- this.gpuDeviceNumber = deviceNumber;
- this.gpuDeviceRequired = required;
+ public void setGpuDevice(int deviceNumber) {
+ this.gpuDevice = deviceNumber;
}
- public boolean hasGpuDevice() {
- return gpuDeviceNumber > -1;
- }
-
- public boolean gpuDeviceRequired() {
- return gpuDeviceRequired;
+ public int gpuDevice() {
+ return gpuDevice;
}
}
diff --git a/searchcore/src/vespa/searchcore/config/onnx-models.def b/searchcore/src/vespa/searchcore/config/onnx-models.def
index b8f5d319075..85b061fcd7c 100644
--- a/searchcore/src/vespa/searchcore/config/onnx-models.def
+++ b/searchcore/src/vespa/searchcore/config/onnx-models.def
@@ -12,4 +12,3 @@ model[].stateless_execution_mode string default=""
model[].stateless_interop_threads int default=-1
model[].stateless_intraop_threads int default=-1
model[].gpu_device int default=-1
-model[].gpu_device_required bool default=false