diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-08 14:23:16 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-08 14:23:16 +0200 |
commit | c0652d7794a90e0afb593fc1a3db17c99606a808 (patch) | |
tree | 17887acf2818107bbeb7355f5ee463f5fb02873d /linguistics-components | |
parent | c3d8c532e0f5b1db896d8693409098e8c2980da1 (diff) |
Disable padding and make it configurable
Diffstat (limited to 'linguistics-components')
2 files changed, 30 insertions, 12 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 2c66fc18c9b..a0b41669b2b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -43,9 +43,10 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); - if (b.maxLength != null) hfb.optMaxLength(b.maxLength); - if (b.truncation != null) hfb.optTruncation(b.truncation); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) + .optTruncation(b.truncation != null ? b.truncation : true) + .optMaxLength(b.maxLength != null ? b.maxLength : 512); + if (b.padding != null && b.padding) hfb.optPadToMaxLength(); return hfb.build(); })); }); @@ -97,6 +98,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, private Boolean addSpecialTokens; private Integer maxLength; private Boolean truncation; + private Boolean padding; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { @@ -105,6 +107,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); if (cfg.truncation()) setTruncation(true); + if (cfg.padding()) setPadding(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } @@ -112,6 +115,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } public Builder setMaxLength(int length) { maxLength = length; return this; } public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } + public Builder setPadding(boolean enabled) { padding = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index 6197fe214f1..8b34e1487be 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -27,7 +27,10 @@ class HuggingFaceTokenizerTest { @Test void bert_tokenizer() throws IOException { - try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) { + try (var tokenizer = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .build()) { var tester = new EmbedderTester(tokenizer); tester.assertSegmented("what was the impact of the manhattan project", "what", "was", "the", "impact", "of", "the", "manhattan", "project"); @@ -41,7 +44,10 @@ class HuggingFaceTokenizerTest { @Test void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException { - try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) { + try (var tokenizer = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2")) + .build()) { var tester = new EmbedderTester(tokenizer); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); @@ -87,6 +93,21 @@ class HuggingFaceTokenizerTest { } } + @Test + void disables_padding_by_default() throws IOException { + var builder = new HuggingFaceTokenizer.Builder() + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .addSpecialTokens(true).setMaxLength(16); + String input = "what was the impact of the manhattan project"; + try (var tokenizerWithDefaultPadding = builder.build(); + var tokenizerWithPaddingDisabled = builder.setPadding(false).build(); + var tokenizerWithPaddingEnabled = builder.setPadding(true).build()) { + assertMaxLengthRespected(10, tokenizerWithDefaultPadding.encode(input)); + assertMaxLengthRespected(10, tokenizerWithPaddingDisabled.encode(input)); + assertMaxLengthRespected(16, tokenizerWithPaddingEnabled.encode(input)); + } + } + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { assertEquals(maxLength, encoding.ids().size()); assertEquals(maxLength, encoding.tokens().size()); @@ -94,13 +115,6 @@ class HuggingFaceTokenizerTest { assertEquals(maxLength, encoding.typeIds().size()); } - private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { - return new HuggingFaceTokenizer.Builder() - .addSpecialTokens(false) - .addDefaultModel(decompressModelFile(tmp, model)) - .build(); - } - private static Path decompressModelFile(Path tmp, String model) throws IOException { var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)); Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); |