aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java26
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java79
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java47
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java20
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java35
-rw-r--r--config-model/src/main/resources/schema/common.rnc35
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml21
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilderTest.java4
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java27
-rw-r--r--configdefinitions/src/main/java/com/yahoo/embedding/huggingface/package-info.java9
-rw-r--r--configdefinitions/src/main/java/com/yahoo/language/huggingface/config/package-info.java9
-rw-r--r--configdefinitions/src/vespa/hugging-face-embedder.def (renamed from model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def)0
-rw-r--r--configdefinitions/src/vespa/language.huggingface.hugging-face-tokenizer.def (renamed from linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def)2
-rw-r--r--linguistics-components/pom.xml6
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java1
-rw-r--r--model-integration/pom.xml6
18 files changed, 315 insertions, 16 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
index c57122e5bf5..fa0ee3f9857 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java
@@ -3,12 +3,13 @@ package com.yahoo.vespa.model.builder.xml.dom;
import com.yahoo.component.ComponentId;
import com.yahoo.config.model.deploy.DeployState;
-import com.yahoo.container.bundle.BundleInstantiationSpecification;
-import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.config.model.producer.AnyConfigProducer;
import com.yahoo.config.model.producer.TreeConfigProducer;
+import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.component.Component;
+import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
+import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder;
import org.w3c.dom.Element;
@@ -31,17 +32,24 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde
}
@Override
- protected Component doBuild(DeployState deployState, TreeConfigProducer<AnyConfigProducer> ancestor, Element spec) {
- Component component = buildComponent(spec);
+ protected Component<? super Component<?, ?>, ?> doBuild(DeployState deployState, TreeConfigProducer<AnyConfigProducer> ancestor, Element spec) {
+ var component = buildComponent(spec, deployState);
addChildren(deployState, ancestor, spec, component);
return component;
}
- private Component buildComponent(Element spec) {
- BundleInstantiationSpecification bundleSpec =
- BundleInstantiationSpecificationBuilder.build(spec).nestInNamespace(namespace);
-
- return new Component<Component<?, ?>, ComponentModel>(new ComponentModel(bundleSpec));
+ private Component<? super Component<?, ?>, ?> buildComponent(Element spec, DeployState state) {
+ if (spec.hasAttribute("type")) {
+ var type = spec.getAttribute("type");
+ return switch (type) {
+ case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state);
+ case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state);
+ default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type));
+ };
+ } else {
+ var bundleSpec = BundleInstantiationSpecificationBuilder.build(spec).nestInNamespace(namespace);
+ return new Component<>(new ComponentModel(bundleSpec));
+ }
}
public static void addChildren(DeployState deployState, TreeConfigProducer<AnyConfigProducer> ancestor, Element componentNode, Component<? super Component<?, ?>, ?> component) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 34c565871db..c227700733e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -26,6 +26,7 @@ public class ContainerModelEvaluation implements
OnnxModelsConfig.Producer,
RankingExpressionsConfig.Producer {
+ public final static String LINGUISTICS_BUNDLE_NAME = "linguistics-components";
public final static String EVALUATION_BUNDLE_NAME = "model-evaluation";
public final static String INTEGRATION_BUNDLE_NAME = "model-integration";
public final static String ONNXRUNTIME_BUNDLE_NAME = "container-onnxruntime.jar";
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
index 19df9a4064f..dbc7cd62fbd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
@@ -12,6 +12,7 @@ import java.util.stream.Stream;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.EVALUATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.ONNXRUNTIME_BUNDLE_NAME;
/**
@@ -57,7 +58,7 @@ public class PlatformBundles {
public static final Set<Path> SEARCH_AND_DOCPROC_BUNDLES = toBundlePaths(
SEARCH_AND_DOCPROC_BUNDLE,
"docprocs",
- "linguistics-components",
+ LINGUISTICS_BUNDLE_NAME,
EVALUATION_BUNDLE_NAME,
INTEGRATION_BUNDLE_NAME,
ONNXRUNTIME_BUNDLE_NAME
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
new file mode 100644
index 00000000000..1c36716699e
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceEmbedder.java
@@ -0,0 +1,79 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import java.util.Optional;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild;
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
+
+
+/**
+ * @author bjorncs
+ */
+public class HuggingFaceEmbedder extends TypedComponent implements HuggingFaceEmbedderConfig.Producer {
+ private final ModelReference model;
+ private final ModelReference vocab;
+ private final Integer maxTokens;
+ private final String transformerInputIds;
+ private final String transformerAttentionMask;
+ private final String transformerTokenTypeIds;
+ private final String transformerOutput;
+ private final Boolean normalize;
+ private final String onnxExecutionMode;
+ private final Integer onnxInteropThreads;
+ private final Integer onnxIntraopThreads;
+ private final Integer onnxGpuDevice;
+
+ public HuggingFaceEmbedder(Element xml, DeployState state) {
+ super("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", INTEGRATION_BUNDLE_NAME, xml);
+ boolean hosted = state.isHosted();
+ var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow();
+ model = ModelIdResolver.resolveToModelReference(transformerModelElem, hosted);
+ vocab = getOptionalChild(xml, "tokenizer-model")
+ .map(elem -> ModelIdResolver.resolveToModelReference(elem, hosted))
+ .orElseGet(() -> resolveDefaultVocab(transformerModelElem, hosted));
+ maxTokens = getOptionalChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null);
+ transformerInputIds = getOptionalChildValue(xml, "transformer-input-ids").orElse(null);
+ transformerAttentionMask = getOptionalChildValue(xml, "transformer-attention-mask").orElse(null);
+ transformerTokenTypeIds = getOptionalChildValue(xml, "transformer-token-type-ids").orElse(null);
+ transformerOutput = getOptionalChildValue(xml, "transformer-output").orElse(null);
+ normalize = getOptionalChildValue(xml, "normalize").map(Boolean::parseBoolean).orElse(null);
+ onnxExecutionMode = getOptionalChildValue(xml, "onnx-execution-mode").orElse(null);
+ onnxInteropThreads = getOptionalChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null);
+ onnxIntraopThreads = getOptionalChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null);
+ onnxGpuDevice = getOptionalChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null);
+ }
+
+ private static ModelReference resolveDefaultVocab(Element model, boolean hosted) {
+ if (hosted && model.hasAttribute("model-id")) {
+ var implicitVocabId = model.getAttribute("model-id") + "-vocab";
+ return ModelIdResolver.resolveToModelReference(
+ "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), true);
+ }
+ throw new IllegalArgumentException("'tokenizer-model' must be specified");
+ }
+
+ @Override
+ public void getConfig(HuggingFaceEmbedderConfig.Builder b) {
+ b.transformerModel(model).tokenizerPath(vocab);
+ if (maxTokens != null) b.transformerMaxTokens(maxTokens);
+ if (transformerInputIds != null) b.transformerInputIds(transformerInputIds);
+ if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask);
+ if (transformerTokenTypeIds != null) b.transformerTokenTypeIds(transformerTokenTypeIds);
+ if (transformerOutput != null) b.transformerOutput(transformerOutput);
+ if (normalize != null) b.normalize(normalize);
+ if (onnxExecutionMode != null) b.transformerExecutionMode(
+ HuggingFaceEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode));
+ if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads);
+ if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads);
+ if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice);
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
new file mode 100644
index 00000000000..ba8521a0089
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -0,0 +1,47 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
+import com.yahoo.text.XML;
+import com.yahoo.vespa.model.container.xml.ModelIdResolver;
+import org.w3c.dom.Element;
+
+import java.util.Map;
+import java.util.TreeMap;
+
+import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
+
+/**
+ * @author bjorncs
+ */
+public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer {
+
+ private final Map<String, ModelReference> langToModel = new TreeMap<>();
+ private final Boolean specialTokens;
+ private final Integer maxLength;
+ private final Boolean truncation;
+
+ public HuggingFaceTokenizer(Element xml, DeployState state) {
+ super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
+ for (Element element : XML.getChildren(xml, "model")) {
+ var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
+ langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state.isHosted()));
+ }
+ specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
+ maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
+ truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null);
+ }
+
+ @Override
+ public void getConfig(HuggingFaceTokenizerConfig.Builder builder) {
+ langToModel.forEach((lang, vocab) -> {
+ builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab));
+ });
+ if (specialTokens != null) builder.addSpecialTokens(specialTokens);
+ if (maxLength != null) builder.maxLength(maxLength);
+ if (truncation != null) builder.truncation(truncation);
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java
new file mode 100644
index 00000000000..522c78f2f25
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/TypedComponent.java
@@ -0,0 +1,20 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.vespa.model.container.component;
+
+import com.yahoo.osgi.provider.model.ComponentModel;
+import org.w3c.dom.Element;
+
+/**
+ * @author bjorncs
+ */
+abstract class TypedComponent extends SimpleComponent {
+
+ private final Element xml;
+
+ protected TypedComponent(String className, String bundle, Element xml) {
+ super(new ComponentModel(xml.getAttribute("id"), className, bundle));
+ this.xml = xml;
+ }
+
+}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
index 16864b71646..c0f49f3148d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -1,12 +1,17 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.xml;
+import com.yahoo.config.FileReference;
+import com.yahoo.config.ModelReference;
+import com.yahoo.config.UrlReference;
+import com.yahoo.config.model.builder.xml.XmlHelper;
import com.yahoo.text.XML;
import org.w3c.dom.Element;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
import java.util.stream.Collectors;
/**
@@ -70,11 +75,37 @@ public class ModelIdResolver {
value.removeAttribute("path");
}
else if ( ! value.hasAttribute("url") && ! value.hasAttribute("path")) {
- throw new IllegalArgumentException(value.getTagName() + " is configured with only a 'model-id'. " +
- "Add a 'path' or 'url' to deploy this outside Vespa Cloud");
+ throw onlyModelIdInHostedException(value.getTagName());
}
}
+
+ public static ModelReference resolveToModelReference(Element elem, boolean hosted) {
+ return resolveToModelReference(
+ elem.getTagName(), XmlHelper.getOptionalAttribute(elem, "model-id"),
+ XmlHelper.getOptionalAttribute(elem, "url"), XmlHelper.getOptionalAttribute(elem, "path"), hosted);
+ }
+
+ public static ModelReference resolveToModelReference(
+ String paramName, Optional<String> id, Optional<String> url, Optional<String> path, boolean hosted) {
+ if (id.isEmpty()) return ModelReference.unresolved(
+ Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new));
+ else if (hosted) {
+ return ModelReference.unresolved(
+ id, Optional.of(UrlReference.valueOf(modelIdToUrl(paramName, id.get()))), Optional.empty());
+ } else if (url.isEmpty() && path.isEmpty()) {
+ throw onlyModelIdInHostedException(paramName);
+ } else {
+ return ModelReference.unresolved(
+ Optional.empty(), url.map(UrlReference::valueOf), path.map(FileReference::new));
+ }
+ }
+
+ private static IllegalArgumentException onlyModelIdInHostedException(String paramName) {
+ return new IllegalArgumentException(paramName + " is configured with only a 'model-id'. " +
+ "Add a 'path' or 'url' to deploy this outside Vespa Cloud");
+ }
+
private static String modelIdToUrl(String valueName, String modelId) {
if ( ! providedModels.containsKey(modelId))
throw new IllegalArgumentException("Unknown model id '" + modelId + "' on '" + valueName + "'. Available models are [" +
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 21f3399a027..4e7cb526efb 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -53,6 +53,11 @@ GenericConfig = element config {
anyElement*
}
+ModelReference =
+ attribute model-id { xsd:string }? &
+ attribute path { xsd:string }? &
+ attribute url { xsd:string }?
+
ComponentSpec =
( attribute id { xsd:Name | JavaId } | attribute idref { xsd:Name } | attribute ident { xsd:Name } )
@@ -64,7 +69,7 @@ BundleSpec =
attribute bundle { xsd:Name }?
Component = element component {
- ComponentDefinition
+ (ComponentDefinition | TypedComponentDefinition)
}
ComponentDefinition =
@@ -72,3 +77,31 @@ ComponentDefinition =
BundleSpec &
GenericConfig* &
Component*
+
+TypedComponentDefinition =
+ attribute id { xsd:Name } &
+ (HuggingFaceEmbedder | HuggingFaceTokenizer) &
+ GenericConfig* &
+ Component*
+
+HuggingFaceEmbedder =
+ attribute type { "hugging-face-embedder" } &
+ element transformer-model { ModelReference } &
+ element tokenizer-model { ModelReference }? &
+ element max-tokens { xsd:nonNegativeInteger }? &
+ element transformer-input-ids { xsd:string }? &
+ element transformer-attention-mask { xsd:string }? &
+ element transformer-token-type-ids { xsd:string }? &
+ element transformer-output { xsd:string }? &
+ element normalize { xsd:boolean }? &
+ element onnx-execution-mode { "parallel" | "sequential" }? &
+ element onnx-interop-threads { xsd:integer }? &
+ element onnx-intraop-threads { xsd:integer }? &
+ element onnx-gpu-device { xsd:integer }?
+
+HuggingFaceTokenizer =
+ attribute type { "hugging-face-tokenizer" } &
+ element model { attribute language { xsd:string }? & ModelReference }+ &
+ element special-tokens { xsd:boolean }? &
+ element max-length { xsd:integer }? &
+ element truncation { xsd:boolean }? \ No newline at end of file
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index fcb1f10f32c..99c89bc4324 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -3,6 +3,27 @@
<services version="1.0">
<container version="1.0">
+ <component id="hf-embedder" type="hugging-face-embedder">
+ <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/>
+ <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/>
+ <max-tokens>1024</max-tokens>
+ <transformer-input-ids>my_input_ids</transformer-input-ids>
+ <transformer-attention-mask>my_attention_mask</transformer-attention-mask>
+ <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
+ <transformer-output>my_output</transformer-output>
+ <normalize>true</normalize>
+ <onnx-execution-mode>parallel</onnx-execution-mode>
+ <onnx-intraop-threads>10</onnx-intraop-threads>
+ <onnx-interop-threads>8</onnx-interop-threads>
+ <onnx-gpu-device>1</onnx-gpu-device>
+ </component>
+
+ <component id="hf-tokenizer" type="hugging-face-tokenizer">
+ <model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/>
+ <special-tokens>true</special-tokens>
+ <max-length>768</max-length>
+ <truncation>true</truncation>
+ </component>
<component id="transformer" class="ai.vespa.embedding.BertBaseEmbedder" bundle="model-integration">
<config name="embedding.bert-base-embedder">
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilderTest.java
index ed3073a0ef4..78c95c03b44 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilderTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilderTest.java
@@ -30,13 +30,13 @@ public class DomComponentBuilderTest extends DomBuilderTest {
@Test
@SuppressWarnings("unchecked")
void components_can_be_nested() {
- Component<Component<?, ?>, ?> parent = new DomComponentBuilder().doBuild(root.getDeployState(), root, parse(
+ Component<? super Component<?, ?>, ?> parent = new DomComponentBuilder().doBuild(root.getDeployState(), root, parse(
"<component id='parent'>",
" <component id='child' />",
"</component>"));
assertEquals(ComponentId.fromString("parent"), parent.getGlobalComponentId());
- Component<?, ?> child = first(parent.getChildren().values());
+ Component<?, ?> child = (Component<?, ?>) first(parent.getChildren().values());
assertNotNull(child);
assertEquals(ComponentId.fromString("child@parent"), child.getGlobalComponentId());
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 50416d50fe5..69981233c3f 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -5,6 +5,8 @@ import com.yahoo.component.ComponentId;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.path.Path;
import com.yahoo.text.XML;
import com.yahoo.vespa.config.ConfigDefinitionKey;
@@ -12,6 +14,9 @@ import com.yahoo.vespa.config.ConfigPayloadBuilder;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import com.yahoo.vespa.model.container.component.Component;
+import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder;
+import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer;
+import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg;
import com.yahoo.yolean.Exceptions;
import org.junit.jupiter.api.Test;
import org.w3c.dom.Document;
@@ -108,6 +113,28 @@ public class EmbedderTestCase {
assertEquals("minilm-l6-v2 application-url \"\"", config.getObject("transformerModel").getValue());
assertEquals("\"\" \"\" files/vocab.txt", config.getObject("tokenizerVocab").getValue());
assertEquals("4", config.getObject("onnxIntraOpThreads").getValue());
+
+ {
+ var hfEmbedder = (HuggingFaceEmbedder)containerCluster.getComponentsMap().get(new ComponentId("hf-embedder"));
+ assertEquals("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", hfEmbedder.getClassId().getName());
+ var cfgBuilder = new HuggingFaceEmbedderConfig.Builder();
+ hfEmbedder.getConfig(cfgBuilder);
+ var cfg = cfgBuilder.build();
+ assertEquals("my_input_ids", cfg.transformerInputIds());
+ }
+ {
+ var hfTokenizer = (HuggingFaceTokenizer)containerCluster.getComponentsMap().get(new ComponentId("hf-tokenizer"));
+ assertEquals("com.yahoo.language.huggingface.HuggingFaceTokenizer", hfTokenizer.getClassId().getName());
+ var cfgBuilder = new HuggingFaceTokenizerConfig.Builder();
+ hfTokenizer.getConfig(cfgBuilder);
+ var cfg = cfgBuilder.build();
+ assertEquals(768, cfg.maxLength());
+ }
+ }
+
+ @Test
+ void passesXmlValdiation() {
+ new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create();
}
@Test
diff --git a/configdefinitions/src/main/java/com/yahoo/embedding/huggingface/package-info.java b/configdefinitions/src/main/java/com/yahoo/embedding/huggingface/package-info.java
new file mode 100644
index 00000000000..7bcc994e616
--- /dev/null
+++ b/configdefinitions/src/main/java/com/yahoo/embedding/huggingface/package-info.java
@@ -0,0 +1,9 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+/**
+ * @author bjorncs
+ */
+@ExportPackage
+package com.yahoo.embedding.huggingface;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/configdefinitions/src/main/java/com/yahoo/language/huggingface/config/package-info.java b/configdefinitions/src/main/java/com/yahoo/language/huggingface/config/package-info.java
new file mode 100644
index 00000000000..fb9048b5fb4
--- /dev/null
+++ b/configdefinitions/src/main/java/com/yahoo/language/huggingface/config/package-info.java
@@ -0,0 +1,9 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+/**
+ * @author bjorncs
+ */
+@ExportPackage
+package com.yahoo.language.huggingface.config;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def b/configdefinitions/src/vespa/hugging-face-embedder.def
index 36957004e02..36957004e02 100644
--- a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
+++ b/configdefinitions/src/vespa/hugging-face-embedder.def
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/configdefinitions/src/vespa/language.huggingface.hugging-face-tokenizer.def
index 67b3b927f94..18b3631e494 100644
--- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
+++ b/configdefinitions/src/vespa/language.huggingface.hugging-face-tokenizer.def
@@ -1,6 +1,6 @@
# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-namespace=language.huggingface
+namespace=language.huggingface.config
# The language a model is for, one of the language tags in com.yahoo.language.Language.
# Use "unknown" for models to be used with any language.
diff --git a/linguistics-components/pom.xml b/linguistics-components/pom.xml
index 5031ad73556..b3bc52c5e23 100644
--- a/linguistics-components/pom.xml
+++ b/linguistics-components/pom.xml
@@ -89,6 +89,12 @@
<scope>provided</scope>
<classifier>no_aop</classifier>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>configdefinitions</artifactId>
+ <version>${project.version}</version>
+ <scope>compile</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index f9a37bc477b..2c66fc18c9b 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -6,6 +6,7 @@ import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.tools.Embed;
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 681003fdc89..519aebe6f79 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -81,6 +81,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>configdefinitions</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>net.java.dev.jna</groupId>
<artifactId>jna</artifactId>
<scope>provided</scope>