aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorOptions.java
blob: cefafc3654b702fb1559d7a36fe12ea5ff211c9d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package ai.vespa.modelintegration.evaluator;

import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;

import java.util.Objects;

import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.PARALLEL;
import static ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode.SEQUENTIAL;

/**
 * Session options for ONNX Runtime evaluation
 *
 * @author lesters
 */
public class OnnxEvaluatorOptions {

    private OrtSession.SessionOptions.OptLevel optimizationLevel;
    private OrtSession.SessionOptions.ExecutionMode executionMode;
    private int interOpThreads;
    private int intraOpThreads;
    private int gpuDeviceNumber;
    private boolean gpuDeviceRequired;

    public OnnxEvaluatorOptions() {
        // Defaults:
        optimizationLevel = OrtSession.SessionOptions.OptLevel.ALL_OPT;
        executionMode = SEQUENTIAL;
        int quarterVcpu = Math.max(1, (int) Math.ceil(Runtime.getRuntime().availableProcessors() / 4d));
        interOpThreads = quarterVcpu;
        intraOpThreads = quarterVcpu;
        gpuDeviceNumber = -1;
        gpuDeviceRequired = false;
    }

    public OrtSession.SessionOptions getOptions(boolean loadCuda) throws OrtException {
        OrtSession.SessionOptions options = new OrtSession.SessionOptions();
        options.setOptimizationLevel(optimizationLevel);
        options.setExecutionMode(executionMode);
        options.setInterOpNumThreads(executionMode == PARALLEL ? interOpThreads : 1);
        options.setIntraOpNumThreads(intraOpThreads);
        options.setCPUArenaAllocator(false);
        if (loadCuda) {
            options.addCUDA(gpuDeviceNumber);
        }
        return options;
    }

    public void setExecutionMode(String mode) {
        if ("parallel".equalsIgnoreCase(mode)) {
            executionMode = OrtSession.SessionOptions.ExecutionMode.PARALLEL;
        } else if ("sequential".equalsIgnoreCase(mode)) {
            executionMode = SEQUENTIAL;
        }
    }

    public void setInterOpThreads(int threads) {
        if (threads >= 0) {
            interOpThreads = threads;
        }
    }

    public void setIntraOpThreads(int threads) {
        if (threads >= 0) {
            intraOpThreads = threads;
        }
    }

    /**
     * Sets the number of threads for inter and intra op execution.
     * A negative number is interpreted as an inverse scaling factor <code>threads=CPU/-n</code>
     */
    public void setThreads(int interOp, int intraOp) {
        interOpThreads = calculateThreads(interOp);
        intraOpThreads = calculateThreads(intraOp);
    }

    private static int calculateThreads(int t) {
        if (t >= 0) return t;
        return Math.max(1, (int) Math.ceil(-1d * Runtime.getRuntime().availableProcessors() / t));
    }

    public void setGpuDevice(int deviceNumber, boolean required) {
        this.gpuDeviceNumber = deviceNumber;
        this.gpuDeviceRequired = required;
    }

    public void setGpuDevice(int deviceNumber) { gpuDeviceNumber = deviceNumber; }

    public boolean requestingGpu() {
        return gpuDeviceNumber > -1;
    }

    public boolean gpuDeviceRequired() {
        return gpuDeviceRequired;
    }

    public int gpuDeviceNumber() { return gpuDeviceNumber; }

    public OnnxEvaluatorOptions copy() {
        var copy = new OnnxEvaluatorOptions();
        copy.gpuDeviceNumber = gpuDeviceNumber;
        copy.gpuDeviceRequired = gpuDeviceRequired;
        copy.executionMode = executionMode;
        copy.interOpThreads = interOpThreads;
        copy.intraOpThreads = intraOpThreads;
        copy.optimizationLevel = optimizationLevel;
        return copy;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) return true;
        if (o == null || getClass() != o.getClass()) return false;
        OnnxEvaluatorOptions that = (OnnxEvaluatorOptions) o;
        return interOpThreads == that.interOpThreads && intraOpThreads == that.intraOpThreads
                && gpuDeviceNumber == that.gpuDeviceNumber && gpuDeviceRequired == that.gpuDeviceRequired
                && optimizationLevel == that.optimizationLevel && executionMode == that.executionMode;
    }

    @Override
    public int hashCode() {
        return Objects.hash(optimizationLevel, executionMode, interOpThreads, intraOpThreads, gpuDeviceNumber, gpuDeviceRequired);
    }
}