diff options
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index b92e0678970..2c66fc18c9b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -6,6 +6,7 @@ import com.yahoo.api.annotations.Beta; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.language.Language; +import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; import com.yahoo.language.tools.Embed; @@ -39,10 +40,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, try { b.models.forEach((language, path) -> { models.put(language, - uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() - .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .build())); + uncheck(() -> { + var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) hfb.optMaxLength(b.maxLength); + if (b.truncation != null) hfb.optTruncation(b.truncation); + return hfb.build(); + })); }); } finally { Thread.currentThread().setContextClassLoader(original); @@ -76,6 +81,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); } @Override public void close() { models.forEach((__, model) -> model.close()); } + @Override public void deconstruct() { close(); } private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { // Disregard language if there is default model @@ -89,17 +95,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private Boolean addSpecialTokens; + private Integer maxLength; + private Boolean truncation; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { for (var model : cfg.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); + if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); + if (cfg.truncation()) setTruncation(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } + public Builder setMaxLength(int length) { maxLength = length; return this; } + public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } |