diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 15:41:00 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 16:41:54 +0200 |
commit | fe63824738fc1892221311e7ddd777efcb209f5b (patch) | |
tree | dc7d3ce16c4e56ab7cbbc941f2cb9f162d6dacb2 /linguistics-components | |
parent | ae700d12753e1a81de4def087d2f64607f0361df (diff) |
Disable special tokens by default
Diffstat (limited to 'linguistics-components')
3 files changed, 10 insertions, 12 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index dd53bd1c695..b92e0678970 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -13,7 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.nio.file.Path; -import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.EnumMap; import java.util.List; @@ -41,6 +41,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, models.put(language, uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) .build())); }); } finally { @@ -51,11 +52,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, @Override public List<Integer> embed(String text, Context ctx) { var encoding = resolve(ctx.getLanguage()).encode(text); - var ids = encoding.getIds(); - var result = new ArrayList<Integer>(ids.length-2); // heuristic: -2 to exclude start/end tokens - for (int i = 0; i < ids.length; i++) - if (encoding.getSpecialTokenMask()[i] == 0) result.add(Math.toIntExact(ids[i])); - return result; + return Arrays.stream(encoding.getIds()).mapToInt(Math::toIntExact).boxed().toList(); } @Override @@ -65,12 +62,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, @Override public List<String> segment(String input, Language language) { - var encoding = resolve(language).encode(input); - var tokens = encoding.getTokens(); - var result = new ArrayList<String>(tokens.length-2); // heuristic: -2 to exclude start/end tokens - for (int i = 0; i < tokens.length; i++) - if (encoding.getSpecialTokenMask()[i] == 0) result.add(tokens[i]); - return result; + return List.of(resolve(language).encode(input).getTokens()); } @Override @@ -96,15 +88,18 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public static final class Builder { private final Map<Language, Path> models = new EnumMap<>(Language.class); + private Boolean addSpecialTokens; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { for (var model : cfg.model()) addModel(Language.fromLanguageTag(model.language()), model.path()); + addSpecialTokens(cfg.addSpecialTokens()); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } + public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def index 9d0ab65c28f..a3e54ea38da 100644 --- a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def +++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def @@ -7,3 +7,5 @@ namespace=language.huggingface model[].language string # The path to the model relative to the application package root model[].path path + +addSpecialTokens bool default=true
\ No newline at end of file diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index f9fa0ef2afe..c79ecbfbfbe 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -71,6 +71,7 @@ class HuggingFaceTokenizerTest { private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { return new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)))) .build(); } |