diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-04-14 16:31:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-14 16:31:13 +0200 |
commit | 28e916b2a36d57e5b5e99231babd5f4e81fb0d38 (patch) | |
tree | 7dd2c31787ee9a7f313f8ca04d97906cca3db029 /model-integration | |
parent | 4f2f29e1459b900d4b074f5cfc4c126837c54bfd (diff) |
Revert "Allow start end sequence tokens as args bertbaseembedder"
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java | 18 | ||||
-rw-r--r-- | model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def | 4 |
2 files changed, 8 insertions, 14 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index 19536f3cb32..b40e2b5be72 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -32,9 +32,10 @@ import java.util.Map; */ public class BertBaseEmbedder extends AbstractComponent implements Embedder { + private final static int TOKEN_CLS = 101; // [CLS] + private final static int TOKEN_SEP = 102; // [SEP] + private final int maxTokens; - private final int startSequenceToken; - private final int endSequenceToken; private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -47,8 +48,6 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { @Inject public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); - startSequenceToken = config.transformerStartSequenceToken(); - endSequenceToken = config.transformerStartSequenceToken(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -99,7 +98,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { if (!type.dimensions().get(0).isIndexed()) { throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed."); } - List<Integer> tokens = embedWithSeparatorTokens(text, context, maxTokens); + List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens); return embedTokens(tokens, type); } @@ -110,7 +109,6 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { Tensor attentionMask = createAttentionMask(inputSequence); Tensor tokenTypeIds = createTokenTypeIds(inputSequence); - Map<String, Tensor> inputs; if (!"".equals(tokenTypeIdsName)) { inputs = Map.of(inputIdsName, inputSequence.expand("d0"), @@ -140,14 +138,14 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) { + private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) { List<Integer> tokens = new ArrayList<>(); - tokens.add(startSequenceToken); + tokens.add(TOKEN_CLS); tokens.addAll(embed(text, context)); - tokens.add(endSequenceToken); + tokens.add(TOKEN_SEP); if (tokens.size() > maxLength) { tokens = tokens.subList(0, maxLength-1); - tokens.add(endSequenceToken); + tokens.add(TOKEN_SEP); } return tokens; } diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def index ef42d81e1fe..14d953eeef9 100644 --- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def +++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def @@ -17,10 +17,6 @@ transformerInputIds string default=input_ids transformerAttentionMask string default=attention_mask transformerTokenTypeIds string default=token_type_ids -# special token ids -transformerStartSequenceToken int default=101 -transformerEndSequenceToken int default=102 - # Output name transformerOutput string default=output_0 |