diff options
author | connell gough <connell@portrait-analytics.com> | 2023-04-11 09:28:27 -0700 |
---|---|---|
committer | connell gough <connell@portrait-analytics.com> | 2023-04-13 10:19:34 -0700 |
commit | d398c0760456b8bffb35cf30c0edd708dcde69d4 (patch) | |
tree | 30161d609af363a15f904cc5afbbf6e936d51dc2 /model-integration | |
parent | 9771cd37fe57b3b581b011c89ee0a4320298d783 (diff) |
Add special tokens as arguments and allow tokenTypeIds to be null
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java | 16 | ||||
-rw-r--r-- | model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def | 4 |
2 files changed, 13 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index b40e2b5be72..a8c4d935cae 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -32,10 +32,10 @@ import java.util.Map; */ public class BertBaseEmbedder extends AbstractComponent implements Embedder { - private final static int TOKEN_CLS = 101; // [CLS] - private final static int TOKEN_SEP = 102; // [SEP] - private final int maxTokens; + private final int startSequenceToken; + private final int endSequenceToken; + private final int separatorToken; private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -48,6 +48,8 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { @Inject public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); + startSequenceToken = config.transformerStartSequenceToken(); + endSequenceToken = config.transformerStartSequenceToken(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -107,7 +109,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { Tensor embedTokens(List<Integer> tokens, TensorType type) { Tensor inputSequence = createTensorRepresentation(tokens, "d1"); Tensor attentionMask = createAttentionMask(inputSequence); - Tensor tokenTypeIds = createTokenTypeIds(inputSequence); + Map<String, Tensor> inputs; if (!"".equals(tokenTypeIdsName)) { @@ -140,12 +142,12 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) { List<Integer> tokens = new ArrayList<>(); - tokens.add(TOKEN_CLS); + tokens.add(startSequenceToken); tokens.addAll(embed(text, context)); - tokens.add(TOKEN_SEP); + tokens.add(endSequenceToken); if (tokens.size() > maxLength) { tokens = tokens.subList(0, maxLength-1); - tokens.add(TOKEN_SEP); + tokens.add(endSequenceToken); } return tokens; } diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def index 14d953eeef9..14f5e95a6b8 100644 --- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def +++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def @@ -17,6 +17,10 @@ transformerInputIds string default=input_ids transformerAttentionMask string default=attention_mask transformerTokenTypeIds string default=token_type_ids +# special token ids +transformerStartSequenceToken int default= +transformerEndSequenceToken int default= + # Output name transformerOutput string default=output_0 |