aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-11 16:24:23 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-11 16:41:54 +0200
commit9498a63f89a08f862c2f6a4a7c17441a6365e69a (patch)
tree740f62e276306a329a6f382ca40ef87182226700
parenta7e91df672012078a9ab6566c6ee604460a4dcc5 (diff)
Handle models requiring token type ids
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java32
-rw-r--r--model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def1
2 files changed, 20 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 123ca621d0a..0c1cc80544e 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -27,6 +27,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
private final String inputIdsName;
private final String attentionMaskName;
+ private final String tokenTypeIdsName;
private final String outputName;
private final int maxTokens;
private final boolean normalize;
@@ -38,6 +39,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
maxTokens = config.transformerMaxTokens();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
+ tokenTypeIdsName = config.transformerTokenTypeIds();
outputName = config.transformerOutput();
normalize = config.normalize();
tokenizer = new HuggingFaceTokenizer.Builder()
@@ -57,6 +59,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
validateName(inputs, attentionMaskName, "input");
+ if (!tokenTypeIdsName.isEmpty()) validateName(inputs, tokenTypeIdsName, "input");
Map<String, TensorType> outputs = evaluator.getOutputInfo();
validateName(outputs, outputName, "output");
@@ -91,18 +94,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String s, Context context, TensorType tensorType) {
- List<Integer> tokenIds = embed(s, context);
- return embedTokens(tokenIds, tensorType);
- }
-
- Tensor embedTokens(List<Integer> tokenIds, TensorType tensorType) {
- Tensor inputSequence = createTensorRepresentation(tokenIds, "d1");
- Tensor attentionMask = createAttentionMask(inputSequence);
-
- Map<String, Tensor> inputs = Map.of(
- inputIdsName, inputSequence.expand("d0"),
- attentionMaskName, attentionMask.expand("d0")
- );
+ var encoding = tokenizer.encode(s, context.getLanguage());
+ Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
+ Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
+ Tensor tokenTypeIds = createTensorRepresentation(encoding.typeIds(), "d1");
+
+
+ Map<String, Tensor> inputs;
+ if (tokenTypeIds.isEmpty()) {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"));
+ } else {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"),
+ tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ }
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
@@ -140,7 +146,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
+ private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
int size = input.size();
TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
diff --git a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
index 1dccea0ddf6..97515818f14 100644
--- a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def
@@ -12,6 +12,7 @@ transformerMaxTokens int default=512
# Input names
transformerInputIds string default=input_ids
transformerAttentionMask string default=attention_mask
+transformerTokenTypeIds string default=token_type_ids
# Output name
transformerOutput string default=last_hidden_state