blob: 4196af18fb64b96b124cf9a9b9c81fb67da93a72 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.derived;
import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.schema.OnnxModel;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.logging.Logger;
/**
* ONNX models distributed as files.
*
* @author bratseth
*/
public class FileDistributedOnnxModels {
private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());
private final Map<String, OnnxModel> models;
public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) {
Map<String, OnnxModel> distributableModels = new LinkedHashMap<>();
for (var model : models) {
model.validate();
model.register(fileRegistry);
distributableModels.put(model.getName(), model);
}
this.models = Collections.unmodifiableMap(distributableModels);
}
public Map<String, OnnxModel> asMap() { return models; }
public void getConfig(OnnxModelsConfig.Builder builder) {
for (OnnxModel model : models.values()) {
if ("".equals(model.getFileReference()))
log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
else {
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
modelBuilder.dry_run_on_setup(true);
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
if (model.getStatelessExecutionMode().isPresent())
modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
if (model.getStatelessInterOpThreads().isPresent())
modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
if (model.getStatelessIntraOpThreads().isPresent())
modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
if (model.getGpuDevice().isPresent()) {
modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber());
modelBuilder.gpu_device_required(model.getGpuDevice().get().required());
}
builder.model(modelBuilder);
}
}
}
}
|