diff options
Diffstat (limited to 'linguistics-components/src/main')
2 files changed, 17 insertions, 5 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index d812b85b82e..f9a37bc477b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -39,10 +39,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, try { b.models.forEach((language, path) -> { models.put(language, - uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() - .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .build())); + uncheck(() -> { + var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) hfb.optMaxLength(b.maxLength); + if (b.truncation != null) hfb.optTruncation(b.truncation); + return hfb.build(); + })); }); } finally { Thread.currentThread().setContextClassLoader(original); @@ -90,17 +94,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private Boolean addSpecialTokens; + private Integer maxLength; + private Boolean truncation; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { for (var model : cfg.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); + if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); + if (cfg.truncation()) setTruncation(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } + public Builder setMaxLength(int length) { maxLength = length; return this; } + public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def index 5e58547879c..67b3b927f94 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def @@ -8,4 +8,6 @@ model[].language string # The path to the model relative to the application package root model[].path model -addSpecialTokens bool default=true
\ No newline at end of file +addSpecialTokens bool default=true +maxLength int default=-1 +truncation bool default=false
\ No newline at end of file |