summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:41:37 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 16:51:26 +0200
commit4f722322cc9f8df5146ffb27d74239b3b4f2d634 (patch)
treedad0f0a70513a861844d10a35ba93c1901b48057 /linguistics-components
parent838f918baf2f64b5cb737a59e624f20773d95baa (diff)
Prefer truncation configuration from tokenizer model
Only override truncation if not specified or max length exceeds max tokens accepted by model. Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented). Simply configuration for pure tokenizer embedder. Disable DJL usage telemetry.
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java76
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java41
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java9
3 files changed, 112 insertions, 14 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 1f1757e6ade..17360efd0af 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -2,10 +2,14 @@
package com.yahoo.language.huggingface;
+import ai.djl.huggingface.tokenizers.jni.LibUtils;
+import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
+import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy;
+import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
@@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import static com.yahoo.yolean.Exceptions.uncheck;
@@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck;
@Beta
public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
- private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class);
+ private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models;
@Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); }
+ static {
+ // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership
+ // See ai.djl.util.Ec2Utils.callHome()
+ System.setProperty("OPT_OUT_TRACKING", "true");
+ }
+
private HuggingFaceTokenizer(Builder b) {
- var original = Thread.currentThread().getContextClassLoader();
- Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
- try {
+ this.models = withContextClassloader(() -> {
+ var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class);
b.models.forEach((language, path) -> {
models.put(language,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .optTruncation(b.truncation != null ? b.truncation : true)
- .optMaxLength(b.maxLength != null ? b.maxLength : 512);
- if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) {
+ hfb.optMaxLength(b.maxLength);
+ // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512
+ hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE));
+ }
+ if (b.padding != null) {
+ if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ }
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
return hfb.build();
}));
});
- } finally {
- Thread.currentThread().setContextClassLoader(original);
- }
+ return models;
+ });
}
@Override
@@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
@Override public void close() { models.forEach((__, model) -> model.close()); }
@Override public void deconstruct() { close(); }
+ public static ModelInfo getModelInfo(Path path) {
+ return withContextClassloader(() -> {
+ // Hackish solution to read padding/truncation configuration through JNI wrapper directly
+ LibUtils.checkStatus();
+ var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path)));
+ try {
+ return new ModelInfo(
+ TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)),
+ PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)),
+ TokenizersLibrary.LIB.getMaxLength(handle),
+ TokenizersLibrary.LIB.getStride(handle),
+ TokenizersLibrary.LIB.getPadToMultipleOf(handle));
+ } finally {
+ TokenizersLibrary.LIB.deleteTokenizer(handle);
+ }
+ });
+ }
+
private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
@@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
throw new IllegalArgumentException("No model for language " + language);
}
+ private static <R> R withContextClassloader(Supplier<R> r) {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ return r.get();
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); }
public static final class Builder {
@@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
- if (cfg.truncation()) setTruncation(true);
- if (cfg.padding()) setPadding(true);
+ switch (cfg.truncation()) {
+ case ON -> setTruncation(true);
+ case OFF -> setTruncation(false);
+ }
+ switch (cfg.padding()) {
+ case ON -> setPadding(true);
+ case OFF -> setPadding(false);
+ }
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
new file mode 100644
index 00000000000..4b30b1f0435
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
@@ -0,0 +1,41 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import java.util.Arrays;
+
+/**
+ * @author bjorncs
+ */
+public record ModelInfo(
+ TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) {
+
+ public enum TruncationStrategy {
+ LONGEST_FIRST,
+ ONLY_FIRST,
+ ONLY_SECOND,
+ DO_NOT_TRUNCATE;
+
+ public static TruncationStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST_FIRST;
+ else if ("false".equals(v)) return DO_NOT_TRUNCATE;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+
+ public enum PaddingStrategy {
+ LONGEST,
+ MAX_LENGTH,
+ DO_NOT_PAD;
+
+ public static PaddingStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST;
+ else if ("false".equals(v)) return DO_NOT_PAD;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index bf2e0f82829..f727252cccb 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest {
}
@Test
- void disables_padding_by_default() throws IOException {
+ void pads_to_max_length() throws IOException {
var builder = new HuggingFaceTokenizer.Builder()
.setTruncation(true)
.addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
@@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void provides_model_info() throws IOException {
+ var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0);
+ var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2"));
+ assertEquals(expected, actual);
+ }
+
private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
assertEquals(maxLength, encoding.ids().size());
assertEquals(maxLength, encoding.tokens().size());