summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-04-14 16:31:13 +0200
committerGitHub <noreply@github.com>2023-04-14 16:31:13 +0200
commit28e916b2a36d57e5b5e99231babd5f4e81fb0d38 (patch)
tree7dd2c31787ee9a7f313f8ca04d97906cca3db029 /model-integration
parent4f2f29e1459b900d4b074f5cfc4c126837c54bfd (diff)
Revert "Allow start end sequence tokens as args bertbaseembedder"
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java18
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def4
2 files changed, 8 insertions, 14 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index 19536f3cb32..b40e2b5be72 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -32,9 +32,10 @@ import java.util.Map;
*/
public class BertBaseEmbedder extends AbstractComponent implements Embedder {
+ private final static int TOKEN_CLS = 101; // [CLS]
+ private final static int TOKEN_SEP = 102; // [SEP]
+
private final int maxTokens;
- private final int startSequenceToken;
- private final int endSequenceToken;
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -47,8 +48,6 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
@Inject
public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
- startSequenceToken = config.transformerStartSequenceToken();
- endSequenceToken = config.transformerStartSequenceToken();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
@@ -99,7 +98,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
if (!type.dimensions().get(0).isIndexed()) {
throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed.");
}
- List<Integer> tokens = embedWithSeparatorTokens(text, context, maxTokens);
+ List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens);
return embedTokens(tokens, type);
}
@@ -110,7 +109,6 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor attentionMask = createAttentionMask(inputSequence);
Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
-
Map<String, Tensor> inputs;
if (!"".equals(tokenTypeIdsName)) {
inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
@@ -140,14 +138,14 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
+ private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) {
List<Integer> tokens = new ArrayList<>();
- tokens.add(startSequenceToken);
+ tokens.add(TOKEN_CLS);
tokens.addAll(embed(text, context));
- tokens.add(endSequenceToken);
+ tokens.add(TOKEN_SEP);
if (tokens.size() > maxLength) {
tokens = tokens.subList(0, maxLength-1);
- tokens.add(endSequenceToken);
+ tokens.add(TOKEN_SEP);
}
return tokens;
}
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
index ef42d81e1fe..14d953eeef9 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
@@ -17,10 +17,6 @@ transformerInputIds string default=input_ids
transformerAttentionMask string default=attention_mask
transformerTokenTypeIds string default=token_type_ids
-# special token ids
-transformerStartSequenceToken int default=101
-transformerEndSequenceToken int default=102
-
# Output name
transformerOutput string default=output_0