diff options
author | Lester Solbakken <lesters@oath.com> | 2023-02-10 14:06:29 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2023-02-10 14:06:29 +0100 |
commit | f5118dcd8b04293cf65434f1509fa0e06833492b (patch) | |
tree | c6d77c5a81c7fbfe697e219897a459879871ef0e /linguistics-components/src | |
parent | f62bb48baf715609606faa82a6119012b8a727de (diff) |
Add decoding of sentencepiece token sequence to text
Diffstat (limited to 'linguistics-components/src')
5 files changed, 40 insertions, 2 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java index 74f300057dc..b4f216c4c9c 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java @@ -7,6 +7,8 @@ import sentencepiece.SentencepieceModel; import java.io.IOException; import java.nio.file.Path; +import java.util.HashMap; +import java.util.Map; /** * A SentencePiece model @@ -20,6 +22,7 @@ final class Model { final float minScore; final float maxScore; final Trie tokens = new Trie(); + final Map<Integer, String> tokenId2Token = new HashMap<>(); Model(Language language, Path path) { try { @@ -31,6 +34,7 @@ final class Model { for (int i = 0; i < sp.getPiecesCount(); i++) { var piece = sp.getPieces(i); tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore()); + tokenId2Token.put(i, piece.getPiece()); minScore = Math.min(piece.getScore(), minScore); maxScore = Math.max(piece.getScore(), maxScore); } @@ -48,7 +52,7 @@ final class Model { case NORMAL : return TokenType.text; case CONTROL : return TokenType.control; case UNUSED : return TokenType.unused; - default : throw new IllegalArgumentException("Unknkown token type " + type); + default : throw new IllegalArgumentException("Unknown token type " + type); } } diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java index e3a0457d77a..474add27023 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java +++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java @@ -89,6 +89,20 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { } /** + * Converts the list of token id's into a text. The opposite operation of embed. + * + * @param tokens the list of tokens to decode to a string + * @param context the context which specifies the language used to select a model + * @return the string formed by decoding the tokens back to their string repreesentation + */ + @Override + public String decode(List<Integer> tokens, Embedder.Context context) { + Model model = resolveModelFrom(context.getLanguage()); + String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining()); + return denormalize(normalized); + } + + /** * <p>Embeds text into a tensor.</p> * * <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order @@ -137,6 +151,11 @@ public class SentencePieceEmbedder implements Segmenter, Embedder { return b.toString(); } + public String denormalize(String s) { + String result = s.replace(SentencePieceAlgorithm.spaceSymbol, ' '); + return result.charAt(0) == ' ' ? result.substring(1) : result; // Skip first space + } + public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); diff --git a/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java index a724d6542d5..0643eee7094 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java +++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java @@ -20,7 +20,6 @@ import java.util.List; import java.util.Map; import java.util.NavigableMap; import java.util.TreeMap; -import java.util.stream.Collectors; /** * A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index). diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java index 2fbafb23485..8a7af01a8a3 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java @@ -45,6 +45,14 @@ public class SentencePieceTest { } @Test + public void testEnglishDecoding() { + var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build()); + tester.assertDecoded("this is a sentence"); + tester.assertDecoded("hello, world!"); + tester.assertDecoded(")(/&#(small)/ \"in quotes\")"); + } + + @Test public void testNoCollapse() { var builder = new SentencePieceEmbedder.Builder() .addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath()) diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java index 638efba2480..eddbb70494d 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java +++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java @@ -59,4 +59,12 @@ public class EmbedderTester { expectedSegments, ((Segmenter)embedder).segment(input, language).toArray()); } + public void assertDecoded(String input) { + var context = new Embedder.Context("test"); + var tokens = embedder.embed(input, context); + var result = embedder.decode(tokens, context); + assertEquals(input, result); + } + + } |