summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java18
1 files changed, 6 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index d22e6afc3d1..abca3290a31 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -8,9 +8,13 @@ import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import org.w3c.dom.Element;
+import java.util.Set;
+
import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode;
import static com.yahoo.text.XML.getChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER;
+import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL;
/**
@@ -37,16 +41,14 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml);
- var model = Model.fromXml(state, xml, "transformer-model").orElseThrow();
+ var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow();
this.onnxModelOptions = new OnnxModelOptions(
getChildValue(xml, "onnx-execution-mode"),
getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt),
getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new));
modelRef = model.modelReference();
- vocabRef = Model.fromXml(state, xml, "tokenizer-model")
- .map(Model::modelReference)
- .orElseGet(() -> resolveDefaultVocab(model, state));
+ vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference();
maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null);
maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null);
@@ -59,14 +61,6 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
model.registerOnnxModelCost(cluster, onnxModelOptions);
}
- private static ModelReference resolveDefaultVocab(Model model, DeployState state) {
- var modelId = model.modelId().orElse(null);
- if (state.isHosted() && modelId != null) {
- return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference();
- }
- throw new IllegalArgumentException("'tokenizer-model' must be specified");
- }
-
@Override
public void getConfig(ColBertEmbedderConfig.Builder b) {
b.transformerModel(modelRef).tokenizerPath(vocabRef);