summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java10
1 files changed, 7 insertions, 3 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 2c66fc18c9b..a0b41669b2b 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -43,9 +43,10 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
- if (b.maxLength != null) hfb.optMaxLength(b.maxLength);
- if (b.truncation != null) hfb.optTruncation(b.truncation);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
+ .optTruncation(b.truncation != null ? b.truncation : true)
+ .optMaxLength(b.maxLength != null ? b.maxLength : 512);
+ if (b.padding != null && b.padding) hfb.optPadToMaxLength();
return hfb.build();
}));
});
@@ -97,6 +98,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
private Boolean addSpecialTokens;
private Integer maxLength;
private Boolean truncation;
+ private Boolean padding;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
@@ -105,6 +107,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
if (cfg.truncation()) setTruncation(true);
+ if (cfg.padding()) setPadding(true);
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
@@ -112,6 +115,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
public Builder setMaxLength(int length) { maxLength = length; return this; }
public Builder setTruncation(boolean enabled) { truncation = enabled; return this; }
+ public Builder setPadding(boolean enabled) { padding = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}