diff options
Diffstat (limited to 'linguistics')
-rw-r--r-- | linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java | 20 | ||||
-rw-r--r-- | linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java | 21 |
2 files changed, 27 insertions, 14 deletions
diff --git a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java index 9509c1d070d..a755a9e6ff3 100644 --- a/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java +++ b/linguistics/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEncoder.java @@ -96,23 +96,17 @@ public class SentencePieceEncoder implements Segmenter { while (start < input.length()) { // segment from this position to the end of the text Trie.Node node = model.tokens.root; int characterPosition = start; - boolean addedSingleCharacterSegment = false; - while (characterPosition < input.length()) { // traverse the trie one character at the time from this position - node = node.children.get(input.charAt(characterPosition)); - characterPosition++; - if (node == null) break; + while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position + node = node.children.get(input.charAt(characterPosition++)); int length = characterPosition - start; - if (node.isToken()) { - if (node.type == TokenType.unused) continue; - + if (node != null && node.isToken() && node.type != TokenType.unused) { float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); } - if (! addedSingleCharacterSegment && length == 1) - addedSingleCharacterSegment = true; + else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable + addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds); + } } - if ( ! addedSingleCharacterSegment) // add an unknown 1 character token to be able to start from the next character - addSegment(TokenType.unknown, 0, start, start + 1, unknownScore, segmentEnds); start++; } @@ -248,7 +242,7 @@ public class SentencePieceEncoder implements Segmenter { Float score; private final Map<Character, Node> children = new HashMap<>(); - boolean isToken() { return score != null; } + boolean isToken() { return type != null; } } diff --git a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 70361f55750..7d0c1c5c78e 100644 --- a/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -17,7 +17,7 @@ import static org.junit.Assert.assertArrayEquals; public class SentencePieceTest { @Test - public void testEnglishTokenization() throws IOException { + public void testEnglishTokenization() { var tester = new SentencePieceTester(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); @@ -42,6 +42,25 @@ public class SentencePieceTest { } @Test + public void testNoCollapse() { + var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setCollapseUnknowns(false)); + tester.assertSegmented("KHJ hello", "▁", "K", "H", "J", "▁hel", "lo"); + } + + @Test + public void testHighestScore() { + var tester = new SentencePieceTester(new SentencePieceEncoder.Builder() + .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) + .setScoring(SentencePieceEncoder.Scoring.highestScore)); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁h", "el"); + tester.assertSegmented("hello", "▁h", "el", "lo"); + } + + @Test public void testJapaneseTokenization() throws IOException { SentencePieceEncoder.Builder builder = new SentencePieceEncoder.Builder(); builder.addModel(Language.JAPANESE, new File("src/test/models/sentencepiece/ja.wiki.bpe.vs5000.model").toPath()); |