summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java15
1 files changed, 12 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index d4a93999dff..002350ce3cf 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -67,8 +67,11 @@ public class BertBaseEmbedder implements Embedder {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
validateName(inputs, attentionMaskName, "input");
- validateName(inputs, tokenTypeIdsName, "input");
-
+ // some BERT inspired models such as DistilBERT do not have token_type_ids input
+ // one can explicitly declare this is such model by setting that config to empty string
+ if (!"".equals(tokenTypeIdsName)) {
+ validateName(inputs, tokenTypeIdsName, "input");
+ }
Map<String, TensorType> outputs = evaluator.getOutputInfo();
validateName(outputs, outputName, "output");
}
@@ -102,9 +105,15 @@ public class BertBaseEmbedder implements Embedder {
Tensor attentionMask = createAttentionMask(inputSequence);
Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
- Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ Map<String, Tensor> inputs;
+ if (!"".equals(tokenTypeIdsName)) {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
attentionMaskName, attentionMask.expand("d0"),
tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ } else {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"));
+ }
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);