aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-11 11:07:23 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-11 16:41:35 +0200
commitce7dd2c983a8840981786eef95a9cc4741487be7 (patch)
tree8aa4140fb8fd7ec8a95d425adda4119d34cf46d3 /model-integration
parent386e4198d8459803eec0ead6ad81a821737082a7 (diff)
Make HF tokenizer a separate embedder
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml20
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java54
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java47
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java71
5 files changed, 6 insertions, 197 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index cc5eccff2ac..681003fdc89 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -91,26 +91,6 @@
<artifactId>protobuf-java</artifactId>
</dependency>
- <dependency>
- <groupId>ai.djl.huggingface</groupId>
- <artifactId>tokenizers</artifactId>
- <version>0.22.1</version>
- <exclusions>
- <exclusion>
- <groupId>com.google.code.gson</groupId>
- <artifactId>gson</artifactId>
- </exclusion>
- <exclusion>
- <groupId>net.java.dev.jna</groupId>
- <artifactId>jna</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-api</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
-
<dependency>
<groupId>org.lz4</groupId>
<artifactId>lz4-java</artifactId>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java
deleted file mode 100644
index 274c29a57b2..00000000000
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.embedding.huggingface;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-/**
- * @author bjorncs
- */
-public record Encoding(
- List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask,
- List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) {
-
- public record CharSpan(int start, int end) {
- public static final CharSpan NONE = new CharSpan(-1, -1);
- static CharSpan from(ai.djl.huggingface.tokenizers.jni.CharSpan s) {
- if (s == null) return NONE;
- return new CharSpan(s.getStart(), s.getEnd());
- }
- public boolean isNone() { return this.equals(NONE); }
- }
-
- public Encoding {
- ids = List.copyOf(ids);
- typeIds = List.copyOf(typeIds);
- tokens = List.copyOf(tokens);
- wordIds = List.copyOf(wordIds);
- attentionMask = List.copyOf(attentionMask);
- specialTokenMask = List.copyOf(specialTokenMask);
- charTokenSpans = List.copyOf(charTokenSpans);
- overflowing = List.copyOf(overflowing);
- }
-
- static Encoding from(ai.djl.huggingface.tokenizers.Encoding e) {
- return new Encoding(
- toList(e.getIds()),
- toList(e.getTypeIds()),
- List.of(e.getTokens()),
- toList(e.getWordIds()),
- toList(e.getAttentionMask()),
- toList(e.getSpecialTokenMask()),
- Arrays.stream(e.getCharTokenSpans()).map(CharSpan::from).toList(),
- Arrays.stream(e.getOverflowing()).map(Encoding::from).toList());
- }
-
- private static List<Long> toList(long[] array) {
- if (array == null) return List.of();
- var list = new ArrayList<Long>(array.length);
- for (long e : array) list.add(e);
- return list;
- }
-}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 365a50f47b5..5faff435a30 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -6,6 +6,7 @@ import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -14,7 +15,6 @@ import com.yahoo.tensor.TensorType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
@@ -32,13 +32,15 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException {
+ public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
outputName = config.transformerOutput();
normalize = config.normalize();
- tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
+ tokenizer = new HuggingFaceTokenizer.Builder()
+ .addDefaultModel(Paths.get(config.tokenizerPath().toString()))
+ .build();
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());
@@ -66,8 +68,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@Override
public List<Integer> embed(String s, Context context) {
- Encoding encoding = tokenizer.encode(s);
- List<Integer> tokenIds = encoding.ids().stream().map(Long::intValue).toList();
+ var tokenIds = tokenizer.embed(s, context);
int tokensSize = tokenIds.size();
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java
deleted file mode 100644
index e6765a4cc8a..00000000000
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.embedding.huggingface;
-
-import java.io.IOException;
-import java.nio.file.Path;
-
-/**
- * Wrapper around Deep Java Library's HuggingFace tokenizer.
- *
- * @author bjorncs
- */
-public class HuggingFaceTokenizer implements AutoCloseable {
-
- private final ai.djl.huggingface.tokenizers.HuggingFaceTokenizer instance;
-
- public HuggingFaceTokenizer(Path path) throws IOException { this(path, HuggingFaceTokenizerOptions.defaults()); }
-
- public HuggingFaceTokenizer(Path path, HuggingFaceTokenizerOptions opts) throws IOException {
- var original = Thread.currentThread().getContextClassLoader();
- Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
- try {
- instance = createInstance(path, opts);
- } finally {
- Thread.currentThread().setContextClassLoader(original);
- }
- }
-
- public Encoding encode(String text) { return Encoding.from(instance.encode(text)); }
-
- @Override public void close() { instance.close(); }
-
- private static ai.djl.huggingface.tokenizers.HuggingFaceTokenizer createInstance(
- Path path, HuggingFaceTokenizerOptions opts) throws IOException {
- var builder = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path);
- opts.addSpecialToken().ifPresent(builder::optAddSpecialTokens);
- opts.truncation().ifPresent(builder::optTruncation);
- if (opts.truncateFirstOnly()) builder.optTruncateFirstOnly();
- if (opts.truncateSecondOnly()) builder.optTruncateSecondOnly();
- opts.padding().ifPresent(builder::optPadding);
- if (opts.padToMaxLength()) builder.optPadToMaxLength();
- opts.maxLength().ifPresent(builder::optMaxLength);
- opts.padToMultipleOf().ifPresent(builder::optPadToMultipleOf);
- opts.stride().ifPresent(builder::optStride);
- return builder.build();
- }
-}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java
deleted file mode 100644
index 74f80756603..00000000000
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.embedding.huggingface;
-
-import java.util.Optional;
-import java.util.OptionalInt;
-
-/**
- * @author bjorncs
- */
-public class HuggingFaceTokenizerOptions {
-
- private final Boolean addSpecialToken;
- private final Boolean truncation;
- private final boolean truncateFirstOnly;
- private final boolean truncateSecondOnly;
- private final Boolean padding;
- private final boolean padToMaxLength;
- private final Integer maxLength;
- private final Integer padToMultipleOf;
- private final Integer stride;
-
- private HuggingFaceTokenizerOptions(Builder b) {
- addSpecialToken = b.addSpecialToken;
- truncation = b.truncation;
- truncateFirstOnly = b.truncateFirstOnly;
- truncateSecondOnly = b.truncateSecondOnly;
- padding = b.padding;
- padToMaxLength = b.padToMaxLength;
- maxLength = b.maxLength;
- padToMultipleOf = b.padToMultipleOf;
- stride = b.stride;
- }
-
- public static Builder custom() { return new Builder(); }
- public static HuggingFaceTokenizerOptions defaults() { return new Builder().build(); }
-
- Optional<Boolean> addSpecialToken() { return Optional.ofNullable(addSpecialToken); }
- Optional<Boolean> truncation() { return Optional.ofNullable(truncation); }
- boolean truncateFirstOnly() { return truncateFirstOnly; }
- boolean truncateSecondOnly() { return truncateSecondOnly; }
- Optional<Boolean> padding() { return Optional.ofNullable(padding); }
- boolean padToMaxLength() { return padToMaxLength; }
- OptionalInt maxLength() { return maxLength != null ? OptionalInt.of(maxLength) : OptionalInt.empty(); }
- OptionalInt padToMultipleOf() { return padToMultipleOf != null ? OptionalInt.of(padToMultipleOf) : OptionalInt.empty(); }
- OptionalInt stride() { return stride != null ? OptionalInt.of(stride) : OptionalInt.empty(); }
-
- public static class Builder {
- private Boolean addSpecialToken;
- private Boolean truncation;
- private boolean truncateFirstOnly;
- private boolean truncateSecondOnly;
- private Boolean padding;
- private boolean padToMaxLength;
- private Integer maxLength;
- private Integer padToMultipleOf;
- private Integer stride;
-
- public Builder addSpecialToken(boolean enabled) { addSpecialToken = enabled; return this; }
- public Builder truncation(boolean enabled) { truncation = enabled; return this; }
- public Builder truncateFirstOnly() { truncateFirstOnly = true; return this; }
- public Builder truncateSecondOnly() { truncateSecondOnly = true; return this; }
- public Builder padding(boolean enabled) { padding = enabled; return this; }
- public Builder padToMaxLength() { padToMaxLength = true; return this; }
- public Builder maxLength(int length) { maxLength = length; return this; }
- public Builder padToMultipleOf(int num) { padToMultipleOf = num; return this; }
- public Builder stride(int stride) { this.stride = stride; return this; }
- public HuggingFaceTokenizerOptions build() { return new HuggingFaceTokenizerOptions(this); }
- }
-
-}