summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@vespa.ai>2024-01-05 14:23:08 +0100
committerBjørn Christian Seime <bjorncs@vespa.ai>2024-01-05 14:23:08 +0100
commit657dd9b3b4dd808d401288c1dee63d2980ea4e1b (patch)
tree8c497d363513afaf0c54c9d59648da0c6ca3ce48
parent9d03920ca70489d3d334dc90b5c0ad14b5ba2d63 (diff)
Deduplicate logic for determining implicit tokenizer model id
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java12
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java12
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java13
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java11
4 files changed, 16 insertions, 32 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index 780dddc6684..abca3290a31 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -48,9 +48,7 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER))
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state));
+ vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference();
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null);
maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null);
@@ -63,14 +61,6 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
- private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
- var modelId = model.modelId().orElse(null);
- if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference();
- }
- throw new IllegalArgumentException("'tokenizer-model' must be specified");
- }
-
@Override
public void getConfig(ColBertEmbedderConfig.Builder b) {
b.transformerModel(modelRef).tokenizerPath(vocabRef);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
index b489d3edd98..fe0bb7c8075 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -43,9 +43,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER))
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state));
+ vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference();
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
@@ -56,14 +54,6 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
- private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
- var modelId = model.modelId().orElse(null);
- if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference();
- }
- throw new IllegalArgumentException("'tokenizer-model' must be specified");
- }
-
@Override
public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
b.transformerModel(modelRef).tokenizerPath(vocabRef);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
index 558b58f12be..7d6285d00c1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Model.java
@@ -56,6 +56,19 @@ class Model {
return new Model(ds, model.getTagName(), modelId, url, path, requiredTags);
}
+ /** Return tokenizer model from XML if specified, alternatively use model id for ONNX model with suffix '-vocab' appended */
+ static Model fromXmlOrImplicitlyFromOnnxModel(
+ DeployState ds, Element parent, Model onnxModel, String paramName, Set<String> requiredTags) {
+ return fromXml(ds, parent, paramName, requiredTags)
+ .orElseGet(() -> {
+ var modelId = onnxModel.modelId().orElse(null);
+ if (ds.isHosted() && modelId != null) {
+ return fromParams(ds, onnxModel.name(), modelId + "-vocab", null, null, requiredTags);
+ }
+ throw new IllegalArgumentException("'%s' must be specified".formatted(paramName));
+ });
+ }
+
void registerOnnxModelCost(ApplicationContainerCluster c, OnnxModelOptions onnxModelOptions) {
var resolvedUrl = resolvedUrl().orElse(null);
if (file != null) c.onnxModelCostCalculator().registerModel(file, onnxModelOptions);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
index 53358e7576a..9e0a3a0ba5c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
@@ -36,9 +36,7 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER))
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state));
+ vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference();
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
@@ -48,13 +46,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
- private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
- var modelId = model.modelId().orElse(null);
- if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference();
- }
- throw new IllegalArgumentException("'tokenizer-model' must be specified");
- }
@Override
public void getConfig(SpladeEmbedderConfig.Builder b) {
b.transformerModel(modelRef).tokenizerPath(vocabRef);