summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorMartin Polden <mpolden@mpolden.no>2023-01-26 14:42:15 +0100
committerMartin Polden <mpolden@mpolden.no>2023-01-26 14:42:15 +0100
commit3d4f199ed58cce65680c775df824ae376e518a23 (patch)
treeba5cb7bc03ccbdbc0cdc724f63c031c2683df309 /model-integration
parented5d394998b5538e2a3330409391ab18acadb1df (diff)
Remove 'required' attribute
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java25
2 files changed, 14 insertions, 22 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index bb40333f9b3..ebed464421b 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -31,7 +31,7 @@ public class OnnxEvaluator {
public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) {
environment = OrtEnvironment.getEnvironment();
- session = createSession(modelPath, environment, options, true);
+ session = createSession(modelPath, environment, options);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
@@ -86,19 +86,18 @@ public class OnnxEvaluator {
}
}
- private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return environment.createSession(modelPath, options.getOptions(tryCuda && options.hasGpuDevice()));
+ return environment.createSession(modelPath, options.getOptions());
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
throw new IllegalArgumentException("No such file: " + modelPath);
}
- if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
- // Failed in CUDA native code, but GPU device is optional, so we can proceed without it
- return createSession(modelPath, environment, options, false);
+ if (isCudaError(e)) {
+ throw new IllegalArgumentException("GPU device " + options.gpuDevice() + " requested, but CUDA initialization failed", e);
}
throw new RuntimeException("ONNX Runtime exception", e);
}
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 9d82531df02..f838a3b3f7f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -16,8 +16,7 @@ public class OnnxEvaluatorOptions {
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
- private int gpuDeviceNumber;
- private boolean gpuDeviceRequired;
+ private int gpuDevice;
public OnnxEvaluatorOptions() {
// Defaults:
@@ -25,18 +24,17 @@ public class OnnxEvaluatorOptions {
executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
interOpThreads = 1;
intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4));
- gpuDeviceNumber = -1;
- gpuDeviceRequired = false;
+ gpuDevice = -1;
}
- public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException {
+ public OrtSession.SessionOptions getOptions() throws OrtException {
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
options.setOptimizationLevel(optimizationLevel);
options.setExecutionMode(executionMode);
options.setInterOpNumThreads(interOpThreads);
options.setIntraOpNumThreads(intraOpThreads);
- if (loadCuda) {
- options.addCUDA(gpuDeviceNumber);
+ if (gpuDevice > -1) {
+ options.addCUDA(gpuDevice);
}
return options;
}
@@ -61,17 +59,12 @@ public class OnnxEvaluatorOptions {
}
}
- public void setGpuDevice(int deviceNumber, boolean required) {
- this.gpuDeviceNumber = deviceNumber;
- this.gpuDeviceRequired = required;
+ public void setGpuDevice(int deviceNumber) {
+ this.gpuDevice = deviceNumber;
}
- public boolean hasGpuDevice() {
- return gpuDeviceNumber > -1;
- }
-
- public boolean gpuDeviceRequired() {
- return gpuDeviceRequired;
+ public int gpuDevice() {
+ return gpuDevice;
}
}