aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 13:55:00 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 13:55:00 +0200
commitb4073925d4ce5c08ebc91620219541cb4114ac52 (patch)
treeee96ae200505f7ca0d6a8cc46855577692fecec8 /model-integration
parent85289c1c179d3469bfe7681ad3d04488185e6c7d (diff)
Require GPU when available for ONNX evaluation in global-phase and embedders
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml6
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java26
3 files changed, 42 insertions, 5 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index c27ed9d2c31..d5d7ae534a4 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -40,6 +40,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>searchcore</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>searchlib</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 4a35f4275fa..76a2031171f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
*/
public class OnnxEvaluatorOptions {
- private final OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private OrtSession.SessionOptions.OptLevel optimizationLevel;
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
@@ -94,6 +94,19 @@ public class OnnxEvaluatorOptions {
return gpuDeviceRequired;
}
+ public int gpuDeviceNumber() { return gpuDeviceNumber; }
+
+ public OnnxEvaluatorOptions copy() {
+ var copy = new OnnxEvaluatorOptions();
+ copy.gpuDeviceNumber = gpuDeviceNumber;
+ copy.gpuDeviceRequired = gpuDeviceRequired;
+ copy.executionMode = executionMode;
+ copy.interOpThreads = interOpThreads;
+ copy.intraOpThreads = intraOpThreads;
+ copy.optimizationLevel = optimizationLevel;
+ return copy;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
index ece1db55c1e..ab44a2ae33f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -10,6 +10,7 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
import com.yahoo.jdisc.refcount.References;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import net.jpountz.xxhash.XXHashFactory;
import java.io.IOException;
@@ -52,17 +53,24 @@ public class OnnxRuntime extends AbstractComponent {
private final Object monitor = new Object();
private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
private final OrtSessionFactory factory;
+ private final int gpusAvailable;
- @Inject public OnnxRuntime() { this(defaultFactory); }
+ // For test use only
+ public OnnxRuntime() { this(defaultFactory, new OnnxModelsConfig.Builder().build()); }
- OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
+ @Inject public OnnxRuntime(OnnxModelsConfig cfg) { this(defaultFactory, cfg); }
+
+ OnnxRuntime(OrtSessionFactory factory, OnnxModelsConfig cfg) {
+ this.factory = factory;
+ this.gpusAvailable = cfg.gpu().count();
+ }
public OnnxEvaluator evaluatorOf(byte[] model) {
return new OnnxEvaluator(model, null, this);
}
public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
- return new OnnxEvaluator(model, options, this);
+ return new OnnxEvaluator(model, overrideOptions(options), this);
}
public OnnxEvaluator evaluatorOf(String modelPath) {
@@ -70,7 +78,7 @@ public class OnnxRuntime extends AbstractComponent {
}
public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
- return new OnnxEvaluator(modelPath, options, this);
+ return new OnnxEvaluator(modelPath, overrideOptions(options), this);
}
public static OrtEnvironment ortEnvironment() {
@@ -167,6 +175,16 @@ public class OnnxRuntime extends AbstractComponent {
}
}
+ private OnnxEvaluatorOptions overrideOptions(OnnxEvaluatorOptions opts) {
+ // Set GPU device required if GPU requested and GPUs are available on system
+ if (gpusAvailable > 0 && opts.requestingGpu() && !opts.gpuDeviceRequired()) {
+ var copy = opts.copy();
+ copy.setGpuDevice(opts.gpuDeviceNumber(), true);
+ return copy;
+ }
+ return opts;
+ }
+
int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
static class ReferencedOrtSession implements AutoCloseable {