diff options
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java')
-rw-r--r-- | linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index c79ecbfbfbe..6197fe214f1 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -15,6 +15,9 @@ import java.nio.file.Paths; import java.nio.file.StandardOpenOption; import java.util.zip.GZIPInputStream; +import static org.junit.jupiter.api.Assertions.assertEquals; + + /** * @author bjorncs */ @@ -69,14 +72,37 @@ class HuggingFaceTokenizerTest { } } + @Test + void truncates_to_max_length() throws IOException { + int maxLength = 3; + var builder = new HuggingFaceTokenizer.Builder() + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .setMaxLength(maxLength) + .setTruncation(true); + String input = "what was the impact of the manhattan project"; + try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build(); + var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) { + assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input)); + assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input)); + } + } + + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { + assertEquals(maxLength, encoding.ids().size()); + assertEquals(maxLength, encoding.tokens().size()); + assertEquals(maxLength, encoding.attentionMask().size()); + assertEquals(maxLength, encoding.typeIds().size()); + } + private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { return new HuggingFaceTokenizer.Builder() .addSpecialTokens(false) - .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)))) + .addDefaultModel(decompressModelFile(tmp, model)) .build(); } - private static Path decompressModelFile(Path tmp, Path source) throws IOException { + private static Path decompressModelFile(Path tmp, String model) throws IOException { + var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)); Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); try (InputStream in = new GZIPInputStream(Files.newInputStream(source)); OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) { |