summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java')
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java30
1 files changed, 28 insertions, 2 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index c79ecbfbfbe..6197fe214f1 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -15,6 +15,9 @@ import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.util.zip.GZIPInputStream;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+
/**
* @author bjorncs
*/
@@ -69,14 +72,37 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void truncates_to_max_length() throws IOException {
+ int maxLength = 3;
+ var builder = new HuggingFaceTokenizer.Builder()
+ .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
+ .setMaxLength(maxLength)
+ .setTruncation(true);
+ String input = "what was the impact of the manhattan project";
+ try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build();
+ var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) {
+ assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input));
+ assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input));
+ }
+ }
+
+ private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
+ assertEquals(maxLength, encoding.ids().size());
+ assertEquals(maxLength, encoding.tokens().size());
+ assertEquals(maxLength, encoding.attentionMask().size());
+ assertEquals(maxLength, encoding.typeIds().size());
+ }
+
private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException {
return new HuggingFaceTokenizer.Builder()
.addSpecialTokens(false)
- .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model))))
+ .addDefaultModel(decompressModelFile(tmp, model))
.build();
}
- private static Path decompressModelFile(Path tmp, Path source) throws IOException {
+ private static Path decompressModelFile(Path tmp, String model) throws IOException {
+ var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model));
Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", ""));
try (InputStream in = new GZIPInputStream(Files.newInputStream(source));
OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) {