diff options
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java | 76 |
1 files changed, 63 insertions, 13 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 1f1757e6ade..17360efd0af 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -2,10 +2,14 @@ package com.yahoo.language.huggingface; +import ai.djl.huggingface.tokenizers.jni.LibUtils; +import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary; import com.yahoo.api.annotations.Beta; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.language.Language; +import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy; +import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; @@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Collection; import java.util.EnumMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static com.yahoo.yolean.Exceptions.uncheck; @@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck; @Beta public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable { - private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class); + private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models; @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); } + static { + // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership + // See ai.djl.util.Ec2Utils.callHome() + System.setProperty("OPT_OUT_TRACKING", "true"); + } + private HuggingFaceTokenizer(Builder b) { - var original = Thread.currentThread().getContextClassLoader(); - Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); - try { + this.models = withContextClassloader(() -> { + var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class); b.models.forEach((language, path) -> { models.put(language, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .optTruncation(b.truncation != null ? b.truncation : true) - .optMaxLength(b.maxLength != null ? b.maxLength : 512); - if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) { + hfb.optMaxLength(b.maxLength); + // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512 + hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE)); + } + if (b.padding != null) { + if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + } + if (b.truncation != null) hfb.optTruncation(b.truncation); return hfb.build(); })); }); - } finally { - Thread.currentThread().setContextClassLoader(original); - } + return models; + }); } @Override @@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, @Override public void close() { models.forEach((__, model) -> model.close()); } @Override public void deconstruct() { close(); } + public static ModelInfo getModelInfo(Path path) { + return withContextClassloader(() -> { + // Hackish solution to read padding/truncation configuration through JNI wrapper directly + LibUtils.checkStatus(); + var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path))); + try { + return new ModelInfo( + TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)), + PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)), + TokenizersLibrary.LIB.getMaxLength(handle), + TokenizersLibrary.LIB.getStride(handle), + TokenizersLibrary.LIB.getPadToMultipleOf(handle)); + } finally { + TokenizersLibrary.LIB.deleteTokenizer(handle); + } + }); + } + private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); @@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, throw new IllegalArgumentException("No model for language " + language); } + private static <R> R withContextClassloader(Supplier<R> r) { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + return r.get(); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); } public static final class Builder { @@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); - if (cfg.truncation()) setTruncation(true); - if (cfg.padding()) setPadding(true); + switch (cfg.truncation()) { + case ON -> setTruncation(true); + case OFF -> setTruncation(false); + } + switch (cfg.padding()) { + case ON -> setPadding(true); + case OFF -> setPadding(false); + } } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } |