aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-30 16:52:45 +0200
committerBjørn Christian Seime <bjorncs@yahooinc.com>2023-05-30 16:52:47 +0200
commitbafc5758799fa2e9e2ece36a85ea27c2c92982fd (patch)
treeba55b18d5a1d9b7e48209b18f659a47db6d77f23 /model-integration
parent4279a9423629c5ea9daa15d5d297f2347b7c22ec (diff)
Properly ignore token type ids from tokenizer if disabled
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java4
1 files changed, 2 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 88028da85c1..01804656bb6 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -84,11 +84,11 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
var encoding = tokenizer.encode(s, context.getLanguage());
Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1");
Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1");
- Tensor tokenTypeIds = createTensorRepresentation(encoding.typeIds(), "d1");
+ Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1");
Map<String, Tensor> inputs;
- if (tokenTypeIds.isEmpty()) {
+ if (tokenTypeIdsName.isEmpty() || tokenTypeIds.isEmpty()) {
inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
attentionMaskName, attentionMask.expand("d0"));
} else {