diff options
Diffstat (limited to 'model-integration/src')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java | 18 | ||||
-rw-r--r-- | model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def | 4 |
2 files changed, 14 insertions, 8 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index b40e2b5be72..19536f3cb32 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -32,10 +32,9 @@ import java.util.Map; */ public class BertBaseEmbedder extends AbstractComponent implements Embedder { - private final static int TOKEN_CLS = 101; // [CLS] - private final static int TOKEN_SEP = 102; // [SEP] - private final int maxTokens; + private final int startSequenceToken; + private final int endSequenceToken; private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -48,6 +47,8 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { @Inject public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) { maxTokens = config.transformerMaxTokens(); + startSequenceToken = config.transformerStartSequenceToken(); + endSequenceToken = config.transformerStartSequenceToken(); inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -98,7 +99,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { if (!type.dimensions().get(0).isIndexed()) { throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed."); } - List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens); + List<Integer> tokens = embedWithSeparatorTokens(text, context, maxTokens); return embedTokens(tokens, type); } @@ -109,6 +110,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { Tensor attentionMask = createAttentionMask(inputSequence); Tensor tokenTypeIds = createTokenTypeIds(inputSequence); + Map<String, Tensor> inputs; if (!"".equals(tokenTypeIdsName)) { inputs = Map.of(inputIdsName, inputSequence.expand("d0"), @@ -138,14 +140,14 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) { + private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) { List<Integer> tokens = new ArrayList<>(); - tokens.add(TOKEN_CLS); + tokens.add(startSequenceToken); tokens.addAll(embed(text, context)); - tokens.add(TOKEN_SEP); + tokens.add(endSequenceToken); if (tokens.size() > maxLength) { tokens = tokens.subList(0, maxLength-1); - tokens.add(TOKEN_SEP); + tokens.add(endSequenceToken); } return tokens; } diff --git a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def index 14d953eeef9..ef42d81e1fe 100644 --- a/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def +++ b/model-integration/src/main/resources/configdefinitions/embedding.bert-base-embedder.def @@ -17,6 +17,10 @@ transformerInputIds string default=input_ids transformerAttentionMask string default=attention_mask transformerTokenTypeIds string default=token_type_ids +# special token ids +transformerStartSequenceToken int default=101 +transformerEndSequenceToken int default=102 + # Output name transformerOutput string default=output_0 |