aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2021-09-16 13:18:11 +0200
committerJon Bratseth <bratseth@gmail.com>2021-09-16 13:18:11 +0200
commitf84bde56a310096c7019c7d899573e36ea5a7316 (patch)
treee2e6d0a5c7d27ac540e3d82132a988b9d31890de /linguistics/src/test/java/com/yahoo
parent39a052f324aeeb4c0eb2d4313edf57ddbc4db2c7 (diff)
Encode to dense tensor
Diffstat (limited to 'linguistics/src/test/java/com/yahoo')
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java14
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java9
2 files changed, 23 insertions, 0 deletions
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index f86bc2f716b..b8fb8a2fdbb 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -3,6 +3,7 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
+import com.yahoo.tensor.Tensor;
import org.junit.Test;
import java.io.File;
@@ -33,11 +34,24 @@ public class SentencePieceTest {
tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁)", "(", "/", "&", "#", "(", "sm", "all", ")", "/", "\"", "in", "▁qu", "otes", "\")");
tester.assertSegmented("x.400AS", "▁x", ".", "4", "00", "AS");
tester.assertSegmented("A normal sentence. Yes one more.", "▁", "A", "▁normal", "▁sentence", ".", "▁", "Y", "es", "▁one", "▁more", ".");
+ }
+
+ @Test
+ public void testIntegerListEncoding() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
tester.assertEncoded("hello, world!", 908, 1418, 9934, 501, 9960);
tester.assertEncoded("Hello, world!", 9912, 0, 6595, 9934, 501, 9960);
}
@Test
+ public void testDenseTensorEncoding() {
+ var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
+ tester.assertEncoded("hello, world!", "tensor(d[10])", "[908,1418,9934,501,9960,0,0,0,0,0]");
+ tester.assertEncoded("Hello, world!", "tensor(d[10])", "[9912,0,6595,9934,501,9960,0,0,0,0]");
+ tester.assertEncoded("hello, world!", "tensor(d[2])", "[908,1418]");
+ }
+
+ @Test
public void testNoCollapse() {
var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
index dee9be5aa7e..1ba7c9b472d 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTester.java
@@ -4,10 +4,13 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
class SentencePieceTester {
@@ -29,6 +32,12 @@ class SentencePieceTester {
assertArrayEquals(expectedCodes, encoder.encode(input, Language.UNKNOWN).toArray());
}
+ public void assertEncoded(String input, String tensorType, String tensor) {
+ TensorType type = TensorType.fromSpec(tensorType);
+ Tensor expected = Tensor.from(type, tensor);
+ assertEquals(expected, encoder.encode(input, Language.UNKNOWN, type));
+ }
+
public void assertSegmented(String input, String... expectedSegments) {
assertSegmented(Language.UNKNOWN, input, expectedSegments);
}