diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-08 14:13:14 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-08 14:13:14 +0200 |
commit | f7b3f6c02a27fc0ca98c3b87f96b9b6f1b652e32 (patch) | |
tree | c837969ef00b84dcf0d316023a9b1954f025445f /model-integration/src/main/java | |
parent | b4073925d4ce5c08ebc91620219541cb4114ac52 (diff) |
Require GPU when requested and available for Bert + HF embedders
Diffstat (limited to 'model-integration/src/main/java')
3 files changed, 4 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index 8e5211ccff1..3ce01c9ae08 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -58,6 +58,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); options.setExecutionMode(config.onnxExecutionMode().toString()); options.setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads()); + options.setGpuDevice(config.onnxGpuDevice()); tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build(); this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 21dd326689c..cc13254385b 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -40,7 +40,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) - onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired()); + onnxOpts.setGpuDevice(config.transformerGpuDevice()); onnxOpts.setExecutionMode(config.transformerExecutionMode().toString()); onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads()); evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 76a2031171f..6048be8aca9 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -86,6 +86,8 @@ public class OnnxEvaluatorOptions { this.gpuDeviceRequired = required; } + public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; } + public boolean requestingGpu() { return gpuDeviceNumber > -1; } |