summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java20
1 files changed, 16 insertions, 4 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index b92e0678970..2c66fc18c9b 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -6,6 +6,7 @@ import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
import com.yahoo.language.tools.Embed;
@@ -39,10 +40,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
try {
b.models.forEach((language, path) -> {
models.put(language,
- uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
- .optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .build()));
+ uncheck(() -> {
+ var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
+ .optTokenizerPath(path)
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) hfb.optMaxLength(b.maxLength);
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
+ return hfb.build();
+ }));
});
} finally {
Thread.currentThread().setContextClassLoader(original);
@@ -76,6 +81,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); }
@Override public void close() { models.forEach((__, model) -> model.close()); }
+ @Override public void deconstruct() { close(); }
private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
// Disregard language if there is default model
@@ -89,17 +95,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
private Boolean addSpecialTokens;
+ private Integer maxLength;
+ private Boolean truncation;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
for (var model : cfg.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
+ if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
+ if (cfg.truncation()) setTruncation(true);
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); }
public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
+ public Builder setMaxLength(int length) { maxLength = length; return this; }
+ public Builder setTruncation(boolean enabled) { truncation = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}