summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2022-04-01 09:42:13 +0200
committerLester Solbakken <lesters@oath.com>2022-04-01 09:42:13 +0200
commit4cc5a56f7006c40b059855f345ded365ace8550c (patch)
tree6142647b7fd6240e64d1ef21ca74ec50e920f004 /model-integration
parent02151a078dd5f45defaa42b19bf127bc2c999944 (diff)
Revert "Revert "Add bert base embedder""
This reverts commit 80556883dcdc350e2b5cfebde8ef482baeb36872.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml12
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java165
-rw-r--r--model-integration/src/main/resources/configdefinitions/bert-base-embedder.def27
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java29
-rw-r--r--model-integration/src/test/models/onnx/transformer/dummy_transformer.onnxbin0 -> 27895 bytes
-rw-r--r--model-integration/src/test/models/onnx/transformer/dummy_transformer.py52
-rw-r--r--model-integration/src/test/models/onnx/transformer/dummy_vocab.txt35
7 files changed, 320 insertions, 0 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 41139394690..d064a3ff709 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -45,6 +45,18 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>linguistics</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>linguistics-components</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
<dependency>
<groupId>com.google.guava</groupId>
diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
new file mode 100644
index 00000000000..42e3d653359
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java
@@ -0,0 +1,165 @@
+package ai.vespa.embedding;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.wordpiece.WordPieceEmbedder;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+
+/**
+ * A BERT Base compatible embedder. This embedder uses a WordPiece embedder to
+ * produce a token sequence that is input to a transformer model. A BERT base
+ * compatible transformer model must have three inputs:
+ *
+ * - A token sequence (input_ids)
+ * - An attention mask (attention_mask)
+ * - Token types for cross encoding (token_type_ids)
+ *
+ * See bert-base-embedder.def for configurable parameters.
+ *
+ * @author lesters
+ */
+public class BertBaseEmbedder implements Embedder {
+
+ private final static int TOKEN_CLS = 101; // [CLS]
+ private final static int TOKEN_SEP = 102; // [SEP]
+
+ private final int maxTokens;
+ private final String inputIdsName;
+ private final String attentionMaskName;
+ private final String tokenTypeIdsName;
+ private final String outputName;
+ private final String poolingStrategy;
+
+ private final WordPieceEmbedder tokenizer;
+ private final OnnxEvaluator evaluator;
+
+ @Inject
+ public BertBaseEmbedder(BertBaseEmbedderConfig config) {
+ maxTokens = config.transformerMaxTokens();
+ inputIdsName = config.transformerInputIds();
+ attentionMaskName = config.transformerAttentionMask();
+ tokenTypeIdsName = config.transformerTokenTypeIds();
+ outputName = config.transformerOutput();
+ poolingStrategy = config.poolingStrategy().toString();
+
+ OnnxEvaluatorOptions options = new OnnxEvaluatorOptions();
+ options.setExecutionMode(config.onnxExecutionMode().toString());
+ options.setInterOpThreads(modifyThreadCount(config.onnxInterOpThreads()));
+ options.setIntraOpThreads(modifyThreadCount(config.onnxIntraOpThreads()));
+
+ // Todo: use either file or url
+ tokenizer = new WordPieceEmbedder.Builder(config.tokenizerVocabUrl().getAbsolutePath()).build();
+ evaluator = new OnnxEvaluator(config.transformerModelUrl().getAbsolutePath(), options);
+
+ validateModel();
+ }
+
+ private void validateModel() {
+ Map<String, TensorType> inputs = evaluator.getInputInfo();
+ validateName(inputs, inputIdsName, "input");
+ validateName(inputs, attentionMaskName, "input");
+ validateName(inputs, tokenTypeIdsName, "input");
+
+ Map<String, TensorType> outputs = evaluator.getOutputInfo();
+ validateName(outputs, outputName, "output");
+ }
+
+ private void validateName(Map<String, TensorType> types, String name, String type) {
+ if ( ! types.containsKey(name)) {
+ throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
+ "Model contains: " + String.join(",", types.keySet()));
+ }
+ }
+
+ @Override
+ public List<Integer> embed(String text, Context context) {
+ return tokenizer.embed(text, context);
+ }
+
+ @Override
+ public Tensor embed(String text, Context context, TensorType type) {
+ if (type.dimensions().size() != 1) {
+ throw new IllegalArgumentException("Error in embedding to type '" + type + "': should only have one dimension.");
+ }
+ if (!type.dimensions().get(0).isIndexed()) {
+ throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed.");
+ }
+ List<Integer> tokens = embedWithSeperatorTokens(text, context, maxTokens);
+ return embedTokens(tokens, type);
+ }
+
+ Tensor embedTokens(List<Integer> tokens, TensorType type) {
+ Tensor inputSequence = createTensorRepresentation(tokens, "d1");
+ Tensor attentionMask = createAttentionMask(inputSequence);
+ Tensor tokenTypeIds = createTokenTypeIds(inputSequence);
+
+ Map<String, Tensor> inputs = Map.of(inputIdsName, inputSequence.expand("d0"),
+ attentionMaskName, attentionMask.expand("d0"),
+ tokenTypeIdsName, tokenTypeIds.expand("d0"));
+ Map<String, Tensor> outputs = evaluator.evaluate(inputs);
+
+ Tensor tokenEmbeddings = outputs.get(outputName);
+
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ if (poolingStrategy.equals("mean")) { // average over tokens
+ Tensor summedEmbeddings = tokenEmbeddings.sum("d1");
+ Tensor summedAttentionMask = attentionMask.expand("d0").sum("d1");
+ Tensor averaged = summedEmbeddings.join(summedAttentionMask, (x, y) -> x / y);
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(averaged.get(TensorAddress.of(0,i)), i);
+ }
+ } else { // CLS - use first token
+ for (int i = 0; i < type.dimensions().get(0).size().get(); i++) {
+ builder.cell(tokenEmbeddings.get(TensorAddress.of(0,0,i)), i);
+ }
+ }
+ return builder.build();
+ }
+
+ private List<Integer> embedWithSeperatorTokens(String text, Context context, int maxLength) {
+ List<Integer> tokens = new ArrayList<>();
+ tokens.add(TOKEN_CLS);
+ tokens.addAll(embed(text, context));
+ tokens.add(TOKEN_SEP);
+ if (tokens.size() > maxLength) {
+ tokens = tokens.subList(0, maxLength-1);
+ tokens.add(TOKEN_SEP);
+ }
+ return tokens;
+ }
+
+ private IndexedTensor createTensorRepresentation(List<Integer> input, String dimension) {
+ int size = input.size();
+ TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
+ for (int i = 0; i < size; ++i) {
+ builder.cell(input.get(i), i);
+ }
+ return builder.build();
+ }
+
+ private static Tensor createAttentionMask(Tensor d) {
+ return d.map((x) -> x > 0 ? 1:0);
+ }
+
+ private static Tensor createTokenTypeIds(Tensor d) {
+ return d.map((x) -> x > 0 ? 0:0);
+ }
+
+ private int modifyThreadCount(int numThreads) {
+ if (numThreads >= 0)
+ return numThreads;
+ return Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / (-1 * numThreads)));
+ }
+
+}
diff --git a/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def b/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def
new file mode 100644
index 00000000000..7e3ff151466
--- /dev/null
+++ b/model-integration/src/main/resources/configdefinitions/bert-base-embedder.def
@@ -0,0 +1,27 @@
+package=ai.vespa.embedding
+
+# Transformer model settings
+transformerModelUrl url
+
+# Max length of token sequence model can handle
+transformerMaxTokens int default=384
+
+# Pooling strategy
+poolingStrategy enum { cls, mean } default=mean
+
+# Input names
+transformerInputIds string default=input_ids
+transformerAttentionMask string default=attention_mask
+transformerTokenTypeIds string default=token_type_ids
+
+# Output name
+transformerOutput string default=output_0
+
+# Settings for ONNX model evaluation
+onnxExecutionMode enum { parallel, sequential } default=sequential
+onnxInterOpThreads int default=1
+onnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
+
+# Settings for wordpiece tokenizer
+tokenizerVocabUrl url
+
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
new file mode 100644
index 00000000000..519f24795ca
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -0,0 +1,29 @@
+package ai.vespa.embedding;
+
+import com.yahoo.config.UrlReference;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+public class BertBaseEmbedderTest {
+
+ @Test
+ public void testEmbedder() {
+ BertBaseEmbedderConfig.Builder builder = new BertBaseEmbedderConfig.Builder();
+ builder.tokenizerVocabUrl(new UrlReference("src/test/models/onnx/transformer/dummy_vocab.txt"));
+ builder.transformerModelUrl(new UrlReference("src/test/models/onnx/transformer/dummy_transformer.onnx"));
+ BertBaseEmbedder embedder = new BertBaseEmbedder(builder.build());
+
+ TensorType destType = TensorType.fromSpec("tensor<float>(x[7])");
+ List<Integer> tokens = List.of(1,2,3,4,5); // use random tokens instead of invoking the tokenizer
+ Tensor embedding = embedder.embedTokens(tokens, destType);
+
+ Tensor expected = Tensor.from("tensor<float>(x[7]):[-0.6178509, -0.8135831, 0.34416935, 0.3912577, -0.13068882, 2.5897025E-4, -0.18638384]");
+ assertEquals(embedding, expected);
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx b/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx
new file mode 100644
index 00000000000..2101beec786
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/transformer/dummy_transformer.py b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py
new file mode 100644
index 00000000000..1028035d7c0
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/dummy_transformer.py
@@ -0,0 +1,52 @@
+# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import torch
+import torch.onnx
+import torch.nn as nn
+from torch.nn import TransformerEncoder, TransformerEncoderLayer
+
+
+class TransformerModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, num_heads, hidden_dim_size, num_layers, dropout=0.2):
+ super(TransformerModel, self).__init__()
+ self.encoder = nn.Embedding(vocab_size, emb_size)
+ encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout)
+ self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
+
+ def forward(self, tokens, attention_mask, token_type_ids):
+ src = self.encoder((tokens * attention_mask) + token_type_ids)
+ output = self.transformer_encoder(src)
+ return output
+
+
+def main():
+ vocabulary_size = 20
+ embedding_size = 16
+ hidden_dim_size = 32
+ num_layers = 2
+ num_heads = 2
+ model = TransformerModel(vocabulary_size, embedding_size, num_heads, hidden_dim_size, num_layers)
+
+ # Omit training - just export randomly initialized network
+
+ tokens = torch.LongTensor([[1,2,3,4,5]])
+ attention_mask = torch.LongTensor([[1,1,1,1,1]])
+ token_type_ids = torch.LongTensor([[0,0,0,0,0]])
+ torch.onnx.export(model,
+ (tokens, attention_mask, token_type_ids),
+ "dummy_transformer.onnx",
+ input_names = ["input_ids", "attention_mask", "token_type_ids"],
+ output_names = ["output_0"],
+ dynamic_axes = {
+ "input_ids": {0:"batch", 1:"tokens"},
+ "attention_mask": {0:"batch", 1:"tokens"},
+ "token_type_ids": {0:"batch", 1:"tokens"},
+ "output_0": {0:"batch", 1:"tokens"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt b/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt
new file mode 100644
index 00000000000..7dc0c6ecb6e
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/dummy_vocab.txt
@@ -0,0 +1,35 @@
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+x
+y
+z
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9 \ No newline at end of file