summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 17:25:55 +0200
committerGitHub <noreply@github.com>2023-06-12 17:25:55 +0200
commitfb5d1bf9f451fbeb4a40d7f73fa856ef81bd77ed (patch)
tree1f8ab291370e84407a827f9a80bcf943f522ca29 /linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
parent0647b650c3334ff86d50431e78549e25dc46caf9 (diff)
parent4f722322cc9f8df5146ffb27d74239b3b4f2d634 (diff)
Merge pull request #27387 from vespa-engine/bjorncs/hfv8.176.13
Prefer truncation configuration from tokenizer model
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java')
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java9
1 files changed, 8 insertions, 1 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index bf2e0f82829..f727252cccb 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest {
}
@Test
- void disables_padding_by_default() throws IOException {
+ void pads_to_max_length() throws IOException {
var builder = new HuggingFaceTokenizer.Builder()
.setTruncation(true)
.addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
@@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void provides_model_info() throws IOException {
+ var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0);
+ var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2"));
+ assertEquals(expected, actual);
+ }
+
private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
assertEquals(maxLength, encoding.ids().size());
assertEquals(maxLength, encoding.tokens().size());