From 657dd9b3b4dd808d401288c1dee63d2980ea4e1b Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Fri, 5 Jan 2024 14:23:08 +0100 Subject: Deduplicate logic for determining implicit tokenizer model id --- .../vespa/model/container/component/ColBertEmbedder.java | 12 +----------- .../model/container/component/HuggingFaceEmbedder.java | 12 +----------- .../com/yahoo/vespa/model/container/component/Model.java | 13 +++++++++++++ .../vespa/model/container/component/SpladeEmbedder.java | 11 +---------- 4 files changed, 16 insertions(+), 32 deletions(-) (limited to 'config-model') diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index 780dddc6684..abca3290a31 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -48,9 +48,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)); + vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); @@ -63,14 +61,6 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo model.registerOnnxModelCost(cluster, onnxModelOptions); } - private static ModelReference resolveDefaultVocab(Model model, DeployState state) { - var modelId = model.modelId().orElse(null); - if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); - } - throw new IllegalArgumentException("'tokenizer-model' must be specified"); - } - @Override public void getConfig(ColBertEmbedderConfig.Builder b) { b.transformerModel(modelRef).tokenizerPath(vocabRef); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index b489d3edd98..fe0bb7c8075 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -43,9 +43,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)); + vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -56,14 +54,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm model.registerOnnxModelCost(cluster, onnxModelOptions); } - private static ModelReference resolveDefaultVocab(Model model, DeployState state) { - var modelId = model.modelId().orElse(null); - if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); - } - throw new IllegalArgumentException("'tokenizer-model' must be specified"); - } - @Override public void getConfig(HuggingFaceEmbedderConfig.Builder b) { b.transformerModel(modelRef).tokenizerPath(vocabRef); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java index 558b58f12be..7d6285d00c1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java @@ -56,6 +56,19 @@ class Model { return new Model(ds, model.getTagName(), modelId, url, path, requiredTags); } + /** Return tokenizer model from XML if specified, alternatively use model id for ONNX model with suffix '-vocab' appended */ + static Model fromXmlOrImplicitlyFromOnnxModel( + DeployState ds, Element parent, Model onnxModel, String paramName, Set requiredTags) { + return fromXml(ds, parent, paramName, requiredTags) + .orElseGet(() -> { + var modelId = onnxModel.modelId().orElse(null); + if (ds.isHosted() && modelId != null) { + return fromParams(ds, onnxModel.name(), modelId + "-vocab", null, null, requiredTags); + } + throw new IllegalArgumentException("'%s' must be specified".formatted(paramName)); + }); + } + void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) { var resolvedUrl = resolvedUrl().orElse(null); if (file != null) c.onnxModelCostCalculator().registerModel(file, onnxModelOptions); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java index 53358e7576a..9e0a3a0ba5c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java @@ -36,9 +36,7 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)); + vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); @@ -48,13 +46,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf model.registerOnnxModelCost(cluster, onnxModelOptions); } - private static ModelReference resolveDefaultVocab(Model model, DeployState state) { - var modelId = model.modelId().orElse(null); - if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); - } - throw new IllegalArgumentException("'tokenizer-model' must be specified"); - } @Override public void getConfig(SpladeEmbedderConfig.Builder b) { b.transformerModel(modelRef).tokenizerPath(vocabRef); -- cgit v1.2.3