summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-26 15:44:33 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-26 15:44:33 +0200
commitda7eaaf9389b6b2d1bfd1bcff3770d61b150fd1f (patch)
tree9562eb4764e15ea9452f8f0911bfe3737c5666b7 /linguistics-components
parent17f53418cfaaf49b87f06123af0aaf73f3593fe8 (diff)
Make truncation and max length configurable
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java18
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def4
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java30
3 files changed, 45 insertions, 7 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index d812b85b82e..f9a37bc477b 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -39,10 +39,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
try {
b.models.forEach((language, path) -> {
models.put(language,
- uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
- .optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .build()));
+ uncheck(() -> {
+ var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
+ .optTokenizerPath(path)
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) hfb.optMaxLength(b.maxLength);
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
+ return hfb.build();
+ }));
});
} finally {
Thread.currentThread().setContextClassLoader(original);
@@ -90,17 +94,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
private Boolean addSpecialTokens;
+ private Integer maxLength;
+ private Boolean truncation;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
for (var model : cfg.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
+ if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
+ if (cfg.truncation()) setTruncation(true);
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); }
public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
+ public Builder setMaxLength(int length) { maxLength = length; return this; }
+ public Builder setTruncation(boolean enabled) { truncation = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
index 5e58547879c..67b3b927f94 100644
--- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
+++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
@@ -8,4 +8,6 @@ model[].language string
# The path to the model relative to the application package root
model[].path model
-addSpecialTokens bool default=true \ No newline at end of file
+addSpecialTokens bool default=true
+maxLength int default=-1
+truncation bool default=false \ No newline at end of file
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index c79ecbfbfbe..6197fe214f1 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -15,6 +15,9 @@ import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.zip.GZIPInputStream;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+
/**
* @author bjorncs
*/
@@ -69,14 +72,37 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void truncates_to_max_length() throws IOException {
+ int maxLength = 3;
+ var builder = new HuggingFaceTokenizer.Builder()
+ .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
+ .setMaxLength(maxLength)
+ .setTruncation(true);
+ String input = "what was the impact of the manhattan project";
+ try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build();
+ var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) {
+ assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input));
+ assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input));
+ }
+ }
+
+ private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
+ assertEquals(maxLength, encoding.ids().size());
+ assertEquals(maxLength, encoding.tokens().size());
+ assertEquals(maxLength, encoding.attentionMask().size());
+ assertEquals(maxLength, encoding.typeIds().size());
+ }
+
private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException {
return new HuggingFaceTokenizer.Builder()
.addSpecialTokens(false)
- .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model))))
+ .addDefaultModel(decompressModelFile(tmp, model))
.build();
}
- private static Path decompressModelFile(Path tmp, Path source) throws IOException {
+ private static Path decompressModelFile(Path tmp, String model) throws IOException {
+ var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model));
Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", ""));
try (InputStream in = new GZIPInputStream(Files.newInputStream(source));
OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) {