summaryrefslogtreecommitdiffstats
path: root/linguistics-components
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-11 15:41:00 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-11 16:41:54 +0200
commitfe63824738fc1892221311e7ddd777efcb209f5b (patch)
treedc7d3ce16c4e56ab7cbbc941f2cb9f162d6dacb2 /linguistics-components
parentae700d12753e1a81de4def087d2f64607f0361df (diff)
Disable special tokens by default
Diffstat (limited to 'linguistics-components')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java19
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def2
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java1
3 files changed, 10 insertions, 12 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index dd53bd1c695..b92e0678970 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -13,7 +13,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.nio.file.Path;
-import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
@@ -41,6 +41,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
models.put(language,
uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
.build()));
});
} finally {
@@ -51,11 +52,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
@Override
public List<Integer> embed(String text, Context ctx) {
var encoding = resolve(ctx.getLanguage()).encode(text);
- var ids = encoding.getIds();
- var result = new ArrayList<Integer>(ids.length-2); // heuristic: -2 to exclude start/end tokens
- for (int i = 0; i < ids.length; i++)
- if (encoding.getSpecialTokenMask()[i] == 0) result.add(Math.toIntExact(ids[i]));
- return result;
+ return Arrays.stream(encoding.getIds()).mapToInt(Math::toIntExact).boxed().toList();
}
@Override
@@ -65,12 +62,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
@Override
public List<String> segment(String input, Language language) {
- var encoding = resolve(language).encode(input);
- var tokens = encoding.getTokens();
- var result = new ArrayList<String>(tokens.length-2); // heuristic: -2 to exclude start/end tokens
- for (int i = 0; i < tokens.length; i++)
- if (encoding.getSpecialTokenMask()[i] == 0) result.add(tokens[i]);
- return result;
+ return List.of(resolve(language).encode(input).getTokens());
}
@Override
@@ -96,15 +88,18 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public static final class Builder {
private final Map<Language, Path> models = new EnumMap<>(Language.class);
+ private Boolean addSpecialTokens;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
for (var model : cfg.model())
addModel(Language.fromLanguageTag(model.language()), model.path());
+ addSpecialTokens(cfg.addSpecialTokens());
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); }
+ public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
index 9d0ab65c28f..a3e54ea38da 100644
--- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
+++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
@@ -7,3 +7,5 @@ namespace=language.huggingface
model[].language string
# The path to the model relative to the application package root
model[].path path
+
+addSpecialTokens bool default=true \ No newline at end of file
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index f9fa0ef2afe..c79ecbfbfbe 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -71,6 +71,7 @@ class HuggingFaceTokenizerTest {
private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException {
return new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(false)
.addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model))))
.build();
}