summaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:41:37 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:51:26 +0200
commit4f722322cc9f8df5146ffb27d74239b3b4f2d634 (patch)
treedad0f0a70513a861844d10a35ba93c1901b48057 /config-model/src
parent838f918baf2f64b5cb737a59e624f20773d95baa (diff)
Prefer truncation configuration from tokenizer model
Only override truncation if not specified or max length exceeds max tokens accepted by model. Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented). Simply configuration for pure tokenizer embedder. Disable DJL usage telemetry.
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java16
-rw-r--r--config-model/src/main/resources/schema/common.rnc8
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java6
4 files changed, 9 insertions, 24 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index e0572f8391e..0bf5491e872 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -4,6 +4,8 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.xml.ModelIdResolver;
import org.w3c.dom.Element;
@@ -11,7 +13,6 @@ import org.w3c.dom.Element;
import java.util.Map;
import java.util.TreeMap;
-import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
/**
@@ -20,10 +21,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTI
public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer {
private final Map<String, ModelReference> langToModel = new TreeMap<>();
- private final Boolean specialTokens;
- private final Integer maxLength;
- private final Boolean truncation;
- private final Boolean padding;
public HuggingFaceTokenizer(Element xml, DeployState state) {
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
@@ -31,10 +28,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state));
}
- specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
- maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
- truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null);
- padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null);
}
@Override
@@ -42,9 +35,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
langToModel.forEach((lang, vocab) -> {
builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab));
});
- if (specialTokens != null) builder.addSpecialTokens(specialTokens);
- if (maxLength != null) builder.maxLength(maxLength);
- if (truncation != null) builder.truncation(truncation);
- if (padding != null) builder.padding(padding);
+ builder.truncation(Truncation.Enum.OFF).padding(Padding.Enum.OFF).addSpecialTokens(false);
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index e130bed0297..ba7e2b6674e 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -88,7 +88,7 @@ HuggingFaceEmbedder =
attribute type { "hugging-face-embedder" } &
element transformer-model { ModelReference } &
element tokenizer-model { ModelReference }? &
- element max-tokens { xsd:nonNegativeInteger }? &
+ element max-tokens { xsd:positiveInteger }? &
element transformer-input-ids { xsd:string }? &
element transformer-attention-mask { xsd:string }? &
element transformer-token-type-ids { xsd:string }? &
@@ -99,11 +99,7 @@ HuggingFaceEmbedder =
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
- element model { attribute language { xsd:string }? & ModelReference }+ &
- element special-tokens { xsd:boolean }? &
- element max-length { xsd:integer }? &
- element truncation { xsd:boolean }? &
- element padding { xsd:boolean }?
+ element model { attribute language { xsd:string }? & ModelReference }+
BertBaseEmbedder =
attribute type { "bert-embedder" } &
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 6823ef900ae..b70a9d5f5f1 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -21,9 +21,6 @@
<component id="hf-tokenizer" type="hugging-face-tokenizer">
<model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/>
- <special-tokens>true</special-tokens>
- <max-length>768</max-length>
- <truncation>true</truncation>
</component>
<component id="bert-embedder" type="bert-embedder">
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 2a82daef9e3..dc62bfdbbef 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -77,9 +77,10 @@ public class EmbedderTestCase {
var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
assertEquals("my_input_ids", embedderCfg.transformerInputIds());
assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
- assertEquals(768, tokenizerCfg.maxLength());
+ assertEquals(-1, tokenizerCfg.maxLength());
}
@Test
@@ -89,9 +90,10 @@ public class EmbedderTestCase {
var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
assertEquals("my_input_ids", embedderCfg.transformerInputIds());
assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
- assertEquals(768, tokenizerCfg.maxLength());
+ assertEquals(-1, tokenizerCfg.maxLength());
}