aboutsummaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main/java')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java11
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java17
2 files changed, 22 insertions, 6 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
index b4f216c4c9c..28a1b9d2930 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
@@ -22,7 +22,8 @@ final class Model {
final float minScore;
final float maxScore;
final Trie tokens = new Trie();
- final Map<Integer, String> tokenId2Token = new HashMap<>();
+ final Map<Integer, Token> tokenId2Token = new HashMap<>();
+
Model(Language language, Path path) {
try {
@@ -33,8 +34,10 @@ final class Model {
float maxScore = Float.MIN_VALUE;
for (int i = 0; i < sp.getPiecesCount(); i++) {
var piece = sp.getPieces(i);
- tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore());
- tokenId2Token.put(i, piece.getPiece());
+ var type = toTokenType(piece.getType());
+ var word = piece.getPiece();
+ tokens.add(type, i, word, piece.getScore());
+ tokenId2Token.put(i, new Token(word, type));
minScore = Math.min(piece.getScore(), minScore);
maxScore = Math.max(piece.getScore(), maxScore);
}
@@ -61,4 +64,6 @@ final class Model {
return "SentencePiece model for " + language + ": '" + source + "'";
}
+ record Token(String text, TokenType type) { }
+
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 474add27023..b4d542b0d82 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -93,13 +93,24 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
*
* @param tokens the list of tokens to decode to a string
* @param context the context which specifies the language used to select a model
- * @return the string formed by decoding the tokens back to their string repreesentation
+ * @return the string formed by decoding the tokens back to their string representation
*/
@Override
public String decode(List<Integer> tokens, Embedder.Context context) {
+ return decode(tokens, context, false);
+ }
+
+ public String decode(List<Integer> tokens, Embedder.Context context, boolean skipControl) {
Model model = resolveModelFrom(context.getLanguage());
- String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining());
- return denormalize(normalized);
+ StringBuilder sb = new StringBuilder();
+ for (var tokenId : tokens) {
+ var token = model.tokenId2Token.get(tokenId);
+ var skip = skipControl && token.type() == TokenType.control;
+ if ( ! skip) {
+ sb.append(token.text());
+ }
+ }
+ return denormalize(sb.toString());
}
/**