diff options
author | Geir Storli <geirst@yahooinc.com> | 2023-05-08 16:54:28 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-08 16:54:28 +0200 |
commit | 1169a91fa02ea0bbc01c9e110bb7efe352b9054d (patch) | |
tree | 8f0787e8e43ea79f5fa3072e4d19b51eb3c1b9c1 /model-integration/src/main/java/ai | |
parent | 93d0034b4603aca06771f646dd814d17faf598f8 (diff) |
Revert "Bjorncs/embedder onnx gpu"
Diffstat (limited to 'model-integration/src/main/java/ai')
4 files changed, 6 insertions, 40 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index b172ef7beee..8e5211ccff1 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -58,7 +58,6 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { OnnxEvaluatorOptions options = new OnnxEvaluatorOptions(); options.setExecutionMode(config.onnxExecutionMode().toString()); options.setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads()); - if (config.onnxGpuDevice() >= 0) options.setGpuDevice(config.onnxGpuDevice()); tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build(); this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index cc13254385b..21dd326689c 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -40,7 +40,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) - onnxOpts.setGpuDevice(config.transformerGpuDevice()); + onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired()); onnxOpts.setExecutionMode(config.transformerExecutionMode().toString()); onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads()); evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 6048be8aca9..4a35f4275fa 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; */ public class OnnxEvaluatorOptions { - private OrtSession.SessionOptions.OptLevel optimizationLevel; + private final OrtSession.SessionOptions.OptLevel optimizationLevel; private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; @@ -86,8 +86,6 @@ public class OnnxEvaluatorOptions { this.gpuDeviceRequired = required; } - public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; } - public boolean requestingGpu() { return gpuDeviceNumber > -1; } @@ -96,19 +94,6 @@ public class OnnxEvaluatorOptions { return gpuDeviceRequired; } - public int gpuDeviceNumber() { return gpuDeviceNumber; } - - public OnnxEvaluatorOptions copy() { - var copy = new OnnxEvaluatorOptions(); - copy.gpuDeviceNumber = gpuDeviceNumber; - copy.gpuDeviceRequired = gpuDeviceRequired; - copy.executionMode = executionMode; - copy.interOpThreads = interOpThreads; - copy.intraOpThreads = intraOpThreads; - copy.optimizationLevel = optimizationLevel; - return copy; - } - @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java index ab44a2ae33f..ece1db55c1e 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java @@ -10,7 +10,6 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.jdisc.ResourceReference; import com.yahoo.jdisc.refcount.DebugReferencesWithStack; import com.yahoo.jdisc.refcount.References; -import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import net.jpountz.xxhash.XXHashFactory; import java.io.IOException; @@ -53,24 +52,17 @@ public class OnnxRuntime extends AbstractComponent { private final Object monitor = new Object(); private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>(); private final OrtSessionFactory factory; - private final int gpusAvailable; - // For test use only - public OnnxRuntime() { this(defaultFactory, new OnnxModelsConfig.Builder().build()); } + @Inject public OnnxRuntime() { this(defaultFactory); } - @Inject public OnnxRuntime(OnnxModelsConfig cfg) { this(defaultFactory, cfg); } - - OnnxRuntime(OrtSessionFactory factory, OnnxModelsConfig cfg) { - this.factory = factory; - this.gpusAvailable = cfg.gpu().count(); - } + OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; } public OnnxEvaluator evaluatorOf(byte[] model) { return new OnnxEvaluator(model, null, this); } public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) { - return new OnnxEvaluator(model, overrideOptions(options), this); + return new OnnxEvaluator(model, options, this); } public OnnxEvaluator evaluatorOf(String modelPath) { @@ -78,7 +70,7 @@ public class OnnxRuntime extends AbstractComponent { } public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) { - return new OnnxEvaluator(modelPath, overrideOptions(options), this); + return new OnnxEvaluator(modelPath, options, this); } public static OrtEnvironment ortEnvironment() { @@ -175,16 +167,6 @@ public class OnnxRuntime extends AbstractComponent { } } - private OnnxEvaluatorOptions overrideOptions(OnnxEvaluatorOptions opts) { - // Set GPU device required if GPU requested and GPUs are available on system - if (gpusAvailable > 0 && opts.requestingGpu() && !opts.gpuDeviceRequired()) { - var copy = opts.copy(); - copy.setGpuDevice(opts.gpuDeviceNumber(), true); - return copy; - } - return opts; - } - int sessionsCached() { synchronized(monitor) { return sessions.size(); } } static class ReferencedOrtSession implements AutoCloseable { |