aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 14:13:14 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-08 14:13:14 +0200
commitf7b3f6c02a27fc0ca98c3b87f96b9b6f1b652e32 (patch)
treec837969ef00b84dcf0d316023a9b1954f025445f /model-integration
parentb4073925d4ce5c08ebc91620219541cb4114ac52 (diff)
Require GPU when requested and available for Bert + HF embedders
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java2
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def2
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def4
5 files changed, 6 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 8e5211ccff1..3ce01c9ae08 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -58,6 +58,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
options.setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads());
+ options.setGpuDevice(config.onnxGpuDevice());
tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 21dd326689c..cc13254385b 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -40,7 +40,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
- onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired());
+ onnxOpts.setGpuDevice(config.transformerGpuDevice());
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 76a2031171f..6048be8aca9 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -86,6 +86,8 @@ public class OnnxEvaluatorOptions {
this.gpuDeviceRequired = required;
}
+ public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; }
+
public boolean requestingGpu() {
return gpuDeviceNumber > -1;
}
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
index ef42d81e1fe..e37a33d3b81 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
@@ -28,4 +28,4 @@ transformerOutput string default=output_0
onnxExecutionMode enum { parallel, sequential } default=sequential
onnxInterOpThreads int default=1
onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
-
+onnxGpuDevice int default=-1
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
index adc8f653168..584f23046ba 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
@@ -17,9 +17,6 @@ transformerAttentionMask string default=attention_mask
# Output name
transformerOutput string default=last_hidden_state
-# GPU configuration
-transformerGpuDevice int default=-1
-transformerGpuRequired bool default=false
# Normalize tensors from tokenizer
normalize bool default=false
@@ -28,3 +25,4 @@ normalize bool default=false
transformerExecutionMode enum { parallel, sequential } default=sequential
transformerInterOpThreads int default=1
transformerIntraOpThreads int default=-4
+transformerGpuDevice int default=-1