summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-22 10:40:23 +0100
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-02-22 10:40:23 +0100
commitc744731f33372b8176f09f21959f4a2b321f4b64 (patch)
tree9083470746dc3ecfe48eebf1f9bc52181ab526ba /model-integration
parent813df553bcf33a6fe161a22b6f6b93b2780308b5 (diff)
Implement equals()/hashCode()
Required for using instances as key in HashMap
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java18
1 files changed, 17 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
index b6de9698f1a..1ed219a8560 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
@@ -5,6 +5,8 @@ package ai.vespa.modelintegration.evaluator;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import java.util.Objects;
+
/**
* Session options for ONNX Runtime evaluation
*
@@ -12,7 +14,7 @@ import ai.onnxruntime.OrtSession;
*/
public class OnnxEvaluatorOptions {
- private OrtSession.SessionOptions.OptLevel optimizationLevel;
+ private final OrtSession.SessionOptions.OptLevel optimizationLevel;
private OrtSession.SessionOptions.ExecutionMode executionMode;
private int interOpThreads;
private int intraOpThreads;
@@ -74,4 +76,18 @@ public class OnnxEvaluatorOptions {
return gpuDeviceRequired;
}
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ OnnxEvaluatorOptions that = (OnnxEvaluatorOptions) o;
+ return interOpThreads == that.interOpThreads && intraOpThreads == that.intraOpThreads
+ && gpuDeviceNumber == that.gpuDeviceNumber && gpuDeviceRequired == that.gpuDeviceRequired
+ && optimizationLevel == that.optimizationLevel && executionMode == that.executionMode;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(optimizationLevel, executionMode, interOpThreads, intraOpThreads, gpuDeviceNumber, gpuDeviceRequired);
+ }
}