summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java76
1 files changed, 63 insertions, 13 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 1f1757e6ade..17360efd0af 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -2,10 +2,14 @@
package com.yahoo.language.huggingface;
+import ai.djl.huggingface.tokenizers.jni.LibUtils;
+import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
+import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy;
+import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
@@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import static com.yahoo.yolean.Exceptions.uncheck;
@@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck;
@Beta
public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
- private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class);
+ private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models;
@Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); }
+ static {
+ // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership
+ // See ai.djl.util.Ec2Utils.callHome()
+ System.setProperty("OPT_OUT_TRACKING", "true");
+ }
+
private HuggingFaceTokenizer(Builder b) {
- var original = Thread.currentThread().getContextClassLoader();
- Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
- try {
+ this.models = withContextClassloader(() -> {
+ var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class);
b.models.forEach((language, path) -> {
models.put(language,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .optTruncation(b.truncation != null ? b.truncation : true)
- .optMaxLength(b.maxLength != null ? b.maxLength : 512);
- if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) {
+ hfb.optMaxLength(b.maxLength);
+ // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512
+ hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE));
+ }
+ if (b.padding != null) {
+ if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ }
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
return hfb.build();
}));
});
- } finally {
- Thread.currentThread().setContextClassLoader(original);
- }
+ return models;
+ });
}
@Override
@@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
@Override public void close() { models.forEach((__, model) -> model.close()); }
@Override public void deconstruct() { close(); }
+ public static ModelInfo getModelInfo(Path path) {
+ return withContextClassloader(() -> {
+ // Hackish solution to read padding/truncation configuration through JNI wrapper directly
+ LibUtils.checkStatus();
+ var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path)));
+ try {
+ return new ModelInfo(
+ TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)),
+ PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)),
+ TokenizersLibrary.LIB.getMaxLength(handle),
+ TokenizersLibrary.LIB.getStride(handle),
+ TokenizersLibrary.LIB.getPadToMultipleOf(handle));
+ } finally {
+ TokenizersLibrary.LIB.deleteTokenizer(handle);
+ }
+ });
+ }
+
private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
@@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
throw new IllegalArgumentException("No model for language " + language);
}
+ private static <R> R withContextClassloader(Supplier<R> r) {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ return r.get();
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); }
public static final class Builder {
@@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
- if (cfg.truncation()) setTruncation(true);
- if (cfg.padding()) setPadding(true);
+ switch (cfg.truncation()) {
+ case ON -> setTruncation(true);
+ case OFF -> setTruncation(false);
+ }
+ switch (cfg.padding()) {
+ case ON -> setPadding(true);
+ case OFF -> setPadding(false);
+ }
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }