diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-08 13:55:00 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-08 13:55:00 +0200 |
commit | b4073925d4ce5c08ebc91620219541cb4114ac52 (patch) | |
tree | ee96ae200505f7ca0d6a8cc46855577692fecec8 /model-integration | |
parent | 85289c1c179d3469bfe7681ad3d04488185e6c7d (diff) |
Require GPU when available for ONNX evaluation in global-phase and embedders
Diffstat (limited to 'model-integration')
3 files changed, 42 insertions, 5 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index c27ed9d2c31..d5d7ae534a4 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -40,6 +40,12 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>searchcore</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>searchlib</artifactId> <version>${project.version}</version> <scope>provided</scope> diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 4a35f4275fa..76a2031171f 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; */ public class OnnxEvaluatorOptions { - private final OrtSession.SessionOptions.OptLevel optimizationLevel; + private OrtSession.SessionOptions.OptLevel optimizationLevel; private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; @@ -94,6 +94,19 @@ public class OnnxEvaluatorOptions { return gpuDeviceRequired; } + public int gpuDeviceNumber() { return gpuDeviceNumber; } + + public OnnxEvaluatorOptions copy() { + var copy = new OnnxEvaluatorOptions(); + copy.gpuDeviceNumber = gpuDeviceNumber; + copy.gpuDeviceRequired = gpuDeviceRequired; + copy.executionMode = executionMode; + copy.interOpThreads = interOpThreads; + copy.intraOpThreads = intraOpThreads; + copy.optimizationLevel = optimizationLevel; + return copy; + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java index ece1db55c1e..ab44a2ae33f 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java @@ -10,6 +10,7 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.jdisc.ResourceReference; import com.yahoo.jdisc.refcount.DebugReferencesWithStack; import com.yahoo.jdisc.refcount.References; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import net.jpountz.xxhash.XXHashFactory; import java.io.IOException; @@ -52,17 +53,24 @@ public class OnnxRuntime extends AbstractComponent { private final Object monitor = new Object(); private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>(); private final OrtSessionFactory factory; + private final int gpusAvailable; - @Inject public OnnxRuntime() { this(defaultFactory); } + // For test use only + public OnnxRuntime() { this(defaultFactory, new OnnxModelsConfig.Builder().build()); } - OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; } + @Inject public OnnxRuntime(OnnxModelsConfig cfg) { this(defaultFactory, cfg); } + + OnnxRuntime(OrtSessionFactory factory, OnnxModelsConfig cfg) { + this.factory = factory; + this.gpusAvailable = cfg.gpu().count(); + } public OnnxEvaluator evaluatorOf(byte[] model) { return new OnnxEvaluator(model, null, this); } public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) { - return new OnnxEvaluator(model, options, this); + return new OnnxEvaluator(model, overrideOptions(options), this); } public OnnxEvaluator evaluatorOf(String modelPath) { @@ -70,7 +78,7 @@ public class OnnxRuntime extends AbstractComponent { } public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { - return new OnnxEvaluator(modelPath, options, this); + return new OnnxEvaluator(modelPath, overrideOptions(options), this); } public static OrtEnvironment ortEnvironment() { @@ -167,6 +175,16 @@ public class OnnxRuntime extends AbstractComponent { } } + private OnnxEvaluatorOptions overrideOptions(OnnxEvaluatorOptions opts) { + // Set GPU device required if GPU requested and GPUs are available on system + if (gpusAvailable > 0 && opts.requestingGpu() && !opts.gpuDeviceRequired()) { + var copy = opts.copy(); + copy.setGpuDevice(opts.gpuDeviceNumber(), true); + return copy; + } + return opts; + } + int sessionsCached() { synchronized(monitor) { return sessions.size(); } } static class ReferencedOrtSession implements AutoCloseable { |