diff options
Diffstat (limited to 'model-integration/src/main')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index d4a93999dff..002350ce3cf 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -67,8 +67,11 @@ public class BertBaseEmbedder implements Embedder { Map<String, TensorType> inputs = evaluator.getInputInfo(); validateName(inputs, inputIdsName, "input"); validateName(inputs, attentionMaskName, "input"); - validateName(inputs, tokenTypeIdsName, "input"); - + // some BERT inspired models such as DistilBERT do not have token_type_ids input + // one can explicitly declare this is such model by setting that config to empty string + if (!"".equals(tokenTypeIdsName)) { + validateName(inputs, tokenTypeIdsName, "input"); + } Map<String, TensorType> outputs = evaluator.getOutputInfo(); validateName(outputs, outputName, "output"); } @@ -102,9 +105,15 @@ public class BertBaseEmbedder implements Embedder { Tensor attentionMask = createAttentionMask(inputSequence); Tensor tokenTypeIds = createTokenTypeIds(inputSequence); - Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + Map<String, Tensor> inputs; + if (!"".equals(tokenTypeIdsName)) { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), attentionMaskName, attentionMask.expand("d0"), tokenTypeIdsName, tokenTypeIds.expand("d0")); + } else { + inputs = Map.of(inputIdsName, inputSequence.expand("d0"), + attentionMaskName, attentionMask.expand("d0")); + } Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); |