diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java index 4a35f4275fa..6048be8aca9 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java @@ -17,7 +17,7 @@ import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL; */ public class OnnxEvaluatorOptions { - private final OrtSession.SessionOptions.OptLevel optimizationLevel; + private OrtSession.SessionOptions.OptLevel optimizationLevel; private OrtSession.SessionOptions.ExecutionMode executionMode; private int interOpThreads; private int intraOpThreads; @@ -86,6 +86,8 @@ public class OnnxEvaluatorOptions { this.gpuDeviceRequired = required; } + public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; } + public boolean requestingGpu() { return gpuDeviceNumber > -1; } @@ -94,6 +96,19 @@ public class OnnxEvaluatorOptions { return gpuDeviceRequired; } + public int gpuDeviceNumber() { return gpuDeviceNumber; } + + public OnnxEvaluatorOptions copy() { + var copy = new OnnxEvaluatorOptions(); + copy.gpuDeviceNumber = gpuDeviceNumber; + copy.gpuDeviceRequired = gpuDeviceRequired; + copy.executionMode = executionMode; + copy.interOpThreads = interOpThreads; + copy.intraOpThreads = intraOpThreads; + copy.optimizationLevel = optimizationLevel; + return copy; + } + @Override public boolean equals(Object o) { if (this == o) return true; |