diff options
6 files changed, 48 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java index 966dbe8260a..e0572f8391e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java @@ -23,6 +23,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT private final Boolean specialTokens; private final Integer maxLength; private final Boolean truncation; + private final Boolean padding; public HuggingFaceTokenizer(Element xml, DeployState state) { super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml); @@ -33,6 +34,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null); maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null); truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null); + padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null); } @Override @@ -43,5 +45,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT if (specialTokens != null) builder.addSpecialTokens(specialTokens); if (maxLength != null) builder.maxLength(maxLength); if (truncation != null) builder.truncation(truncation); + if (padding != null) builder.padding(padding); } } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 061e54740f1..e130bed0297 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -102,7 +102,8 @@ HuggingFaceTokenizer = element model { attribute language { xsd:string }? & ModelReference }+ & element special-tokens { xsd:boolean }? & element max-length { xsd:integer }? & - element truncation { xsd:boolean }? + element truncation { xsd:boolean }? & + element padding { xsd:boolean }? BertBaseEmbedder = attribute type { "bert-embedder" } & diff --git a/configdefinitions/src/vespa/hugging-face-tokenizer.def b/configdefinitions/src/vespa/hugging-face-tokenizer.def index 18b3631e494..bc0d5300de5 100644 --- a/configdefinitions/src/vespa/hugging-face-tokenizer.def +++ b/configdefinitions/src/vespa/hugging-face-tokenizer.def @@ -9,5 +9,6 @@ model[].language string model[].path model addSpecialTokens bool default=true -maxLength int default=-1 -truncation bool default=false
\ No newline at end of file +maxLength int default=512 +truncation bool default=true +padding bool default=false diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java index 2c66fc18c9b..1f1757e6ade 100644 --- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java +++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java @@ -43,9 +43,10 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, uncheck(() -> { var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder() .optTokenizerPath(path) - .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true); - if (b.maxLength != null) hfb.optMaxLength(b.maxLength); - if (b.truncation != null) hfb.optTruncation(b.truncation); + .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true) + .optTruncation(b.truncation != null ? b.truncation : true) + .optMaxLength(b.maxLength != null ? b.maxLength : 512); + if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false); return hfb.build(); })); }); @@ -97,6 +98,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, private Boolean addSpecialTokens; private Integer maxLength; private Boolean truncation; + private Boolean padding; public Builder() {} public Builder(HuggingFaceTokenizerConfig cfg) { @@ -105,6 +107,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, addSpecialTokens(cfg.addSpecialTokens()); if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength()); if (cfg.truncation()) setTruncation(true); + if (cfg.padding()) setPadding(true); } public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; } @@ -112,6 +115,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder, public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; } public Builder setMaxLength(int length) { maxLength = length; return this; } public Builder setTruncation(boolean enabled) { truncation = enabled; return this; } + public Builder setPadding(boolean enabled) { padding = enabled; return this; } public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); } } diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java index 6197fe214f1..bf2e0f82829 100644 --- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java +++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java @@ -16,6 +16,7 @@ import java.nio.file.StandardOpenOption; import java.util.zip.GZIPInputStream; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; /** @@ -27,7 +28,10 @@ class HuggingFaceTokenizerTest { @Test void bert_tokenizer() throws IOException { - try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) { + try (var tokenizer = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .build()) { var tester = new EmbedderTester(tokenizer); tester.assertSegmented("what was the impact of the manhattan project", "what", "was", "the", "impact", "of", "the", "manhattan", "project"); @@ -41,7 +45,10 @@ class HuggingFaceTokenizerTest { @Test void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException { - try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) { + try (var tokenizer = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2")) + .build()) { var tester = new EmbedderTester(tokenizer); tester.assertSegmented("h", "▁h"); tester.assertSegmented("he", "▁he"); @@ -82,8 +89,28 @@ class HuggingFaceTokenizerTest { String input = "what was the impact of the manhattan project"; try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build(); var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) { - assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input)); - assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input)); + var encodingWithoutSpecialTokens = tokenizerWithoutSpecialTokens.encode(input); + assertMaxLengthRespected(maxLength, encodingWithoutSpecialTokens); + assertNotEquals(101, encodingWithoutSpecialTokens.ids().get(0)); + var encodingWithSpecialTokens = tokenizerWithSpecialTokens.encode(input); + assertMaxLengthRespected(maxLength, encodingWithSpecialTokens); + assertEquals(101, encodingWithSpecialTokens.ids().get(0)); + } + } + + @Test + void disables_padding_by_default() throws IOException { + var builder = new HuggingFaceTokenizer.Builder() + .setTruncation(true) + .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased")) + .addSpecialTokens(true).setMaxLength(32); + String input = "what was the impact of the manhattan project"; + try (var tokenizerWithDefaultPadding = builder.build(); + var tokenizerWithPaddingDisabled = builder.setPadding(false).build(); + var tokenizerWithPaddingEnabled = builder.setPadding(true).build()) { + assertMaxLengthRespected(10, tokenizerWithDefaultPadding.encode(input)); + assertMaxLengthRespected(10, tokenizerWithPaddingDisabled.encode(input)); + assertMaxLengthRespected(32, tokenizerWithPaddingEnabled.encode(input)); } } @@ -94,13 +121,6 @@ class HuggingFaceTokenizerTest { assertEquals(maxLength, encoding.typeIds().size()); } - private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException { - return new HuggingFaceTokenizer.Builder() - .addSpecialTokens(false) - .addDefaultModel(decompressModelFile(tmp, model)) - .build(); - } - private static Path decompressModelFile(Path tmp, String model) throws IOException { var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model)); Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", "")); diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index f93b1a3c1f8..17b63fb1c7d 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -42,6 +42,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { .addSpecialTokens(true) .addDefaultModel(Paths.get(config.tokenizerPath().toString())) .setTruncation(true) + .setPadding(false) .setMaxLength(config.transformerMaxTokens()) .build(); poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString()); @@ -102,17 +103,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - Tensor.Builder builder = Tensor.Builder.of(tensorType); - - // Mean pooling implementation - Tensor summedEmbeddings = tokenEmbeddings.sum("d1"); - Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1"); - Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) { - builder.cell(averaged.get(TensorAddress.of(0,i)), i); - } - - Tensor result = builder.build(); + var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); return normalize ? normalize(result, tensorType) : result; } |