summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorconnell gough <connell@portrait-analytics.com>2023-04-11 09:28:27 -0700
committerconnell gough <connell@portrait-analytics.com>2023-04-13 10:19:34 -0700
commitd398c0760456b8bffb35cf30c0edd708dcde69d4 (patch)
tree30161d609af363a15f904cc5afbbf6e936d51dc2 /model-integration
parent9771cd37fe57b3b581b011c89ee0a4320298d783 (diff)
Add special tokens as arguments and allow tokenTypeIds to be null
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java16
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def4
2 files changed, 13 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b40e2b5be72..a8c4d935cae 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -32,10 +32,10 @@ import java.util.Map;
*/
public class BertBaseEmbedder extends AbstractComponent implements Embedder {
- private final static int TOKEN_CLS = 101; // [CLS]
- private final static int TOKEN_SEP = 102; // [SEP]
-
private final int maxTokens;
+ private final int startSequenceToken;
+ private final int endSequenceToken;
+ private final int separatorToken;
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -48,6 +48,8 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
@Inject
public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
+ startSequenceToken = config.transformerStartSequenceToken();
+ endSequenceToken = config.transformerStartSequenceToken();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
@@ -107,7 +109,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor embedTokens(List<Integer> tokens, TensorType type) {
Tensor inputSequence = createTensorRepresentation(tokens, "d1");
Tensor attentionMask = createAttentionMask(inputSequence);
- Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
+
Map<String, Tensor> inputs;
if (!"".equals(tokenTypeIdsName)) {
@@ -140,12 +142,12 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) {
List<Integer> tokens = new ArrayList<>();
- tokens.add(TOKEN_CLS);
+ tokens.add(startSequenceToken);
tokens.addAll(embed(text, context));
- tokens.add(TOKEN_SEP);
+ tokens.add(endSequenceToken);
if (tokens.size() > maxLength) {
tokens = tokens.subList(0, maxLength-1);
- tokens.add(TOKEN_SEP);
+ tokens.add(endSequenceToken);
}
return tokens;
}
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
index 14d953eeef9..14f5e95a6b8 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
@@ -17,6 +17,10 @@ transformerInputIds string default=input_ids
transformerAttentionMask string default=attention_mask
transformerTokenTypeIds string default=token_type_ids
+# special token ids
+transformerStartSequenceToken int default=
+transformerEndSequenceToken int default=
+
# Output name
transformerOutput string default=output_0