diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-05 16:41:20 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-05 16:47:06 +0200 |
commit | ea351a9d4d393cbf9a2018197557f42ce3c490c1 (patch) | |
tree | cede33353b0870479e60c3b80cc1950fb13e8482 /model-integration/src/main/java/ai | |
parent | 530b1aeedbd3cc492e19a6797477727be733af68 (diff) |
Allow for manual configuration of GPU
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 3e51e1f0919..44593fa2e57 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -1,6 +1,7 @@ package ai.vespa.embedding.huggingface; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; @@ -35,7 +36,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); - evaluator = onnx.evaluatorOf(config.transformerModel().toString()); + var onnxOpts = new OnnxEvaluatorOptions(); + if (config.transformerGpuDevice() >= 0) + onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts); validateModel(); } |