summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2023-04-18 11:11:20 +0200
committerJon Bratseth <bratseth@vespa.ai>2023-04-18 11:11:20 +0200
commit5082e42350a117c46f045a31c3502f2c8ad97a0d (patch)
tree58828b8922b16c19dc2746982a0a11b2e1dbf2ef /model-integration
parent95805d16d0b043f0149edf733876067f1a8ee8e7 (diff)
Revert "Merge pull request #26744 from vespa-engine/revert-26708-allow-start-end-sequence-tokens-as-args-bertbaseembedder"
This reverts commit d025a93015e66efc0027d81a64e70530d6cb240e, reversing changes made to 4f2f29e1459b900d4b074f5cfc4c126837c54bfd.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java18
-rw-r--r--model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def4
2 files changed, 14 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index b40e2b5be72..19536f3cb32 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -32,10 +32,9 @@ import java.util.Map;
*/
public class BertBaseEmbedder extends AbstractComponent implements Embedder {
- private final static int TOKEN_CLS = 101; // [CLS]
- private final static int TOKEN_SEP = 102; // [SEP]
-
private final int maxTokens;
+ private final int startSequenceToken;
+ private final int endSequenceToken;
private final String inputIdsName;
private final String attentionMaskName;
private final String tokenTypeIdsName;
@@ -48,6 +47,8 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
@Inject
public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) {
maxTokens = config.transformerMaxTokens();
+ startSequenceToken = config.transformerStartSequenceToken();
+ endSequenceToken = config.transformerStartSequenceToken();
inputIdsName = config.transformerInputIds();
attentionMaskName = config.transformerAttentionMask();
tokenTypeIdsName = config.transformerTokenTypeIds();
@@ -98,7 +99,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
if (!type.dimensions().get(0).isIndexed()) {
throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed.");
}
- List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens);
+ List<Integer> tokens = embedWithSeparatorTokens(text, context, maxTokens);
return embedTokens(tokens, type);
}
@@ -109,6 +110,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
Tensor attentionMask = createAttentionMask(inputSequence);
Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
+
Map<String, Tensor> inputs;
if (!"".equals(tokenTypeIdsName)) {
inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
@@ -138,14 +140,14 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) {
+ private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) {
List<Integer> tokens = new ArrayList<>();
- tokens.add(TOKEN_CLS);
+ tokens.add(startSequenceToken);
tokens.addAll(embed(text, context));
- tokens.add(TOKEN_SEP);
+ tokens.add(endSequenceToken);
if (tokens.size() > maxLength) {
tokens = tokens.subList(0, maxLength-1);
- tokens.add(TOKEN_SEP);
+ tokens.add(endSequenceToken);
}
return tokens;
}
diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
index 14d953eeef9..ef42d81e1fe 100644
--- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
+++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def
@@ -17,6 +17,10 @@ transformerInputIds string default=input_ids
transformerAttentionMask string default=attention_mask
transformerTokenTypeIds string default=token_type_ids
+# special token ids
+transformerStartSequenceToken int default=101
+transformerEndSequenceToken int default=102
+
# Output name
transformerOutput string default=output_0