diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 11:07:23 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 16:41:35 +0200 |
commit | ce7dd2c983a8840981786eef95a9cc4741487be7 (patch) | |
tree | 8aa4140fb8fd7ec8a95d425adda4119d34cf46d3 /linguistics-components | |
parent | 386e4198d8459803eec0ead6ad81a821737082a7 (diff) |
Make HF tokenizer a separate embedder
Diffstat (limited to 'linguistics-components')
8 files changed, 300 insertions, 2 deletions
diff --git a/linguistics-components/pom.xml b/linguistics-components/pom.xml index ad4cbd6ce22..5031ad73556 100644 --- a/linguistics-components/pom.xml +++ b/linguistics-components/pom.xml @@ -19,12 +19,42 @@ <artifactId>protobuf-java</artifactId> </dependency> <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> + <groupId>ai.djl.huggingface</groupId> + <artifactId>tokenizers</artifactId> + <version>0.22.1</version> + <exclusions> + <exclusion> + <groupId>com.google.code.gson</groupId> + <artifactId>gson</artifactId> + </exclusion> + <exclusion> + <groupId>net.java.dev.jna</groupId> + <artifactId>jna</artifactId> + </exclusion> + <exclusion> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-engine</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.junit.vintage</groupId> + <artifactId>junit-vintage-engine</artifactId> <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>jdisc_core</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>annotations</artifactId> <version>${project.version}</version> <scope>provided</scope> diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java new file mode 100644 index 00000000000..ddb098c911d --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java @@ -0,0 +1,54 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * @author bjorncs + */ +public record Encoding( + List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask, + List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) { + + public record CharSpan(int start, int end) { + public static final CharSpan NONE = new CharSpan(-1, -1); + static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) { + if (s == null) return NONE; + return new CharSpan(s.getStart(), s.getEnd()); + } + public boolean isNone() { return this.equals(NONE); } + } + + public Encoding { + ids = List.copyOf(ids); + typeIds = List.copyOf(typeIds); + tokens = List.copyOf(tokens); + wordIds = List.copyOf(wordIds); + attentionMask = List.copyOf(attentionMask); + specialTokenMask = List.copyOf(specialTokenMask); + charTokenSpans = List.copyOf(charTokenSpans); + overflowing = List.copyOf(overflowing); + } + + static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) { + return new Encoding( + toList(e.getIds()), + toList(e.getTypeIds()), + List.of(e.getTokens()), + toList(e.getWordIds()), + toList(e.getAttentionMask()), + toList(e.getSpecialTokenMask()), + Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(), + Arrays.stream(e.getOverflowing()).map(Encoding::from).toList()); + } + + private static List<Long> toList(long[] array) { + if (array == null) return List.of(); + var list = new ArrayList<Long>(array.length); + for (long e : array) list.add(e); + return list; + } +} diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java new file mode 100644 index 00000000000..56fba370470 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -0,0 +1,109 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.language.tools.Embed; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collection; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; + +import static com.yahoo.yolean.Exceptions.uncheck; + +/** + * {@link Embedder}/{@link Segmenter} using Deep Java Library's HuggingFace Tokenizer. + * + * @author bjorncs + */ +public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable { + + private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class); + + @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); } + + private HuggingFaceTokenizer(Builder b) { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + b.models.forEach((language, path) -> { + models.put(language, + uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .build())); + }); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + + @Override + public List<Integer> embed(String text, Context ctx) { + var encoding = resolve(ctx.getLanguage()).encode(text); + var ids = encoding.getIds(); + var result = new ArrayList<Integer>(ids.length-2); // heuristic: -2 to exclude start/end tokens + for (int i = 0; i < ids.length; i++) + if (encoding.getSpecialTokenMask()[i] == 0) result.add(Math.toIntExact(ids[i])); + return result; + } + + @Override + public Tensor embed(String text, Context ctx, TensorType type) { + return Embed.asTensor(text, this, ctx, type); + } + + @Override + public List<String> segment(String input, Language language) { + var encoding = resolve(language).encode(input); + var tokens = encoding.getTokens(); + var result = new ArrayList<String>(tokens.length-2); // heuristic: -2 to exclude start/end tokens + for (int i = 0; i < tokens.length; i++) + if (encoding.getSpecialTokenMask()[i] == 0) result.add(tokens[i]); + return result; + } + + @Override + public String decode(List<Integer> tokens, Context ctx) { + return resolve(ctx.getLanguage()).decode(toArray(tokens)); + } + + public Encoding encode(String text) { return encode(text, Language.UNKNOWN); } + public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); } + public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); } + public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); } + + @Override public void close() { models.forEach((__, model) -> model.close()); } + + private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { + // Disregard language if there is default model + if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); + if (models.containsKey(language)) return models.get(language); + throw new IllegalArgumentException("No model for language " + language); + } + + private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); } + + public static final class Builder { + private final Map<Language, Path> models = new EnumMap<>(Language.class); + + public Builder() {} + public Builder(HuggingFaceTokenizerConfig cfg) { + for (var model : cfg.model()) + addModel(Language.fromLanguageTag(model.language()), model.path()); + } + + public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } + public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } + public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } + } + +}
\ No newline at end of file diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java new file mode 100644 index 00000000000..7cec01ffed6 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java @@ -0,0 +1,9 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.language.huggingface; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def new file mode 100644 index 00000000000..9d0ab65c28f --- /dev/null +++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +namespace=language.huggingface + +# The language a model is for, one of the language tags in com.yahoo.language.Language. +# Use "unknown" for models to be used with any language. +model[].language string +# The path to the model relative to the application package root +model[].path path diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java new file mode 100644 index 00000000000..f9fa0ef2afe --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -0,0 +1,87 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import com.yahoo.language.tools.EmbedderTester; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.zip.GZIPInputStream; + +/** + * @author bjorncs + */ +class HuggingFaceTokenizerTest { + + @TempDir Path tmp; + + @Test + void bert_tokenizer() throws IOException { + try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) { + var tester = new EmbedderTester(tokenizer); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + tester.assertSegmented("overcommunication", "over", "##com", "##mun", "##ication"); + tester.assertEmbedded("what was the impact of the manhattan project", + "tensor(x[8])", + 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); + tester.assertDecoded("what was the impact of the manhattan project"); + } + } + + @Test + void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException { + try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) { + var tester = new EmbedderTester(tokenizer); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁hel"); + tester.assertSegmented("hello", "▁hell", "o"); + tester.assertSegmented("hei", "▁hei"); + tester.assertSegmented("hei you", "▁hei", "▁you"); + tester.assertSegmented("hei you", "▁hei", "▁you"); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("hello world!", "▁hell", "o", "▁world", "!"); + tester.assertSegmented("Hello, world!", "▁Hello", ",", "▁world", "!"); + tester.assertSegmented("HELLO, world!", "▁H", "ELLO", ",", "▁world", "!"); + tester.assertSegmented("KHJKJHHKJHHSH", "▁KH", "JK", "J", "H", "HK", "J", "HH", "SH"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁KH", "JK", "J", "H", "HK", "J", "HH", "SH", "▁hell", "o"); + tester.assertSegmented(" hello ", "▁hell", "o"); + tester.assertSegmented(")(/&#()/\"\")", "▁", ")(", "/", "&#", "(", ")", "/", "\"", "\")"); + tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁", ")(", "/", "&#", "(", "s", "mall", ")", "/", "\"", "in", "▁quote", "s", "\")"); + tester.assertSegmented("x.400AS", "▁x", ".", "400", "AS"); + tester.assertSegmented("A normal sentence. Yes one more.", "▁A", "▁normal", "▁sentence", ".", "▁Yes", "▁one", "▁more", "."); + + tester.assertEmbedded("hello, world!", "tensor(d[10])", 33600, 31, 4, 8999, 38); + tester.assertEmbedded("Hello, world!", "tensor(d[10])", 35378, 4, 8999, 38); + tester.assertEmbedded("hello, world!", "tensor(d[2])", 33600, 31, 4, 8999, 38); + + tester.assertDecoded("this is a sentence"); + tester.assertDecoded("hello, world!"); + tester.assertDecoded(")(/&#(small)/ \"in quotes\")"); + } + } + + private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { + return new HuggingFaceTokenizer.Builder() + .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)))) + .build(); + } + + private static Path decompressModelFile(Path tmp, Path source) throws IOException { + Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); + try (InputStream in = new GZIPInputStream(Files.newInputStream(source)); + OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) { + in.transferTo(out); + } + return destination; + } + +}
\ No newline at end of file diff --git a/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz b/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz Binary files differnew file mode 100644 index 00000000000..7d0541849f7 --- /dev/null +++ b/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz diff --git a/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz b/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz Binary files differnew file mode 100644 index 00000000000..7b61a27198c --- /dev/null +++ b/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz |