aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorn.christian@seime.no>2023-10-26 11:29:01 +0200
committerGitHub <noreply@github.com>2023-10-26 11:29:01 +0200
commitbe43a57b172ec346d1615348977ad45959a4db1c (patch)
tree56d0a1ca585db6528b9d57598b2de876b7a6ed93
parent8abb16bdd81cb19311b58254d959169e9bc657c2 (diff)
parent8aea914273b599112047f262e7598c9b9dc2e05a (diff)
Merge pull request #29107 from vespa-engine/jobergum/cuda-session-logging
Less verbose logging when failing to find CUDA and it is optional
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java51
2 files changed, 53 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 628fe933bf5..cd698eb1647 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -151,8 +151,8 @@ public class OnnxEvaluator implements AutoCloseable {
throw new IllegalArgumentException("No such file: " + model.path().get());
}
if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
- LOG.log(Level.WARNING, "Failed to create session with CUDA using GPU device " +
- options.gpuDeviceNumber() + ". Falling back to CPU", e);
+ LOG.log(Level.INFO, "Failed to create session with CUDA using GPU device " +
+ options.gpuDeviceNumber() + ". Falling back to CPU. Reason: " + e.getMessage());
return createSession(model, runtime, options, false);
}
if (isCudaError(e)) {
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index db2e9db1277..9bb01fc8073 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -9,9 +9,16 @@ import org.junit.Test;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.logging.Handler;
+import java.util.logging.LogRecord;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeTrue;
@@ -170,6 +177,29 @@ public class OnnxEvaluatorTest {
evaluator.close();
}
+ @Test
+ public void testLoggingMessages() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ Logger logger = Logger.getLogger(OnnxEvaluator.class.getName());
+ CustomLogHandler logHandler = new CustomLogHandler();
+ logger.addHandler(logHandler);
+ var runtime = new OnnxRuntime();
+ var model = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx"));
+ OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
+ options.setGpuDevice(0);
+ var evaluator = runtime.evaluatorOf(model,options);
+ evaluator.close();
+ List<LogRecord> records = logHandler.getLogRecords();
+ assertEquals(1,records.size());
+ assertEquals(Level.INFO,records.get(0).getLevel());
+ String message = records.get(0).getMessage();
+ assertEquals("Failed to create session with CUDA using GPU device 0. " +
+ "Falling back to CPU. Reason: Error code - ORT_EP_FAIL - message:" +
+ " Failed to find CUDA shared provider", message);
+ logger.removeHandler(logHandler);
+
+ }
+
private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) {
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
@@ -182,4 +212,25 @@ public class OnnxEvaluatorTest {
assertEquals(expected.type().valueType(), result.type().valueType());
}
+ static class CustomLogHandler extends Handler {
+ private List<LogRecord> records = new ArrayList<>();
+
+ @Override
+ public void publish(LogRecord record) {
+ records.add(record);
+ }
+
+ @Override
+ public void flush() {
+ }
+
+ @Override
+ public void close() throws SecurityException {
+ }
+
+ public List<LogRecord> getLogRecords() {
+ return records;
+ }
+ }
+
}