summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java')
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java49
1 files changed, 49 insertions, 0 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
new file mode 100644
index 00000000000..1ba7c9b472d
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -0,0 +1,49 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+//
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ this(builder.build());
+ }
+
+ public SentencePieceTester(SentencePieceEncoder encoder) {
+ this.encoder = encoder;
+ }
+
+ public void assertEncoded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ public void assertEncoded(String input, String tensorType, String tensor) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ Tensor expected = Tensor.from(type, tensor);
+ assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type));
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+}