diff options
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java')
-rw-r--r-- | linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 474add27023..b4d542b0d82 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -93,13 +93,24 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * * @param tokens the list of tokens to decode to a string * @param context the context which specifies the language used to select a model - * @return the string formed by decoding the tokens back to their string repreesentation + * @return the string formed by decoding the tokens back to their string representation */ @Override public String decode(List<Integer> tokens, Embedder.Context context) { + return decode(tokens, context, false); + } + + public String decode(List<Integer> tokens, Embedder.Context context, boolean skipControl) { Model model = resolveModelFrom(context.getLanguage()); - String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining()); - return denormalize(normalized); + StringBuilder sb = new StringBuilder(); + for (var tokenId : tokens) { + var token = model.tokenId2Token.get(tokenId); + var skip = skipControl && token.type() == TokenType.control; + if ( ! skip) { + sb.append(token.text()); + } + } + return denormalize(sb.toString()); } /** |