diff options
author | Jon Bratseth <bratseth@gmail.com> | 2021-10-19 11:35:59 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2021-10-19 11:35:59 +0200 |
commit | 5915b0d1470e7b6ae7e30ad4e532835843b75f63 (patch) | |
tree | e5a04732c852391bebca230314edb00e17e9aee4 /config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java | |
parent | f6eff7508eba3a8772d6cf0f3ed6d230fd95daef (diff) |
Inherit ONNX models
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java | 32 |
1 files changed, 26 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java index 9c1ee2bb609..e249fe59f14 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java @@ -8,6 +8,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Optional; /** * ONNX models tied to a search definition or global. @@ -16,11 +17,16 @@ import java.util.Map; */ public class OnnxModels { - private final Map<String, OnnxModel> models = new HashMap<>(); private final FileRegistry fileRegistry; - public OnnxModels(FileRegistry fileRegistry) { + /** The schema this belongs to, or empty if it is global */ + private final Optional<Schema> owner; + + private final Map<String, OnnxModel> models = new HashMap<>(); + + public OnnxModels(FileRegistry fileRegistry, Optional<Schema> owner) { this.fileRegistry = fileRegistry; + this.owner = owner; } public void add(OnnxModel model) { @@ -35,20 +41,34 @@ public class OnnxModels { } public OnnxModel get(String name) { - return models.get(name); + var model = models.get(name); + if (model != null) return model; + if (owner.isPresent() && owner.get().inherited().isPresent()) + return owner.get().inherited().get().onnxModels().get(name); + return null; } public boolean has(String name) { - return models.containsKey(name); + boolean has = models.containsKey(name); + if (has) return true; + if (owner.isPresent() && owner.get().inherited().isPresent()) + return owner.get().inherited().get().onnxModels().has(name); + return false; } public Map<String, OnnxModel> asMap() { - return Collections.unmodifiableMap(models); + // Shortcuts + if (owner.isEmpty() || owner.get().inherited().isEmpty()) return Collections.unmodifiableMap(models); + if (models.isEmpty()) return owner.get().inherited().get().onnxModels().asMap(); + + var allModels = new HashMap<>(owner.get().inherited().get().onnxModels().asMap()); + allModels.putAll(models); + return Collections.unmodifiableMap(allModels); } /** Initiate sending of these models to some services over file distribution */ public void sendTo(Collection<? extends AbstractService> services) { - models.values().forEach(model -> model.sendTo(services)); + asMap().values().forEach(model -> model.sendTo(services)); } } |