summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2021-09-27 23:09:03 +0200
committerGitHub <noreply@github.com>2021-09-27 23:09:03 +0200
commit2df97d23d9f25ae60f010a2e9f273cb5b38e049b (patch)
treed2923a45682e91d80e7011c60cfb301e05acead3 /linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
parent037f756caf4cfb99bcd988174839d7bc385267b9 (diff)
parent8f3fb1a105ded07144f6de527266a438e48a1766 (diff)
Merge pull request #19294 from vespa-engine/bratseth/linguistics-componentsv7.473.17
Bratseth/linguistics components
Diffstat (limited to 'linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java')
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java49
1 files changed, 49 insertions, 0 deletions
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
new file mode 100644
index 00000000000..1ba7c9b472d
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -0,0 +1,49 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+//
+
+package com.yahoo.language.sentencepiece;
+
+import com.yahoo.language.Language;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+class SentencePieceTester {
+
+ private final SentencePieceEncoder encoder;
+
+ public SentencePieceTester(Path model) {
+ this(new SentencePieceEncoder.Builder().addDefaultModel(model));
+ }
+
+ public SentencePieceTester(SentencePieceEncoder.Builder builder) {
+ this(builder.build());
+ }
+
+ public SentencePieceTester(SentencePieceEncoder encoder) {
+ this.encoder = encoder;
+ }
+
+ public void assertEncoded(String input, Integer... expectedCodes) {
+ assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
+ }
+
+ public void assertEncoded(String input, String tensorType, String tensor) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ Tensor expected = Tensor.from(type, tensor);
+ assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type));
+ }
+
+ public void assertSegmented(String input, String... expectedSegments) {
+ assertSegmented(Language.UNKNOWN, input, expectedSegments);
+ }
+
+ public void assertSegmented(Language language, String input, String... expectedSegments) {
+ assertArrayEquals(expectedSegments, encoder.segment(input, language).toArray());
+ }
+
+}