diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-26 15:44:33 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-26 15:44:33 +0200 |
commit | da7eaaf9389b6b2d1bfd1bcff3770d61b150fd1f (patch) | |
tree | 9562eb4764e15ea9452f8f0911bfe3737c5666b7 /linguistics-components | |
parent | 17f53418cfaaf49b87f06123af0aaf73f3593fe8 (diff) |
Make truncation and max length configurable
Diffstat (limited to 'linguistics-components')
3 files changed, 45 insertions, 7 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index d812b85b82e..f9a37bc477b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -39,10 +39,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, try { b.models.forEach((language, path) -> { models.put(language, - uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() - .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .build())); + uncheck(() -> { + var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) hfb.optMaxLength(b.maxLength); + if (b.truncation != null) hfb.optTruncation(b.truncation); + return hfb.build(); + })); }); } finally { Thread.currentThread().setContextClassLoader(original); @@ -90,17 +94,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private Boolean addSpecialTokens; + private Integer maxLength; + private Boolean truncation; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { for (var model : cfg.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); + if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); + if (cfg.truncation()) setTruncation(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } + public Builder setMaxLength(int length) { maxLength = length; return this; } + public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def index 5e58547879c..67b3b927f94 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def @@ -8,4 +8,6 @@ model[].language string # The path to the model relative to the application package root model[].path model -addSpecialTokens bool default=true
\ No newline at end of file +addSpecialTokens bool default=true +maxLength int default=-1 +truncation bool default=false
\ No newline at end of file diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index c79ecbfbfbe..6197fe214f1 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -15,6 +15,9 @@ import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.zip.GZIPInputStream; +import static org.junit.jupiter.api.Assertions.assertEquals; + + /** * @author bjorncs */ @@ -69,14 +72,37 @@ class HuggingFaceTokenizerTest { } } + @Test + void truncates_to_max_length() throws IOException { + int maxLength = 3; + var builder = new HuggingFaceTokenizer.Builder() + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .setMaxLength(maxLength) + .setTruncation(true); + String input = "what was the impact of the manhattan project"; + try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build(); + var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) { + assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input)); + assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input)); + } + } + + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { + assertEquals(maxLength, encoding.ids().size()); + assertEquals(maxLength, encoding.tokens().size()); + assertEquals(maxLength, encoding.attentionMask().size()); + assertEquals(maxLength, encoding.typeIds().size()); + } + private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { return new HuggingFaceTokenizer.Builder() .addSpecialTokens(false) - .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)))) + .addDefaultModel(decompressModelFile(tmp, model)) .build(); } - private static Path decompressModelFile(Path tmp, Path source) throws IOException { + private static Path decompressModelFile(Path tmp, String model) throws IOException { + var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)); Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); try (InputStream in = new GZIPInputStream(Files.newInputStream(source)); OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) { |