aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java17
1 files changed, 14 insertions, 3 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 474add27023..b4d542b0d82 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -93,13 +93,24 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
*
* @param tokens the list of tokens to decode to a string
* @param context the context which specifies the language used to select a model
- * @return the string formed by decoding the tokens back to their string repreesentation
+ * @return the string formed by decoding the tokens back to their string representation
*/
@Override
public String decode(List<Integer> tokens, Embedder.Context context) {
+ return decode(tokens, context, false);
+ }
+
+ public String decode(List<Integer> tokens, Embedder.Context context, boolean skipControl) {
Model model = resolveModelFrom(context.getLanguage());
- String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining());
- return denormalize(normalized);
+ StringBuilder sb = new StringBuilder();
+ for (var tokenId : tokens) {
+ var token = model.tokenId2Token.get(tokenId);
+ var skip = skipControl && token.type() == TokenType.control;
+ if ( ! skip) {
+ sb.append(token.text());
+ }
+ }
+ return denormalize(sb.toString());
}
/**