diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 11:07:23 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 16:41:35 +0200 |
commit | ce7dd2c983a8840981786eef95a9cc4741487be7 (patch) | |
tree | 8aa4140fb8fd7ec8a95d425adda4119d34cf46d3 /model-integration/src/main | |
parent | 386e4198d8459803eec0ead6ad81a821737082a7 (diff) |
Make HF tokenizer a separate embedder
Diffstat (limited to 'model-integration/src/main')
4 files changed, 6 insertions, 177 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java deleted file mode 100644 index 274c29a57b2..00000000000 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.embedding.huggingface; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** - * @author bjorncs - */ -public record Encoding( - List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask, - List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) { - - public record CharSpan(int start, int end) { - public static final CharSpan NONE = new CharSpan(-1, -1); - static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) { - if (s == null) return NONE; - return new CharSpan(s.getStart(), s.getEnd()); - } - public boolean isNone() { return this.equals(NONE); } - } - - public Encoding { - ids = List.copyOf(ids); - typeIds = List.copyOf(typeIds); - tokens = List.copyOf(tokens); - wordIds = List.copyOf(wordIds); - attentionMask = List.copyOf(attentionMask); - specialTokenMask = List.copyOf(specialTokenMask); - charTokenSpans = List.copyOf(charTokenSpans); - overflowing = List.copyOf(overflowing); - } - - static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) { - return new Encoding( - toList(e.getIds()), - toList(e.getTypeIds()), - List.of(e.getTokens()), - toList(e.getWordIds()), - toList(e.getAttentionMask()), - toList(e.getSpecialTokenMask()), - Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(), - Arrays.stream(e.getOverflowing()).map(Encoding::from).toList()); - } - - private static List<Long> toList(long[] array) { - if (array == null) return List.of(); - var list = new ArrayList<Long>(array.length); - for (long e : array) list.add(e); - return list; - } -} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 365a50f47b5..5faff435a30 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -6,6 +6,7 @@ import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -14,7 +15,6 @@ import com.yahoo.tensor.TensorType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; import java.nio.file.Paths; import java.util.List; import java.util.Map; @@ -32,13 +32,15 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final OnnxEvaluator evaluator; @Inject - public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException { + public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); normalize = config.normalize(); - tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); + tokenizer = new HuggingFaceTokenizer.Builder() + .addDefaultModel(Paths.get(config.tokenizerPath().toString())) + .build(); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); @@ -66,8 +68,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { - Encoding encoding = tokenizer.encode(s); - List<Integer> tokenIds = encoding.ids().stream().map(Long::intValue).toList(); + var tokenIds = tokenizer.embed(s, context); int tokensSize = tokenIds.size(); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java deleted file mode 100644 index e6765a4cc8a..00000000000 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.embedding.huggingface; - -import java.io.IOException; -import java.nio.file.Path; - -/** - * Wrapper around Deep Java Library's HuggingFace tokenizer. - * - * @author bjorncs - */ -public class HuggingFaceTokenizer implements AutoCloseable { - - private final ai.djl.huggingface.tokenizers.HuggingFaceTokenizer instance; - - public HuggingFaceTokenizer(Path path) throws IOException { this(path, HuggingFaceTokenizerOptions.defaults()); } - - public HuggingFaceTokenizer(Path path, HuggingFaceTokenizerOptions opts) throws IOException { - var original = Thread.currentThread().getContextClassLoader(); - Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); - try { - instance = createInstance(path, opts); - } finally { - Thread.currentThread().setContextClassLoader(original); - } - } - - public Encoding encode(String text) { return Encoding.from(instance.encode(text)); } - - @Override public void close() { instance.close(); } - - private static ai.djl.huggingface.tokenizers.HuggingFaceTokenizer createInstance( - Path path, HuggingFaceTokenizerOptions opts) throws IOException { - var builder = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path); - opts.addSpecialToken().ifPresent(builder::optAddSpecialTokens); - opts.truncation().ifPresent(builder::optTruncation); - if (opts.truncateFirstOnly()) builder.optTruncateFirstOnly(); - if (opts.truncateSecondOnly()) builder.optTruncateSecondOnly(); - opts.padding().ifPresent(builder::optPadding); - if (opts.padToMaxLength()) builder.optPadToMaxLength(); - opts.maxLength().ifPresent(builder::optMaxLength); - opts.padToMultipleOf().ifPresent(builder::optPadToMultipleOf); - opts.stride().ifPresent(builder::optStride); - return builder.build(); - } -} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java deleted file mode 100644 index 74f80756603..00000000000 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.embedding.huggingface; - -import java.util.Optional; -import java.util.OptionalInt; - -/** - * @author bjorncs - */ -public class HuggingFaceTokenizerOptions { - - private final Boolean addSpecialToken; - private final Boolean truncation; - private final boolean truncateFirstOnly; - private final boolean truncateSecondOnly; - private final Boolean padding; - private final boolean padToMaxLength; - private final Integer maxLength; - private final Integer padToMultipleOf; - private final Integer stride; - - private HuggingFaceTokenizerOptions(Builder b) { - addSpecialToken = b.addSpecialToken; - truncation = b.truncation; - truncateFirstOnly = b.truncateFirstOnly; - truncateSecondOnly = b.truncateSecondOnly; - padding = b.padding; - padToMaxLength = b.padToMaxLength; - maxLength = b.maxLength; - padToMultipleOf = b.padToMultipleOf; - stride = b.stride; - } - - public static Builder custom() { return new Builder(); } - public static HuggingFaceTokenizerOptions defaults() { return new Builder().build(); } - - Optional<Boolean> addSpecialToken() { return Optional.ofNullable(addSpecialToken); } - Optional<Boolean> truncation() { return Optional.ofNullable(truncation); } - boolean truncateFirstOnly() { return truncateFirstOnly; } - boolean truncateSecondOnly() { return truncateSecondOnly; } - Optional<Boolean> padding() { return Optional.ofNullable(padding); } - boolean padToMaxLength() { return padToMaxLength; } - OptionalInt maxLength() { return maxLength != null ? OptionalInt.of(maxLength) : OptionalInt.empty(); } - OptionalInt padToMultipleOf() { return padToMultipleOf != null ? OptionalInt.of(padToMultipleOf) : OptionalInt.empty(); } - OptionalInt stride() { return stride != null ? OptionalInt.of(stride) : OptionalInt.empty(); } - - public static class Builder { - private Boolean addSpecialToken; - private Boolean truncation; - private boolean truncateFirstOnly; - private boolean truncateSecondOnly; - private Boolean padding; - private boolean padToMaxLength; - private Integer maxLength; - private Integer padToMultipleOf; - private Integer stride; - - public Builder addSpecialToken(boolean enabled) { addSpecialToken = enabled; return this; } - public Builder truncation(boolean enabled) { truncation = enabled; return this; } - public Builder truncateFirstOnly() { truncateFirstOnly = true; return this; } - public Builder truncateSecondOnly() { truncateSecondOnly = true; return this; } - public Builder padding(boolean enabled) { padding = enabled; return this; } - public Builder padToMaxLength() { padToMaxLength = true; return this; } - public Builder maxLength(int length) { maxLength = length; return this; } - public Builder padToMultipleOf(int num) { padToMultipleOf = num; return this; } - public Builder stride(int stride) { this.stride = stride; return this; } - public HuggingFaceTokenizerOptions build() { return new HuggingFaceTokenizerOptions(this); } - } - -} |