summaryrefslogtreecommitdiffstats
path: root/linguistics
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics')
-rw-r--r--linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java20
-rw-r--r--linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java21
2 files changed, 27 insertions, 14 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
index 9509c1d070d..a755a9e6ff3 100644
--- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
+++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java
@@ -96,23 +96,17 @@ public class SentencePieceEncoder implements Segmenter {
while (start < input.length()) { // segment from this position to the end of the text
Trie.Node node = model.tokens.root;
int characterPosition = start;
- boolean addedSingleCharacterSegment = false;
- while (characterPosition < input.length()) { // traverse the trie one character at the time from this position
- node = node.children.get(input.charAt(characterPosition));
- characterPosition++;
- if (node == null) break;
+ while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position
+ node = node.children.get(input.charAt(characterPosition++));
int length = characterPosition - start;
- if (node.isToken()) {
- if (node.type == TokenType.unused) continue;
-
+ if (node != null && node.isToken() && node.type != TokenType.unused) {
float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score;
addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds);
}
- if (! addedSingleCharacterSegment && length == 1)
- addedSingleCharacterSegment = true;
+ else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable
+ addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds);
+ }
}
- if ( ! addedSingleCharacterSegment) // add an unknown 1 character token to be able to start from the next character
- addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds);
start++;
}
@@ -248,7 +242,7 @@ public class SentencePieceEncoder implements Segmenter {
Float score;
private final Map<Character, Node> children = new HashMap<>();
- boolean isToken() { return score != null; }
+ boolean isToken() { return type != null; }
}
diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 70361f55750..7d0c1c5c78e 100644
--- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -17,7 +17,7 @@ import static org.junit.Assert.assertArrayEquals;
public class SentencePieceTest {
@Test
- public void testEnglishTokenization() throws IOException {
+ public void testEnglishTokenization() {
var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath());
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
@@ -42,6 +42,25 @@ public class SentencePieceTest {
}
@Test
+ public void testNoCollapse() {
+ var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setCollapseUnknowns(false));
+ tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo");
+ }
+
+ @Test
+ public void testHighestScore() {
+ var tester = new SentencePieceTester(new SentencePieceEncoder.Builder()
+ .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
+ .setScoring(SentencePieceEncoder.Scoring.highestScore));
+ tester.assertSegmented("h", "▁h");
+ tester.assertSegmented("he", "▁he");
+ tester.assertSegmented("hel", "▁h", "el");
+ tester.assertSegmented("hello", "▁h", "el", "lo");
+ }
+
+ @Test
public void testJapaneseTokenization() throws IOException {
SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder();
builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath());