aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-06-12 17:25:55 +0200
committerGitHub <noreply@github.com>2023-06-12 17:25:55 +0200
commitfb5d1bf9f451fbeb4a40d7f73fa856ef81bd77ed (patch)
tree1f8ab291370e84407a827f9a80bcf943f522ca29
parent0647b650c3334ff86d50431e78549e25dc46caf9 (diff)
parent4f722322cc9f8df5146ffb27d74239b3b4f2d634 (diff)
Merge pull request #27387 from vespa-engine/bjorncs/hfv8.176.13
Prefer truncation configuration from tokenizer model
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java16
-rw-r--r--config-model/src/main/resources/schema/common.rnc8
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml3
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java6
-rw-r--r--configdefinitions/src/vespa/hugging-face-tokenizer.def13
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java76
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java41
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java25
9 files changed, 150 insertions, 47 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index e0572f8391e..0bf5491e872 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -4,6 +4,8 @@ package com.yahoo.vespa.model.container.component;
import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Padding;
+import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig.Truncation;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.xml.ModelIdResolver;
import org.w3c.dom.Element;
@@ -11,7 +13,6 @@ import org.w3c.dom.Element;
import java.util.Map;
import java.util.TreeMap;
-import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChildValue;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
/**
@@ -20,10 +21,6 @@ import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTI
public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceTokenizerConfig.Producer {
private final Map<String, ModelReference> langToModel = new TreeMap<>();
- private final Boolean specialTokens;
- private final Integer maxLength;
- private final Boolean truncation;
- private final Boolean padding;
public HuggingFaceTokenizer(Element xml, DeployState state) {
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
@@ -31,10 +28,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
var lang = element.hasAttribute("language") ? element.getAttribute("language") : "unknown";
langToModel.put(lang, ModelIdResolver.resolveToModelReference(element, state));
}
- specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
- maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
- truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null);
- padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null);
}
@Override
@@ -42,9 +35,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
langToModel.forEach((lang, vocab) -> {
builder.model.add(new HuggingFaceTokenizerConfig.Model.Builder().language(lang).path(vocab));
});
- if (specialTokens != null) builder.addSpecialTokens(specialTokens);
- if (maxLength != null) builder.maxLength(maxLength);
- if (truncation != null) builder.truncation(truncation);
- if (padding != null) builder.padding(padding);
+ builder.truncation(Truncation.Enum.OFF).padding(Padding.Enum.OFF).addSpecialTokens(false);
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index e130bed0297..ba7e2b6674e 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -88,7 +88,7 @@ HuggingFaceEmbedder =
attribute type { "hugging-face-embedder" } &
element transformer-model { ModelReference } &
element tokenizer-model { ModelReference }? &
- element max-tokens { xsd:nonNegativeInteger }? &
+ element max-tokens { xsd:positiveInteger }? &
element transformer-input-ids { xsd:string }? &
element transformer-attention-mask { xsd:string }? &
element transformer-token-type-ids { xsd:string }? &
@@ -99,11 +99,7 @@ HuggingFaceEmbedder =
HuggingFaceTokenizer =
attribute type { "hugging-face-tokenizer" } &
- element model { attribute language { xsd:string }? & ModelReference }+ &
- element special-tokens { xsd:boolean }? &
- element max-length { xsd:integer }? &
- element truncation { xsd:boolean }? &
- element padding { xsd:boolean }?
+ element model { attribute language { xsd:string }? & ModelReference }+
BertBaseEmbedder =
attribute type { "bert-embedder" } &
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 6823ef900ae..b70a9d5f5f1 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -21,9 +21,6 @@
<component id="hf-tokenizer" type="hugging-face-tokenizer">
<model language="no" model-id="multilingual-e5-base-vocab" url="https://my/url/tokenizer.json"/>
- <special-tokens>true</special-tokens>
- <max-length>768</max-length>
- <truncation>true</truncation>
</component>
<component id="bert-embedder" type="bert-embedder">
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 2a82daef9e3..dc62bfdbbef 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -77,9 +77,10 @@ public class EmbedderTestCase {
var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
assertEquals("my_input_ids", embedderCfg.transformerInputIds());
assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
- assertEquals(768, tokenizerCfg.maxLength());
+ assertEquals(-1, tokenizerCfg.maxLength());
}
@Test
@@ -89,9 +90,10 @@ public class EmbedderTestCase {
var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster);
assertEquals("my_input_ids", embedderCfg.transformerInputIds());
assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value());
+ assertEquals(1024, embedderCfg.transformerMaxTokens());
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
- assertEquals(768, tokenizerCfg.maxLength());
+ assertEquals(-1, tokenizerCfg.maxLength());
}
diff --git a/configdefinitions/src/vespa/hugging-face-tokenizer.def b/configdefinitions/src/vespa/hugging-face-tokenizer.def
index bc0d5300de5..896a7b03234 100644
--- a/configdefinitions/src/vespa/hugging-face-tokenizer.def
+++ b/configdefinitions/src/vespa/hugging-face-tokenizer.def
@@ -8,7 +8,14 @@ model[].language string
# The path to the model relative to the application package root
model[].path model
+# Include special tokens in output
addSpecialTokens bool default=true
-maxLength int default=512
-truncation bool default=true
-padding bool default=false
+
+# Used for truncation/padding. Use -1 for model default.
+maxLength int default=-1
+
+# Truncation strategy. Use NOTSET for model default.
+truncation enum { ON, OFF, NOTSET } default=NOTSET
+
+# Padding strategy. Use NOTSET for model default.
+padding enum { ON, OFF, NOTSET } default=NOTSET
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 1f1757e6ade..17360efd0af 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -2,10 +2,14 @@
package com.yahoo.language.huggingface;
+import ai.djl.huggingface.tokenizers.jni.LibUtils;
+import ai.djl.huggingface.tokenizers.jni.TokenizersLibrary;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.Language;
+import com.yahoo.language.huggingface.ModelInfo.PaddingStrategy;
+import com.yahoo.language.huggingface.ModelInfo.TruncationStrategy;
import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.process.Segmenter;
@@ -13,12 +17,14 @@ import com.yahoo.language.tools.Embed;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
+import java.util.function.Supplier;
import static com.yahoo.yolean.Exceptions.uncheck;
@@ -30,29 +36,39 @@ import static com.yahoo.yolean.Exceptions.uncheck;
@Beta
public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, Segmenter, AutoCloseable {
- private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models = new EnumMap<>(Language.class);
+ private final Map<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer> models;
@Inject public HuggingFaceTokenizer(HuggingFaceTokenizerConfig cfg) { this(new Builder(cfg)); }
+ static {
+ // Stop HuggingFace Tokenizer from reporting usage statistics back to mothership
+ // See ai.djl.util.Ec2Utils.callHome()
+ System.setProperty("OPT_OUT_TRACKING", "true");
+ }
+
private HuggingFaceTokenizer(Builder b) {
- var original = Thread.currentThread().getContextClassLoader();
- Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
- try {
+ this.models = withContextClassloader(() -> {
+ var models = new EnumMap<Language, ai.djl.huggingface.tokenizers.HuggingFaceTokenizer>(Language.class);
b.models.forEach((language, path) -> {
models.put(language,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
- .optTruncation(b.truncation != null ? b.truncation : true)
- .optMaxLength(b.maxLength != null ? b.maxLength : 512);
- if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
+ if (b.maxLength != null) {
+ hfb.optMaxLength(b.maxLength);
+ // Override modelMaxLength to workaround HF tokenizer limiting maxLength to 512
+ hfb.configure(Map.of("modelMaxLength", b.maxLength > 0 ? b.maxLength : Integer.MAX_VALUE));
+ }
+ if (b.padding != null) {
+ if (b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
+ }
+ if (b.truncation != null) hfb.optTruncation(b.truncation);
return hfb.build();
}));
});
- } finally {
- Thread.currentThread().setContextClassLoader(original);
- }
+ return models;
+ });
}
@Override
@@ -84,6 +100,24 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
@Override public void close() { models.forEach((__, model) -> model.close()); }
@Override public void deconstruct() { close(); }
+ public static ModelInfo getModelInfo(Path path) {
+ return withContextClassloader(() -> {
+ // Hackish solution to read padding/truncation configuration through JNI wrapper directly
+ LibUtils.checkStatus();
+ var handle = TokenizersLibrary.LIB.createTokenizerFromString(uncheck(() -> Files.readString(path)));
+ try {
+ return new ModelInfo(
+ TruncationStrategy.fromString(TokenizersLibrary.LIB.getTruncationStrategy(handle)),
+ PaddingStrategy.fromString(TokenizersLibrary.LIB.getPaddingStrategy(handle)),
+ TokenizersLibrary.LIB.getMaxLength(handle),
+ TokenizersLibrary.LIB.getStride(handle),
+ TokenizersLibrary.LIB.getPadToMultipleOf(handle));
+ } finally {
+ TokenizersLibrary.LIB.deleteTokenizer(handle);
+ }
+ });
+ }
+
private ai.djl.huggingface.tokenizers.HuggingFaceTokenizer resolve(Language language) {
// Disregard language if there is default model
if (models.size() == 1 && models.containsKey(Language.UNKNOWN)) return models.get(Language.UNKNOWN);
@@ -91,6 +125,16 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
throw new IllegalArgumentException("No model for language " + language);
}
+ private static <R> R withContextClassloader(Supplier<R> r) {
+ var original = Thread.currentThread().getContextClassLoader();
+ Thread.currentThread().setContextClassLoader(HuggingFaceTokenizer.class.getClassLoader());
+ try {
+ return r.get();
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
private static long[] toArray(Collection<? extends Number> c) { return c.stream().mapToLong(Number::longValue).toArray(); }
public static final class Builder {
@@ -106,8 +150,14 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addModel(Language.fromLanguageTag(model.language()), model.path());
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
- if (cfg.truncation()) setTruncation(true);
- if (cfg.padding()) setPadding(true);
+ switch (cfg.truncation()) {
+ case ON -> setTruncation(true);
+ case OFF -> setTruncation(false);
+ }
+ switch (cfg.padding()) {
+ case ON -> setPadding(true);
+ case OFF -> setPadding(false);
+ }
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
new file mode 100644
index 00000000000..4b30b1f0435
--- /dev/null
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/ModelInfo.java
@@ -0,0 +1,41 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.language.huggingface;
+
+import java.util.Arrays;
+
+/**
+ * @author bjorncs
+ */
+public record ModelInfo(
+ TruncationStrategy truncation, PaddingStrategy padding, int maxLength, int stride, int padToMultipleOf) {
+
+ public enum TruncationStrategy {
+ LONGEST_FIRST,
+ ONLY_FIRST,
+ ONLY_SECOND,
+ DO_NOT_TRUNCATE;
+
+ public static TruncationStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST_FIRST;
+ else if ("false".equals(v)) return DO_NOT_TRUNCATE;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+
+ public enum PaddingStrategy {
+ LONGEST,
+ MAX_LENGTH,
+ DO_NOT_PAD;
+
+ public static PaddingStrategy fromString(String v) {
+ if ("true".equals(v)) return LONGEST;
+ else if ("false".equals(v)) return DO_NOT_PAD;
+ return Arrays.stream(values())
+ .filter(s -> s.name().equalsIgnoreCase(v))
+ .findAny().orElseThrow(() -> new IllegalArgumentException("Invalid strategy '%s'".formatted(v)));
+ }
+ }
+}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index bf2e0f82829..f727252cccb 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -99,7 +99,7 @@ class HuggingFaceTokenizerTest {
}
@Test
- void disables_padding_by_default() throws IOException {
+ void pads_to_max_length() throws IOException {
var builder = new HuggingFaceTokenizer.Builder()
.setTruncation(true)
.addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
@@ -114,6 +114,13 @@ class HuggingFaceTokenizerTest {
}
}
+ @Test
+ void provides_model_info() throws IOException {
+ var expected = new ModelInfo(ModelInfo.TruncationStrategy.LONGEST_FIRST, ModelInfo.PaddingStrategy.LONGEST, 128, 0, 0);
+ var actual = HuggingFaceTokenizer.getModelInfo(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2"));
+ assertEquals(expected, actual);
+ }
+
private static void assertMaxLengthRespected(int maxLength, Encoding encoding) {
assertEquals(maxLength, encoding.ids().size());
assertEquals(maxLength, encoding.tokens().size());
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 17b63fb1c7d..b035541bb0f 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -18,10 +18,15 @@ import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
+import java.util.logging.Logger;
+
+import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST;
@Beta
public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
+ private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName());
+
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -38,13 +43,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
- tokenizer = new HuggingFaceTokenizer.Builder()
+ var tokenizerPath = Paths.get(config.tokenizerPath().toString());
+ var builder = new HuggingFaceTokenizer.Builder()
.addSpecialTokens(true)
- .addDefaultModel(Paths.get(config.tokenizerPath().toString()))
- .setTruncation(true)
- .setPadding(false)
- .setMaxLength(config.transformerMaxTokens())
- .build();
+ .addDefaultModel(tokenizerPath)
+ .setPadding(false);
+ var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath);
+ log.fine(() -> "'%s' has info '%s'".formatted(tokenizerPath, info));
+ if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) {
+ // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration
+ int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens()
+ ? info.maxLength()
+ : config.transformerMaxTokens();
+ builder.setTruncation(true).setMaxLength(maxLength);
+ }
+ this.tokenizer = builder.build();
poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
var onnxOpts = new OnnxEvaluatorOptions();
if (config.transformerGpuDevice() >= 0)