aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-10-19 11:35:59 +0200
committerJon Bratseth <bratseth@gmail.com>2021-10-19 11:35:59 +0200
commit5915b0d1470e7b6ae7e30ad4e532835843b75f63 (patch)
treee5a04732c852391bebca230314edb00e17e9aee4 /config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
parentf6eff7508eba3a8772d6cf0f3ed6d230fd95daef (diff)
Inherit ONNX models
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java32
1 files changed, 26 insertions, 6 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
index 9c1ee2bb609..e249fe59f14 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/OnnxModels.java
@@ -8,6 +8,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
/**
* ONNX models tied to a search definition or global.
@@ -16,11 +17,16 @@ import java.util.Map;
*/
public class OnnxModels {
- private final Map<String, OnnxModel> models = new HashMap<>();
private final FileRegistry fileRegistry;
- public OnnxModels(FileRegistry fileRegistry) {
+ /** The schema this belongs to, or empty if it is global */
+ private final Optional<Schema> owner;
+
+ private final Map<String, OnnxModel> models = new HashMap<>();
+
+ public OnnxModels(FileRegistry fileRegistry, Optional<Schema> owner) {
this.fileRegistry = fileRegistry;
+ this.owner = owner;
}
public void add(OnnxModel model) {
@@ -35,20 +41,34 @@ public class OnnxModels {
}
public OnnxModel get(String name) {
- return models.get(name);
+ var model = models.get(name);
+ if (model != null) return model;
+ if (owner.isPresent() && owner.get().inherited().isPresent())
+ return owner.get().inherited().get().onnxModels().get(name);
+ return null;
}
public boolean has(String name) {
- return models.containsKey(name);
+ boolean has = models.containsKey(name);
+ if (has) return true;
+ if (owner.isPresent() && owner.get().inherited().isPresent())
+ return owner.get().inherited().get().onnxModels().has(name);
+ return false;
}
public Map<String, OnnxModel> asMap() {
- return Collections.unmodifiableMap(models);
+ // Shortcuts
+ if (owner.isEmpty() || owner.get().inherited().isEmpty()) return Collections.unmodifiableMap(models);
+ if (models.isEmpty()) return owner.get().inherited().get().onnxModels().asMap();
+
+ var allModels = new HashMap<>(owner.get().inherited().get().onnxModels().asMap());
+ allModels.putAll(models);
+ return Collections.unmodifiableMap(allModels);
}
/** Initiate sending of these models to some services over file distribution */
public void sendTo(Collection<? extends AbstractService> services) {
- models.values().forEach(model -> model.sendTo(services));
+ asMap().values().forEach(model -> model.sendTo(services));
}
}