diff options
Diffstat (limited to 'config-model/src/main')
5 files changed, 49 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java index e7bdb68a03d..8594b40a367 100644 --- a/config-model/src/main/java/com/yahoo/schema/DistributableResource.java +++ b/config-model/src/main/java/com/yahoo/schema/DistributableResource.java @@ -8,7 +8,7 @@ import com.yahoo.path.Path; import java.nio.ByteBuffer; import java.util.Objects; -public class DistributableResource implements Comparable <DistributableResource> { +public class DistributableResource implements Comparable <DistributableResource>, Cloneable { public enum PathType { FILE, URI, BLOB } @@ -35,6 +35,11 @@ public class DistributableResource implements Comparable <DistributableResource> this.pathType = type; } + @Override + public DistributableResource clone() throws CloneNotSupportedException { + return (DistributableResource) super.clone(); + } + // TODO: Remove and make path/pathType final public void setFileName(String fileName) { Objects.requireNonNull(fileName, "Filename cannot be null"); diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 90a27d1f036..fbec97c797d 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -18,7 +18,7 @@ import java.util.Set; * * @author lesters */ -public class OnnxModel extends DistributableResource { +public class OnnxModel extends DistributableResource implements Cloneable { private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); @@ -40,6 +40,19 @@ public class OnnxModel extends DistributableResource { } @Override + public OnnxModel clone() { + try { + OnnxModel clone = (OnnxModel) super.clone(); + clone.inputMap.putAll(inputMap); + clone.outputMap.putAll(outputMap); + clone.initializers.addAll(initializers); + return clone; + } catch (CloneNotSupportedException e) { + throw new RuntimeException("Clone not supported", e); + } + } + + @Override public void setUri(String uri) { throw new IllegalArgumentException("URI for ONNX models are not currently supported"); } diff --git a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java index e3c697e3262..c3fa6aedf31 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/FileDistributedOnnxModels.java @@ -35,6 +35,16 @@ public class FileDistributedOnnxModels extends Derived implements OnnxModelsConf this.models = Collections.unmodifiableMap(distributableModels); } + private FileDistributedOnnxModels(Collection<OnnxModel> models) { + Map<String, OnnxModel> distributableModels = models.stream() + .collect(LinkedHashMap::new, (m, v) -> m.put(v.getName(), v.clone()), LinkedHashMap::putAll); + this.models = Collections.unmodifiableMap(distributableModels); + } + + public FileDistributedOnnxModels clone() { + return new FileDistributedOnnxModels(models.values()); + } + public Map<String, OnnxModel> asMap() { return models; } public void getConfig(OnnxModelsConfig.Builder builder) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index c227700733e..906ef739ef1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.model.container; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.osgi.provider.model.ComponentModel; +import com.yahoo.schema.derived.FileDistributedOnnxModels; import com.yahoo.schema.derived.RankProfileList; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.OnnxModelsConfig; @@ -42,9 +43,16 @@ public class ContainerModelEvaluation implements /** Global rank profiles, aka models */ private final RankProfileList rankProfileList; + private final FileDistributedOnnxModels onnxModels; // For cluster specific ONNX model settings public ContainerModelEvaluation(ApplicationContainerCluster cluster, RankProfileList rankProfileList) { + this(cluster, rankProfileList, null); + } + + public ContainerModelEvaluation(ApplicationContainerCluster cluster, + RankProfileList rankProfileList, FileDistributedOnnxModels onnxModels) { this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null"); + this.onnxModels = onnxModels; cluster.addSimpleComponent(EVALUATOR_NAME, null, EVALUATION_BUNDLE_NAME); cluster.addComponent(ContainerModelEvaluation.getHandler()); } @@ -61,7 +69,11 @@ public class ContainerModelEvaluation implements @Override public void getConfig(OnnxModelsConfig.Builder builder) { - rankProfileList.getConfig(builder); + if (onnxModels != null) { + onnxModels.getConfig(builder); + } else { + rankProfileList.getConfig(builder); + } } public void getConfig(RankingExpressionsConfig.Builder builder) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index e7692aeee7b..fb4dce1cc38 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -44,6 +44,7 @@ import com.yahoo.jdisc.http.server.jetty.VoidRequestLog; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.path.Path; import com.yahoo.schema.OnnxModel; +import com.yahoo.schema.derived.FileDistributedOnnxModels; import com.yahoo.schema.derived.RankProfileList; import com.yahoo.search.rendering.RendererRegistry; import com.yahoo.security.X509CertificateUtils; @@ -751,10 +752,13 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { RankProfileList profiles = context.vespaModel() != null ? context.vespaModel().rankProfileList() : RankProfileList.empty; + // Create a copy of models so each cluster can have its own specific settings + FileDistributedOnnxModels models = profiles.getOnnxModels().clone(); + Element onnxElement = XML.getChild(modelEvaluationElement, "onnx"); Element modelsElement = XML.getChild(onnxElement, "models"); for (Element modelElement : XML.getChildren(modelsElement, "model") ) { - OnnxModel onnxModel = profiles.getOnnxModels().asMap().get(modelElement.getAttribute("name")); + OnnxModel onnxModel = models.asMap().get(modelElement.getAttribute("name")); if (onnxModel == null) { String availableModels = String.join(", ", profiles.getOnnxModels().asMap().keySet()); context.getDeployState().getDeployLogger().logApplicationPackage(WARNING, @@ -774,7 +778,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } } - cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles)); + cluster.setModelEvaluation(new ContainerModelEvaluation(cluster, profiles, models)); } private String getStringValue(Element element, String name, String defaultValue) { |