aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2023-05-08 16:54:28 +0200
committerGitHub <noreply@github.com>2023-05-08 16:54:28 +0200
commit1169a91fa02ea0bbc01c9e110bb7efe352b9054d (patch)
tree8f0787e8e43ea79f5fa3072e4d19b51eb3c1b9c1 /model-integration/src/main/java/ai
parent93d0034b4603aca06771f646dd814d17faf598f8 (diff)
Revert "Bjorncs/embedder onnx gpu"
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java17
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java26
4 files changed, 6 insertions, 40 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b172ef7beee..8e5211ccff1 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -58,7 +58,6 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
options.setExecutionMode(config.onnxExecutionMode().toString());
options.setThreads(config.onnxInterOpThreads(), config.onnxIntraOpThreads());
- if (config.onnxGpuDevice() >= 0) options.setGpuDevice(config.onnxGpuDevice());
tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocab().toString()).build();
this.evaluator = onnx.evaluatorOf(config.transformerModel().toString(), options);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index cc13254385b..21dd326689c 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -40,7 +40,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
- onnxOpts.setGpuDevice(config.transformerGpuDevice());
+ onnxOpts.setGpuDevice(config.transformerGpuDevice(), config.transformerGpuRequired());
onnxOpts.setExecutionMode(config.transformerExecutionMode().toString());
onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads());
evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts);
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 6048be8aca9..4a35f4275fa 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
*/
public class OnnxEvaluatorOptions {
- private OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private final OrtSession.SessionOptions.OptLevel optimizationLevel;
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
@@ -86,8 +86,6 @@ public class OnnxEvaluatorOptions {
this.gpuDeviceRequired = required;
}
- public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; }
-
public boolean requestingGpu() {
return gpuDeviceNumber > -1;
}
@@ -96,19 +94,6 @@ public class OnnxEvaluatorOptions {
return gpuDeviceRequired;
}
- public int gpuDeviceNumber() { return gpuDeviceNumber; }
-
- public OnnxEvaluatorOptions copy() {
- var copy = new OnnxEvaluatorOptions();
- copy.gpuDeviceNumber = gpuDeviceNumber;
- copy.gpuDeviceRequired = gpuDeviceRequired;
- copy.executionMode = executionMode;
- copy.interOpThreads = interOpThreads;
- copy.intraOpThreads = intraOpThreads;
- copy.optimizationLevel = optimizationLevel;
- return copy;
- }
-
@Override
public boolean equals(Object o) {
if (this == o) return true;
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
index ab44a2ae33f..ece1db55c1e 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -10,7 +10,6 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
import com.yahoo.jdisc.refcount.References;
-import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import net.jpountz.xxhash.XXHashFactory;
import java.io.IOException;
@@ -53,24 +52,17 @@ public class OnnxRuntime extends AbstractComponent {
private final Object monitor = new Object();
private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
private final OrtSessionFactory factory;
- private final int gpusAvailable;
- // For test use only
- public OnnxRuntime() { this(defaultFactory, new OnnxModelsConfig.Builder().build()); }
+ @Inject public OnnxRuntime() { this(defaultFactory); }
- @Inject public OnnxRuntime(OnnxModelsConfig cfg) { this(defaultFactory, cfg); }
-
- OnnxRuntime(OrtSessionFactory factory, OnnxModelsConfig cfg) {
- this.factory = factory;
- this.gpusAvailable = cfg.gpu().count();
- }
+ OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
public OnnxEvaluator evaluatorOf(byte[] model) {
return new OnnxEvaluator(model, null, this);
}
public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
- return new OnnxEvaluator(model, overrideOptions(options), this);
+ return new OnnxEvaluator(model, options, this);
}
public OnnxEvaluator evaluatorOf(String modelPath) {
@@ -78,7 +70,7 @@ public class OnnxRuntime extends AbstractComponent {
}
public OnnxEvaluator evaluatorOf(String modelPath, OnnxEvaluatorOptions options) {
- return new OnnxEvaluator(modelPath, overrideOptions(options), this);
+ return new OnnxEvaluator(modelPath, options, this);
}
public static OrtEnvironment ortEnvironment() {
@@ -175,16 +167,6 @@ public class OnnxRuntime extends AbstractComponent {
}
}
- private OnnxEvaluatorOptions overrideOptions(OnnxEvaluatorOptions opts) {
- // Set GPU device required if GPU requested and GPUs are available on system
- if (gpusAvailable > 0 && opts.requestingGpu() && !opts.gpuDeviceRequired()) {
- var copy = opts.copy();
- copy.setGpuDevice(opts.gpuDeviceNumber(), true);
- return copy;
- }
- return opts;
- }
-
int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
static class ReferencedOrtSession implements AutoCloseable {