diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java | 18 |
1 files changed, 6 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index d22e6afc3d1..abca3290a31 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -8,9 +8,13 @@ import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Set; + import static com.yahoo.embedding.ColBertEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; /** @@ -37,16 +41,14 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") - .map(Model::modelReference) - .orElseGet(() -> resolveDefaultVocab(model, state)); + vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of(HF_TOKENIZER)).modelReference(); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); @@ -59,14 +61,6 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo model.registerOnnxModelCost(cluster, onnxModelOptions); } - private static ModelReference resolveDefaultVocab(Model model, DeployState state) { - var modelId = model.modelId().orElse(null); - if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); - } - throw new IllegalArgumentException("'tokenizer-model' must be specified"); - } - @Override public void getConfig(ColBertEmbedderConfig.Builder b) { b.transformerModel(modelRef).tokenizerPath(vocabRef); |