diff options
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 2c66fc18c9b..a0b41669b2b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -43,9 +43,10 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); - if (b.maxLength != null) hfb.optMaxLength(b.maxLength); - if (b.truncation != null) hfb.optTruncation(b.truncation); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) + .optTruncation(b.truncation != null ? b.truncation : true) + .optMaxLength(b.maxLength != null ? b.maxLength : 512); + if (b.padding != null && b.padding) hfb.optPadToMaxLength(); return hfb.build(); })); }); @@ -97,6 +98,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, private Boolean addSpecialTokens; private Integer maxLength; private Boolean truncation; + private Boolean padding; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { @@ -105,6 +107,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); if (cfg.truncation()) setTruncation(true); + if (cfg.padding()) setPadding(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } @@ -112,6 +115,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } public Builder setMaxLength(int length) { maxLength = length; return this; } public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } + public Builder setPadding(boolean enabled) { padding = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } |