summaryrefslogtreecommitdiffstats
path: root/config-model/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/DistributableResource.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java15
-rw-r--r--config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java14
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java8
5 files changed, 49 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java
index e7bdb68a03d..8594b40a367 100644
--- a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java
+++ b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java
@@ -8,7 +8,7 @@ import com.yahoo.path.Path;
import java.nio.ByteBuffer;
import java.util.Objects;
-public class DistributableResource implements Comparable <DistributableResource> {
+public class DistributableResource implements Comparable <DistributableResource>, Cloneable {
public enum PathType { FILE, URI, BLOB }
@@ -35,6 +35,11 @@ public class DistributableResource implements Comparable <DistributableResource>
this.pathType = type;
}
+ @Override
+ public DistributableResource clone() throws CloneNotSupportedException {
+ return (DistributableResource) super.clone();
+ }
+
// TODO: Remove and make path/pathType final
public void setFileName(String fileName) {
Objects.requireNonNull(fileName, "Filename cannot be null");
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 90a27d1f036..fbec97c797d 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -18,7 +18,7 @@ import java.util.Set;
*
* @author lesters
*/
-public class OnnxModel extends DistributableResource {
+public class OnnxModel extends DistributableResource implements Cloneable {
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
@@ -40,6 +40,19 @@ public class OnnxModel extends DistributableResource {
}
@Override
+ public OnnxModel clone() {
+ try {
+ OnnxModel clone = (OnnxModel) super.clone();
+ clone.inputMap.putAll(inputMap);
+ clone.outputMap.putAll(outputMap);
+ clone.initializers.addAll(initializers);
+ return clone;
+ } catch (CloneNotSupportedException e) {
+ throw new RuntimeException("Clone not supported", e);
+ }
+ }
+
+ @Override
public void setUri(String uri) {
throw new IllegalArgumentException("URI for ONNX models are not currently supported");
}
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
index e3c697e3262..c3fa6aedf31 100644
--- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
+++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java
@@ -35,6 +35,16 @@ public class FileDistributedOnnxModels extends Derived implements OnnxModelsConf
this.models = Collections.unmodifiableMap(distributableModels);
}
+ private FileDistributedOnnxModels(Collection<OnnxModel> models) {
+ Map<String, OnnxModel> distributableModels = models.stream()
+ .collect(LinkedHashMap::new, (m, v) -> m.put(v.getName(), v.clone()), LinkedHashMap::putAll);
+ this.models = Collections.unmodifiableMap(distributableModels);
+ }
+
+ public FileDistributedOnnxModels clone() {
+ return new FileDistributedOnnxModels(models.values());
+ }
+
public Map<String, OnnxModel> asMap() { return models; }
public void getConfig(OnnxModelsConfig.Builder builder) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index c227700733e..906ef739ef1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.model.container;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.osgi.provider.model.ComponentModel;
+import com.yahoo.schema.derived.FileDistributedOnnxModels;
import com.yahoo.schema.derived.RankProfileList;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
@@ -42,9 +43,16 @@ public class ContainerModelEvaluation implements
/** Global rank profiles, aka models */
private final RankProfileList rankProfileList;
+ private final FileDistributedOnnxModels onnxModels; // For cluster specific ONNX model settings
public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) {
+ this(cluster, rankProfileList, null);
+ }
+
+ public ContainerModelEvaluation(ApplicationContainerCluster cluster,
+ RankProfileList rankProfileList, FileDistributedOnnxModels onnxModels) {
this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null");
+ this.onnxModels = onnxModels;
cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME);
cluster.addComponent(ContainerModelEvaluation.getHandler());
}
@@ -61,7 +69,11 @@ public class ContainerModelEvaluation implements
@Override
public void getConfig(OnnxModelsConfig.Builder builder) {
- rankProfileList.getConfig(builder);
+ if (onnxModels != null) {
+ onnxModels.getConfig(builder);
+ } else {
+ rankProfileList.getConfig(builder);
+ }
}
public void getConfig(RankingExpressionsConfig.Builder builder) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index e7692aeee7b..fb4dce1cc38 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -44,6 +44,7 @@ import com.yahoo.jdisc.http.server.jetty.VoidRequestLog;
import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.path.Path;
import com.yahoo.schema.OnnxModel;
+import com.yahoo.schema.derived.FileDistributedOnnxModels;
import com.yahoo.schema.derived.RankProfileList;
import com.yahoo.search.rendering.RendererRegistry;
import com.yahoo.security.X509CertificateUtils;
@@ -751,10 +752,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
RankProfileList profiles =
context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty;
+ // Create a copy of models so each cluster can have its own specific settings
+ FileDistributedOnnxModels models = profiles.getOnnxModels().clone();
+
Element onnxElement = XML.getChild(modelEvaluationElement, "onnx");
Element modelsElement = XML.getChild(onnxElement, "models");
for (Element modelElement : XML.getChildren(modelsElement, "model") ) {
- OnnxModel onnxModel = profiles.getOnnxModels().asMap().get(modelElement.getAttribute("name"));
+ OnnxModel onnxModel = models.asMap().get(modelElement.getAttribute("name"));
if (onnxModel == null) {
String availableModels = String.join(", ", profiles.getOnnxModels().asMap().keySet());
context.getDeployState().getDeployLogger().logApplicationPackage(WARNING,
@@ -774,7 +778,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
}
}
- cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles));
+ cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models));
}
private String getStringValue(Element element, String name, String defaultValue) {