diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-12 16:41:37 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-12 16:51:26 +0200 |
commit | 4f722322cc9f8df5146ffb27d74239b3b4f2d634 (patch) | |
tree | dad0f0a70513a861844d10a35ba93c1901b48057 /linguistics-components | |
parent | 838f918baf2f64b5cb737a59e624f20773d95baa (diff) |
Prefer truncation configuration from tokenizer model
Only override truncation if not specified or max length exceeds max tokens accepted by model.
Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented).
Simply configuration for pure tokenizer embedder.
Disable DJL usage telemetry.
Diffstat (limited to 'linguistics-components')
3 files changed, 112 insertions, 14 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 1f1757e6ade..17360efd0af 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -2,10 +2,14 @@ package com.yahoo.language.huggingface; +import ai.djl.huggingface.tokenizers.jni.LibUtils; +import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary; import com.yahoo.api.annotations.Beta; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.language.Language; +import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy; +import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; @@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Collection; import java.util.EnumMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static com.yahoo.yolean.Exceptions.uncheck; @@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck; @Beta public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable { - private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class); + private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models; @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); } + static { + // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership + // See ai.djl.util.Ec2Utils.callHome() + System.setProperty("OPT_OUT_TRACKING", "true"); + } + private HuggingFaceTokenizer(Builder b) { - var original = Thread.currentThread().getContextClassLoader(); - Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); - try { + this.models = withContextClassloader(() -> { + var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class); b.models.forEach((language, path) -> { models.put(language, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .optTruncation(b.truncation != null ? b.truncation : true) - .optMaxLength(b.maxLength != null ? b.maxLength : 512); - if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) { + hfb.optMaxLength(b.maxLength); + // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512 + hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE)); + } + if (b.padding != null) { + if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + } + if (b.truncation != null) hfb.optTruncation(b.truncation); return hfb.build(); })); }); - } finally { - Thread.currentThread().setContextClassLoader(original); - } + return models; + }); } @Override @@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, @Override public void close() { models.forEach((__, model) -> model.close()); } @Override public void deconstruct() { close(); } + public static ModelInfo getModelInfo(Path path) { + return withContextClassloader(() -> { + // Hackish solution to read padding/truncation configuration through JNI wrapper directly + LibUtils.checkStatus(); + var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path))); + try { + return new ModelInfo( + TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)), + PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)), + TokenizersLibrary.LIB.getMaxLength(handle), + TokenizersLibrary.LIB.getStride(handle), + TokenizersLibrary.LIB.getPadToMultipleOf(handle)); + } finally { + TokenizersLibrary.LIB.deleteTokenizer(handle); + } + }); + } + private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); @@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, throw new IllegalArgumentException("No model for language " + language); } + private static <R> R withContextClassloader(Supplier<R> r) { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + return r.get(); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); } public static final class Builder { @@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); - if (cfg.truncation()) setTruncation(true); - if (cfg.padding()) setPadding(true); + switch (cfg.truncation()) { + case ON -> setTruncation(true); + case OFF -> setTruncation(false); + } + switch (cfg.padding()) { + case ON -> setPadding(true); + case OFF -> setPadding(false); + } } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java new file mode 100644 index 00000000000..4b30b1f0435 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java @@ -0,0 +1,41 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import java.util.Arrays; + +/** + * @author bjorncs + */ +public record ModelInfo( + TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) { + + public enum TruncationStrategy { + LONGEST_FIRST, + ONLY_FIRST, + ONLY_SECOND, + DO_NOT_TRUNCATE; + + public static TruncationStrategy fromString(String v) { + if ("true".equals(v)) return LONGEST_FIRST; + else if ("false".equals(v)) return DO_NOT_TRUNCATE; + return Arrays.stream(values()) + .filter(s -> s.name().equalsIgnoreCase(v)) + .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v))); + } + } + + public enum PaddingStrategy { + LONGEST, + MAX_LENGTH, + DO_NOT_PAD; + + public static PaddingStrategy fromString(String v) { + if ("true".equals(v)) return LONGEST; + else if ("false".equals(v)) return DO_NOT_PAD; + return Arrays.stream(values()) + .filter(s -> s.name().equalsIgnoreCase(v)) + .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v))); + } + } +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index bf2e0f82829..f727252cccb 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest { } @Test - void disables_padding_by_default() throws IOException { + void pads_to_max_length() throws IOException { var builder = new HuggingFaceTokenizer.Builder() .setTruncation(true) .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) @@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest { } } + @Test + void provides_model_info() throws IOException { + var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0); + var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2")); + assertEquals(expected, actual); + } + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { assertEquals(maxLength, encoding.ids().size()); assertEquals(maxLength, encoding.tokens().size()); |