summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java17
1 files changed, 16 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index 4a35f4275fa..6048be8aca9 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;
*/
public class OnnxEvaluatorOptions {
- private final OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private OrtSession.SessionOptions.OptLevel optimizationLevel;
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
@@ -86,6 +86,8 @@ public class OnnxEvaluatorOptions {
this.gpuDeviceRequired = required;
}
+ public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; }
+
public boolean requestingGpu() {
return gpuDeviceNumber > -1;
}
@@ -94,6 +96,19 @@ public class OnnxEvaluatorOptions {
return gpuDeviceRequired;
}
+ public int gpuDeviceNumber() { return gpuDeviceNumber; }
+
+ public OnnxEvaluatorOptions copy() {
+ var copy = new OnnxEvaluatorOptions();
+ copy.gpuDeviceNumber = gpuDeviceNumber;
+ copy.gpuDeviceRequired = gpuDeviceRequired;
+ copy.executionMode = executionMode;
+ copy.interOpThreads = interOpThreads;
+ copy.intraOpThreads = intraOpThreads;
+ copy.optimizationLevel = optimizationLevel;
+ return copy;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;