diff options
Diffstat (limited to 'linguistics-components/src/main/java')
2 files changed, 22 insertions, 6 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java index b4f216c4c9c..28a1b9d2930 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -22,7 +22,8 @@ final class Model { final float minScore; final float maxScore; final Trie tokens = new Trie(); - final Map<Integer, String> tokenId2Token = new HashMap<>(); + final Map<Integer, Token> tokenId2Token = new HashMap<>(); + Model(Language language, Path path) { try { @@ -33,8 +34,10 @@ final class Model { float maxScore = Float.MIN_VALUE; for (int i = 0; i < sp.getPiecesCount(); i++) { var piece = sp.getPieces(i); - tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); - tokenId2Token.put(i, piece.getPiece()); + var type = toTokenType(piece.getType()); + var word = piece.getPiece(); + tokens.add(type, i, word, piece.getScore()); + tokenId2Token.put(i, new Token(word, type)); minScore = Math.min(piece.getScore(), minScore); maxScore = Math.max(piece.getScore(), maxScore); } @@ -61,4 +64,6 @@ final class Model { return "SentencePiece model for " + language + ": '" + source + "'"; } + record Token(String text, TokenType type) { } + } diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 474add27023..b4d542b0d82 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -93,13 +93,24 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * * @param tokens the list of tokens to decode to a string * @param context the context which specifies the language used to select a model - * @return the string formed by decoding the tokens back to their string repreesentation + * @return the string formed by decoding the tokens back to their string representation */ @Override public String decode(List<Integer> tokens, Embedder.Context context) { + return decode(tokens, context, false); + } + + public String decode(List<Integer> tokens, Embedder.Context context, boolean skipControl) { Model model = resolveModelFrom(context.getLanguage()); - String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining()); - return denormalize(normalized); + StringBuilder sb = new StringBuilder(); + for (var tokenId : tokens) { + var token = model.tokenId2Token.get(tokenId); + var skip = skipControl && token.type() == TokenType.control; + if ( ! skip) { + sb.append(token.text()); + } + } + return denormalize(sb.toString()); } /** |