diff options
author | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 16:24:23 +0200 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@yahooinc.com> | 2023-05-11 16:41:54 +0200 |
commit | 9498a63f89a08f862c2f6a4a7c17441a6365e69a (patch) | |
tree | 740f62e276306a329a6f382ca40ef87182226700 | |
parent | a7e91df672012078a9ab6566c6ee604460a4dcc5 (diff) |
Handle models requiring token type ids
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 32 | ||||
-rw-r--r-- | model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def | 1 |
2 files changed, 20 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index 123ca621d0a..0c1cc80544e 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -27,6 +27,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final String inputIdsName; private final String attentionMaskName; + private final String tokenTypeIdsName; private final String outputName; private final int maxTokens; private final boolean normalize; @@ -38,6 +39,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { maxTokens = config.transformerMaxTokens(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); + tokenTypeIdsName = config.transformerTokenTypeIds(); outputName = config.transformerOutput(); normalize = config.normalize(); tokenizer = new HuggingFaceTokenizer.Builder() @@ -57,6 +59,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); validateName(inputs, attentionMaskName, "input"); + if (!tokenTypeIdsName.isEmpty()) validateName(inputs, tokenTypeIdsName, "input"); Map<String, TensorType> outputs = evaluator.getOutputInfo(); validateName(outputs, outputName, "output"); @@ -91,18 +94,21 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String s, Context context, TensorType tensorType) { - List<Integer> tokenIds = embed(s, context); - return embedTokens(tokenIds, tensorType); - } - - Tensor embedTokens(List<Integer> tokenIds, TensorType tensorType) { - Tensor inputSequence = createTensorRepresentation(tokenIds, "d1"); - Tensor attentionMask = createAttentionMask(inputSequence); - - Map<String, Tensor> inputs = Map.of( - inputIdsName, inputSequence.expand("d0"), - attentionMaskName, attentionMask.expand("d0") - ); + var encoding = tokenizer.encode(s, context.getLanguage()); + Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); + Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); + Tensor tokenTypeIds = createTensorRepresentation(encoding.typeIds(), "d1"); + + + Map<String, Tensor> inputs; + if (tokenTypeIds.isEmpty()) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0"), + tokenTypeIdsName, tokenTypeIds.expand("d0")); + } Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); @@ -140,7 +146,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) { + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { int size = input.size(); TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); diff --git a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def index 1dccea0ddf6..97515818f14 100644 --- a/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def +++ b/model-integration/src/main/resources/configdefinitions/hugging-face-embedder.def @@ -12,6 +12,7 @@ transformerMaxTokens int default=512 # Input names transformerInputIds string default=input_ids transformerAttentionMask string default=attention_mask +transformerTokenTypeIds string default=token_type_ids # Output name transformerOutput string default=last_hidden_state |