aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-12 10:21:48 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-12 10:21:48 +0200
commite030993d0c356ba6acd50c3e64da5a1f6e1538fd (patch)
tree853878e89743d224a67ba6edb44ec803e1ca9bcf
parentbef1950a75be8b256df07ca5ef6aacd1731c5ef9 (diff)
Revert "Revert "Bjorncs/huggingface tokenizer""
This reverts commit 2bb74878879b3acb1919fd658b8f2c476d8129d6.
-rw-r--r--cloud-tenant-base-dependencies-enforcer/pom.xml3
-rw-r--r--fat-model-dependencies/pom.xml4
-rw-r--r--linguistics-components/pom.xml34
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java (renamed from model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java)5
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java106
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java9
-rw-r--r--linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def11
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java88
-rw-r--r--linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gzbin0 -> 191737 bytes
-rw-r--r--linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gzbin0 -> 3543796 bytes
-rw-r--r--model-integration/pom.xml20
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java46
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java47
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java71
-rw-r--r--model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def1
15 files changed, 279 insertions, 166 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml
index c8c0c30e820..166236f91a0 100644
--- a/cloud-tenant-base-dependencies-enforcer/pom.xml
+++ b/cloud-tenant-base-dependencies-enforcer/pom.xml
@@ -175,8 +175,6 @@
<include>com.yahoo.vespa:vsm:*:test</include>
<!-- 3rd party test dependencies -->
- <include>ai.djl:api:jar:0.22.1:test</include>
- <include>ai.djl.huggingface:tokenizers:jar:0.22.1:test</include>
<include>com.google.code.findbugs:jsr305:3.0.2:test</include>
<include>com.google.protobuf:protobuf-java:3.21.7:test</include>
<include>com.ibm.icu:icu4j:70.1:test</include>
@@ -193,7 +191,6 @@
<include>org.antlr:antlr4-runtime:4.11.1:test</include>
<include>org.apache.commons:commons-exec:1.3:test</include>
<include>org.apache.commons:commons-math3:3.6.1:test</include>
- <include>org.apache.commons:commons-compress:jar:1.22:test</include>
<include>org.apache.felix:org.apache.felix.framework:${felix.version}:test</include>
<include>org.apache.felix:org.apache.felix.log:1.0.1:test</include>
<include>org.apache.httpcomponents.client5:httpclient5:${httpclient5.version}:test</include>
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index 0ae2c68e6a8..b2ee5c64b38 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -93,10 +93,6 @@
<version>${project.version}</version>
<exclusions>
<exclusion>
- <groupId>ai.djl.huggingface</groupId>
- <artifactId>tokenizers</artifactId>
- </exclusion>
- <exclusion>
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
</exclusion>
diff --git a/linguistics-components/pom.xml b/linguistics-components/pom.xml
index ad4cbd6ce22..5031ad73556 100644
--- a/linguistics-components/pom.xml
+++ b/linguistics-components/pom.xml
@@ -19,12 +19,42 @@
<artifactId>protobuf-java</artifactId>
</dependency>
<dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
+ <groupId>ai.djl.huggingface</groupId>
+ <artifactId>tokenizers</artifactId>
+ <version>0.22.1</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.google.code.gson</groupId>
+ <artifactId>gson</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>net.java.dev.jna</groupId>
+ <artifactId>jna</artifactId>
+ </exclusion>
+ <exclusion>
+ <groupId>org.slf4j</groupId>
+ <artifactId>slf4j-api</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
+ <groupId>org.junit.jupiter</groupId>
+ <artifactId>junit-jupiter-engine</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.junit.vintage</groupId>
+ <artifactId>junit-vintage-engine</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>jdisc_core</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>annotations</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java
index 274c29a57b2..107900ff73c 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java
@@ -1,6 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.embedding.huggingface;
+package com.yahoo.language.huggingface;
+
+import com.yahoo.api.annotations.Beta;
import java.util.ArrayList;
import java.util.Arrays;
@@ -9,6 +11,7 @@ import java.util.List;
/**
* @author bjorncs
*/
+@Beta
public record Encoding(
List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask,
List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) {
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
new file mode 100644
index 00000000000..b92e0678970
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -0,0 +1,106 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.process.Segmenter;
+import com.yahoo.language.tools.Embed;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.EnumMap;
+import java.util.List;
+import java.util.Map;
+
+import static com.yahoo.yolean.Exceptions.uncheck;
+
+/**
+ * {@link Embedder}/{@link Segmenter} using Deep Java Library's HuggingFace Tokenizer.
+ *
+ * @author bjorncs
+ */
+@Beta
+public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
+
+ private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class);
+
+ @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); }
+
+ private HuggingFaceTokenizer(Builder b) {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ b.models.forEach((language, path) -> {
+ models.put(language,
+ uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
+ .optTokenizerPath(path)
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
+ .build()));
+ });
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
+ @Override
+ public List<Integer> embed(String text, Context ctx) {
+ var encoding = resolve(ctx.getLanguage()).encode(text);
+ return Arrays.stream(encoding.getIds()).mapToInt(Math::toIntExact).boxed().toList();
+ }
+
+ @Override
+ public Tensor embed(String text, Context ctx, TensorType type) {
+ return Embed.asTensor(text, this, ctx, type);
+ }
+
+ @Override
+ public List<String> segment(String input, Language language) {
+ return List.of(resolve(language).encode(input).getTokens());
+ }
+
+ @Override
+ public String decode(List<Integer> tokens, Context ctx) {
+ return resolve(ctx.getLanguage()).decode(toArray(tokens));
+ }
+
+ public Encoding encode(String text) { return encode(text, Language.UNKNOWN); }
+ public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); }
+ public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); }
+ public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); }
+
+ @Override public void close() { models.forEach((__, model) -> model.close()); }
+
+ private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
+ // Disregard language if there is default model
+ if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
+ if (models.containsKey(language)) return models.get(language);
+ throw new IllegalArgumentException("No model for language " + language);
+ }
+
+ private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); }
+
+ public static final class Builder {
+ private final Map<Language, Path> models = new EnumMap<>(Language.class);
+ private Boolean addSpecialTokens;
+
+ public Builder() {}
+ public Builder(HuggingFaceTokenizerConfig cfg) {
+ for (var model : cfg.model())
+ addModel(Language.fromLanguageTag(model.language()), model.path());
+ addSpecialTokens(cfg.addSpecialTokens());
+ }
+
+ public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
+ public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); }
+ public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
+ public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
+ }
+
+} \ No newline at end of file
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java
new file mode 100644
index 00000000000..7cec01ffed6
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java
@@ -0,0 +1,9 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+/**
+ * @author bjorncs
+ */
+@ExportPackage
+package com.yahoo.language.huggingface;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
new file mode 100644
index 00000000000..a3e54ea38da
--- /dev/null
+++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def
@@ -0,0 +1,11 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+namespace=language.huggingface
+
+# The language a model is for, one of the language tags in com.yahoo.language.Language.
+# Use "unknown" for models to be used with any language.
+model[].language string
+# The path to the model relative to the application package root
+model[].path path
+
+addSpecialTokens bool default=true \ No newline at end of file
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
new file mode 100644
index 00000000000..c79ecbfbfbe
--- /dev/null
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -0,0 +1,88 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import com.yahoo.language.tools.EmbedderTester;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.nio.file.StandardOpenOption;
+import java.util.zip.GZIPInputStream;
+
+/**
+ * @author bjorncs
+ */
+class HuggingFaceTokenizerTest {
+
+ @TempDir Path tmp;
+
+ @Test
+ void bert_tokenizer() throws IOException {
+ try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) {
+ var tester = new EmbedderTester(tokenizer);
+ tester.assertSegmented("what was the impact of the manhattan project",
+ "what", "was", "the", "impact", "of", "the", "manhattan", "project");
+ tester.assertSegmented("overcommunication", "over", "##com", "##mun", "##ication");
+ tester.assertEmbedded("what was the impact of the manhattan project",
+ "tensor(x[8])",
+ 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622);
+ tester.assertDecoded("what was the impact of the manhattan project");
+ }
+ }
+
+ @Test
+ void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException {
+ try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) {
+ var tester = new EmbedderTester(tokenizer);
+ tester.assertSegmented("h", "▁h");
+ tester.assertSegmented("he", "▁he");
+ tester.assertSegmented("hel", "▁hel");
+ tester.assertSegmented("hello", "▁hell", "o");
+ tester.assertSegmented("hei", "▁hei");
+ tester.assertSegmented("hei you", "▁hei", "▁you");
+ tester.assertSegmented("hei you", "▁hei", "▁you");
+ tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence");
+ tester.assertSegmented("hello world!", "▁hell", "o", "▁world", "!");
+ tester.assertSegmented("Hello, world!", "▁Hello", ",", "▁world", "!");
+ tester.assertSegmented("HELLO, world!", "▁H", "ELLO", ",", "▁world", "!");
+ tester.assertSegmented("KHJKJHHKJHHSH", "▁KH", "JK", "J", "H", "HK", "J", "HH", "SH");
+ tester.assertSegmented("KHJKJHHKJHHSH hello", "▁KH", "JK", "J", "H", "HK", "J", "HH", "SH", "▁hell", "o");
+ tester.assertSegmented(" hello ", "▁hell", "o");
+ tester.assertSegmented(")(/&#()/\"\")", "▁", ")(", "/", "&#", "(", ")", "/", "\"", "\")");
+ tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁", ")(", "/", "&#", "(", "s", "mall", ")", "/", "\"", "in", "▁quote", "s", "\")");
+ tester.assertSegmented("x.400AS", "▁x", ".", "400", "AS");
+ tester.assertSegmented("A normal sentence. Yes one more.", "▁A", "▁normal", "▁sentence", ".", "▁Yes", "▁one", "▁more", ".");
+
+ tester.assertEmbedded("hello, world!", "tensor(d[10])", 33600, 31, 4, 8999, 38);
+ tester.assertEmbedded("Hello, world!", "tensor(d[10])", 35378, 4, 8999, 38);
+ tester.assertEmbedded("hello, world!", "tensor(d[2])", 33600, 31, 4, 8999, 38);
+
+ tester.assertDecoded("this is a sentence");
+ tester.assertDecoded("hello, world!");
+ tester.assertDecoded(")(/&#(small)/ \"in quotes\")");
+ }
+ }
+
+ private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException {
+ return new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(false)
+ .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model))))
+ .build();
+ }
+
+ private static Path decompressModelFile(Path tmp, Path source) throws IOException {
+ Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", ""));
+ try (InputStream in = new GZIPInputStream(Files.newInputStream(source));
+ OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) {
+ in.transferTo(out);
+ }
+ return destination;
+ }
+
+} \ No newline at end of file
diff --git a/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz b/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz
new file mode 100644
index 00000000000..7d0541849f7
--- /dev/null
+++ b/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz
Binary files differ
diff --git a/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz b/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz
new file mode 100644
index 00000000000..7b61a27198c
--- /dev/null
+++ b/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz
Binary files differ
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index cc5eccff2ac..681003fdc89 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -91,26 +91,6 @@
<artifactId>protobuf-java</artifactId>
</dependency>
- <dependency>
- <groupId>ai.djl.huggingface</groupId>
- <artifactId>tokenizers</artifactId>
- <version>0.22.1</version>
- <exclusions>
- <exclusion>
- <groupId>com.google.code.gson</groupId>
- <artifactId>gson</artifactId>
- </exclusion>
- <exclusion>
- <groupId>net.java.dev.jna</groupId>
- <artifactId>jna</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.slf4j</groupId>
- <artifactId>slf4j-api</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
-
<dependency>
<groupId>org.lz4</groupId>
<artifactId>lz4-java</artifactId>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 365a50f47b5..0c1cc80544e 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -3,9 +3,11 @@ package ai.vespa.embedding.huggingface;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
+import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.language.huggingface.HuggingFaceTokenizer;
import com.yahoo.language.process.Embedder;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -14,17 +16,18 @@ import com.yahoo.tensor.TensorType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
+@Beta
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName());
private final String inputIdsName;
private final String attentionMaskName;
+ private final String tokenTypeIdsName;
private final String outputName;
private final int maxTokens;
private final boolean normalize;
@@ -32,13 +35,17 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final OnnxEvaluator evaluator;
@Inject
- public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException {
+ public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
+ tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
- tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString()));
+ tokenizer = new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(true)
+ .addDefaultModel(Paths.get(config.tokenizerPath().toString()))
+ .build();
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)
onnxOpts.setGpuDevice(config.transformerGpuDevice());
@@ -52,6 +59,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
validateName(inputs, attentionMaskName, "input");
+ if (!tokenTypeIdsName.isEmpty()) validateName(inputs, tokenTypeIdsName, "input");
Map<String, TensorType> outputs = evaluator.getOutputInfo();
validateName(outputs, outputName, "output");
@@ -66,8 +74,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@Override
public List<Integer> embed(String s, Context context) {
- Encoding encoding = tokenizer.encode(s);
- List<Integer> tokenIds = encoding.ids().stream().map(Long::intValue).toList();
+ var tokenIds = tokenizer.embed(s, context);
int tokensSize = tokenIds.size();
@@ -87,18 +94,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String s, Context context, TensorType tensorType) {
- List<Integer> tokenIds = embed(s.toLowerCase(), context);
- return embedTokens(tokenIds, tensorType);
- }
-
- Tensor embedTokens(List<Integer> tokenIds, TensorType tensorType) {
- Tensor inputSequence = createTensorRepresentation(tokenIds, "d1");
- Tensor attentionMask = createAttentionMask(inputSequence);
-
- Map<String, Tensor> inputs = Map.of(
- inputIdsName, inputSequence.expand("d0"),
- attentionMaskName, attentionMask.expand("d0")
- );
+ var encoding = tokenizer.encode(s, context.getLanguage());
+ Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
+ Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
+ Tensor tokenTypeIds = createTensorRepresentation(encoding.typeIds(), "d1");
+
+
+ Map<String, Tensor> inputs;
+ if (tokenTypeIds.isEmpty()) {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"));
+ } else {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"),
+ tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ }
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
@@ -136,7 +146,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
+ private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
int size = input.size();
TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java
deleted file mode 100644
index e6765a4cc8a..00000000000
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java
+++ /dev/null
@@ -1,47 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.embedding.huggingface;
-
-import java.io.IOException;
-import java.nio.file.Path;
-
-/**
- * Wrapper around Deep Java Library's HuggingFace tokenizer.
- *
- * @author bjorncs
- */
-public class HuggingFaceTokenizer implements AutoCloseable {
-
- private final ai.djl.huggingface.tokenizers.HuggingFaceTokenizer instance;
-
- public HuggingFaceTokenizer(Path path) throws IOException { this(path, HuggingFaceTokenizerOptions.defaults()); }
-
- public HuggingFaceTokenizer(Path path, HuggingFaceTokenizerOptions opts) throws IOException {
- var original = Thread.currentThread().getContextClassLoader();
- Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
- try {
- instance = createInstance(path, opts);
- } finally {
- Thread.currentThread().setContextClassLoader(original);
- }
- }
-
- public Encoding encode(String text) { return Encoding.from(instance.encode(text)); }
-
- @Override public void close() { instance.close(); }
-
- private static ai.djl.huggingface.tokenizers.HuggingFaceTokenizer createInstance(
- Path path, HuggingFaceTokenizerOptions opts) throws IOException {
- var builder = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path);
- opts.addSpecialToken().ifPresent(builder::optAddSpecialTokens);
- opts.truncation().ifPresent(builder::optTruncation);
- if (opts.truncateFirstOnly()) builder.optTruncateFirstOnly();
- if (opts.truncateSecondOnly()) builder.optTruncateSecondOnly();
- opts.padding().ifPresent(builder::optPadding);
- if (opts.padToMaxLength()) builder.optPadToMaxLength();
- opts.maxLength().ifPresent(builder::optMaxLength);
- opts.padToMultipleOf().ifPresent(builder::optPadToMultipleOf);
- opts.stride().ifPresent(builder::optStride);
- return builder.build();
- }
-}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java
deleted file mode 100644
index 74f80756603..00000000000
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.embedding.huggingface;
-
-import java.util.Optional;
-import java.util.OptionalInt;
-
-/**
- * @author bjorncs
- */
-public class HuggingFaceTokenizerOptions {
-
- private final Boolean addSpecialToken;
- private final Boolean truncation;
- private final boolean truncateFirstOnly;
- private final boolean truncateSecondOnly;
- private final Boolean padding;
- private final boolean padToMaxLength;
- private final Integer maxLength;
- private final Integer padToMultipleOf;
- private final Integer stride;
-
- private HuggingFaceTokenizerOptions(Builder b) {
- addSpecialToken = b.addSpecialToken;
- truncation = b.truncation;
- truncateFirstOnly = b.truncateFirstOnly;
- truncateSecondOnly = b.truncateSecondOnly;
- padding = b.padding;
- padToMaxLength = b.padToMaxLength;
- maxLength = b.maxLength;
- padToMultipleOf = b.padToMultipleOf;
- stride = b.stride;
- }
-
- public static Builder custom() { return new Builder(); }
- public static HuggingFaceTokenizerOptions defaults() { return new Builder().build(); }
-
- Optional<Boolean> addSpecialToken() { return Optional.ofNullable(addSpecialToken); }
- Optional<Boolean> truncation() { return Optional.ofNullable(truncation); }
- boolean truncateFirstOnly() { return truncateFirstOnly; }
- boolean truncateSecondOnly() { return truncateSecondOnly; }
- Optional<Boolean> padding() { return Optional.ofNullable(padding); }
- boolean padToMaxLength() { return padToMaxLength; }
- OptionalInt maxLength() { return maxLength != null ? OptionalInt.of(maxLength) : OptionalInt.empty(); }
- OptionalInt padToMultipleOf() { return padToMultipleOf != null ? OptionalInt.of(padToMultipleOf) : OptionalInt.empty(); }
- OptionalInt stride() { return stride != null ? OptionalInt.of(stride) : OptionalInt.empty(); }
-
- public static class Builder {
- private Boolean addSpecialToken;
- private Boolean truncation;
- private boolean truncateFirstOnly;
- private boolean truncateSecondOnly;
- private Boolean padding;
- private boolean padToMaxLength;
- private Integer maxLength;
- private Integer padToMultipleOf;
- private Integer stride;
-
- public Builder addSpecialToken(boolean enabled) { addSpecialToken = enabled; return this; }
- public Builder truncation(boolean enabled) { truncation = enabled; return this; }
- public Builder truncateFirstOnly() { truncateFirstOnly = true; return this; }
- public Builder truncateSecondOnly() { truncateSecondOnly = true; return this; }
- public Builder padding(boolean enabled) { padding = enabled; return this; }
- public Builder padToMaxLength() { padToMaxLength = true; return this; }
- public Builder maxLength(int length) { maxLength = length; return this; }
- public Builder padToMultipleOf(int num) { padToMultipleOf = num; return this; }
- public Builder stride(int stride) { this.stride = stride; return this; }
- public HuggingFaceTokenizerOptions build() { return new HuggingFaceTokenizerOptions(this); }
- }
-
-}
diff --git a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
index 1dccea0ddf6..97515818f14 100644
--- a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
@@ -12,6 +12,7 @@ transformerMaxTokens int default=512
# Input names
transformerInputIds string default=input_ids
transformerAttentionMask string default=attention_mask
+transformerTokenTypeIds string default=token_type_ids
# Output name
transformerOutput string default=last_hidden_state