From a73052a800754c9e03dd4b1212e4f1c4d9e5ac7b Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 10 Feb 2023 15:18:18 +0100 Subject: Add skipping of control tokens --- .../java/com/yahoo/language/sentencepiece/Model.java | 11 ++++++++--- .../language/sentencepiece/SentencePieceEmbedder.java | 17 ++++++++++++++--- .../yahoo/language/sentencepiece/SentencePieceTest.java | 12 ++++++++++++ .../java/com/yahoo/language/tools/EmbedderTester.java | 1 - 4 files changed, 34 insertions(+), 7 deletions(-) (limited to 'linguistics-components/src') diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java index b4f216c4c9c..28a1b9d2930 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -22,7 +22,8 @@ final class Model { final float minScore; final float maxScore; final Trie tokens = new Trie(); - final Map tokenId2Token = new HashMap<>(); + final Map tokenId2Token = new HashMap<>(); + Model(Language language, Path path) { try { @@ -33,8 +34,10 @@ final class Model { float maxScore = Float.MIN_VALUE; for (int i = 0; i < sp.getPiecesCount(); i++) { var piece = sp.getPieces(i); - tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); - tokenId2Token.put(i, piece.getPiece()); + var type = toTokenType(piece.getType()); + var word = piece.getPiece(); + tokens.add(type, i, word, piece.getScore()); + tokenId2Token.put(i, new Token(word, type)); minScore = Math.min(piece.getScore(), minScore); maxScore = Math.max(piece.getScore(), maxScore); } @@ -61,4 +64,6 @@ final class Model { return "SentencePiece model for " + language + ": '" + source + "'"; } + record Token(String text, TokenType type) { } + } diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index 474add27023..b4d542b0d82 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -93,13 +93,24 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { * * @param tokens the list of tokens to decode to a string * @param context the context which specifies the language used to select a model - * @return the string formed by decoding the tokens back to their string repreesentation + * @return the string formed by decoding the tokens back to their string representation */ @Override public String decode(List tokens, Embedder.Context context) { + return decode(tokens, context, false); + } + + public String decode(List tokens, Embedder.Context context, boolean skipControl) { Model model = resolveModelFrom(context.getLanguage()); - String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining()); - return denormalize(normalized); + StringBuilder sb = new StringBuilder(); + for (var tokenId : tokens) { + var token = model.tokenId2Token.get(tokenId); + var skip = skipControl && token.type() == TokenType.control; + if ( ! skip) { + sb.append(token.text()); + } + } + return denormalize(sb.toString()); } /** diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 8a7af01a8a3..daa31f8773b 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -3,11 +3,14 @@ package com.yahoo.language.sentencepiece; import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; import com.yahoo.language.tools.EmbedderTester; import org.junit.Test; import java.io.File; +import static org.junit.Assert.assertEquals; + /** * @author bratseth */ @@ -52,6 +55,15 @@ public class SentencePieceTest { tester.assertDecoded(")(/&#(small)/ \"in quotes\")"); } + @Test + public void testSkipControl() { + var embedder = new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build(); + var context = new Embedder.Context("test"); + var tokens = embedder.embed("hello, world!", context); + assertEquals("hello, world!", embedder.decode(tokens, context, false)); + assertEquals("hello, world!", embedder.decode(tokens, context, true)); + } + @Test public void testNoCollapse() { var builder = new SentencePieceEmbedder.Builder() diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java index eddbb70494d..a403b1ee943 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java +++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java @@ -66,5 +66,4 @@ public class EmbedderTester { assertEquals(input, result); } - } -- cgit v1.2.3