aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 14:13:14 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 14:13:14 +0200
commitf7b3f6c02a27fc0ca98c3b87f96b9b6f1b652e32 (patch)
treec837969ef00b84dcf0d316023a9b1954f025445f /model-integration/src/main/java
parentb4073925d4ce5c08ebc91620219541cb4114ac52 (diff)
Require GPU when requested and available for Bert + HF embedders
Diffstat (limited to 'model-integration/src/main/java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java2
3 files changed, 4 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 8e5211ccff1..3ce01c9ae08 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -58,6 +58,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
options.setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads());
+ options.setGpuDevice(config.onnxGpuDevice());
tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 21dd326689c..cc13254385b 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -40,7 +40,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
- onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired());
+ onnxOpts.setGpuDevice(config.transformerGpuDevice());
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 76a2031171f..6048be8aca9 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -86,6 +86,8 @@ public class OnnxEvaluatorOptions {
this.gpuDeviceRequired = required;
}
+ public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; }
+
public boolean requestingGpu() {
return gpuDeviceNumber > -1;
}