diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java | 46 |
1 files changed, 27 insertions, 19 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java index 1cc33cc4180..6c990eec04b 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java @@ -5,9 +5,11 @@ import com.yahoo.config.application.api.FileRegistry; import com.yahoo.schema.OnnxModel; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; import java.util.logging.Logger; @@ -16,7 +18,7 @@ import java.util.logging.Logger; * * @author bratseth */ -public class FileDistributedOnnxModels extends Derived implements OnnxModelsConfig.Producer { +public class FileDistributedOnnxModels extends Derived { private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName()); @@ -47,30 +49,36 @@ public class FileDistributedOnnxModels extends Derived implements OnnxModelsConf public Map<String, OnnxModel> asMap() { return models; } - public void getConfig(OnnxModelsConfig.Builder builder) { + private static OnnxModelsConfig.Model.Builder toConfig(OnnxModel model) { + OnnxModelsConfig.Model.Builder builder = new OnnxModelsConfig.Model.Builder(); + builder.dry_run_on_setup(true); + builder.name(model.getName()); + builder.fileref(model.getFileReference()); + model.getInputMap().forEach((name, source) -> builder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); + model.getOutputMap().forEach((name, as) -> builder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); + if (model.getStatelessExecutionMode().isPresent()) + builder.stateless_execution_mode(model.getStatelessExecutionMode().get()); + if (model.getStatelessInterOpThreads().isPresent()) + builder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); + if (model.getStatelessIntraOpThreads().isPresent()) + builder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); + if (model.getGpuDevice().isPresent()) { + builder.gpu_device(model.getGpuDevice().get().deviceNumber()); + builder.gpu_device_required(model.getGpuDevice().get().required()); + } + return builder; + } + + public List<OnnxModelsConfig.Model.Builder> getConfig() { + List<OnnxModelsConfig.Model.Builder> cfgList = new ArrayList<>(); for (OnnxModel model : models.values()) { if ("".equals(model.getFileReference())) log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way else { - OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); - modelBuilder.dry_run_on_setup(true); - modelBuilder.name(model.getName()); - modelBuilder.fileref(model.getFileReference()); - model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); - model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as))); - if (model.getStatelessExecutionMode().isPresent()) - modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get()); - if (model.getStatelessInterOpThreads().isPresent()) - modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get()); - if (model.getStatelessIntraOpThreads().isPresent()) - modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get()); - if (model.getGpuDevice().isPresent()) { - modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber()); - modelBuilder.gpu_device_required(model.getGpuDevice().get().required()); - } - builder.model(modelBuilder); + cfgList.add(toConfig(model)); } } + return cfgList; } } |