diff options
author | Jon Bratseth <bratseth@oath.com> | 2021-09-27 23:09:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-27 23:09:03 +0200 |
commit | 2df97d23d9f25ae60f010a2e9f273cb5b38e049b (patch) | |
tree | d2923a45682e91d80e7011c60cfb301e05acead3 /linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java | |
parent | 037f756caf4cfb99bcd988174839d7bc385267b9 (diff) | |
parent | 8f3fb1a105ded07144f6de527266a438e48a1766 (diff) |
Merge pull request #19294 from vespa-engine/bratseth/linguistics-componentsv7.473.17
Bratseth/linguistics components
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java | 90 |
1 files changed, 90 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java new file mode 100644 index 00000000000..1659e3c0fa7 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceAlgorithm.java @@ -0,0 +1,90 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.language.sentencepiece; + +/** + * SentencePiece algorithm implementation + * + * @author bratseth + */ +class SentencePieceAlgorithm { + + // TODO: Support characters beyond BMP + + static final char spaceSymbol = '▁'; + + private final boolean collapseUnknowns; + private final Scoring scoring; + + SentencePieceAlgorithm(boolean collapseUnknowns, Scoring scoring) { + this.collapseUnknowns = collapseUnknowns; + this.scoring = scoring; + } + + public <RESULTTYPE> void segment(String input, ResultBuilder<RESULTTYPE> resultBuilder, Model model) { + SegmentEnd[] segmentEnds = new SegmentEnd[input.length() + 1]; + segmentEnds[0] = new SegmentEnd(TokenType.unknown, 0, 0, 0, 0); + int start = 0; + while (start < input.length()) { // segment from this position to the end of the text + Trie.Node node = model.tokens.root; + int characterPosition = start; + while (node != null && characterPosition < input.length()) { // traverse the trie one character at the time from this position + node = node.children.get(input.charAt(characterPosition++)); + int length = characterPosition - start; + if (node != null && node.isToken() && node.type != TokenType.unused) { + float score = node.type == TokenType.userDefined ? (length * model.maxScore - 0.1f) : node.score; + addSegment(TokenType.text, node.id, start, characterPosition, score, segmentEnds); + } + else if (length == 1) { // add an 'unknown' length 1 token to make the next position reachable + addSegment(TokenType.unknown, 0, start, start + 1, model.minScore - 10.0f, segmentEnds); + } + } + start++; + } + resultBuilder.build(input, segmentEnds, collapseUnknowns); + } + + private void addSegment(TokenType type, int id, int start, int end, float score, SegmentEnd[] segmentEnds) { + if (segmentEnds[end] == null || + segmentEnds[start].scoreWith(score) > segmentEnds[end].score()) { + segmentEnds[end] = new SegmentEnd(type, id, + segmentEnds[start].pathScoreSum + score, + segmentEnds[start].pathSegmentCount + 1, + start); + } + } + + final class SegmentEnd { + + final TokenType type; + final int id; + final float pathScoreSum; + final int pathSegmentCount; + final int segmentStart; + + SegmentEnd(TokenType type, int id, float pathScoreSum, int pathSegmentCount, int segmentStart) { + this.type = type; + this.id = id; + this.pathScoreSum = pathScoreSum; + this.pathSegmentCount = pathSegmentCount; + this.segmentStart = segmentStart; + } + + public float score() { + switch (scoring) { + case fewestSegments: return 1f / pathSegmentCount * 10_000_000 + pathScoreSum; + case highestScore: return pathScoreSum; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + public float scoreWith(float additionalSegmentScore) { + switch (scoring) { + case fewestSegments: return 1f / (pathSegmentCount + 1) * 10_000_000 + (pathScoreSum + additionalSegmentScore ); + case highestScore: return pathScoreSum + additionalSegmentScore; + default : throw new IllegalArgumentException("Unknown scoring " + scoring); + } + } + + } + +} |