summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@vespa.ai>2024-01-05 11:55:14 +0100
committerBjørn Christian Seime <bjorncs@vespa.ai>2024-01-05 13:58:47 +0100
commit9d03920ca70489d3d334dc90b5c0ad14b5ba2d63 (patch)
tree3ed98c68603948f8902a2010be0f4f102f9633ee
parent21b8ca070ade5a2a35cf89c4b8b5f9748510ce3b (diff)
Tag provided models with type to catch misconfiguration
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java10
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java17
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java11
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java86
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java35
8 files changed, 127 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
index 67fb720b8c0..f546f5060ca 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java
@@ -8,10 +8,14 @@ import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Set;
+
import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode;
import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.BERT_VOCAB;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL;
/**
* @author bjorncs
@@ -32,14 +36,14 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf
public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference();
+ vocabRef = Model.fromXml(state, xml, "tokenizer-vocab", Set.of(BERT_VOCAB)).orElseThrow().modelReference();
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index d22e6afc3d1..780dddc6684 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -8,9 +8,13 @@ import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Set;
+
import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL;
/**
@@ -37,14 +41,14 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER))
.map(Model::modelReference)
.orElseGet(() -> resolveDefaultVocab(model, state));
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
@@ -62,7 +66,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
var modelId = model.modelId().orElse(null);
if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
+ return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference();
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index d98c72ab3a4..b489d3edd98 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -8,10 +8,14 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Set;
+
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy;
import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL;
/**
@@ -32,14 +36,14 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER))
.map(Model::modelReference)
.orElseGet(() -> resolveDefaultVocab(model, state));
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
@@ -55,7 +59,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
var modelId = model.modelId().orElse(null);
if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
+ return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference();
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index f808916d83b..c73cf9dce37 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -10,9 +10,11 @@ import com.yahoo.text.XML;
import org.w3c.dom.Element;
import java.util.Map;
+import java.util.Set;
import java.util.TreeMap;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
/**
* @author bjorncs
@@ -25,7 +27,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
for (Element element : XML.getChildren(xml, "model")) {
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
- langToModel.put(lang, Model.fromXml(state, element).modelReference());
+ langToModel.put(lang, Model.fromXml(state, element, Set.of(HF_TOKENIZER)).modelReference());
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
index 102ed926fad..558b58f12be 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
@@ -15,6 +15,7 @@ import org.w3c.dom.Element;
import java.net.URI;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
/**
* Represents a model, e.g ONNX model for an embedder.
@@ -28,7 +29,7 @@ class Model {
private final ApplicationFile file;
private final ModelReference ref;
- private Model(DeployState ds, String paramName, String modelId, URI url, Path file) {
+ private Model(DeployState ds, String paramName, String modelId, URI url, Path file, Set<String> requiredTags) {
this.paramName = Objects.requireNonNull(paramName);
if (modelId == null && url == null && file == null)
throw new IllegalArgumentException("At least one of 'model-id', 'url' or 'path' must be specified");
@@ -37,22 +38,22 @@ class Model {
this.file = file != null ? ds.getApplicationPackage().getFile(file) : null;
this.ref = ModelIdResolver.resolveToModelReference(
paramName, Optional.ofNullable(modelId), Optional.ofNullable(url).map(URI::toString),
- Optional.ofNullable(file).map(Path::toString), ds);
+ Optional.ofNullable(file).map(Path::toString), requiredTags, ds);
}
- static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file) {
- return new Model(ds, paramName, modelId, url, file);
+ static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file, Set<String> requiredTags) {
+ return new Model(ds, paramName, modelId, url, file, requiredTags);
}
- static Optional<Model> fromXml(DeployState ds, Element parent, String name) {
- return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e));
+ static Optional<Model> fromXml(DeployState ds, Element parent, String name, Set<String> requiredTags) {
+ return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e, requiredTags));
}
- static Model fromXml(DeployState ds, Element model) {
+ static Model fromXml(DeployState ds, Element model, Set<String> requiredTags) {
var modelId = XmlHelper.getOptionalAttribute(model, "model-id").orElse(null);
var url = XmlHelper.getOptionalAttribute(model, "url").map(URI::create).orElse(null);
var path = XmlHelper.getOptionalAttribute(model, "path").map(Path::fromString).orElse(null);
- return new Model(ds, model.getTagName(), modelId, url, path);
+ return new Model(ds, model.getTagName(), modelId, url, path, requiredTags);
}
void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
index 038a6cb78c8..53358e7576a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
@@ -7,8 +7,13 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.SpladeEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+
+import java.util.Set;
+
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL;
public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConfig.Producer {
@@ -24,14 +29,14 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
public SpladeEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.SpladeEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model")
+ vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER))
.map(Model::modelReference)
.orElseGet(() -> resolveDefaultVocab(model, state));
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
@@ -46,7 +51,7 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
var modelId = model.modelId().orElse(null);
if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
+ return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference();
}
throw new IllegalArgumentException("'tokenizer-model' must be specified");
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
index 0142b7f246a..9ff9344edcb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -7,10 +7,11 @@ import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.text.XML;
import org.w3c.dom.Element;
-import java.util.Collections;
+import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -21,44 +22,55 @@ import java.util.stream.Collectors;
*/
public class ModelIdResolver {
- private static Map<String, String> setupProvidedModels() {
- Map<String, String> models = new HashMap<>();
- models.put("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx");
- models.put("mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx");
- models.put("bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt");
- models.put("flan-t5-vocab", "https://data.vespa.oath.cloud/onnx_models/flan-t5-spiece.model");
- models.put("flan-t5-small-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-encoder-model.onnx");
- models.put("flan-t5-small-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-decoder-model.onnx");
- models.put("flan-t5-base-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-encoder-model.onnx");
- models.put("flan-t5-base-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-decoder-model.onnx");
- models.put("flan-t5-large-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-encoder-model.onnx");
- models.put("flan-t5-large-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-decoder-model.onnx");
+ public static final String HF_TOKENIZER = "huggingface-tokenizer";
+ public static final String ONNX_MODEL = "onnx-model";
+ public static final String BERT_VOCAB = "bert-vocabulary";
- models.put("multilingual-e5-base", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/model.onnx");
- models.put("multilingual-e5-base-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json");
+ private static Map<String, ProvidedModel> setupProvidedModels() {
+ var m = new HashMap<String, ProvidedModel>();
+ register(m, "minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", Set.of(ONNX_MODEL));
+ register(m, "mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx", Set.of(ONNX_MODEL));
+ register(m, "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt", Set.of(BERT_VOCAB));
+ register(m, "flan-t5-vocab", "https://data.vespa.oath.cloud/onnx_models/flan-t5-spiece.model", Set.of());
+ register(m, "flan-t5-small-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-encoder-model.onnx", Set.of(ONNX_MODEL));
+ register(m, "flan-t5-small-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-decoder-model.onnx", Set.of(ONNX_MODEL));
+ register(m, "flan-t5-base-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-encoder-model.onnx", Set.of(ONNX_MODEL));
+ register(m, "flan-t5-base-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-decoder-model.onnx", Set.of(ONNX_MODEL));
+ register(m, "flan-t5-large-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-encoder-model.onnx", Set.of(ONNX_MODEL));
+ register(m, "flan-t5-large-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-decoder-model.onnx", Set.of(ONNX_MODEL));
- models.put("multilingual-e5-small", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/model.onnx");
- models.put("multilingual-e5-small-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/tokenizer.json");
+ register(m, "multilingual-e5-base", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "multilingual-e5-base-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", Set.of(HF_TOKENIZER));
- models.put("multilingual-e5-small-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/model.onnx");
- models.put("multilingual-e5-small-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/tokenizer.json");
+ register(m, "multilingual-e5-small", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "multilingual-e5-small-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/tokenizer.json", Set.of(HF_TOKENIZER));
- models.put("e5-small-v2", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/model.onnx");
- models.put("e5-small-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/tokenizer.json");
+ register(m, "multilingual-e5-small-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "multilingual-e5-small-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/tokenizer.json", Set.of(HF_TOKENIZER));
- models.put("e5-small-v2-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/model.onnx");
- models.put("e5-small-v2-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/tokenizer.json");
+ register(m, "e5-small-v2", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "e5-small-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/tokenizer.json", Set.of(HF_TOKENIZER));
- models.put("e5-base-v2", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx");
- models.put("e5-base-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/tokenizer.json");
+ register(m, "e5-small-v2-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "e5-small-v2-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/tokenizer.json", Set.of(HF_TOKENIZER));
- models.put("e5-large-v2", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/model.onnx");
- models.put("e5-large-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/tokenizer.json");
+ register(m, "e5-base-v2", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "e5-base-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/tokenizer.json", Set.of(HF_TOKENIZER));
- return Collections.unmodifiableMap(models);
+ register(m, "e5-large-v2", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/model.onnx", Set.of(ONNX_MODEL));
+ register(m, "e5-large-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/tokenizer.json", Set.of(HF_TOKENIZER));
+ return Map.copyOf(m);
}
- private static final Map<String, String> providedModels = setupProvidedModels();
+ private record ProvidedModel(String modelId, URI uri, Set<String> tags) {
+ ProvidedModel { tags = Set.copyOf(tags); }
+ }
+
+ private static void register(Map<String, ProvidedModel> models, String modelId, String uri, Set<String> tags) {
+ models.put(modelId, new ProvidedModel(modelId, URI.create(uri), tags));
+ }
+
+ private static final Map<String, ProvidedModel> providedModels = setupProvidedModels();
/**
* Finds any config values of type 'model' below the given config element and
@@ -79,7 +91,7 @@ public class ModelIdResolver {
if ( ! value.hasAttribute("model-id")) return;
if (hosted) {
- value.setAttribute("url", modelIdToUrl(value.getTagName(), value.getAttribute("model-id")));
+ value.setAttribute("url", modelIdToUrl(value.getTagName(), value.getAttribute("model-id"), Set.of()));
value.removeAttribute("path");
}
else if ( ! value.hasAttribute("url") && ! value.hasAttribute("path")) {
@@ -88,10 +100,10 @@ public class ModelIdResolver {
}
public static ModelReference resolveToModelReference(
- String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) {
+ String paramName, Optional<String> id, Optional<String> url, Optional<String> path, Set<String> requiredTags, DeployState state) {
if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state);
else if (state.isHosted())
- return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get())), Optional.empty(), state);
+ return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get(), requiredTags)), Optional.empty(), state);
else if (url.isEmpty() && path.isEmpty()) throw onlyModelIdInHostedException(paramName);
else return createModelReference(id, url, path, state);
}
@@ -106,11 +118,17 @@ public class ModelIdResolver {
"Add a 'path' or 'url' to deploy this outside Vespa Cloud");
}
- private static String modelIdToUrl(String valueName, String modelId) {
+ private static String modelIdToUrl(String valueName, String modelId, Set<String> requiredTags) {
if ( ! providedModels.containsKey(modelId))
throw new IllegalArgumentException("Unknown model id '" + modelId + "' on '" + valueName + "'. Available models are [" +
providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]");
- return providedModels.get(modelId);
+ var providedModel = providedModels.get(modelId);
+ if (!providedModel.tags().containsAll(requiredTags)) {
+ throw new IllegalArgumentException(
+ "Model '%s' on '%s' has tags %s but are missing required tags %s"
+ .formatted(modelId, valueName, providedModel.tags(), requiredTags));
+ }
+ return providedModel.uri().toString();
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java
new file mode 100644
index 00000000000..409c3ac833a
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java
@@ -0,0 +1,35 @@
+package com.yahoo.vespa.model.container.xml;
+
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.config.model.deploy.TestProperties;
+import org.junit.jupiter.api.Test;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL;
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+/**
+ * @author bjorncs
+ */
+class ModelIdResolverTest {
+
+ @Test
+ void throws_on_known_model_with_missing_tags() {
+ var state = new DeployState.Builder().properties(new TestProperties().setHostedVespa(true)).build();
+ var e = assertThrows(IllegalArgumentException.class, () ->
+ ModelIdResolver.resolveToModelReference(
+ "param", Optional.of("minilm-l6-v2"), Optional.empty(), Optional.empty(), Set.of(HF_TOKENIZER), state));
+ var expectedMsg = "Model 'minilm-l6-v2' on 'param' has tags [onnx-model] but are missing required tags [huggingface-tokenizer]";
+ assertEquals(expectedMsg, e.getMessage());
+
+ assertDoesNotThrow(
+ () -> ModelIdResolver.resolveToModelReference(
+ "param", Optional.of("minilm-l6-v2"), Optional.empty(), Optional.empty(), Set.of(ONNX_MODEL), state));
+ }
+
+} \ No newline at end of file