diff options
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index d812b85b82e..f9a37bc477b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -39,10 +39,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, try { b.models.forEach((language, path) -> { models.put(language, - uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() - .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .build())); + uncheck(() -> { + var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) hfb.optMaxLength(b.maxLength); + if (b.truncation != null) hfb.optTruncation(b.truncation); + return hfb.build(); + })); }); } finally { Thread.currentThread().setContextClassLoader(original); @@ -90,17 +94,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private Boolean addSpecialTokens; + private Integer maxLength; + private Boolean truncation; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { for (var model : cfg.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); + if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); + if (cfg.truncation()) setTruncation(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } + public Builder setMaxLength(int length) { maxLength = length; return this; } + public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } |