diff options
author | Bjørn Christian Seime <bjorncs@vespa.ai> | 2024-01-05 11:55:14 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@vespa.ai> | 2024-01-05 13:58:47 +0100 |
commit | 9d03920ca70489d3d334dc90b5c0ad14b5ba2d63 (patch) | |
tree | 3ed98c68603948f8902a2010be0f4f102f9633ee /config-model | |
parent | 21b8ca070ade5a2a35cf89c4b8b5f9748510ce3b (diff) |
Tag provided models with type to catch misconfiguration
Diffstat (limited to 'config-model')
8 files changed, 127 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java index 67fb720b8c0..f546f5060ca 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/BertEmbedder.java @@ -8,10 +8,14 @@ import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Set; + import static com.yahoo.embedding.BertBaseEmbedderConfig.OnnxExecutionMode; import static com.yahoo.embedding.BertBaseEmbedderConfig.PoolingStrategy; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.BERT_VOCAB; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; /** * @author bjorncs @@ -32,14 +36,14 @@ public class BertEmbedder extends TypedComponent implements BertBaseEmbedderConf public BertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.BertBaseEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-vocab").orElseThrow().modelReference(); + vocabRef = Model.fromXml(state, xml, "tokenizer-vocab", Set.of(BERT_VOCAB)).orElseThrow().modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index d22e6afc3d1..780dddc6684 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -8,9 +8,13 @@ import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Set; + import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; /** @@ -37,14 +41,14 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") + vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) .map(Model::modelReference) .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); @@ -62,7 +66,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo private static ModelReference resolveDefaultVocab(Model model, DeployState state) { var modelId = model.modelId().orElse(null); if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index d98c72ab3a4..b489d3edd98 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -8,10 +8,14 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Set; + import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; /** @@ -32,14 +36,14 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") + vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) .map(Model::modelReference) .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); @@ -55,7 +59,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private static ModelReference resolveDefaultVocab(Model model, DeployState state) { var modelId = model.modelId().orElse(null); if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index f808916d83b..c73cf9dce37 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -10,9 +10,11 @@ import com.yahoo.text.XML; import org.w3c.dom.Element; import java.util.Map; +import java.util.Set; import java.util.TreeMap; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; /** * @author bjorncs @@ -25,7 +27,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); for (Element element : XML.getChildren(xml, "model")) { var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; - langToModel.put(lang, Model.fromXml(state, element).modelReference()); + langToModel.put(lang, Model.fromXml(state, element, Set.of(HF_TOKENIZER)).modelReference()); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java index 102ed926fad..558b58f12be 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -15,6 +15,7 @@ import org.w3c.dom.Element; import java.net.URI; import java.util.Objects; import java.util.Optional; +import java.util.Set; /** * Represents a model, e.g ONNX model for an embedder. @@ -28,7 +29,7 @@ class Model { private final ApplicationFile file; private final ModelReference ref; - private Model(DeployState ds, String paramName, String modelId, URI url, Path file) { + private Model(DeployState ds, String paramName, String modelId, URI url, Path file, Set<String> requiredTags) { this.paramName = Objects.requireNonNull(paramName); if (modelId == null && url == null && file == null) throw new IllegalArgumentException("At least one of 'model-id', 'url' or 'path' must be specified"); @@ -37,22 +38,22 @@ class Model { this.file = file != null ? ds.getApplicationPackage().getFile(file) : null; this.ref = ModelIdResolver.resolveToModelReference( paramName, Optional.ofNullable(modelId), Optional.ofNullable(url).map(URI::toString), - Optional.ofNullable(file).map(Path::toString), ds); + Optional.ofNullable(file).map(Path::toString), requiredTags, ds); } - static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file) { - return new Model(ds, paramName, modelId, url, file); + static Model fromParams(DeployState ds, String paramName, String modelId, URI url, Path file, Set<String> requiredTags) { + return new Model(ds, paramName, modelId, url, file, requiredTags); } - static Optional<Model> fromXml(DeployState ds, Element parent, String name) { - return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e)); + static Optional<Model> fromXml(DeployState ds, Element parent, String name, Set<String> requiredTags) { + return XmlHelper.getOptionalChild(parent, name).map(e -> fromXml(ds, e, requiredTags)); } - static Model fromXml(DeployState ds, Element model) { + static Model fromXml(DeployState ds, Element model, Set<String> requiredTags) { var modelId = XmlHelper.getOptionalAttribute(model, "model-id").orElse(null); var url = XmlHelper.getOptionalAttribute(model, "url").map(URI::create).orElse(null); var path = XmlHelper.getOptionalAttribute(model, "path").map(Path::fromString).orElse(null); - return new Model(ds, model.getTagName(), modelId, url, path); + return new Model(ds, model.getTagName(), modelId, url, path, requiredTags); } void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java index 038a6cb78c8..53358e7576a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java @@ -7,8 +7,13 @@ import com.yahoo.config.model.deploy.DeployState; import com.yahoo.embedding.SpladeEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; + +import java.util.Set; + import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConfig.Producer { @@ -24,14 +29,14 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf public SpladeEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.SpladeEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") + vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) .map(Model::modelReference) .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); @@ -46,7 +51,7 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf private static ModelReference resolveDefaultVocab(Model model, DeployState state) { var modelId = model.modelId().orElse(null); if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index 0142b7f246a..9ff9344edcb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -7,10 +7,11 @@ import com.yahoo.config.model.deploy.DeployState; import com.yahoo.text.XML; import org.w3c.dom.Element; -import java.util.Collections; +import java.net.URI; import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -21,44 +22,55 @@ import java.util.stream.Collectors; */ public class ModelIdResolver { - private static Map<String, String> setupProvidedModels() { - Map<String, String> models = new HashMap<>(); - models.put("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"); - models.put("mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx"); - models.put("bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"); - models.put("flan-t5-vocab", "https://data.vespa.oath.cloud/onnx_models/flan-t5-spiece.model"); - models.put("flan-t5-small-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-encoder-model.onnx"); - models.put("flan-t5-small-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-decoder-model.onnx"); - models.put("flan-t5-base-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-encoder-model.onnx"); - models.put("flan-t5-base-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-decoder-model.onnx"); - models.put("flan-t5-large-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-encoder-model.onnx"); - models.put("flan-t5-large-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-decoder-model.onnx"); + public static final String HF_TOKENIZER = "huggingface-tokenizer"; + public static final String ONNX_MODEL = "onnx-model"; + public static final String BERT_VOCAB = "bert-vocabulary"; - models.put("multilingual-e5-base", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/model.onnx"); - models.put("multilingual-e5-base-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json"); + private static Map<String, ProvidedModel> setupProvidedModels() { + var m = new HashMap<String, ProvidedModel>(); + register(m, "minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", Set.of(ONNX_MODEL)); + register(m, "mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx", Set.of(ONNX_MODEL)); + register(m, "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt", Set.of(BERT_VOCAB)); + register(m, "flan-t5-vocab", "https://data.vespa.oath.cloud/onnx_models/flan-t5-spiece.model", Set.of()); + register(m, "flan-t5-small-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-encoder-model.onnx", Set.of(ONNX_MODEL)); + register(m, "flan-t5-small-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-decoder-model.onnx", Set.of(ONNX_MODEL)); + register(m, "flan-t5-base-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-encoder-model.onnx", Set.of(ONNX_MODEL)); + register(m, "flan-t5-base-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-decoder-model.onnx", Set.of(ONNX_MODEL)); + register(m, "flan-t5-large-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-encoder-model.onnx", Set.of(ONNX_MODEL)); + register(m, "flan-t5-large-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-decoder-model.onnx", Set.of(ONNX_MODEL)); - models.put("multilingual-e5-small", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/model.onnx"); - models.put("multilingual-e5-small-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/tokenizer.json"); + register(m, "multilingual-e5-base", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/model.onnx", Set.of(ONNX_MODEL)); + register(m, "multilingual-e5-base-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", Set.of(HF_TOKENIZER)); - models.put("multilingual-e5-small-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/model.onnx"); - models.put("multilingual-e5-small-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/tokenizer.json"); + register(m, "multilingual-e5-small", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/model.onnx", Set.of(ONNX_MODEL)); + register(m, "multilingual-e5-small-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/tokenizer.json", Set.of(HF_TOKENIZER)); - models.put("e5-small-v2", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/model.onnx"); - models.put("e5-small-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/tokenizer.json"); + register(m, "multilingual-e5-small-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/model.onnx", Set.of(ONNX_MODEL)); + register(m, "multilingual-e5-small-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/tokenizer.json", Set.of(HF_TOKENIZER)); - models.put("e5-small-v2-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/model.onnx"); - models.put("e5-small-v2-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/tokenizer.json"); + register(m, "e5-small-v2", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/model.onnx", Set.of(ONNX_MODEL)); + register(m, "e5-small-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/tokenizer.json", Set.of(HF_TOKENIZER)); - models.put("e5-base-v2", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx"); - models.put("e5-base-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/tokenizer.json"); + register(m, "e5-small-v2-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/model.onnx", Set.of(ONNX_MODEL)); + register(m, "e5-small-v2-cpu-friendly-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/tokenizer.json", Set.of(HF_TOKENIZER)); - models.put("e5-large-v2", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/model.onnx"); - models.put("e5-large-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/tokenizer.json"); + register(m, "e5-base-v2", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", Set.of(ONNX_MODEL)); + register(m, "e5-base-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/tokenizer.json", Set.of(HF_TOKENIZER)); - return Collections.unmodifiableMap(models); + register(m, "e5-large-v2", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/model.onnx", Set.of(ONNX_MODEL)); + register(m, "e5-large-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/tokenizer.json", Set.of(HF_TOKENIZER)); + return Map.copyOf(m); } - private static final Map<String, String> providedModels = setupProvidedModels(); + private record ProvidedModel(String modelId, URI uri, Set<String> tags) { + ProvidedModel { tags = Set.copyOf(tags); } + } + + private static void register(Map<String, ProvidedModel> models, String modelId, String uri, Set<String> tags) { + models.put(modelId, new ProvidedModel(modelId, URI.create(uri), tags)); + } + + private static final Map<String, ProvidedModel> providedModels = setupProvidedModels(); /** * Finds any config values of type 'model' below the given config element and @@ -79,7 +91,7 @@ public class ModelIdResolver { if ( ! value.hasAttribute("model-id")) return; if (hosted) { - value.setAttribute("url", modelIdToUrl(value.getTagName(), value.getAttribute("model-id"))); + value.setAttribute("url", modelIdToUrl(value.getTagName(), value.getAttribute("model-id"), Set.of())); value.removeAttribute("path"); } else if ( ! value.hasAttribute("url") && ! value.hasAttribute("path")) { @@ -88,10 +100,10 @@ public class ModelIdResolver { } public static ModelReference resolveToModelReference( - String paramName, Optional<String> id, Optional<String> url, Optional<String> path, DeployState state) { + String paramName, Optional<String> id, Optional<String> url, Optional<String> path, Set<String> requiredTags, DeployState state) { if (id.isEmpty()) return createModelReference(Optional.empty(), url, path, state); else if (state.isHosted()) - return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get())), Optional.empty(), state); + return createModelReference(id, Optional.of(modelIdToUrl(paramName, id.get(), requiredTags)), Optional.empty(), state); else if (url.isEmpty() && path.isEmpty()) throw onlyModelIdInHostedException(paramName); else return createModelReference(id, url, path, state); } @@ -106,11 +118,17 @@ public class ModelIdResolver { "Add a 'path' or 'url' to deploy this outside Vespa Cloud"); } - private static String modelIdToUrl(String valueName, String modelId) { + private static String modelIdToUrl(String valueName, String modelId, Set<String> requiredTags) { if ( ! providedModels.containsKey(modelId)) throw new IllegalArgumentException("Unknown model id '" + modelId + "' on '" + valueName + "'. Available models are [" + providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]"); - return providedModels.get(modelId); + var providedModel = providedModels.get(modelId); + if (!providedModel.tags().containsAll(requiredTags)) { + throw new IllegalArgumentException( + "Model '%s' on '%s' has tags %s but are missing required tags %s" + .formatted(modelId, valueName, providedModel.tags(), requiredTags)); + } + return providedModel.uri().toString(); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java new file mode 100644 index 00000000000..409c3ac833a --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/ModelIdResolverTest.java @@ -0,0 +1,35 @@ +package com.yahoo.vespa.model.container.xml; + +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import org.junit.jupiter.api.Test; + +import java.util.Optional; +import java.util.Set; + +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * @author bjorncs + */ +class ModelIdResolverTest { + + @Test + void throws_on_known_model_with_missing_tags() { + var state = new DeployState.Builder().properties(new TestProperties().setHostedVespa(true)).build(); + var e = assertThrows(IllegalArgumentException.class, () -> + ModelIdResolver.resolveToModelReference( + "param", Optional.of("minilm-l6-v2"), Optional.empty(), Optional.empty(), Set.of(HF_TOKENIZER), state)); + var expectedMsg = "Model 'minilm-l6-v2' on 'param' has tags [onnx-model] but are missing required tags [huggingface-tokenizer]"; + assertEquals(expectedMsg, e.getMessage()); + + assertDoesNotThrow( + () -> ModelIdResolver.resolveToModelReference( + "param", Optional.of("minilm-l6-v2"), Optional.empty(), Optional.empty(), Set.of(ONNX_MODEL), state)); + } + +}
\ No newline at end of file |