summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-05 16:41:20 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-05 16:47:06 +0200
commitea351a9d4d393cbf9a2018197557f42ce3c490c1 (patch)
treecede33353b0870479e60c3b80cc1950fb13e8482 /model-integration
parent530b1aeedbd3cc492e19a6797477727be733af68 (diff)
Allow for manual configuration of GPU
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java6
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def3
2 files changed, 8 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 3e51e1f0919..44593fa2e57 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,6 +1,7 @@
package ai.vespa.embedding.huggingface;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
@@ -35,7 +36,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
- evaluator = onnx.evaluatorOf(config.transformerModel().toString());
+ var onnxOpts = new OnnxEvaluatorOptions();
+ if (config.transformerGpuDevice() >= 0)
+ onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired());
+ evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
validateModel();
}
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
index ae7a972f1d2..5ecdb59eae3 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.huggingface.hugging-face-embedder.def
@@ -17,3 +17,6 @@ transformerAttentionMask string default=attention_mask
# Output name
transformerOutput string default=last_hidden_state
+# GPU configuration
+transformerGpuDevice int default=-1
+transformerGpuRequired bool default=false \ No newline at end of file