summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
blob: 4196af18fb64b96b124cf9a9b9c81fb67da93a72 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema.derived;

import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.schema.OnnxModel;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;

import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.logging.Logger;

/**
 * ONNX models distributed as files.
 *
 * @author bratseth
 */
public class FileDistributedOnnxModels {

    private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());

    private final Map<String, OnnxModel> models;

    public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) {
        Map<String, OnnxModel> distributableModels = new LinkedHashMap<>();
        for (var model : models) {
            model.validate();
            model.register(fileRegistry);
            distributableModels.put(model.getName(), model);
        }
        this.models = Collections.unmodifiableMap(distributableModels);
    }

    public Map<String, OnnxModel> asMap() { return models; }

    public void getConfig(OnnxModelsConfig.Builder builder) {
        for (OnnxModel model : models.values()) {
            if ("".equals(model.getFileReference()))
                log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
            else {
                OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
                modelBuilder.dry_run_on_setup(true);
                modelBuilder.name(model.getName());
                modelBuilder.fileref(model.getFileReference());
                model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
                model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
                if (model.getStatelessExecutionMode().isPresent())
                    modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
                if (model.getStatelessInterOpThreads().isPresent())
                    modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
                if (model.getStatelessIntraOpThreads().isPresent())
                    modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
                if (model.getGpuDevice().isPresent()) {
                    modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber());
                    modelBuilder.gpu_device_required(model.getGpuDevice().get().required());
                }
                builder.model(modelBuilder);
            }
        }
    }

}