aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
blob: f3f09150c1d033cb6a05450d9ef86f1976e085d3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;

import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;

/**
 * A global ONNX model distributed using file distribution, similar to ranking constants.
 *
 * @author lesters
 */
public class OnnxModel extends DistributableResource implements Cloneable {

    // Model information
    private OnnxModelInfo modelInfo = null;
    private final Map<String, String> inputMap = new HashMap<>();
    private final Map<String, String> outputMap = new HashMap<>();
    private final Set<String> initializers = new HashSet<>();

    // Runtime options
    private String  statelessExecutionMode = null;
    private Integer statelessInterOpThreads = null;
    private Integer statelessIntraOpThreads = null;
    private GpuDevice gpuDevice = null;

    public OnnxModel(String name) {
        super(name);
    }

    public OnnxModel(String name, String fileName) {
        super(name, fileName);
        validate();
    }

    @Override
    public OnnxModel clone() {
        try {
            return (OnnxModel) super.clone(); // Shallow clone is sufficient here
        } catch (CloneNotSupportedException e) {
            throw new RuntimeException("Clone not supported", e);
        }
    }

    @Override
    public void setUri(String uri) {
        throw new IllegalArgumentException("URI for ONNX models are not currently supported");
    }

    public void addInputNameMapping(String onnxName, String vespaName) {
        addInputNameMapping(onnxName, vespaName, true);
    }

    private String validateInputSource(String source) {
        var optRef = Reference.simple(source);
        if (optRef.isPresent()) {
            Reference ref = optRef.get();
            // input can be one of:
            // attribute(foo), query(foo), constant(foo)
            if (FeatureNames.isSimpleFeature(ref)) {
                return ref.toString();
            }
            // or a function (evaluated by backend)
            if (ref.isSimpleRankingExpressionWrapper()) {
                var arg = ref.simpleArgument();
                if (arg.isPresent()) {
                    return ref.toString();
                }
            }
        } else {
            // otherwise it must be an identifier
            Reference ref = Reference.fromIdentifier(source);
            return ref.toString();
        }
        // invalid input source
        throw new IllegalArgumentException("invalid input for ONNX model " + getName() + ": " + source);
    }

    public void addInputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        String source = validateInputSource(vespaName);
        if (overwrite || ! inputMap.containsKey(onnxName)) {
            inputMap.put(onnxName, source);
        }
    }

    public void addOutputNameMapping(String onnxName, String vespaName) {
        addOutputNameMapping(onnxName, vespaName, true);
    }

    public void addOutputNameMapping(String onnxName, String vespaName, boolean overwrite) {
        Objects.requireNonNull(onnxName, "Onnx name cannot be null");
        Objects.requireNonNull(vespaName, "Vespa name cannot be null");
        // output name must be a valid identifier:
        var ref = Reference.fromIdentifier(vespaName);
        if (overwrite || ! outputMap.containsKey(onnxName)) {
            outputMap.put(onnxName, ref.toString());
        }
    }

    public void setModelInfo(OnnxModelInfo modelInfo) {
        Objects.requireNonNull(modelInfo, "Onnx model info cannot be null");
        for (String onnxName : modelInfo.getInputs()) {
            addInputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
        }
        for (String onnxName : modelInfo.getOutputs()) {
            addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
        }
        initializers.addAll(modelInfo.getInitializers());
        this.modelInfo = modelInfo;
    }

    public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
    public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
    public Set<String> getInitializers() { return Set.copyOf(initializers); }

    public String getDefaultOutput() {
        return modelInfo != null ? modelInfo.getDefaultOutput() : "";
    }

    TensorType getTensorType(String onnxName, Map<String, TensorType> inputTypes) {
        return modelInfo != null ? modelInfo.getTensorType(onnxName, inputTypes) : TensorType.empty;
    }

    public void setStatelessExecutionMode(String executionMode) {
        if ("parallel".equalsIgnoreCase(executionMode)) {
            this.statelessExecutionMode = "parallel";
        } else if ("sequential".equalsIgnoreCase(executionMode)) {
            this.statelessExecutionMode = "sequential";
        }
    }

    public Optional<String> getStatelessExecutionMode() {
        return Optional.ofNullable(statelessExecutionMode);
    }

    public void setStatelessInterOpThreads(int interOpThreads) {
        if (interOpThreads >= 0) {
            this.statelessInterOpThreads = interOpThreads;
        }
    }

    public Optional<Integer> getStatelessInterOpThreads() {
        return Optional.ofNullable(statelessInterOpThreads);
    }

    public void setStatelessIntraOpThreads(int intraOpThreads) {
        if (intraOpThreads >= 0) {
            this.statelessIntraOpThreads = intraOpThreads;
        }
    }

    public Optional<Integer> getStatelessIntraOpThreads() {
        return Optional.ofNullable(statelessIntraOpThreads);
    }

    public void setGpuDevice(int deviceNumber, boolean required) {
        if (deviceNumber >= 0) {
            this.gpuDevice = new GpuDevice(deviceNumber, required);
        }
    }

    public Optional<GpuDevice> getGpuDevice() {
        return Optional.ofNullable(gpuDevice);
    }

    public record GpuDevice(int deviceNumber, boolean required) {
        public GpuDevice {
            if (deviceNumber < 0) throw new IllegalArgumentException("deviceNumber cannot be negative, got " + deviceNumber);
        }
    }

}