aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 11:04:56 +0200
commita2afdafbffcbc09594fd629c65746ec253f180be (patch)
tree987804d4362c801c6c47cfe469ec6b97409de6ef /linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
parent381033510b992049d55cae9964d942b4b47eb9df (diff)
Make SentencePieceEncoder configurable
Diffstat (limited to 'linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java')
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java40
1 files changed, 40 insertions, 0 deletions
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
new file mode 100644
index 00000000000..dee9be5aa7e
--- /dev/null
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -0,0 +1,40 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+//
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+
+class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ this(builder.build());
+ }
+
+ public SentencePieceTester(SentencePieceEncoder encoder) {
+ this.encoder = encoder;
+ }
+
+ public void assertEncoded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+}