summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-05 16:27:13 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-05 16:27:13 +0200
commitb4699156f2b641e383b2459b736bd2e45c724a98 (patch)
treed600339cd96a6359625941b40140f1cdb486c36e /model-integration
parenta3f05ec2a40a7b870ca06ef0dfd7aed00d7afeb2 (diff)
Split out HF Tokenizer
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java50
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java29
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java47
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java71
4 files changed, 174 insertions, 23 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java
new file mode 100644
index 00000000000..f1c0244bfb3
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java
@@ -0,0 +1,50 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding.huggingface;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * @author bjorncs
+ */
+public record Encoding(
+ List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask,
+ List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) {
+
+ public record CharSpan(int start, int end) {
+ static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) {
+ return new CharSpan(s.getStart(), s.getEnd());
+ }
+ }
+
+ public Encoding {
+ ids = List.copyOf(ids);
+ typeIds = List.copyOf(typeIds);
+ tokens = List.copyOf(tokens);
+ wordIds = List.copyOf(wordIds);
+ attentionMask = List.copyOf(attentionMask);
+ specialTokenMask = List.copyOf(specialTokenMask);
+ charTokenSpans = List.copyOf(charTokenSpans);
+ overflowing = List.copyOf(overflowing);
+ }
+
+ static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) {
+ return new Encoding(
+ toList(e.getIds()),
+ toList(e.getTypeIds()),
+ List.of(e.getTokens()),
+ toList(e.getWordIds()),
+ toList(e.getAttentionMask()),
+ toList(e.getSpecialTokenMask()),
+ Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(),
+ Arrays.stream(e.getOverflowing()).map(Encoding::from).toList());
+ }
+
+ private static List<Long> toList(long[] array) {
+ var list = new ArrayList<Long>(array.length);
+ for (long e : array) list.add(e);
+ return list;
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 9572cfcb0e4..c7ccb33aca9 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -1,7 +1,5 @@
package ai.vespa.embedding.huggingface;
-import ai.djl.huggingface.tokenizers.Encoding;
-import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.AbstractComponent;
@@ -17,7 +15,6 @@ import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Paths;
-import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -38,19 +35,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
-
- try {
- ClassLoader tccl = Thread.currentThread().getContextClassLoader();
- try {
- Thread.currentThread().setContextClassLoader(getClass().getClassLoader());
- tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(config.tokenizerPath().toString()));
- } finally {
- Thread.currentThread().setContextClassLoader(tccl);
- }
- } catch (IOException e){
- LOG.info("Could not initialize the tokenizer");
- throw new IOException("Could not initialize the tokenizer.");
- }
+ tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
evaluator = onnx.evaluatorOf(config.transformerModel().toString());
validateModel();
}
@@ -74,7 +59,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@Override
public List<Integer> embed(String s, Context context) {
Encoding encoding = tokenizer.encode(s);
- List<Integer> tokenIds = longToInteger(encoding.getIds());
+ List<Integer> tokenIds = encoding.ids().stream().map(Long::intValue).toList();
int tokensSize = tokenIds.size();
@@ -86,12 +71,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return tokenIds;
}
- @Override public void deconstruct() { evaluator.close(); }
-
- public List<Integer> longToInteger(long[] values) {
- return Arrays.stream(values)
- .boxed().map(Long::intValue)
- .toList();
+ @Override
+ public void deconstruct() {
+ evaluator.close();
+ tokenizer.close();
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java
new file mode 100644
index 00000000000..e6765a4cc8a
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java
@@ -0,0 +1,47 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding.huggingface;
+
+import java.io.IOException;
+import java.nio.file.Path;
+
+/**
+ * Wrapper around Deep Java Library's HuggingFace tokenizer.
+ *
+ * @author bjorncs
+ */
+public class HuggingFaceTokenizer implements AutoCloseable {
+
+ private final ai.djl.huggingface.tokenizers.HuggingFaceTokenizer instance;
+
+ public HuggingFaceTokenizer(Path path) throws IOException { this(path, HuggingFaceTokenizerOptions.defaults()); }
+
+ public HuggingFaceTokenizer(Path path, HuggingFaceTokenizerOptions opts) throws IOException {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ instance = createInstance(path, opts);
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
+ public Encoding encode(String text) { return Encoding.from(instance.encode(text)); }
+
+ @Override public void close() { instance.close(); }
+
+ private static ai.djl.huggingface.tokenizers.HuggingFaceTokenizer createInstance(
+ Path path, HuggingFaceTokenizerOptions opts) throws IOException {
+ var builder = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path);
+ opts.addSpecialToken().ifPresent(builder::optAddSpecialTokens);
+ opts.truncation().ifPresent(builder::optTruncation);
+ if (opts.truncateFirstOnly()) builder.optTruncateFirstOnly();
+ if (opts.truncateSecondOnly()) builder.optTruncateSecondOnly();
+ opts.padding().ifPresent(builder::optPadding);
+ if (opts.padToMaxLength()) builder.optPadToMaxLength();
+ opts.maxLength().ifPresent(builder::optMaxLength);
+ opts.padToMultipleOf().ifPresent(builder::optPadToMultipleOf);
+ opts.stride().ifPresent(builder::optStride);
+ return builder.build();
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java
new file mode 100644
index 00000000000..74f80756603
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java
@@ -0,0 +1,71 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.embedding.huggingface;
+
+import java.util.Optional;
+import java.util.OptionalInt;
+
+/**
+ * @author bjorncs
+ */
+public class HuggingFaceTokenizerOptions {
+
+ private final Boolean addSpecialToken;
+ private final Boolean truncation;
+ private final boolean truncateFirstOnly;
+ private final boolean truncateSecondOnly;
+ private final Boolean padding;
+ private final boolean padToMaxLength;
+ private final Integer maxLength;
+ private final Integer padToMultipleOf;
+ private final Integer stride;
+
+ private HuggingFaceTokenizerOptions(Builder b) {
+ addSpecialToken = b.addSpecialToken;
+ truncation = b.truncation;
+ truncateFirstOnly = b.truncateFirstOnly;
+ truncateSecondOnly = b.truncateSecondOnly;
+ padding = b.padding;
+ padToMaxLength = b.padToMaxLength;
+ maxLength = b.maxLength;
+ padToMultipleOf = b.padToMultipleOf;
+ stride = b.stride;
+ }
+
+ public static Builder custom() { return new Builder(); }
+ public static HuggingFaceTokenizerOptions defaults() { return new Builder().build(); }
+
+ Optional<Boolean> addSpecialToken() { return Optional.ofNullable(addSpecialToken); }
+ Optional<Boolean> truncation() { return Optional.ofNullable(truncation); }
+ boolean truncateFirstOnly() { return truncateFirstOnly; }
+ boolean truncateSecondOnly() { return truncateSecondOnly; }
+ Optional<Boolean> padding() { return Optional.ofNullable(padding); }
+ boolean padToMaxLength() { return padToMaxLength; }
+ OptionalInt maxLength() { return maxLength != null ? OptionalInt.of(maxLength) : OptionalInt.empty(); }
+ OptionalInt padToMultipleOf() { return padToMultipleOf != null ? OptionalInt.of(padToMultipleOf) : OptionalInt.empty(); }
+ OptionalInt stride() { return stride != null ? OptionalInt.of(stride) : OptionalInt.empty(); }
+
+ public static class Builder {
+ private Boolean addSpecialToken;
+ private Boolean truncation;
+ private boolean truncateFirstOnly;
+ private boolean truncateSecondOnly;
+ private Boolean padding;
+ private boolean padToMaxLength;
+ private Integer maxLength;
+ private Integer padToMultipleOf;
+ private Integer stride;
+
+ public Builder addSpecialToken(boolean enabled) { addSpecialToken = enabled; return this; }
+ public Builder truncation(boolean enabled) { truncation = enabled; return this; }
+ public Builder truncateFirstOnly() { truncateFirstOnly = true; return this; }
+ public Builder truncateSecondOnly() { truncateSecondOnly = true; return this; }
+ public Builder padding(boolean enabled) { padding = enabled; return this; }
+ public Builder padToMaxLength() { padToMaxLength = true; return this; }
+ public Builder maxLength(int length) { maxLength = length; return this; }
+ public Builder padToMultipleOf(int num) { padToMultipleOf = num; return this; }
+ public Builder stride(int stride) { this.stride = stride; return this; }
+ public HuggingFaceTokenizerOptions build() { return new HuggingFaceTokenizerOptions(this); }
+ }
+
+}