summaryrefslogtreecommitdiffstats
path: root/model-integration/src
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 15:15:19 +0200
committerGitHub <noreply@github.com>2023-05-08 15:15:19 +0200
commit52c04d5a633f4571300f75e3024c5198d484a267 (patch)
tree405bc0110d209df3c7f22de3e5e579f7b0830797 /model-integration/src
parentbb2930c8f815ee74c3e042b551effc03b1111f0c (diff)
parent61231ac123e46c459fb2b996bf0eeedb68529ceb (diff)
Merge pull request #27026 from vespa-engine/bjorncs/embedder-onnx-gpu
Bjorncs/embedder onnx gpu
Diffstat (limited to 'model-integration/src')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java17
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java26
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def2
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def4
6 files changed, 42 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 8e5211ccff1..b172ef7beee 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -58,6 +58,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
options.setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads());
+ if (config.onnxGpuDevice() >= 0) options.setGpuDevice(config.onnxGpuDevice());
tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 21dd326689c..cc13254385b 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -40,7 +40,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
- onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired());
+ onnxOpts.setGpuDevice(config.transformerGpuDevice());
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 4a35f4275fa..6048be8aca9 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
*/
public class OnnxEvaluatorOptions {
- private final OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private OrtSession.SessionOptions.OptLevel optimizationLevel;
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
@@ -86,6 +86,8 @@ public class OnnxEvaluatorOptions {
this.gpuDeviceRequired = required;
}
+ public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; }
+
public boolean requestingGpu() {
return gpuDeviceNumber > -1;
}
@@ -94,6 +96,19 @@ public class OnnxEvaluatorOptions {
return gpuDeviceRequired;
}
+ public int gpuDeviceNumber() { return gpuDeviceNumber; }
+
+ public OnnxEvaluatorOptions copy() {
+ var copy = new OnnxEvaluatorOptions();
+ copy.gpuDeviceNumber = gpuDeviceNumber;
+ copy.gpuDeviceRequired = gpuDeviceRequired;
+ copy.executionMode = executionMode;
+ copy.interOpThreads = interOpThreads;
+ copy.intraOpThreads = intraOpThreads;
+ copy.optimizationLevel = optimizationLevel;
+ return copy;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
index ece1db55c1e..ab44a2ae33f 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -10,6 +10,7 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
import com.yahoo.jdisc.refcount.References;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import net.jpountz.xxhash.XXHashFactory;
import java.io.IOException;
@@ -52,17 +53,24 @@ public class OnnxRuntime extends AbstractComponent {
private final Object monitor = new Object();
private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
private final OrtSessionFactory factory;
+ private final int gpusAvailable;
- @Inject public OnnxRuntime() { this(defaultFactory); }
+ // For test use only
+ public OnnxRuntime() { this(defaultFactory, new OnnxModelsConfig.Builder().build()); }
- OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
+ @Inject public OnnxRuntime(OnnxModelsConfig cfg) { this(defaultFactory, cfg); }
+
+ OnnxRuntime(OrtSessionFactory factory, OnnxModelsConfig cfg) {
+ this.factory = factory;
+ this.gpusAvailable = cfg.gpu().count();
+ }
public OnnxEvaluator evaluatorOf(byte[] model) {
return new OnnxEvaluator(model, null, this);
}
public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
- return new OnnxEvaluator(model, options, this);
+ return new OnnxEvaluator(model, overrideOptions(options), this);
}
public OnnxEvaluator evaluatorOf(String modelPath) {
@@ -70,7 +78,7 @@ public class OnnxRuntime extends AbstractComponent {
}
public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
- return new OnnxEvaluator(modelPath, options, this);
+ return new OnnxEvaluator(modelPath, overrideOptions(options), this);
}
public static OrtEnvironment ortEnvironment() {
@@ -167,6 +175,16 @@ public class OnnxRuntime extends AbstractComponent {
}
}
+ private OnnxEvaluatorOptions overrideOptions(OnnxEvaluatorOptions opts) {
+ // Set GPU device required if GPU requested and GPUs are available on system
+ if (gpusAvailable > 0 && opts.requestingGpu() && !opts.gpuDeviceRequired()) {
+ var copy = opts.copy();
+ copy.setGpuDevice(opts.gpuDeviceNumber(), true);
+ return copy;
+ }
+ return opts;
+ }
+
int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
static class ReferencedOrtSession implements AutoCloseable {
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
index ef42d81e1fe..e37a33d3b81 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
@@ -28,4 +28,4 @@ transformerOutput string default=output_0
onnxExecutionMode enum { parallel, sequential } default=sequential
onnxInterOpThreads int default=1
onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
-
+onnxGpuDevice int default=-1
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
index adc8f653168..584f23046ba 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
@@ -17,9 +17,6 @@ transformerAttentionMask string default=attention_mask
# Output name
transformerOutput string default=last_hidden_state
-# GPU configuration
-transformerGpuDevice int default=-1
-transformerGpuRequired bool default=false
# Normalize tensors from tokenizer
normalize bool default=false
@@ -28,3 +25,4 @@ normalize bool default=false
transformerExecutionMode enum { parallel, sequential } default=sequential
transformerInterOpThreads int default=1
transformerIntraOpThreads int default=-4
+transformerGpuDevice int default=-1