diff options
author | Martin Polden <mpolden@mpolden.no> | 2023-01-26 14:42:15 +0100 |
---|---|---|
committer | Martin Polden <mpolden@mpolden.no> | 2023-01-26 14:42:15 +0100 |
commit | 3d4f199ed58cce65680c775df824ae376e518a23 (patch) | |
tree | ba5cb7bc03ccbdbc0cdc724f63c031c2683df309 /model-integration/src/main/java | |
parent | ed5d394998b5538e2a3330409391ab18acadb1df (diff) |
Remove 'required' attribute
Diffstat (limited to 'model-integration/src/main/java')
2 files changed, 14 insertions, 22 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index bb40333f9b3..ebed464421b 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -31,7 +31,7 @@ public class OnnxEvaluator { public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { environment = OrtEnvironment.getEnvironment(); - session = createSession(modelPath, environment, options, true); + session = createSession(modelPath, environment, options); } public Tensor evaluate(Map<String, Tensor> inputs, String output) { @@ -86,19 +86,18 @@ public class OnnxEvaluator { } } - private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) { + private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return environment.createSession(modelPath, options.getOptions(tryCuda && options.hasGpuDevice())); + return environment.createSession(modelPath, options.getOptions()); } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { throw new IllegalArgumentException("No such file: " + modelPath); } - if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { - // Failed in CUDA native code, but GPU device is optional, so we can proceed without it - return createSession(modelPath, environment, options, false); + if (isCudaError(e)) { + throw new IllegalArgumentException("GPU device " + options.gpuDevice() + " requested, but CUDA initialization failed", e); } throw new RuntimeException("ONNX Runtime exception", e); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 9d82531df02..f838a3b3f7f 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -16,8 +16,7 @@ public class OnnxEvaluatorOptions { private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; - private int gpuDeviceNumber; - private boolean gpuDeviceRequired; + private int gpuDevice; public OnnxEvaluatorOptions() { // Defaults: @@ -25,18 +24,17 @@ public class OnnxEvaluatorOptions { executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; interOpThreads = 1; intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); - gpuDeviceNumber = -1; - gpuDeviceRequired = false; + gpuDevice = -1; } - public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException { + public OrtSession.SessionOptions getOptions() throws OrtException { OrtSession.SessionOptions options = new OrtSession.SessionOptions(); options.setOptimizationLevel(optimizationLevel); options.setExecutionMode(executionMode); options.setInterOpNumThreads(interOpThreads); options.setIntraOpNumThreads(intraOpThreads); - if (loadCuda) { - options.addCUDA(gpuDeviceNumber); + if (gpuDevice > -1) { + options.addCUDA(gpuDevice); } return options; } @@ -61,17 +59,12 @@ public class OnnxEvaluatorOptions { } } - public void setGpuDevice(int deviceNumber, boolean required) { - this.gpuDeviceNumber = deviceNumber; - this.gpuDeviceRequired = required; + public void setGpuDevice(int deviceNumber) { + this.gpuDevice = deviceNumber; } - public boolean hasGpuDevice() { - return gpuDeviceNumber > -1; - } - - public boolean gpuDeviceRequired() { - return gpuDeviceRequired; + public int gpuDevice() { + return gpuDevice; } } |