diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-05 16:27:13 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-05 16:27:13 +0200 |
commit | b4699156f2b641e383b2459b736bd2e45c724a98 (patch) | |
tree | d600339cd96a6359625941b40140f1cdb486c36e /model-integration/src/main | |
parent | a3f05ec2a40a7b870ca06ef0dfd7aed00d7afeb2 (diff) |
Split out HF Tokenizer
Diffstat (limited to 'model-integration/src/main')
4 files changed, 174 insertions, 23 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java new file mode 100644 index 00000000000..f1c0244bfb3 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding.huggingface; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * @author bjorncs + */ +public record Encoding( + List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask, + List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) { + + public record CharSpan(int start, int end) { + static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) { + return new CharSpan(s.getStart(), s.getEnd()); + } + } + + public Encoding { + ids = List.copyOf(ids); + typeIds = List.copyOf(typeIds); + tokens = List.copyOf(tokens); + wordIds = List.copyOf(wordIds); + attentionMask = List.copyOf(attentionMask); + specialTokenMask = List.copyOf(specialTokenMask); + charTokenSpans = List.copyOf(charTokenSpans); + overflowing = List.copyOf(overflowing); + } + + static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) { + return new Encoding( + toList(e.getIds()), + toList(e.getTypeIds()), + List.of(e.getTokens()), + toList(e.getWordIds()), + toList(e.getAttentionMask()), + toList(e.getSpecialTokenMask()), + Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(), + Arrays.stream(e.getOverflowing()).map(Encoding::from).toList()); + } + + private static List<Long> toList(long[] array) { + var list = new ArrayList<Long>(array.length); + for (long e : array) list.add(e); + return list; + } +} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 9572cfcb0e4..c7ccb33aca9 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -1,7 +1,5 @@ package ai.vespa.embedding.huggingface; -import ai.djl.huggingface.tokenizers.Encoding; -import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.AbstractComponent; @@ -17,7 +15,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.nio.file.Paths; -import java.util.Arrays; import java.util.List; import java.util.Map; @@ -38,19 +35,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); - - try { - ClassLoader tccl = Thread.currentThread().getContextClassLoader(); - try { - Thread.currentThread().setContextClassLoader(getClass().getClassLoader()); - tokenizer = HuggingFaceTokenizer.newInstance(Paths.get(config.tokenizerPath().toString())); - } finally { - Thread.currentThread().setContextClassLoader(tccl); - } - } catch (IOException e){ - LOG.info("Could not initialize the tokenizer"); - throw new IOException("Could not initialize the tokenizer."); - } + tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); evaluator = onnx.evaluatorOf(config.transformerModel().toString()); validateModel(); } @@ -74,7 +59,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { Encoding encoding = tokenizer.encode(s); - List<Integer> tokenIds = longToInteger(encoding.getIds()); + List<Integer> tokenIds = encoding.ids().stream().map(Long::intValue).toList(); int tokensSize = tokenIds.size(); @@ -86,12 +71,10 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return tokenIds; } - @Override public void deconstruct() { evaluator.close(); } - - public List<Integer> longToInteger(long[] values) { - return Arrays.stream(values) - .boxed().map(Long::intValue) - .toList(); + @Override + public void deconstruct() { + evaluator.close(); + tokenizer.close(); } @Override diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java new file mode 100644 index 00000000000..e6765a4cc8a --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java @@ -0,0 +1,47 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding.huggingface; + +import java.io.IOException; +import java.nio.file.Path; + +/** + * Wrapper around Deep Java Library's HuggingFace tokenizer. + * + * @author bjorncs + */ +public class HuggingFaceTokenizer implements AutoCloseable { + + private final ai.djl.huggingface.tokenizers.HuggingFaceTokenizer instance; + + public HuggingFaceTokenizer(Path path) throws IOException { this(path, HuggingFaceTokenizerOptions.defaults()); } + + public HuggingFaceTokenizer(Path path, HuggingFaceTokenizerOptions opts) throws IOException { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + instance = createInstance(path, opts); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + + public Encoding encode(String text) { return Encoding.from(instance.encode(text)); } + + @Override public void close() { instance.close(); } + + private static ai.djl.huggingface.tokenizers.HuggingFaceTokenizer createInstance( + Path path, HuggingFaceTokenizerOptions opts) throws IOException { + var builder = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path); + opts.addSpecialToken().ifPresent(builder::optAddSpecialTokens); + opts.truncation().ifPresent(builder::optTruncation); + if (opts.truncateFirstOnly()) builder.optTruncateFirstOnly(); + if (opts.truncateSecondOnly()) builder.optTruncateSecondOnly(); + opts.padding().ifPresent(builder::optPadding); + if (opts.padToMaxLength()) builder.optPadToMaxLength(); + opts.maxLength().ifPresent(builder::optMaxLength); + opts.padToMultipleOf().ifPresent(builder::optPadToMultipleOf); + opts.stride().ifPresent(builder::optStride); + return builder.build(); + } +} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java new file mode 100644 index 00000000000..74f80756603 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java @@ -0,0 +1,71 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding.huggingface; + +import java.util.Optional; +import java.util.OptionalInt; + +/** + * @author bjorncs + */ +public class HuggingFaceTokenizerOptions { + + private final Boolean addSpecialToken; + private final Boolean truncation; + private final boolean truncateFirstOnly; + private final boolean truncateSecondOnly; + private final Boolean padding; + private final boolean padToMaxLength; + private final Integer maxLength; + private final Integer padToMultipleOf; + private final Integer stride; + + private HuggingFaceTokenizerOptions(Builder b) { + addSpecialToken = b.addSpecialToken; + truncation = b.truncation; + truncateFirstOnly = b.truncateFirstOnly; + truncateSecondOnly = b.truncateSecondOnly; + padding = b.padding; + padToMaxLength = b.padToMaxLength; + maxLength = b.maxLength; + padToMultipleOf = b.padToMultipleOf; + stride = b.stride; + } + + public static Builder custom() { return new Builder(); } + public static HuggingFaceTokenizerOptions defaults() { return new Builder().build(); } + + Optional<Boolean> addSpecialToken() { return Optional.ofNullable(addSpecialToken); } + Optional<Boolean> truncation() { return Optional.ofNullable(truncation); } + boolean truncateFirstOnly() { return truncateFirstOnly; } + boolean truncateSecondOnly() { return truncateSecondOnly; } + Optional<Boolean> padding() { return Optional.ofNullable(padding); } + boolean padToMaxLength() { return padToMaxLength; } + OptionalInt maxLength() { return maxLength != null ? OptionalInt.of(maxLength) : OptionalInt.empty(); } + OptionalInt padToMultipleOf() { return padToMultipleOf != null ? OptionalInt.of(padToMultipleOf) : OptionalInt.empty(); } + OptionalInt stride() { return stride != null ? OptionalInt.of(stride) : OptionalInt.empty(); } + + public static class Builder { + private Boolean addSpecialToken; + private Boolean truncation; + private boolean truncateFirstOnly; + private boolean truncateSecondOnly; + private Boolean padding; + private boolean padToMaxLength; + private Integer maxLength; + private Integer padToMultipleOf; + private Integer stride; + + public Builder addSpecialToken(boolean enabled) { addSpecialToken = enabled; return this; } + public Builder truncation(boolean enabled) { truncation = enabled; return this; } + public Builder truncateFirstOnly() { truncateFirstOnly = true; return this; } + public Builder truncateSecondOnly() { truncateSecondOnly = true; return this; } + public Builder padding(boolean enabled) { padding = enabled; return this; } + public Builder padToMaxLength() { padToMaxLength = true; return this; } + public Builder maxLength(int length) { maxLength = length; return this; } + public Builder padToMultipleOf(int num) { padToMultipleOf = num; return this; } + public Builder stride(int stride) { this.stride = stride; return this; } + public HuggingFaceTokenizerOptions build() { return new HuggingFaceTokenizerOptions(this); } + } + +} |