aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java3
-rw-r--r--config-model/src/main/resources/schema/common.rnc3
-rw-r--r--configdefinitions/src/vespa/hugging-face-tokenizer.def5
-rw-r--r--linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java10
-rw-r--r--linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java42
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java13
6 files changed, 48 insertions, 28 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
index 966dbe8260a..e0572f8391e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/HuggingFaceTokenizer.java
@@ -23,6 +23,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
private final Boolean specialTokens;
private final Integer maxLength;
private final Boolean truncation;
+ private final Boolean padding;
public HuggingFaceTokenizer(Element xml, DeployState state) {
super("com.yahoo.language.huggingface.HuggingFaceTokenizer", LINGUISTICS_BUNDLE_NAME, xml);
@@ -33,6 +34,7 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
specialTokens = getOptionalChildValue(xml, "special-tokens").map(Boolean::parseBoolean).orElse(null);
maxLength = getOptionalChildValue(xml, "max-length").map(Integer::parseInt).orElse(null);
truncation = getOptionalChildValue(xml, "truncation").map(Boolean::parseBoolean).orElse(null);
+ padding = getOptionalChildValue(xml, "padding").map(Boolean::parseBoolean).orElse(null);
}
@Override
@@ -43,5 +45,6 @@ public class HuggingFaceTokenizer extends TypedComponent implements HuggingFaceT
if (specialTokens != null) builder.addSpecialTokens(specialTokens);
if (maxLength != null) builder.maxLength(maxLength);
if (truncation != null) builder.truncation(truncation);
+ if (padding != null) builder.padding(padding);
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 061e54740f1..e130bed0297 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -102,7 +102,8 @@ HuggingFaceTokenizer =
element model { attribute language { xsd:string }? & ModelReference }+ &
element special-tokens { xsd:boolean }? &
element max-length { xsd:integer }? &
- element truncation { xsd:boolean }?
+ element truncation { xsd:boolean }? &
+ element padding { xsd:boolean }?
BertBaseEmbedder =
attribute type { "bert-embedder" } &
diff --git a/configdefinitions/src/vespa/hugging-face-tokenizer.def b/configdefinitions/src/vespa/hugging-face-tokenizer.def
index 18b3631e494..bc0d5300de5 100644
--- a/configdefinitions/src/vespa/hugging-face-tokenizer.def
+++ b/configdefinitions/src/vespa/hugging-face-tokenizer.def
@@ -9,5 +9,6 @@ model[].language string
model[].path model
addSpecialTokens bool default=true
-maxLength int default=-1
-truncation bool default=false \ No newline at end of file
+maxLength int default=512
+truncation bool default=true
+padding bool default=false
diff --git a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
index 2c66fc18c9b..1f1757e6ade 100644
--- a/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
+++ b/linguistics-components/src/main/java/com/yahoo/language/huggingface/HuggingFaceTokenizer.java
@@ -43,9 +43,10 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
uncheck(() -> {
var hfb = ai.djl.huggingface.tokenizers.HuggingFaceTokenizer.builder()
.optTokenizerPath(path)
- .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true);
- if (b.maxLength != null) hfb.optMaxLength(b.maxLength);
- if (b.truncation != null) hfb.optTruncation(b.truncation);
+ .optAddSpecialTokens(b.addSpecialTokens != null ? b.addSpecialTokens : true)
+ .optTruncation(b.truncation != null ? b.truncation : true)
+ .optMaxLength(b.maxLength != null ? b.maxLength : 512);
+ if (b.padding != null && b.padding) hfb.optPadToMaxLength(); else hfb.optPadding(false);
return hfb.build();
}));
});
@@ -97,6 +98,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
private Boolean addSpecialTokens;
private Integer maxLength;
private Boolean truncation;
+ private Boolean padding;
public Builder() {}
public Builder(HuggingFaceTokenizerConfig cfg) {
@@ -105,6 +107,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
addSpecialTokens(cfg.addSpecialTokens());
if (cfg.maxLength() != -1) setMaxLength(cfg.maxLength());
if (cfg.truncation()) setTruncation(true);
+ if (cfg.padding()) setPadding(true);
}
public Builder addModel(Language lang, Path path) { models.put(lang, path); return this; }
@@ -112,6 +115,7 @@ public class HuggingFaceTokenizer extends AbstractComponent implements Embedder,
public Builder addSpecialTokens(boolean enabled) { addSpecialTokens = enabled; return this; }
public Builder setMaxLength(int length) { maxLength = length; return this; }
public Builder setTruncation(boolean enabled) { truncation = enabled; return this; }
+ public Builder setPadding(boolean enabled) { padding = enabled; return this; }
public HuggingFaceTokenizer build() { return new HuggingFaceTokenizer(this); }
}
diff --git a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
index 6197fe214f1..bf2e0f82829 100644
--- a/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
+++ b/linguistics-components/src/test/java/com/yahoo/language/huggingface/HuggingFaceTokenizerTest.java
@@ -16,6 +16,7 @@ import java.nio.file.StandardOpenOption;
import java.util.zip.GZIPInputStream;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
/**
@@ -27,7 +28,10 @@ class HuggingFaceTokenizerTest {
@Test
void bert_tokenizer() throws IOException {
- try (var tokenizer = createTokenizer(tmp, "bert-base-uncased")) {
+ try (var tokenizer = new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(false)
+ .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
+ .build()) {
var tester = new EmbedderTester(tokenizer);
tester.assertSegmented("what was the impact of the manhattan project",
"what", "was", "the", "impact", "of", "the", "manhattan", "project");
@@ -41,7 +45,10 @@ class HuggingFaceTokenizerTest {
@Test
void tokenizes_using_paraphrase_multilingual_mpnet_base_v2() throws IOException {
- try (var tokenizer = createTokenizer(tmp, "paraphrase-multilingual-mpnet-base-v2")) {
+ try (var tokenizer = new HuggingFaceTokenizer.Builder()
+ .addSpecialTokens(false)
+ .addDefaultModel(decompressModelFile(tmp, "paraphrase-multilingual-mpnet-base-v2"))
+ .build()) {
var tester = new EmbedderTester(tokenizer);
tester.assertSegmented("h", "▁h");
tester.assertSegmented("he", "▁he");
@@ -82,8 +89,28 @@ class HuggingFaceTokenizerTest {
String input = "what was the impact of the manhattan project";
try (var tokenizerWithoutSpecialTokens = builder.addSpecialTokens(false).build();
var tokenizerWithSpecialTokens = builder.addSpecialTokens(true).build()) {
- assertMaxLengthRespected(maxLength, tokenizerWithoutSpecialTokens.encode(input));
- assertMaxLengthRespected(maxLength, tokenizerWithSpecialTokens.encode(input));
+ var encodingWithoutSpecialTokens = tokenizerWithoutSpecialTokens.encode(input);
+ assertMaxLengthRespected(maxLength, encodingWithoutSpecialTokens);
+ assertNotEquals(101, encodingWithoutSpecialTokens.ids().get(0));
+ var encodingWithSpecialTokens = tokenizerWithSpecialTokens.encode(input);
+ assertMaxLengthRespected(maxLength, encodingWithSpecialTokens);
+ assertEquals(101, encodingWithSpecialTokens.ids().get(0));
+ }
+ }
+
+ @Test
+ void disables_padding_by_default() throws IOException {
+ var builder = new HuggingFaceTokenizer.Builder()
+ .setTruncation(true)
+ .addDefaultModel(decompressModelFile(tmp, "bert-base-uncased"))
+ .addSpecialTokens(true).setMaxLength(32);
+ String input = "what was the impact of the manhattan project";
+ try (var tokenizerWithDefaultPadding = builder.build();
+ var tokenizerWithPaddingDisabled = builder.setPadding(false).build();
+ var tokenizerWithPaddingEnabled = builder.setPadding(true).build()) {
+ assertMaxLengthRespected(10, tokenizerWithDefaultPadding.encode(input));
+ assertMaxLengthRespected(10, tokenizerWithPaddingDisabled.encode(input));
+ assertMaxLengthRespected(32, tokenizerWithPaddingEnabled.encode(input));
}
}
@@ -94,13 +121,6 @@ class HuggingFaceTokenizerTest {
assertEquals(maxLength, encoding.typeIds().size());
}
- private static HuggingFaceTokenizer createTokenizer(Path tmp, String model) throws IOException {
- return new HuggingFaceTokenizer.Builder()
- .addSpecialTokens(false)
- .addDefaultModel(decompressModelFile(tmp, model))
- .build();
- }
-
private static Path decompressModelFile(Path tmp, String model) throws IOException {
var source = Paths.get("src/test/models/huggingface/%s.json.gz".formatted(model));
Path destination = tmp.resolve(source.getFileName().toString().replace(".gz", ""));
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index f93b1a3c1f8..17b63fb1c7d 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -42,6 +42,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
.addSpecialTokens(true)
.addDefaultModel(Paths.get(config.tokenizerPath().toString()))
.setTruncation(true)
+ .setPadding(false)
.setMaxLength(config.transformerMaxTokens())
.build();
poolingStrategy = PoolingStrategy.fromString(config.poolingStrategy().toString());
@@ -102,17 +103,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
- Tensor.Builder builder = Tensor.Builder.of(tensorType);
-
- // Mean pooling implementation
- Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
- Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
- Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) {
- builder.cell(averaged.get(TensorAddress.of(0,i)), i);
- }
-
- Tensor result = builder.build();
+ var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
return normalize ? normalize(result, tensorType) : result;
}