diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java index d98c72ab3a4..b489d3edd98 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java @@ -8,10 +8,14 @@ import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.w3c.dom.Element; +import java.util.Set; + import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.PoolingStrategy; import static com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig.TransformerExecutionMode; import static com.yahoo.text.XML.getChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.HF_TOKENIZER; +import static com.yahoo.vespa.model.container.xml.ModelIdResolver.ONNX_MODEL; /** @@ -32,14 +36,14 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm public HuggingFaceEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml); - var model = Model.fromXml(state, xml, "transformer-model").orElseThrow(); + var model = Model.fromXml(state, xml, "transformer-model", Set.of(ONNX_MODEL)).orElseThrow(); this.onnxModelOptions = new OnnxModelOptions( getChildValue(xml, "onnx-execution-mode"), getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt), getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).map(OnnxModelOptions.GpuDevice::new)); modelRef = model.modelReference(); - vocabRef = Model.fromXml(state, xml, "tokenizer-model") + vocabRef = Model.fromXml(state, xml, "tokenizer-model", Set.of(HF_TOKENIZER)) .map(Model::modelReference) .orElseGet(() -> resolveDefaultVocab(model, state)); maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); @@ -55,7 +59,7 @@ public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEm private static ModelReference resolveDefaultVocab(Model model, DeployState state) { var modelId = model.modelId().orElse(null); if (state.isHosted() && modelId != null) { - return Model.fromParams(state, model.name(), modelId + "-vocab", null, null).modelReference(); + return Model.fromParams(state, model.name(), modelId + "-vocab", null, null, Set.of(HF_TOKENIZER)).modelReference(); } throw new IllegalArgumentException("'tokenizer-model' must be specified"); } |