summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2023-02-10 15:18:18 +0100
committerLester Solbakken <lesters@oath.com>2023-02-10 15:18:18 +0100
commita73052a800754c9e03dd4b1212e4f1c4d9e5ac7b (patch)
treef640fc1e8a80aff4d77d6d2b56ac4524e4582c91 /linguistics-components
parent2c261b2b9718d344690a7d202156a054776695d3 (diff)
Add skipping of control tokens
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/abi-spec.json1
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java11
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java17
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java12
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java1
5 files changed, 35 insertions, 7 deletions
diff --git a/linguistics-components/abi-spec.json b/linguistics-components/abi-spec.json
index f010ffd3e7c..4b713afba83 100644
--- a/linguistics-components/abi-spec.json
+++ b/linguistics-components/abi-spec.json
@@ -183,6 +183,7 @@
"public java.util.List segment(java.lang.String, com.yahoo.language.Language)",
"public java.util.List embed(java.lang.String, com.yahoo.language.process.Embedder$Context)",
"public java.lang.String decode(java.util.List, com.yahoo.language.process.Embedder$Context)",
+ "public java.lang.String decode(java.util.List, com.yahoo.language.process.Embedder$Context, boolean)",
"public com.yahoo.tensor.Tensor embed(java.lang.String, com.yahoo.language.process.Embedder$Context, com.yahoo.tensor.TensorType)",
"public java.lang.String normalize(java.lang.String)",
"public java.lang.String denormalize(java.lang.String)"
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
index b4f216c4c9c..28a1b9d2930 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/Model.java
@@ -22,7 +22,8 @@ final class Model {
final float minScore;
final float maxScore;
final Trie tokens = new Trie();
- final Map<Integer, String> tokenId2Token = new HashMap<>();
+ final Map<Integer, Token> tokenId2Token = new HashMap<>();
+
Model(Language language, Path path) {
try {
@@ -33,8 +34,10 @@ final class Model {
float maxScore = Float.MIN_VALUE;
for (int i = 0; i < sp.getPiecesCount(); i++) {
var piece = sp.getPieces(i);
- tokens.add(toTokenType(piece.getType()), i, piece.getPiece(), piece.getScore());
- tokenId2Token.put(i, piece.getPiece());
+ var type = toTokenType(piece.getType());
+ var word = piece.getPiece();
+ tokens.add(type, i, word, piece.getScore());
+ tokenId2Token.put(i, new Token(word, type));
minScore = Math.min(piece.getScore(), minScore);
maxScore = Math.max(piece.getScore(), maxScore);
}
@@ -61,4 +64,6 @@ final class Model {
return "SentencePiece model for " + language + ": '" + source + "'";
}
+ record Token(String text, TokenType type) { }
+
}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
index 474add27023..b4d542b0d82 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/sentencepiece/SentencePieceEmbedder.java
@@ -93,13 +93,24 @@ public class SentencePieceEmbedder implements Segmenter, Embedder {
*
* @param tokens the list of tokens to decode to a string
* @param context the context which specifies the language used to select a model
- * @return the string formed by decoding the tokens back to their string repreesentation
+ * @return the string formed by decoding the tokens back to their string representation
*/
@Override
public String decode(List<Integer> tokens, Embedder.Context context) {
+ return decode(tokens, context, false);
+ }
+
+ public String decode(List<Integer> tokens, Embedder.Context context, boolean skipControl) {
Model model = resolveModelFrom(context.getLanguage());
- String normalized = tokens.stream().map(model.tokenId2Token::get).collect(Collectors.joining());
- return denormalize(normalized);
+ StringBuilder sb = new StringBuilder();
+ for (var tokenId : tokens) {
+ var token = model.tokenId2Token.get(tokenId);
+ var skip = skipControl && token.type() == TokenType.control;
+ if ( ! skip) {
+ sb.append(token.text());
+ }
+ }
+ return denormalize(sb.toString());
}
/**
diff --git a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
index 8a7af01a8a3..daa31f8773b 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/sentencepiece/SentencePieceTest.java
@@ -3,11 +3,14 @@
package com.yahoo.language.sentencepiece;
import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
import com.yahoo.language.tools.EmbedderTester;
import org.junit.Test;
import java.io.File;
+import static org.junit.Assert.assertEquals;
+
/**
* @author bratseth
*/
@@ -53,6 +56,15 @@ public class SentencePieceTest {
}
@Test
+ public void testSkipControl() {
+ var embedder = new SentencePieceEmbedder.Builder("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").build();
+ var context = new Embedder.Context("test");
+ var tokens = embedder.embed("<s>hello</s>, world!", context);
+ assertEquals("<s>hello</s>, world!", embedder.decode(tokens, context, false));
+ assertEquals("hello, world!", embedder.decode(tokens, context, true));
+ }
+
+ @Test
public void testNoCollapse() {
var builder = new SentencePieceEmbedder.Builder()
.addDefaultModel(new File("src/test/models/sentencepiece/en.wiki.bpe.vs10000.model").toPath())
diff --git a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
index eddbb70494d..a403b1ee943 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/tools/EmbedderTester.java
@@ -66,5 +66,4 @@ public class EmbedderTester {
assertEquals(input, result);
}
-
}