From c47f27f1c3362b459e276c59ebcd09ab259b710e Mon Sep 17 00:00:00 2001 From: Martin Polden Date: Tue, 7 Feb 2023 16:04:57 +0100 Subject: Allow fallback to CPU if nodes are provisioned without GPU --- .../modelintegration/evaluator/OnnxEvaluator.java | 12 +++++++---- .../evaluator/OnnxEvaluatorOptions.java | 25 ++++++++++++++-------- 2 files changed, 24 insertions(+), 13 deletions(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index ebed464421b..563ef911f8f 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -31,7 +31,7 @@ public class OnnxEvaluator { public OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options) { environment = OrtEnvironment.getEnvironment(); - session = createSession(modelPath, environment, options); + session = createSession(modelPath, environment, options, true); } public Tensor evaluate(Map inputs, String output) { @@ -86,18 +86,22 @@ public class OnnxEvaluator { } } - private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options) { + private static OrtSession createSession(String modelPath, OrtEnvironment environment, OnnxEvaluatorOptions options, boolean tryCuda) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return environment.createSession(modelPath, options.getOptions()); + return environment.createSession(modelPath, options.getOptions(tryCuda && options.requestingGpu())); } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { throw new IllegalArgumentException("No such file: " + modelPath); } + if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { + // Failed in CUDA native code, but GPU device is optional, so we can proceed without it + return createSession(modelPath, environment, options, false); + } if (isCudaError(e)) { - throw new IllegalArgumentException("GPU device " + options.gpuDevice() + " requested, but CUDA initialization failed", e); + throw new IllegalArgumentException("GPU device is requested, but CUDA initialization failed", e); } throw new RuntimeException("ONNX Runtime exception", e); } diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index f838a3b3f7f..b6de9698f1a 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -16,7 +16,8 @@ public class OnnxEvaluatorOptions { private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; - private int gpuDevice; + private int gpuDeviceNumber; + private boolean gpuDeviceRequired; public OnnxEvaluatorOptions() { // Defaults: @@ -24,17 +25,18 @@ public class OnnxEvaluatorOptions { executionMode = OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; interOpThreads = 1; intraOpThreads = Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / 4)); - gpuDevice = -1; + gpuDeviceNumber = -1; + gpuDeviceRequired = false; } - public OrtSession.SessionOptions getOptions() throws OrtException { + public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException { OrtSession.SessionOptions options = new OrtSession.SessionOptions(); options.setOptimizationLevel(optimizationLevel); options.setExecutionMode(executionMode); options.setInterOpNumThreads(interOpThreads); options.setIntraOpNumThreads(intraOpThreads); - if (gpuDevice > -1) { - options.addCUDA(gpuDevice); + if (loadCuda) { + options.addCUDA(gpuDeviceNumber); } return options; } @@ -59,12 +61,17 @@ public class OnnxEvaluatorOptions { } } - public void setGpuDevice(int deviceNumber) { - this.gpuDevice = deviceNumber; + public void setGpuDevice(int deviceNumber, boolean required) { + this.gpuDeviceNumber = deviceNumber; + this.gpuDeviceRequired = required; } - public int gpuDevice() { - return gpuDevice; + public boolean requestingGpu() { + return gpuDeviceNumber > -1; + } + + public boolean gpuDeviceRequired() { + return gpuDeviceRequired; } } -- cgit v1.2.3