summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-08 14:23:16 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-08 14:23:16 +0200
commitc0652d7794a90e0afb593fc1a3db17c99606a808 (patch)
tree17887acf2818107bbeb7355f5ee463f5fb02873d /linguistics-components
parentc3d8c532e0f5b1db896d8693409098e8c2980da1 (diff)
Disable padding and make it configurable
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java10
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java32
2 files changed, 30 insertions, 12 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 2c66fc18c9b..a0b41669b2b 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -43,9 +43,10 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
- if (b.maxLength != null) hfb.optMaxLength(b.maxLength);
- if (b.truncation != null) hfb.optTruncation(b.truncation);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
+ .optTruncation(b.truncation != null ? b.truncation : true)
+ .optMaxLength(b.maxLength != null ? b.maxLength : 512);
+ if (b.padding != null && b.padding) hfb.optPadToMaxLength();
return hfb.build();
}));
});
@@ -97,6 +98,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
private Boolean addSpecialTokens;
private Integer maxLength;
private Boolean truncation;
+ private Boolean padding;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
@@ -105,6 +107,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
if (cfg.truncation()) setTruncation(true);
+ if (cfg.padding()) setPadding(true);
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
@@ -112,6 +115,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
public Builder setMaxLength(int length) { maxLength = length; return this; }
public Builder setTruncation(boolean enabled) { truncation = enabled; return this; }
+ public Builder setPadding(boolean enabled) { padding = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index 6197fe214f1..8b34e1487be 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -27,7 +27,10 @@ class HuggingFaceTokenizerTest {
@Test
void bert_tokenizer() throws IOException {
- try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) {
+ try (var tokenizer = new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(false)
+ .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
+ .build()) {
var tester = new EmbedderTester(tokenizer);
tester.assertSegmented("what was the impact of the manhattan project",
"what", "was", "the", "impact", "of", "the", "manhattan", "project");
@@ -41,7 +44,10 @@ class HuggingFaceTokenizerTest {
@Test
void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException {
- try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) {
+ try (var tokenizer = new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(false)
+ .addDefaultModel(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2"))
+ .build()) {
var tester = new EmbedderTester(tokenizer);
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
@@ -87,6 +93,21 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void disables_padding_by_default() throws IOException {
+ var builder = new HuggingFaceTokenizer.Builder()
+ .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
+ .addSpecialTokens(true).setMaxLength(16);
+ String input = "what was the impact of the manhattan project";
+ try (var tokenizerWithDefaultPadding = builder.build();
+ var tokenizerWithPaddingDisabled = builder.setPadding(false).build();
+ var tokenizerWithPaddingEnabled = builder.setPadding(true).build()) {
+ assertMaxLengthRespected(10, tokenizerWithDefaultPadding.encode(input));
+ assertMaxLengthRespected(10, tokenizerWithPaddingDisabled.encode(input));
+ assertMaxLengthRespected(16, tokenizerWithPaddingEnabled.encode(input));
+ }
+ }
+
private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
assertEquals(maxLength, encoding.ids().size());
assertEquals(maxLength, encoding.tokens().size());
@@ -94,13 +115,6 @@ class HuggingFaceTokenizerTest {
assertEquals(maxLength, encoding.typeIds().size());
}
- private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException {
- return new HuggingFaceTokenizer.Builder()
- .addSpecialTokens(false)
- .addDefaultModel(decompressModelFile(tmp, model))
- .build();
- }
-
private static Path decompressModelFile(Path tmp, String model) throws IOException {
var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model));
Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", ""));