summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java18
1 files changed, 14 insertions, 4 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index d812b85b82e..f9a37bc477b 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -39,10 +39,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
try {
b.models.forEach((language, path) -> {
models.put(language,
- uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
- .optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .build()));
+ uncheck(() -> {
+ var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
+ .optTokenizerPath(path)
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) hfb.optMaxLength(b.maxLength);
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
+ return hfb.build();
+ }));
});
} finally {
Thread.currentThread().setContextClassLoader(original);
@@ -90,17 +94,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
private Boolean addSpecialTokens;
+ private Integer maxLength;
+ private Boolean truncation;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
for (var model : cfg.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
+ if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
+ if (cfg.truncation()) setTruncation(true);
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); }
public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
+ public Builder setMaxLength(int length) { maxLength = length; return this; }
+ public Builder setTruncation(boolean enabled) { truncation = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}