diff options
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 239f2c74d7b..627f450502f 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -134,13 +134,18 @@ public class OnnxEvaluator implements AutoCloseable { } } - private static ReferencedOrtSession createSession( - ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { + private static ReferencedOrtSession createSession(ModelPathOrData model, OnnxRuntime runtime, + OnnxEvaluatorOptions options, boolean tryCuda) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return runtime.acquireSession(model, options, tryCuda && options.requestingGpu()); + boolean loadCuda = tryCuda && options.requestingGpu(); + ReferencedOrtSession session = runtime.acquireSession(model, options, loadCuda); + if (loadCuda) { + LOG.log(Level.INFO, "Created session with CUDA using GPU device " + options.gpuDeviceNumber()); + } + return session; } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { throw new IllegalArgumentException("No such file: " + model.path().get()); |