diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-15 09:58:25 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-15 09:58:25 +0200 |
commit | a9cf0db43a852ef7f43e275af582786a2be62009 (patch) | |
tree | eee4301b9c6c748e8a85a2af5d776d33de2d3c84 | |
parent | 93aada3ca9572712e06ce2125e7e5f915111cef4 (diff) | |
parent | 814a750816a5213697f2c3a3f91bb18d368c5146 (diff) |
Merge pull request #27088 from vespa-engine/bjorncs/huggingface
Bjorncs/huggingface
15 files changed, 279 insertions, 166 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml index c8c0c30e820..166236f91a0 100644 --- a/cloud-tenant-base-dependencies-enforcer/pom.xml +++ b/cloud-tenant-base-dependencies-enforcer/pom.xml @@ -175,8 +175,6 @@ <include>com.yahoo.vespa:vsm:*:test</include> <!-- 3rd party test dependencies --> - <include>ai.djl:api:jar:0.22.1:test</include> - <include>ai.djl.huggingface:tokenizers:jar:0.22.1:test</include> <include>com.google.code.findbugs:jsr305:3.0.2:test</include> <include>com.google.protobuf:protobuf-java:3.21.7:test</include> <include>com.ibm.icu:icu4j:70.1:test</include> @@ -193,7 +191,6 @@ <include>org.antlr:antlr4-runtime:4.11.1:test</include> <include>org.apache.commons:commons-exec:1.3:test</include> <include>org.apache.commons:commons-math3:3.6.1:test</include> - <include>org.apache.commons:commons-compress:jar:1.22:test</include> <include>org.apache.felix:org.apache.felix.framework:${felix.version}:test</include> <include>org.apache.felix:org.apache.felix.log:1.0.1:test</include> <include>org.apache.httpcomponents.client5:httpclient5:${httpclient5.version}:test</include> diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index 0ae2c68e6a8..b2ee5c64b38 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -93,10 +93,6 @@ <version>${project.version}</version> <exclusions> <exclusion> - <groupId>ai.djl.huggingface</groupId> - <artifactId>tokenizers</artifactId> - </exclusion> - <exclusion> <groupId>com.theokanning.openai-gpt3-java</groupId> <artifactId>service</artifactId> </exclusion> diff --git a/linguistics-components/pom.xml b/linguistics-components/pom.xml index ad4cbd6ce22..5031ad73556 100644 --- a/linguistics-components/pom.xml +++ b/linguistics-components/pom.xml @@ -19,12 +19,42 @@ <artifactId>protobuf-java</artifactId> </dependency> <dependency> - <groupId>junit</groupId> - <artifactId>junit</artifactId> + <groupId>ai.djl.huggingface</groupId> + <artifactId>tokenizers</artifactId> + <version>0.22.1</version> + <exclusions> + <exclusion> + <groupId>com.google.code.gson</groupId> + <artifactId>gson</artifactId> + </exclusion> + <exclusion> + <groupId>net.java.dev.jna</groupId> + <artifactId>jna</artifactId> + </exclusion> + <exclusion> + <groupId>org.slf4j</groupId> + <artifactId>slf4j-api</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> + <groupId>org.junit.jupiter</groupId> + <artifactId>junit-jupiter-engine</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.junit.vintage</groupId> + <artifactId>junit-vintage-engine</artifactId> <scope>test</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>jdisc_core</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>annotations</artifactId> <version>${project.version}</version> <scope>provided</scope> diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java index 274c29a57b2..107900ff73c 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/Encoding.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/Encoding.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.embedding.huggingface; +package com.yahoo.language.huggingface; + +import com.yahoo.api.annotations.Beta; import java.util.ArrayList; import java.util.Arrays; @@ -9,6 +11,7 @@ import java.util.List; /** * @author bjorncs */ +@Beta public record Encoding( List<Long> ids, List<Long> typeIds, List<String> tokens, List<Long> wordIds, List<Long> attentionMask, List<Long> specialTokenMask, List<CharSpan> charTokenSpans, List<Encoding> overflowing) { diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java new file mode 100644 index 00000000000..b92e0678970 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -0,0 +1,106 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.process.Segmenter; +import com.yahoo.language.tools.Embed; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Collection; +import java.util.EnumMap; +import java.util.List; +import java.util.Map; + +import static com.yahoo.yolean.Exceptions.uncheck; + +/** + * {@link Embedder}/{@link Segmenter} using Deep Java Library's HuggingFace Tokenizer. + * + * @author bjorncs + */ +@Beta +public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable { + + private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class); + + @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); } + + private HuggingFaceTokenizer(Builder b) { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + b.models.forEach((language, path) -> { + models.put(language, + uncheck(() -> ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() + .optTokenizerPath(path) + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) + .build())); + }); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + + @Override + public List<Integer> embed(String text, Context ctx) { + var encoding = resolve(ctx.getLanguage()).encode(text); + return Arrays.stream(encoding.getIds()).mapToInt(Math::toIntExact).boxed().toList(); + } + + @Override + public Tensor embed(String text, Context ctx, TensorType type) { + return Embed.asTensor(text, this, ctx, type); + } + + @Override + public List<String> segment(String input, Language language) { + return List.of(resolve(language).encode(input).getTokens()); + } + + @Override + public String decode(List<Integer> tokens, Context ctx) { + return resolve(ctx.getLanguage()).decode(toArray(tokens)); + } + + public Encoding encode(String text) { return encode(text, Language.UNKNOWN); } + public Encoding encode(String text, Language language) { return Encoding.from(resolve(language).encode(text)); } + public String decode(List<Long> tokens) { return decode(tokens, Language.UNKNOWN); } + public String decode(List<Long> tokens, Language language) { return resolve(language).decode(toArray(tokens)); } + + @Override public void close() { models.forEach((__, model) -> model.close()); } + + private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { + // Disregard language if there is default model + if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); + if (models.containsKey(language)) return models.get(language); + throw new IllegalArgumentException("No model for language " + language); + } + + private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); } + + public static final class Builder { + private final Map<Language, Path> models = new EnumMap<>(Language.class); + private Boolean addSpecialTokens; + + public Builder() {} + public Builder(HuggingFaceTokenizerConfig cfg) { + for (var model : cfg.model()) + addModel(Language.fromLanguageTag(model.language()), model.path()); + addSpecialTokens(cfg.addSpecialTokens()); + } + + public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } + public Builder addDefaultModel(Path path) { return addModel(Language.UNKNOWN, path); } + public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } + public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } + } + +}
\ No newline at end of file diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java new file mode 100644 index 00000000000..7cec01ffed6 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/package-info.java @@ -0,0 +1,9 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +/** + * @author bjorncs + */ +@ExportPackage +package com.yahoo.language.huggingface; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def new file mode 100644 index 00000000000..5e58547879c --- /dev/null +++ b/linguistics-components/src/main/resources/configdefinitions/language.huggingface.hugging-face-tokenizer.def @@ -0,0 +1,11 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +namespace=language.huggingface + +# The language a model is for, one of the language tags in com.yahoo.language.Language. +# Use "unknown" for models to be used with any language. +model[].language string +# The path to the model relative to the application package root +model[].path model + +addSpecialTokens bool default=true
\ No newline at end of file diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java new file mode 100644 index 00000000000..c79ecbfbfbe --- /dev/null +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -0,0 +1,88 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import com.yahoo.language.tools.EmbedderTester; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.zip.GZIPInputStream; + +/** + * @author bjorncs + */ +class HuggingFaceTokenizerTest { + + @TempDir Path tmp; + + @Test + void bert_tokenizer() throws IOException { + try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) { + var tester = new EmbedderTester(tokenizer); + tester.assertSegmented("what was the impact of the manhattan project", + "what", "was", "the", "impact", "of", "the", "manhattan", "project"); + tester.assertSegmented("overcommunication", "over", "##com", "##mun", "##ication"); + tester.assertEmbedded("what was the impact of the manhattan project", + "tensor(x[8])", + 2054, 2001, 1996, 4254, 1997, 1996, 7128, 2622); + tester.assertDecoded("what was the impact of the manhattan project"); + } + } + + @Test + void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException { + try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) { + var tester = new EmbedderTester(tokenizer); + tester.assertSegmented("h", "▁h"); + tester.assertSegmented("he", "▁he"); + tester.assertSegmented("hel", "▁hel"); + tester.assertSegmented("hello", "▁hell", "o"); + tester.assertSegmented("hei", "▁hei"); + tester.assertSegmented("hei you", "▁hei", "▁you"); + tester.assertSegmented("hei you", "▁hei", "▁you"); + tester.assertSegmented("this is another sentence", "▁this", "▁is", "▁another", "▁sentence"); + tester.assertSegmented("hello world!", "▁hell", "o", "▁world", "!"); + tester.assertSegmented("Hello, world!", "▁Hello", ",", "▁world", "!"); + tester.assertSegmented("HELLO, world!", "▁H", "ELLO", ",", "▁world", "!"); + tester.assertSegmented("KHJKJHHKJHHSH", "▁KH", "JK", "J", "H", "HK", "J", "HH", "SH"); + tester.assertSegmented("KHJKJHHKJHHSH hello", "▁KH", "JK", "J", "H", "HK", "J", "HH", "SH", "▁hell", "o"); + tester.assertSegmented(" hello ", "▁hell", "o"); + tester.assertSegmented(")(/&#()/\"\")", "▁", ")(", "/", "&#", "(", ")", "/", "\"", "\")"); + tester.assertSegmented(")(/&#(small)/\"in quotes\")", "▁", ")(", "/", "&#", "(", "s", "mall", ")", "/", "\"", "in", "▁quote", "s", "\")"); + tester.assertSegmented("x.400AS", "▁x", ".", "400", "AS"); + tester.assertSegmented("A normal sentence. Yes one more.", "▁A", "▁normal", "▁sentence", ".", "▁Yes", "▁one", "▁more", "."); + + tester.assertEmbedded("hello, world!", "tensor(d[10])", 33600, 31, 4, 8999, 38); + tester.assertEmbedded("Hello, world!", "tensor(d[10])", 35378, 4, 8999, 38); + tester.assertEmbedded("hello, world!", "tensor(d[2])", 33600, 31, 4, 8999, 38); + + tester.assertDecoded("this is a sentence"); + tester.assertDecoded("hello, world!"); + tester.assertDecoded(")(/&#(small)/ \"in quotes\")"); + } + } + + private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { + return new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(decompressModelFile(tmp, Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)))) + .build(); + } + + private static Path decompressModelFile(Path tmp, Path source) throws IOException { + Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); + try (InputStream in = new GZIPInputStream(Files.newInputStream(source)); + OutputStream out = Files.newOutputStream(destination, StandardOpenOption.CREATE)) { + in.transferTo(out); + } + return destination; + } + +}
\ No newline at end of file diff --git a/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz b/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz Binary files differnew file mode 100644 index 00000000000..7d0541849f7 --- /dev/null +++ b/linguistics-components/src/test/models/huggingface/bert-base-uncased.json.gz diff --git a/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz b/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz Binary files differnew file mode 100644 index 00000000000..7b61a27198c --- /dev/null +++ b/linguistics-components/src/test/models/huggingface/paraphrase-multilingual-mpnet-base-v2.json.gz diff --git a/model-integration/pom.xml b/model-integration/pom.xml index cc5eccff2ac..681003fdc89 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -91,26 +91,6 @@ <artifactId>protobuf-java</artifactId> </dependency> - <dependency> - <groupId>ai.djl.huggingface</groupId> - <artifactId>tokenizers</artifactId> - <version>0.22.1</version> - <exclusions> - <exclusion> - <groupId>com.google.code.gson</groupId> - <artifactId>gson</artifactId> - </exclusion> - <exclusion> - <groupId>net.java.dev.jna</groupId> - <artifactId>jna</artifactId> - </exclusion> - <exclusion> - <groupId>org.slf4j</groupId> - <artifactId>slf4j-api</artifactId> - </exclusion> - </exclusions> - </dependency> - <dependency> <groupId>org.lz4</groupId> <artifactId>lz4-java</artifactId> diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 365a50f47b5..0c1cc80544e 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -3,9 +3,11 @@ package ai.vespa.embedding.huggingface; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.api.annotations.Beta; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.language.huggingface.HuggingFaceTokenizer; import com.yahoo.language.process.Embedder; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -14,17 +16,18 @@ import com.yahoo.tensor.TensorType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; import java.nio.file.Paths; import java.util.List; import java.util.Map; +@Beta public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private static final Logger LOG = LoggerFactory.getLogger(HuggingFaceEmbedder.class.getName()); private final String inputIdsName; private final String attentionMaskName; + private final String tokenTypeIdsName; private final String outputName; private final int maxTokens; private final boolean normalize; @@ -32,13 +35,17 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final OnnxEvaluator evaluator; @Inject - public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) throws IOException { + public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); + tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); - tokenizer = new HuggingFaceTokenizer(Paths.get(config.tokenizerPath().toString())); + tokenizer = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(true) + .addDefaultModel(Paths.get(config.tokenizerPath().toString())) + .build(); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) onnxOpts.setGpuDevice(config.transformerGpuDevice()); @@ -52,6 +59,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); validateName(inputs, attentionMaskName, "input"); + if (!tokenTypeIdsName.isEmpty()) validateName(inputs, tokenTypeIdsName, "input"); Map<String, TensorType> outputs = evaluator.getOutputInfo(); validateName(outputs, outputName, "output"); @@ -66,8 +74,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { - Encoding encoding = tokenizer.encode(s); - List<Integer> tokenIds = encoding.ids().stream().map(Long::intValue).toList(); + var tokenIds = tokenizer.embed(s, context); int tokensSize = tokenIds.size(); @@ -87,18 +94,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String s, Context context, TensorType tensorType) { - List<Integer> tokenIds = embed(s.toLowerCase(), context); - return embedTokens(tokenIds, tensorType); - } - - Tensor embedTokens(List<Integer> tokenIds, TensorType tensorType) { - Tensor inputSequence = createTensorRepresentation(tokenIds, "d1"); - Tensor attentionMask = createAttentionMask(inputSequence); - - Map<String, Tensor> inputs = Map.of( - inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0") - ); + var encoding = tokenizer.encode(s, context.getLanguage()); + Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); + Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); + Tensor tokenTypeIds = createTensorRepresentation(encoding.typeIds(), "d1"); + + + Map<String, Tensor> inputs; + if (tokenTypeIds.isEmpty()) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0"), + tokenTypeIdsName, tokenTypeIds.expand("d0")); + } Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); @@ -136,7 +146,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) { + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { int size = input.size(); TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java deleted file mode 100644 index e6765a4cc8a..00000000000 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizer.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.embedding.huggingface; - -import java.io.IOException; -import java.nio.file.Path; - -/** - * Wrapper around Deep Java Library's HuggingFace tokenizer. - * - * @author bjorncs - */ -public class HuggingFaceTokenizer implements AutoCloseable { - - private final ai.djl.huggingface.tokenizers.HuggingFaceTokenizer instance; - - public HuggingFaceTokenizer(Path path) throws IOException { this(path, HuggingFaceTokenizerOptions.defaults()); } - - public HuggingFaceTokenizer(Path path, HuggingFaceTokenizerOptions opts) throws IOException { - var original = Thread.currentThread().getContextClassLoader(); - Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); - try { - instance = createInstance(path, opts); - } finally { - Thread.currentThread().setContextClassLoader(original); - } - } - - public Encoding encode(String text) { return Encoding.from(instance.encode(text)); } - - @Override public void close() { instance.close(); } - - private static ai.djl.huggingface.tokenizers.HuggingFaceTokenizer createInstance( - Path path, HuggingFaceTokenizerOptions opts) throws IOException { - var builder = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder().optTokenizerPath(path); - opts.addSpecialToken().ifPresent(builder::optAddSpecialTokens); - opts.truncation().ifPresent(builder::optTruncation); - if (opts.truncateFirstOnly()) builder.optTruncateFirstOnly(); - if (opts.truncateSecondOnly()) builder.optTruncateSecondOnly(); - opts.padding().ifPresent(builder::optPadding); - if (opts.padToMaxLength()) builder.optPadToMaxLength(); - opts.maxLength().ifPresent(builder::optMaxLength); - opts.padToMultipleOf().ifPresent(builder::optPadToMultipleOf); - opts.stride().ifPresent(builder::optStride); - return builder.build(); - } -} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java deleted file mode 100644 index 74f80756603..00000000000 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceTokenizerOptions.java +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.embedding.huggingface; - -import java.util.Optional; -import java.util.OptionalInt; - -/** - * @author bjorncs - */ -public class HuggingFaceTokenizerOptions { - - private final Boolean addSpecialToken; - private final Boolean truncation; - private final boolean truncateFirstOnly; - private final boolean truncateSecondOnly; - private final Boolean padding; - private final boolean padToMaxLength; - private final Integer maxLength; - private final Integer padToMultipleOf; - private final Integer stride; - - private HuggingFaceTokenizerOptions(Builder b) { - addSpecialToken = b.addSpecialToken; - truncation = b.truncation; - truncateFirstOnly = b.truncateFirstOnly; - truncateSecondOnly = b.truncateSecondOnly; - padding = b.padding; - padToMaxLength = b.padToMaxLength; - maxLength = b.maxLength; - padToMultipleOf = b.padToMultipleOf; - stride = b.stride; - } - - public static Builder custom() { return new Builder(); } - public static HuggingFaceTokenizerOptions defaults() { return new Builder().build(); } - - Optional<Boolean> addSpecialToken() { return Optional.ofNullable(addSpecialToken); } - Optional<Boolean> truncation() { return Optional.ofNullable(truncation); } - boolean truncateFirstOnly() { return truncateFirstOnly; } - boolean truncateSecondOnly() { return truncateSecondOnly; } - Optional<Boolean> padding() { return Optional.ofNullable(padding); } - boolean padToMaxLength() { return padToMaxLength; } - OptionalInt maxLength() { return maxLength != null ? OptionalInt.of(maxLength) : OptionalInt.empty(); } - OptionalInt padToMultipleOf() { return padToMultipleOf != null ? OptionalInt.of(padToMultipleOf) : OptionalInt.empty(); } - OptionalInt stride() { return stride != null ? OptionalInt.of(stride) : OptionalInt.empty(); } - - public static class Builder { - private Boolean addSpecialToken; - private Boolean truncation; - private boolean truncateFirstOnly; - private boolean truncateSecondOnly; - private Boolean padding; - private boolean padToMaxLength; - private Integer maxLength; - private Integer padToMultipleOf; - private Integer stride; - - public Builder addSpecialToken(boolean enabled) { addSpecialToken = enabled; return this; } - public Builder truncation(boolean enabled) { truncation = enabled; return this; } - public Builder truncateFirstOnly() { truncateFirstOnly = true; return this; } - public Builder truncateSecondOnly() { truncateSecondOnly = true; return this; } - public Builder padding(boolean enabled) { padding = enabled; return this; } - public Builder padToMaxLength() { padToMaxLength = true; return this; } - public Builder maxLength(int length) { maxLength = length; return this; } - public Builder padToMultipleOf(int num) { padToMultipleOf = num; return this; } - public Builder stride(int stride) { this.stride = stride; return this; } - public HuggingFaceTokenizerOptions build() { return new HuggingFaceTokenizerOptions(this); } - } - -} diff --git a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def index 1dccea0ddf6..97515818f14 100644 --- a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def +++ b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def @@ -12,6 +12,7 @@ transformerMaxTokens int default=512 # Input names transformerInputIds string default=input_ids transformerAttentionMask string default=attention_mask +transformerTokenTypeIds string default=token_type_ids # Output name transformerOutput string default=last_hidden_state |