aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
blob: ae6f1fd96e491079d9b6f54f4aba7bceaff259be (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;

import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/**
 * A global ONNX model distributed using file distribution, similar to ranking constants.
 *
 * @author lesters
 */
public class OnnxModel extends DistributableResource {

    private OnnxModelInfo modelInfo = null;
    private final Map<String, String> inputMap = new HashMap<>();
    private final Map<String, String> outputMap = new HashMap<>();

    private String  statelessExecutionMode = null;
    private Integer statelessInterOpThreads = null;
    private Integer statelessIntraOpThreads = null;
    private GpuDevice gpuDevice = null;

    public OnnxModel(String name) {
        super(name);
    }

    public OnnxModel(String name, String fileName) {
        super(name, fileName);
        validate();
    }

    @Override
    public void setUri(String uri) {
        throw new IllegalArgumentException("URI for ONNX models are not currently supported");
    }

    public void addInputNameMapping(String onnxName, String vespaName) {
        addInputNameMapping(onnxName, vespaName, true);
    }

    public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        if (overwrite || ! inputMap.containsKey(onnxName)) {
            inputMap.put(onnxName, vespaName);
        }
    }

    public void addOutputNameMapping(String onnxName, String vespaName) {
        addOutputNameMapping(onnxName, vespaName, true);
    }

    public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        if (overwrite || ! outputMap.containsKey(onnxName)) {
            outputMap.put(onnxName, vespaName);
        }
    }

    public void setModelInfo(OnnxModelInfo modelInfo) {
        Objects.requireNonNull(modelInfo, "Onnx model info cannot be null");
        for (String onnxName : modelInfo.getInputs()) {
            addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
        }
        for (String onnxName : modelInfo.getOutputs()) {
            addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
        }
        this.modelInfo = modelInfo;
    }

    public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
    public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }

    public String getDefaultOutput() {
        return modelInfo != null ? modelInfo.getDefaultOutput() : "";
    }

    TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
        return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
    }

    public void setStatelessExecutionMode(String executionMode) {
        if ("parallel".equalsIgnoreCase(executionMode)) {
            this.statelessExecutionMode = "parallel";
        } else if ("sequential".equalsIgnoreCase(executionMode)) {
            this.statelessExecutionMode = "sequential";
        }
    }

    public Optional<String> getStatelessExecutionMode() {
        return Optional.ofNullable(statelessExecutionMode);
    }

    public void setStatelessInterOpThreads(int interOpThreads) {
        if (interOpThreads >= 0) {
            this.statelessInterOpThreads = interOpThreads;
        }
    }

    public Optional<Integer> getStatelessInterOpThreads() {
        return Optional.ofNullable(statelessInterOpThreads);
    }

    public void setStatelessIntraOpThreads(int intraOpThreads) {
        if (intraOpThreads >= 0) {
            this.statelessIntraOpThreads = intraOpThreads;
        }
    }

    public void setGpuDevice(int deviceNumber, boolean required) {
        if (deviceNumber >= 0) {
            this.gpuDevice = new GpuDevice(deviceNumber, required);
        }
    }

    public Optional<Integer> getStatelessIntraOpThreads() {
        return Optional.ofNullable(statelessIntraOpThreads);
    }

    public Optional<GpuDevice> getGpuDevice() {
        return Optional.ofNullable(gpuDevice);
    }

    public record GpuDevice(int deviceNumber, boolean required) {

        public GpuDevice {
            if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
        }

    }

}