summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2023-02-10 14:06:29 +0100
committerLester Solbakken <lesters@oath.com>2023-02-10 14:06:29 +0100
commitf5118dcd8b04293cf65434f1509fa0e06833492b (patch)
treec6d77c5a81c7fbfe697e219897a459879871ef0e /linguistics-components
parentf62bb48baf715609606faa82a6119012b8a727de (diff)
Add decoding of sentencepiece token sequence to text
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java6
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java19
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java1
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java8
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java8
5 files changed, 40 insertions, 2 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
index 74f300057dc..b4f216c4c9c 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
@@ -7,6 +7,8 @@ import sentencepiece.SentencepieceModel;
import java.io.IOException;
import java.nio.file.Path;
+import java.util.HashMap;
+import java.util.Map;
/**
* A SentencePiece model
@@ -20,6 +22,7 @@ final class Model {
final float minScore;
final float maxScore;
final Trie tokens = new Trie();
+ final Map<Integer, String> tokenId2Token = new HashMap<>();
Model(Language language, Path path) {
try {
@@ -31,6 +34,7 @@ final class Model {
for (int i = 0; i < sp.getPiecesCount(); i++) {
var piece = sp.getPieces(i);
tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore());
+ tokenId2Token.put(i, piece.getPiece());
minScore = Math.min(piece.getScore(), minScore);
maxScore = Math.max(piece.getScore(), maxScore);
}
@@ -48,7 +52,7 @@ final class Model {
case NORMAL : return TokenType.text;
case CONTROL : return TokenType.control;
case UNUSED : return TokenType.unused;
- default : throw new IllegalArgumentException("Unknkown token type " + type);
+ default : throw new IllegalArgumentException("Unknown token type " + type);
}
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index e3a0457d77a..474add27023 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -89,6 +89,20 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
}
/**
+ * Converts the list of token id's into a text. The opposite operation of embed.
+ *
+ * @param tokens the list of tokens to decode to a string
+ * @param context the context which specifies the language used to select a model
+ * @return the string formed by decoding the tokens back to their string repreesentation
+ */
+ @Override
+ public String decode(List<Integer> tokens, Embedder.Context context) {
+ Model model = resolveModelFrom(context.getLanguage());
+ String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining());
+ return denormalize(normalized);
+ }
+
+ /**
* <p>Embeds text into a tensor.</p>
*
* <p>If the tensor type is indexed 1-d (bound or unbound) this will return a tensor containing the token ids in the order
@@ -137,6 +151,11 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
return b.toString();
}
+ public String denormalize(String s) {
+ String result = s.replace(SentencePieceAlgorithm.spaceSymbol, ' ');
+ return result.charAt(0) == ' ' ? result.substring(1) : result; // Skip first space
+ }
+
public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
diff --git a/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
index a724d6542d5..0643eee7094 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/wordpiece/Model.java
@@ -20,7 +20,6 @@ import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
-import java.util.stream.Collectors;
/**
* A WordPiece embedder "model" - just a vocabulary of strings with a fixed id (index).
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 2fbafb23485..8a7af01a8a3 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -45,6 +45,14 @@ public class SentencePieceTest {
}
@Test
+ public void testEnglishDecoding() {
+ var tester = new EmbedderTester(new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build());
+ tester.assertDecoded("this is a sentence");
+ tester.assertDecoded("hello, world!");
+ tester.assertDecoded(")(/&#(small)/ \"in quotes\")");
+ }
+
+ @Test
public void testNoCollapse() {
var builder = new SentencePieceEmbedder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
index 638efba2480..eddbb70494d 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
@@ -59,4 +59,12 @@ public class EmbedderTester {
expectedSegments, ((Segmenter)embedder).segment(input, language).toArray());
}
+ public void assertDecoded(String input) {
+ var context = new Embedder.Context("test");
+ var tokens = embedder.embed(input, context);
+ var result = embedder.decode(tokens, context);
+ assertEquals(input, result);
+ }
+
+
}