aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java49
1 files changed, 27 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
index 1cc33cc4180..6e50e3c094c 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
@@ -5,9 +5,11 @@ import com.yahoo.config.application.api.FileRegistry;
import com.yahoo.schema.OnnxModel;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
+import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
@@ -16,13 +18,10 @@ import java.util.logging.Logger;
*
* @author bratseth
*/
-public class FileDistributedOnnxModels extends Derived implements OnnxModelsConfig.Producer {
+public class FileDistributedOnnxModels {
private static final Logger log = Logger.getLogger(FileDistributedOnnxModels.class.getName());
- @Override
- public String getDerivedName() { return "onnx-models"; }
-
private final Map<String, OnnxModel> models;
public FileDistributedOnnxModels(FileRegistry fileRegistry, Collection<OnnxModel> models) {
@@ -47,30 +46,36 @@ public class FileDistributedOnnxModels extends Derived implements OnnxModelsConf
public Map<String, OnnxModel> asMap() { return models; }
- public void getConfig(OnnxModelsConfig.Builder builder) {
+ private static OnnxModelsConfig.Model.Builder toConfig(OnnxModel model) {
+ OnnxModelsConfig.Model.Builder builder = new OnnxModelsConfig.Model.Builder();
+ builder.dry_run_on_setup(true);
+ builder.name(model.getName());
+ builder.fileref(model.getFileReference());
+ model.getInputMap().forEach((name, source) -> builder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
+ model.getOutputMap().forEach((name, as) -> builder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
+ if (model.getStatelessExecutionMode().isPresent())
+ builder.stateless_execution_mode(model.getStatelessExecutionMode().get());
+ if (model.getStatelessInterOpThreads().isPresent())
+ builder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
+ if (model.getStatelessIntraOpThreads().isPresent())
+ builder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
+ if (model.getGpuDevice().isPresent()) {
+ builder.gpu_device(model.getGpuDevice().get().deviceNumber());
+ builder.gpu_device_required(model.getGpuDevice().get().required());
+ }
+ return builder;
+ }
+
+ public List<OnnxModelsConfig.Model.Builder> getConfig() {
+ List<OnnxModelsConfig.Model.Builder> cfgList = new ArrayList<>();
for (OnnxModel model : models.values()) {
if ("".equals(model.getFileReference()))
log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
else {
- OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
- modelBuilder.dry_run_on_setup(true);
- modelBuilder.name(model.getName());
- modelBuilder.fileref(model.getFileReference());
- model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
- model.getOutputMap().forEach((name, as) -> modelBuilder.output(new OnnxModelsConfig.Model.Output.Builder().name(name).as(as)));
- if (model.getStatelessExecutionMode().isPresent())
- modelBuilder.stateless_execution_mode(model.getStatelessExecutionMode().get());
- if (model.getStatelessInterOpThreads().isPresent())
- modelBuilder.stateless_interop_threads(model.getStatelessInterOpThreads().get());
- if (model.getStatelessIntraOpThreads().isPresent())
- modelBuilder.stateless_intraop_threads(model.getStatelessIntraOpThreads().get());
- if (model.getGpuDevice().isPresent()) {
- modelBuilder.gpu_device(model.getGpuDevice().get().deviceNumber());
- modelBuilder.gpu_device_required(model.getGpuDevice().get().required());
- }
- builder.model(modelBuilder);
+ cfgList.add(toConfig(model));
}
}
+ return cfgList;
}
}