summaryrefslogtreecommitdiffstats
path: root/linguistics-components/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'linguistics-components/src/main')
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java54
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java109
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java9
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def9
4 files changed, 181 insertions, 0 deletions
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java
new file mode 100644
index 00000000000..ddb098c911d
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java
@@ -0,0 +1,54 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * @author bjorncs
+ */
+public record Encoding(
+ List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask,
+ List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) {
+
+ public record CharSpan(int start, int end) {
+ public static final CharSpan NONE = new CharSpan(-1, -1);
+ static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) {
+ if (s == null) return NONE;
+ return new CharSpan(s.getStart(), s.getEnd());
+ }
+ public boolean isNone() { return this.equals(NONE); }
+ }
+
+ public Encoding {
+ ids = List.copyOf(ids);
+ typeIds = List.copyOf(typeIds);
+ tokens = List.copyOf(tokens);
+ wordIds = List.copyOf(wordIds);
+ attentionMask = List.copyOf(attentionMask);
+ specialTokenMask = List.copyOf(specialTokenMask);
+ charTokenSpans = List.copyOf(charTokenSpans);
+ overflowing = List.copyOf(overflowing);
+ }
+
+ static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) {
+ return new Encoding(
+ toList(e.getIds()),
+ toList(e.getTypeIds()),
+ List.of(e.getTokens()),
+ toList(e.getWordIds()),
+ toList(e.getAttentionMask()),
+ toList(e.getSpecialTokenMask()),
+ Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(),
+ Arrays.stream(e.getOverflowing()).map(Encoding::from).toList());
+ }
+
+ private static List<Long> toList(long[] array) {
+ if (array == null) return List.of();
+ var list = new ArrayList<Long>(array.length);
+ for (long e : array) list.add(e);
+ return list;
+ }
+}
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
new file mode 100644
index 00000000000..56fba370470
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -0,0 +1,109 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.language.tools.Embed;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
+
+/**
+ * {@link Embedder}/{@link Segmenter} using Deep Java Library's HuggingFace Tokenizer.
+ *
+ * @author bjorncs
+ */
+public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
+
+ private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class);
+
+ @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); }
+
+ private HuggingFaceTokenizer(Builder b) {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ b.models.forEach((language, path) -> {
+ models.put(language,
+ uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
+ .optTokenizerPath(path)
+ .build()));
+ });
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
+ @Override
+ public List<Integer> embed(String text, Context ctx) {
+ var encoding = resolve(ctx.getLanguage()).encode(text);
+ var ids = encoding.getIds();
+ var result = new ArrayList<Integer>(ids.length-2); // heuristic: -2 to exclude start/end tokens
+ for (int i = 0; i < ids.length; i++)
+ if (encoding.getSpecialTokenMask()[i] == 0) result.add(Math.toIntExact(ids[i]));
+ return result;
+ }
+
+ @Override
+ public Tensor embed(String text, Context ctx, TensorType type) {
+ return Embed.asTensor(text, this, ctx, type);
+ }
+
+ @Override
+ public List<String> segment(String input, Language language) {
+ var encoding = resolve(language).encode(input);
+ var tokens = encoding.getTokens();
+ var result = new ArrayList<String>(tokens.length-2); // heuristic: -2 to exclude start/end tokens
+ for (int i = 0; i < tokens.length; i++)
+ if (encoding.getSpecialTokenMask()[i] == 0) result.add(tokens[i]);
+ return result;
+ }
+
+ @Override
+ public String decode(List<Integer> tokens, Context ctx) {
+ return resolve(ctx.getLanguage()).decode(toArray(tokens));
+ }
+
+ public Encoding encode(String text) { return encode(text, Language.UNKNOWN); }
+ public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); }
+ public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); }
+ public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); }
+
+ @Override public void close() { models.forEach((__, model) -> model.close()); }
+
+ private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
+ // Disregard language if there is default model
+ if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
+ if (models.containsKey(language)) return models.get(language);
+ throw new IllegalArgumentException("No model for language " + language);
+ }
+
+ private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); }
+
+ public static final class Builder {
+ private final Map<Language, Path> models = new EnumMap<>(Language.class);
+
+ public Builder() {}
+ public Builder(HuggingFaceTokenizerConfig cfg) {
+ for (var model : cfg.model())
+ addModel(Language.fromLanguageTag(model.language()), model.path());
+ }
+
+ public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
+ public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); }
+ public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
+ }
+
+} \ No newline at end of file
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java
new file mode 100644
index 00000000000..7cec01ffed6
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java
@@ -0,0 +1,9 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+/**
+ * @author bjorncs
+ */
+@ExportPackage
+package com.yahoo.language.huggingface;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
new file mode 100644
index 00000000000..9d0ab65c28f
--- /dev/null
+++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
@@ -0,0 +1,9 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+namespace=language.huggingface
+
+# The language a model is for, one of the language tags in com.yahoo.language.Language.
+# Use "unknown" for models to be used with any language.
+model[].language string
+# The path to the model relative to the application package root
+model[].path path