diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-06-12 17:25:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-12 17:25:55 +0200 |
commit | fb5d1bf9f451fbeb4a40d7f73fa856ef81bd77ed (patch) | |
tree | 1f8ab291370e84407a827f9a80bcf943f522ca29 | |
parent | 0647b650c3334ff86d50431e78549e25dc46caf9 (diff) | |
parent | 4f722322cc9f8df5146ffb27d74239b3b4f2d634 (diff) |
Merge pull request #27387 from vespa-engine/bjorncs/hfv8.176.13
Prefer truncation configuration from tokenizer model
9 files changed, 150 insertions, 47 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index e0572f8391e..0bf5491e872 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -4,6 +4,8 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; +import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding; +import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation; import com.yahoo.text.XML; import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; @@ -11,7 +13,6 @@ import org.w3c.dom.Element; import java.util.Map; import java.util.TreeMap; -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME; /** @@ -20,10 +21,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTI public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer { private final Map<String, ModelReference> langToModel = new TreeMap<>(); - private final Boolean specialTokens; - private final Integer maxLength; - private final Boolean truncation; - private final Boolean padding; public HuggingFaceTokenizer(Element xml, DeployState state) { super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); @@ -31,10 +28,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state)); } - specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null); - maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null); - truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null); - padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null); } @Override @@ -42,9 +35,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT langToModel.forEach((lang, vocab) -> { builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab)); }); - if (specialTokens != null) builder.addSpecialTokens(specialTokens); - if (maxLength != null) builder.maxLength(maxLength); - if (truncation != null) builder.truncation(truncation); - if (padding != null) builder.padding(padding); + builder.truncation(Truncation.Enum.OFF).padding(Padding.Enum.OFF).addSpecialTokens(false); } } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index e130bed0297..ba7e2b6674e 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -88,7 +88,7 @@ HuggingFaceEmbedder = attribute type { "hugging-face-embedder" } & element transformer-model { ModelReference } & element tokenizer-model { ModelReference }? & - element max-tokens { xsd:nonNegativeInteger }? & + element max-tokens { xsd:positiveInteger }? & element transformer-input-ids { xsd:string }? & element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & @@ -99,11 +99,7 @@ HuggingFaceEmbedder = HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & - element model { attribute language { xsd:string }? & ModelReference }+ & - element special-tokens { xsd:boolean }? & - element max-length { xsd:integer }? & - element truncation { xsd:boolean }? & - element padding { xsd:boolean }? + element model { attribute language { xsd:string }? & ModelReference }+ BertBaseEmbedder = attribute type { "bert-embedder" } & diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 6823ef900ae..b70a9d5f5f1 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -21,9 +21,6 @@ <component id="hf-tokenizer" type="hugging-face-tokenizer"> <model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/> - <special-tokens>true</special-tokens> - <max-length>768</max-length> - <truncation>true</truncation> </component> <component id="bert-embedder" type="bert-embedder"> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 2a82daef9e3..dc62bfdbbef 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -77,9 +77,10 @@ public class EmbedderTestCase { var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); assertEquals("my_input_ids", embedderCfg.transformerInputIds()); assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); - assertEquals(768, tokenizerCfg.maxLength()); + assertEquals(-1, tokenizerCfg.maxLength()); } @Test @@ -89,9 +90,10 @@ public class EmbedderTestCase { var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); assertEquals("my_input_ids", embedderCfg.transformerInputIds()); assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); - assertEquals(768, tokenizerCfg.maxLength()); + assertEquals(-1, tokenizerCfg.maxLength()); } diff --git a/configdefinitions/src/vespa/hugging-face-tokenizer.def b/configdefinitions/src/vespa/hugging-face-tokenizer.def index bc0d5300de5..896a7b03234 100644 --- a/configdefinitions/src/vespa/hugging-face-tokenizer.def +++ b/configdefinitions/src/vespa/hugging-face-tokenizer.def @@ -8,7 +8,14 @@ model[].language string # The path to the model relative to the application package root model[].path model +# Include special tokens in output addSpecialTokens bool default=true -maxLength int default=512 -truncation bool default=true -padding bool default=false + +# Used for truncation/padding. Use -1 for model default. +maxLength int default=-1 + +# Truncation strategy. Use NOTSET for model default. +truncation enum { ON, OFF, NOTSET } default=NOTSET + +# Padding strategy. Use NOTSET for model default. +padding enum { ON, OFF, NOTSET } default=NOTSET diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 1f1757e6ade..17360efd0af 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -2,10 +2,14 @@ package com.yahoo.language.huggingface; +import ai.djl.huggingface.tokenizers.jni.LibUtils; +import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary; import com.yahoo.api.annotations.Beta; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.language.Language; +import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy; +import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; @@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Collection; import java.util.EnumMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static com.yahoo.yolean.Exceptions.uncheck; @@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck; @Beta public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable { - private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class); + private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models; @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); } + static { + // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership + // See ai.djl.util.Ec2Utils.callHome() + System.setProperty("OPT_OUT_TRACKING", "true"); + } + private HuggingFaceTokenizer(Builder b) { - var original = Thread.currentThread().getContextClassLoader(); - Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); - try { + this.models = withContextClassloader(() -> { + var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class); b.models.forEach((language, path) -> { models.put(language, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .optTruncation(b.truncation != null ? b.truncation : true) - .optMaxLength(b.maxLength != null ? b.maxLength : 512); - if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) { + hfb.optMaxLength(b.maxLength); + // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512 + hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE)); + } + if (b.padding != null) { + if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + } + if (b.truncation != null) hfb.optTruncation(b.truncation); return hfb.build(); })); }); - } finally { - Thread.currentThread().setContextClassLoader(original); - } + return models; + }); } @Override @@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, @Override public void close() { models.forEach((__, model) -> model.close()); } @Override public void deconstruct() { close(); } + public static ModelInfo getModelInfo(Path path) { + return withContextClassloader(() -> { + // Hackish solution to read padding/truncation configuration through JNI wrapper directly + LibUtils.checkStatus(); + var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path))); + try { + return new ModelInfo( + TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)), + PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)), + TokenizersLibrary.LIB.getMaxLength(handle), + TokenizersLibrary.LIB.getStride(handle), + TokenizersLibrary.LIB.getPadToMultipleOf(handle)); + } finally { + TokenizersLibrary.LIB.deleteTokenizer(handle); + } + }); + } + private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); @@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, throw new IllegalArgumentException("No model for language " + language); } + private static <R> R withContextClassloader(Supplier<R> r) { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + return r.get(); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); } public static final class Builder { @@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); - if (cfg.truncation()) setTruncation(true); - if (cfg.padding()) setPadding(true); + switch (cfg.truncation()) { + case ON -> setTruncation(true); + case OFF -> setTruncation(false); + } + switch (cfg.padding()) { + case ON -> setPadding(true); + case OFF -> setPadding(false); + } } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java new file mode 100644 index 00000000000..4b30b1f0435 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java @@ -0,0 +1,41 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import java.util.Arrays; + +/** + * @author bjorncs + */ +public record ModelInfo( + TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) { + + public enum TruncationStrategy { + LONGEST_FIRST, + ONLY_FIRST, + ONLY_SECOND, + DO_NOT_TRUNCATE; + + public static TruncationStrategy fromString(String v) { + if ("true".equals(v)) return LONGEST_FIRST; + else if ("false".equals(v)) return DO_NOT_TRUNCATE; + return Arrays.stream(values()) + .filter(s -> s.name().equalsIgnoreCase(v)) + .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v))); + } + } + + public enum PaddingStrategy { + LONGEST, + MAX_LENGTH, + DO_NOT_PAD; + + public static PaddingStrategy fromString(String v) { + if ("true".equals(v)) return LONGEST; + else if ("false".equals(v)) return DO_NOT_PAD; + return Arrays.stream(values()) + .filter(s -> s.name().equalsIgnoreCase(v)) + .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v))); + } + } +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index bf2e0f82829..f727252cccb 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest { } @Test - void disables_padding_by_default() throws IOException { + void pads_to_max_length() throws IOException { var builder = new HuggingFaceTokenizer.Builder() .setTruncation(true) .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) @@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest { } } + @Test + void provides_model_info() throws IOException { + var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0); + var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2")); + assertEquals(expected, actual); + } + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { assertEquals(maxLength, encoding.ids().size()); assertEquals(maxLength, encoding.tokens().size()); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 17b63fb1c7d..b035541bb0f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; import java.util.Map; +import java.util.logging.Logger; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; @Beta public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { + private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName()); + private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); - tokenizer = new HuggingFaceTokenizer.Builder() + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) - .addDefaultModel(Paths.get(config.tokenizerPath().toString())) - .setTruncation(true) - .setPadding(false) - .setMaxLength(config.transformerMaxTokens()) - .build(); + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info)); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) |