summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:41:37 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:51:26 +0200
commit4f722322cc9f8df5146ffb27d74239b3b4f2d634 (patch)
treedad0f0a70513a861844d10a35ba93c1901b48057
parent838f918baf2f64b5cb737a59e624f20773d95baa (diff)
Prefer truncation configuration from tokenizer model
Only override truncation if not specified or max length exceeds max tokens accepted by model. Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented). Simply configuration for pure tokenizer embedder. Disable DJL usage telemetry.
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java16
-rw-r--r--config-model/src/main/resources/schema/common.rnc8
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java6
-rw-r--r--configdefinitions/src/vespa/hugging-face-tokenizer.def13
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java76
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java41
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java25
9 files changed, 150 insertions, 47 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index e0572f8391e..0bf5491e872 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -4,6 +4,8 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.xml.ModelIdResolver;
import org.w3c.dom.Element;
@@ -11,7 +13,6 @@ import org.w3c.dom.Element;
import java.util.Map;
import java.util.TreeMap;
-import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
/**
@@ -20,10 +21,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTI
public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer {
private final Map<String, ModelReference> langToModel = new TreeMap<>();
- private final Boolean specialTokens;
- private final Integer maxLength;
- private final Boolean truncation;
- private final Boolean padding;
public HuggingFaceTokenizer(Element xml, DeployState state) {
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
@@ -31,10 +28,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state));
}
- specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
- maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
- truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null);
- padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null);
}
@Override
@@ -42,9 +35,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
langToModel.forEach((lang, vocab) -> {
builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab));
});
- if (specialTokens != null) builder.addSpecialTokens(specialTokens);
- if (maxLength != null) builder.maxLength(maxLength);
- if (truncation != null) builder.truncation(truncation);
- if (padding != null) builder.padding(padding);
+ builder.truncation(Truncation.Enum.OFF).padding(Padding.Enum.OFF).addSpecialTokens(false);
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index e130bed0297..ba7e2b6674e 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -88,7 +88,7 @@ HuggingFaceEmbedder =
attribute type { "hugging-face-embedder" } &
element transformer-model { ModelReference } &
element tokenizer-model { ModelReference }? &
- element max-tokens { xsd:nonNegativeInteger }? &
+ element max-tokens { xsd:positiveInteger }? &
element transformer-input-ids { xsd:string }? &
element transformer-attention-mask { xsd:string }? &
element transformer-token-type-ids { xsd:string }? &
@@ -99,11 +99,7 @@ HuggingFaceEmbedder =
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
- element model { attribute language { xsd:string }? & ModelReference }+ &
- element special-tokens { xsd:boolean }? &
- element max-length { xsd:integer }? &
- element truncation { xsd:boolean }? &
- element padding { xsd:boolean }?
+ element model { attribute language { xsd:string }? & ModelReference }+
BertBaseEmbedder =
attribute type { "bert-embedder" } &
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 6823ef900ae..b70a9d5f5f1 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -21,9 +21,6 @@
<component id="hf-tokenizer" type="hugging-face-tokenizer">
<model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/>
- <special-tokens>true</special-tokens>
- <max-length>768</max-length>
- <truncation>true</truncation>
</component>
<component id="bert-embedder" type="bert-embedder">
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 2a82daef9e3..dc62bfdbbef 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -77,9 +77,10 @@ public class EmbedderTestCase {
var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
assertEquals("my_input_ids", embedderCfg.transformerInputIds());
assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
- assertEquals(768, tokenizerCfg.maxLength());
+ assertEquals(-1, tokenizerCfg.maxLength());
}
@Test
@@ -89,9 +90,10 @@ public class EmbedderTestCase {
var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
assertEquals("my_input_ids", embedderCfg.transformerInputIds());
assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
- assertEquals(768, tokenizerCfg.maxLength());
+ assertEquals(-1, tokenizerCfg.maxLength());
}
diff --git a/configdefinitions/src/vespa/hugging-face-tokenizer.def b/configdefinitions/src/vespa/hugging-face-tokenizer.def
index bc0d5300de5..896a7b03234 100644
--- a/configdefinitions/src/vespa/hugging-face-tokenizer.def
+++ b/configdefinitions/src/vespa/hugging-face-tokenizer.def
@@ -8,7 +8,14 @@ model[].language string
# The path to the model relative to the application package root
model[].path model
+# Include special tokens in output
addSpecialTokens bool default=true
-maxLength int default=512
-truncation bool default=true
-padding bool default=false
+
+# Used for truncation/padding. Use -1 for model default.
+maxLength int default=-1
+
+# Truncation strategy. Use NOTSET for model default.
+truncation enum { ON, OFF, NOTSET } default=NOTSET
+
+# Padding strategy. Use NOTSET for model default.
+padding enum { ON, OFF, NOTSET } default=NOTSET
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 1f1757e6ade..17360efd0af 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -2,10 +2,14 @@
package com.yahoo.language.huggingface;
+import ai.djl.huggingface.tokenizers.jni.LibUtils;
+import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
+import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy;
+import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
@@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import static com.yahoo.yolean.Exceptions.uncheck;
@@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck;
@Beta
public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
- private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class);
+ private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models;
@Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); }
+ static {
+ // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership
+ // See ai.djl.util.Ec2Utils.callHome()
+ System.setProperty("OPT_OUT_TRACKING", "true");
+ }
+
private HuggingFaceTokenizer(Builder b) {
- var original = Thread.currentThread().getContextClassLoader();
- Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
- try {
+ this.models = withContextClassloader(() -> {
+ var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class);
b.models.forEach((language, path) -> {
models.put(language,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .optTruncation(b.truncation != null ? b.truncation : true)
- .optMaxLength(b.maxLength != null ? b.maxLength : 512);
- if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) {
+ hfb.optMaxLength(b.maxLength);
+ // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512
+ hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE));
+ }
+ if (b.padding != null) {
+ if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ }
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
return hfb.build();
}));
});
- } finally {
- Thread.currentThread().setContextClassLoader(original);
- }
+ return models;
+ });
}
@Override
@@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
@Override public void close() { models.forEach((__, model) -> model.close()); }
@Override public void deconstruct() { close(); }
+ public static ModelInfo getModelInfo(Path path) {
+ return withContextClassloader(() -> {
+ // Hackish solution to read padding/truncation configuration through JNI wrapper directly
+ LibUtils.checkStatus();
+ var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path)));
+ try {
+ return new ModelInfo(
+ TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)),
+ PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)),
+ TokenizersLibrary.LIB.getMaxLength(handle),
+ TokenizersLibrary.LIB.getStride(handle),
+ TokenizersLibrary.LIB.getPadToMultipleOf(handle));
+ } finally {
+ TokenizersLibrary.LIB.deleteTokenizer(handle);
+ }
+ });
+ }
+
private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
@@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
throw new IllegalArgumentException("No model for language " + language);
}
+ private static <R> R withContextClassloader(Supplier<R> r) {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ return r.get();
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); }
public static final class Builder {
@@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
- if (cfg.truncation()) setTruncation(true);
- if (cfg.padding()) setPadding(true);
+ switch (cfg.truncation()) {
+ case ON -> setTruncation(true);
+ case OFF -> setTruncation(false);
+ }
+ switch (cfg.padding()) {
+ case ON -> setPadding(true);
+ case OFF -> setPadding(false);
+ }
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
new file mode 100644
index 00000000000..4b30b1f0435
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
@@ -0,0 +1,41 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import java.util.Arrays;
+
+/**
+ * @author bjorncs
+ */
+public record ModelInfo(
+ TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) {
+
+ public enum TruncationStrategy {
+ LONGEST_FIRST,
+ ONLY_FIRST,
+ ONLY_SECOND,
+ DO_NOT_TRUNCATE;
+
+ public static TruncationStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST_FIRST;
+ else if ("false".equals(v)) return DO_NOT_TRUNCATE;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+
+ public enum PaddingStrategy {
+ LONGEST,
+ MAX_LENGTH,
+ DO_NOT_PAD;
+
+ public static PaddingStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST;
+ else if ("false".equals(v)) return DO_NOT_PAD;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index bf2e0f82829..f727252cccb 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest {
}
@Test
- void disables_padding_by_default() throws IOException {
+ void pads_to_max_length() throws IOException {
var builder = new HuggingFaceTokenizer.Builder()
.setTruncation(true)
.addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
@@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void provides_model_info() throws IOException {
+ var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0);
+ var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2"));
+ assertEquals(expected, actual);
+ }
+
private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
assertEquals(maxLength, encoding.ids().size());
assertEquals(maxLength, encoding.tokens().size());
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 17b63fb1c7d..b035541bb0f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
+import java.util.logging.Logger;
+
+import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
@Beta
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
+ private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
+
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
- tokenizer = new HuggingFaceTokenizer.Builder()
+ var tokenizerPath = Paths.get(config.tokenizerPath().toString());
+ var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
- .addDefaultModel(Paths.get(config.tokenizerPath().toString()))
- .setTruncation(true)
- .setPadding(false)
- .setMaxLength(config.transformerMaxTokens())
- .build();
+ .addDefaultModel(tokenizerPath)
+ .setPadding(false);
+ var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath);
+ log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info));
+ if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) {
+ // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration
+ int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens()
+ ? info.maxLength()
+ : config.transformerMaxTokens();
+ builder.setTruncation(true).setMaxLength(maxLength);
+ }
+ this.tokenizer = builder.build();
poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)