summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorChristophe Jolif <cjolif@gmail.com>2022-09-29 08:54:52 +0200
committerChristophe Jolif <cjolif@gmail.com>2022-10-20 13:59:59 +0200
commite39231d1f72080ebb6232f70bd5b388ba83232ec (patch)
tree7903558a30a297e7d8db1f49a8cbbed8f123393e /model-integration
parent8a91e259064b40f2f5fde5f8233c9892446d105e (diff)
support models without tokenTypeIds (like DistilBERT)
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java15
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java36
-rw-r--r--model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnxbin0 -> 28580 bytes
-rw-r--r--model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py49
4 files changed, 97 insertions, 3 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
index d4a93999dff..002350ce3cf 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -67,8 +67,11 @@ public class BertBaseEmbedder implements Embedder {
Map<String, TensorType> inputs = evaluator.getInputInfo();
validateName(inputs, inputIdsName, "input");
validateName(inputs, attentionMaskName, "input");
- validateName(inputs, tokenTypeIdsName, "input");
-
+ // some BERT inspired models such as DistilBERT do not have token_type_ids input
+ // one can explicitly declare this is such model by setting that config to empty string
+ if (!"".equals(tokenTypeIdsName)) {
+ validateName(inputs, tokenTypeIdsName, "input");
+ }
Map<String, TensorType> outputs = evaluator.getOutputInfo();
validateName(outputs, outputName, "output");
}
@@ -102,9 +105,15 @@ public class BertBaseEmbedder implements Embedder {
Tensor attentionMask = createAttentionMask(inputSequence);
Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
- Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ Map<String, Tensor> inputs;
+ if (!"".equals(tokenTypeIdsName)) {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
attentionMaskName, attentionMask.expand("d0"),
tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ } else {
+ inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"));
+ }
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index dbcae24b28f..73359736536 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -9,9 +9,11 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.lang.IllegalArgumentException;
import java.util.List;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
import static org.junit.Assume.assumeTrue;
public class BertBaseEmbedderTest {
@@ -35,4 +37,38 @@ public class BertBaseEmbedderTest {
assertEquals(embedding, expected);
}
+ @Test
+ public void testEmbedderWithoutTokenTypeIdsName() {
+ String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
+ String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+
+ BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
+ builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ builder.transformerTokenTypeIds("");
+ BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+
+ TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
+ List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
+ Tensor embedding = embedder.embedTokens(tokens, destType);
+
+ Tensor expected = Tensor.from("tensor<float>(x[7]):[0.10873623, 0.56411576, 0.6044973, -0.4819714, 0.7519982, -0.83261716, 0.30430704]");
+ assertEquals(embedding, expected);
+ }
+
+ @Test
+ public void testEmbedderWithoutTokenTypeIdsNameButWithConfig() {
+ String vocabPath = "src/test/models/onnx/transformer/dummy_vocab.txt";
+ String modelPath = "src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx";
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable(modelPath));
+
+ BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
+ builder.tokenizerVocab(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ // we did not configured BertBaseEmbedder to accept missing token type ids
+ // so we expect ctor to throw
+ assertThrows(IllegalArgumentException.class, () -> { new BertBaseEmbedder(builder.build()); });
+ }
+
}
diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx
new file mode 100644
index 00000000000..927f1c607c4
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py
new file mode 100644
index 00000000000..4c5f5ebe330
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer_without_type_ids.py
@@ -0,0 +1,49 @@
+import torch
+import torch.onnx
+import torch.nn as nn
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+
+
+class TransformerModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2):
+ super(TransformerModel, self).__init__()
+ self.encoder = nn.Embedding(vocab_size, emb_size)
+ encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout)
+ self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
+
+ def forward(self, tokens, attention_mask):
+ src = self.encoder((tokens * attention_mask))
+ output = self.transformer_encoder(src)
+ return output
+
+
+def main():
+ vocabulary_size = 20
+ embedding_size = 16
+ hidden_dim_size = 32
+ num_layers = 2
+ num_heads = 2
+ model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers)
+
+ # Omit training - just export randomly initialized network
+
+ tokens = torch.LongTensor([[1,2,3,4,5]])
+ attention_mask = torch.LongTensor([[1,1,1,1,1]])
+ token_type_ids = torch.LongTensor([[0,0,0,0,0]])
+ torch.onnx.export(model,
+ (tokens, attention_mask),
+ "dummy_transformer_without_type_ids.onnx",
+ input_names = ["input_ids", "attention_mask"],
+ output_names = ["output_0"],
+ dynamic_axes = {
+ "input_ids": {0:"batch", 1:"tokens"},
+ "attention_mask": {0:"batch", 1:"tokens"},
+ "output_0": {0:"batch", 1:"tokens"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+