aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java11
1 files changed, 8 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 239f2c74d7b..627f450502f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -134,13 +134,18 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
- private static ReferencedOrtSession createSession(
- ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static ReferencedOrtSession createSession(ModelPathOrData model, OnnxRuntime runtime,
+ OnnxEvaluatorOptions options, boolean tryCuda) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return runtime.acquireSession(model, options, tryCuda && options.requestingGpu());
+ boolean loadCuda = tryCuda && options.requestingGpu();
+ ReferencedOrtSession session = runtime.acquireSession(model, options, loadCuda);
+ if (loadCuda) {
+ LOG.log(Level.INFO, "Created session with CUDA using GPU device " + options.gpuDeviceNumber());
+ }
+ return session;
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
throw new IllegalArgumentException("No such file: " + model.path().get());