diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-26 15:44:33 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-26 15:44:33 +0200 |
commit | da7eaaf9389b6b2d1bfd1bcff3770d61b150fd1f (patch) | |
tree | 9562eb4764e15ea9452f8f0911bfe3737c5666b7 | |
parent | 17f53418cfaaf49b87f06123af0aaf73f3593fe8 (diff) |
Make truncation and max length configurable
4 files changed, 48 insertions, 19 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index d812b85b82e..f9a37bc477b 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -39,10 +39,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, try { b.models.forEach((language, path) -> { models.put(language, - uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() - .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .build())); + uncheck(() -> { + var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) hfb.optMaxLength(b.maxLength); + if (b.truncation != null) hfb.optTruncation(b.truncation); + return hfb.build(); + })); }); } finally { Thread.currentThread().setContextClassLoader(original); @@ -90,17 +94,23 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); private Boolean addSpecialTokens; + private Integer maxLength; + private Boolean truncation; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { for (var model : cfg.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); + if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); + if (cfg.truncation()) setTruncation(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } + public Builder setMaxLength(int length) { maxLength = length; return this; } + public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def index 5e58547879c..67b3b927f94 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def @@ -8,4 +8,6 @@ model[].language string # The path to the model relative to the application package root model[].path model -addSpecialTokens bool default=true
\ No newline at end of file +addSpecialTokens bool default=true +maxLength int default=-1 +truncation bool default=false
\ No newline at end of file diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index c79ecbfbfbe..6197fe214f1 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -15,6 +15,9 @@ import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.zip.GZIPInputStream; +import static org.junit.jupiter.api.Assertions.assertEquals; + + /** * @author bjorncs */ @@ -69,14 +72,37 @@ class HuggingFaceTokenizerTest { } } + @Test + void truncates_to_max_length() throws IOException { + int maxLength = 3; + var builder = new HuggingFaceTokenizer.Builder() + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .setMaxLength(maxLength) + .setTruncation(true); + String input = "what was the impact of the manhattan project"; + try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build(); + var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) { + assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input)); + assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input)); + } + } + + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { + assertEquals(maxLength, encoding.ids().size()); + assertEquals(maxLength, encoding.tokens().size()); + assertEquals(maxLength, encoding.attentionMask().size()); + assertEquals(maxLength, encoding.typeIds().size()); + } + private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { return new HuggingFaceTokenizer.Builder() .addSpecialTokens(false) - .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)))) + .addDefaultModel(decompressModelFile(tmp, model)) .build(); } - private static Path decompressModelFile(Path tmp, Path source) throws IOException { + private static Path decompressModelFile(Path tmp, String model) throws IOException { + var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)); Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); try (InputStream in = new GZIPInputStream(Files.newInputStream(source)); OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) { diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 0c1cc80544e..d08bc8a3e8c 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -29,14 +29,12 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final String attentionMaskName; private final String tokenTypeIdsName; private final String outputName; - private final int maxTokens; private final boolean normalize; private final HuggingFaceTokenizer tokenizer; private final OnnxEvaluator evaluator; @Inject public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { - maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -45,6 +43,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenizer = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) .addDefaultModel(Paths.get(config.tokenizerPath().toString())) + .setTruncation(true) + .setMaxLength(config.transformerMaxTokens()) .build(); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) @@ -74,16 +74,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { - var tokenIds = tokenizer.embed(s, context); - - int tokensSize = tokenIds.size(); - - if (tokensSize > maxTokens) { - Integer lastElement = tokenIds.get(tokensSize - 1); - tokenIds = tokenIds.subList(0, maxTokens - 1); - tokenIds.add(lastElement); - } - return tokenIds; + return tokenizer.embed(s, context); } @Override |