From 4f722322cc9f8df5146ffb27d74239b3b4f2d634 Mon Sep 17 00:00:00 2001 From: Bjørn Christian Seime Date: Mon, 12 Jun 2023 16:41:37 +0200 Subject: Prefer truncation configuration from tokenizer model Only override truncation if not specified or max length exceeds max tokens accepted by model. Use JNI wrapper directly to determine existing truncation configuration (JSON format is not really documented). Simply configuration for pure tokenizer embedder. Disable DJL usage telemetry. --- .../container/component/HuggingFaceTokenizer.java | 16 +---- config-model/src/main/resources/schema/common.rnc | 8 +-- .../src/test/cfg/application/embed/services.xml | 3 - .../model/container/xml/EmbedderTestCase.java | 6 +- .../src/vespa/hugging-face-tokenizer.def | 13 +++- .../language/huggingface/HuggingFaceTokenizer.java | 76 ++++++++++++++++++---- .../com/yahoo/language/huggingface/ModelInfo.java | 41 ++++++++++++ .../huggingface/HuggingFaceTokenizerTest.java | 9 ++- .../embedding/huggingface/HuggingFaceEmbedder.java | 25 +++++-- 9 files changed, 150 insertions(+), 47 deletions(-) create mode 100644 linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index e0572f8391e..0bf5491e872 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -4,6 +4,8 @@ package com.yahoo.vespa.model.container.component; import com.yahoo.config.ModelReference; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; +import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding; +import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation; import com.yahoo.text.XML; import com.yahoo.vespa.model.container.xml.ModelIdResolver; import org.w3c.dom.Element; @@ -11,7 +13,6 @@ import org.w3c.dom.Element; import java.util.Map; import java.util.TreeMap; -import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME; /** @@ -20,10 +21,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTI public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer { private final Map langToModel = new TreeMap<>(); - private final Boolean specialTokens; - private final Integer maxLength; - private final Boolean truncation; - private final Boolean padding; public HuggingFaceTokenizer(Element xml, DeployState state) { super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); @@ -31,10 +28,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown"; langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state)); } - specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null); - maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null); - truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null); - padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null); } @Override @@ -42,9 +35,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT langToModel.forEach((lang, vocab) -> { builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab)); }); - if (specialTokens != null) builder.addSpecialTokens(specialTokens); - if (maxLength != null) builder.maxLength(maxLength); - if (truncation != null) builder.truncation(truncation); - if (padding != null) builder.padding(padding); + builder.truncation(Truncation.Enum.OFF).padding(Padding.Enum.OFF).addSpecialTokens(false); } } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index e130bed0297..ba7e2b6674e 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -88,7 +88,7 @@ HuggingFaceEmbedder = attribute type { "hugging-face-embedder" } & element transformer-model { ModelReference } & element tokenizer-model { ModelReference }? & - element max-tokens { xsd:nonNegativeInteger }? & + element max-tokens { xsd:positiveInteger }? & element transformer-input-ids { xsd:string }? & element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & @@ -99,11 +99,7 @@ HuggingFaceEmbedder = HuggingFaceTokenizer = attribute type { "hugging-face-tokenizer" } & - element model { attribute language { xsd:string }? & ModelReference }+ & - element special-tokens { xsd:boolean }? & - element max-length { xsd:integer }? & - element truncation { xsd:boolean }? & - element padding { xsd:boolean }? + element model { attribute language { xsd:string }? & ModelReference }+ BertBaseEmbedder = attribute type { "bert-embedder" } & diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 6823ef900ae..b70a9d5f5f1 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -21,9 +21,6 @@ - true - 768 - true diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 2a82daef9e3..dc62bfdbbef 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -77,9 +77,10 @@ public class EmbedderTestCase { var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); assertEquals("my_input_ids", embedderCfg.transformerInputIds()); assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); - assertEquals(768, tokenizerCfg.maxLength()); + assertEquals(-1, tokenizerCfg.maxLength()); } @Test @@ -89,9 +90,10 @@ public class EmbedderTestCase { var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); assertEquals("my_input_ids", embedderCfg.transformerInputIds()); assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); - assertEquals(768, tokenizerCfg.maxLength()); + assertEquals(-1, tokenizerCfg.maxLength()); } diff --git a/configdefinitions/src/vespa/hugging-face-tokenizer.def b/configdefinitions/src/vespa/hugging-face-tokenizer.def index bc0d5300de5..896a7b03234 100644 --- a/configdefinitions/src/vespa/hugging-face-tokenizer.def +++ b/configdefinitions/src/vespa/hugging-face-tokenizer.def @@ -8,7 +8,14 @@ model[].language string # The path to the model relative to the application package root model[].path model +# Include special tokens in output addSpecialTokens bool default=true -maxLength int default=512 -truncation bool default=true -padding bool default=false + +# Used for truncation/padding. Use -1 for model default. +maxLength int default=-1 + +# Truncation strategy. Use NOTSET for model default. +truncation enum { ON, OFF, NOTSET } default=NOTSET + +# Padding strategy. Use NOTSET for model default. +padding enum { ON, OFF, NOTSET } default=NOTSET diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 1f1757e6ade..17360efd0af 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -2,10 +2,14 @@ package com.yahoo.language.huggingface; +import ai.djl.huggingface.tokenizers.jni.LibUtils; +import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary; import com.yahoo.api.annotations.Beta; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.language.Language; +import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy; +import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.language.process.Embedder; import com.yahoo.language.process.Segmenter; @@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.nio.file.Files; import java.nio.file.Path; import java.util.Arrays; import java.util.Collection; import java.util.EnumMap; import java.util.List; import java.util.Map; +import java.util.function.Supplier; import static com.yahoo.yolean.Exceptions.uncheck; @@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck; @Beta public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable { - private final Map models = new EnumMap<>(Language.class); + private final Map models; @Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); } + static { + // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership + // See ai.djl.util.Ec2Utils.callHome() + System.setProperty("OPT_OUT_TRACKING", "true"); + } + private HuggingFaceTokenizer(Builder b) { - var original = Thread.currentThread().getContextClassLoader(); - Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); - try { + this.models = withContextClassloader(() -> { + var models = new EnumMap(Language.class); b.models.forEach((language, path) -> { models.put(language, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) - .optTruncation(b.truncation != null ? b.truncation : true) - .optMaxLength(b.maxLength != null ? b.maxLength : 512); - if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); + if (b.maxLength != null) { + hfb.optMaxLength(b.maxLength); + // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512 + hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE)); + } + if (b.padding != null) { + if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); + } + if (b.truncation != null) hfb.optTruncation(b.truncation); return hfb.build(); })); }); - } finally { - Thread.currentThread().setContextClassLoader(original); - } + return models; + }); } @Override @@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, @Override public void close() { models.forEach((__, model) -> model.close()); } @Override public void deconstruct() { close(); } + public static ModelInfo getModelInfo(Path path) { + return withContextClassloader(() -> { + // Hackish solution to read padding/truncation configuration through JNI wrapper directly + LibUtils.checkStatus(); + var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path))); + try { + return new ModelInfo( + TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)), + PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)), + TokenizersLibrary.LIB.getMaxLength(handle), + TokenizersLibrary.LIB.getStride(handle), + TokenizersLibrary.LIB.getPadToMultipleOf(handle)); + } finally { + TokenizersLibrary.LIB.deleteTokenizer(handle); + } + }); + } + private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) { // Disregard language if there is default model if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN); @@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, throw new IllegalArgumentException("No model for language " + language); } + private static R withContextClassloader(Supplier r) { + var original = Thread.currentThread().getContextClassLoader(); + Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader()); + try { + return r.get(); + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + private static long[] toArray(Collection c) { return c.stream().mapToLong(Number::longValue).toArray(); } public static final class Builder { @@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addModel(Language.fromLanguageTag(model.language()), model.path()); addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); - if (cfg.truncation()) setTruncation(true); - if (cfg.padding()) setPadding(true); + switch (cfg.truncation()) { + case ON -> setTruncation(true); + case OFF -> setTruncation(false); + } + switch (cfg.padding()) { + case ON -> setPadding(true); + case OFF -> setPadding(false); + } } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java new file mode 100644 index 00000000000..4b30b1f0435 --- /dev/null +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java @@ -0,0 +1,41 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.language.huggingface; + +import java.util.Arrays; + +/** + * @author bjorncs + */ +public record ModelInfo( + TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) { + + public enum TruncationStrategy { + LONGEST_FIRST, + ONLY_FIRST, + ONLY_SECOND, + DO_NOT_TRUNCATE; + + public static TruncationStrategy fromString(String v) { + if ("true".equals(v)) return LONGEST_FIRST; + else if ("false".equals(v)) return DO_NOT_TRUNCATE; + return Arrays.stream(values()) + .filter(s -> s.name().equalsIgnoreCase(v)) + .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v))); + } + } + + public enum PaddingStrategy { + LONGEST, + MAX_LENGTH, + DO_NOT_PAD; + + public static PaddingStrategy fromString(String v) { + if ("true".equals(v)) return LONGEST; + else if ("false".equals(v)) return DO_NOT_PAD; + return Arrays.stream(values()) + .filter(s -> s.name().equalsIgnoreCase(v)) + .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v))); + } + } +} diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index bf2e0f82829..f727252cccb 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest { } @Test - void disables_padding_by_default() throws IOException { + void pads_to_max_length() throws IOException { var builder = new HuggingFaceTokenizer.Builder() .setTruncation(true) .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) @@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest { } } + @Test + void provides_model_info() throws IOException { + var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0); + var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2")); + assertEquals(expected, actual); + } + private static void assertMaxLengthRespected(int maxLength, Encoding encoding) { assertEquals(maxLength, encoding.ids().size()); assertEquals(maxLength, encoding.tokens().size()); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 17b63fb1c7d..b035541bb0f 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType; import java.nio.file.Paths; import java.util.List; import java.util.Map; +import java.util.logging.Logger; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; @Beta public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { + private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName()); + private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); - tokenizer = new HuggingFaceTokenizer.Builder() + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() .addSpecialTokens(true) - .addDefaultModel(Paths.get(config.tokenizerPath().toString())) - .setTruncation(true) - .setPadding(false) - .setMaxLength(config.transformerMaxTokens()) - .build(); + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info)); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); var onnxOpts = new OnnxEvaluatorOptions(); if (config.transformerGpuDevice() >= 0) -- cgit v1.2.3