diff options
author | Lester Solbakken <lesters@oath.com> | 2023-02-18 11:48:09 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2023-02-18 11:48:09 +0100 |
commit | 3c630c78c43cd3ad91cda1836ee47eff0fb4d0bf (patch) | |
tree | d9aebd5a4b852366e93470b518dd938d1dad184b /model-integration/src/test | |
parent | 654fc45d0ae303bd7d171d5754cc3e813f8a3b34 (diff) |
Add initial text generator component
Diffstat (limited to 'model-integration/src/test')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java | 2 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java | 36 | ||||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model | bin | 0 -> 400869 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/random_decoder.onnx | bin | 0 -> 735210 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/random_encoder.onnx | bin | 0 -> 345336 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/random_llm.py | 82 |
6 files changed, 118 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index 73359736536..b06a54d68bb 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,9 +1,7 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import com.yahoo.config.FileReference; import com.yahoo.config.ModelReference; -import com.yahoo.config.UrlReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java new file mode 100644 index 00000000000..733430aa10d --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java @@ -0,0 +1,36 @@ +package ai.vespa.llm; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import com.yahoo.config.ModelReference; +import com.yahoo.llm.GeneratorConfig; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; + +public class GeneratorTest { + + @Test + public void testGenerator() { + String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model"; + String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx"; + String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx"; + assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath)); + + GeneratorConfig.Builder builder = new GeneratorConfig.Builder(); + builder.tokenizerModel(ModelReference.valueOf(vocabPath)); + builder.encoderModel(ModelReference.valueOf(encoderModelPath)); + builder.decoderModel(ModelReference.valueOf(decoderModelPath)); + Generator generator = new Generator(builder.build()); + + GeneratorOptions options = new GeneratorOptions(); + options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY); + options.setMaxLength(10); + + String prompt = "generate some random text"; + String result = generator.generate(prompt, options); + + assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result); + } + +} diff --git a/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model Binary files differnew file mode 100644 index 00000000000..89f93ef3517 --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model diff --git a/model-integration/src/test/models/onnx/llm/random_decoder.onnx b/model-integration/src/test/models/onnx/llm/random_decoder.onnx Binary files differnew file mode 100644 index 00000000000..a8c5f18ddf2 --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/random_decoder.onnx diff --git a/model-integration/src/test/models/onnx/llm/random_encoder.onnx b/model-integration/src/test/models/onnx/llm/random_encoder.onnx Binary files differnew file mode 100644 index 00000000000..a9100fcd6af --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/random_encoder.onnx diff --git a/model-integration/src/test/models/onnx/llm/random_llm.py b/model-integration/src/test/models/onnx/llm/random_llm.py new file mode 100644 index 00000000000..722906fc48b --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/random_llm.py @@ -0,0 +1,82 @@ +import torch +import torch.onnx +import torch.nn as nn +from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoder, TransformerDecoderLayer + + +class EncoderModel(nn.Module): + def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True): + super(EncoderModel, self).__init__() + self.embedding = nn.Embedding(vocab_size, emb_size) + encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout, batch_first=batch_first) + self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) + + def forward(self, tokens, attention_mask): + src = self.embedding(tokens * attention_mask) # N, S, E + output = self.transformer_encoder(src) + return output + + +class DecoderModel(nn.Module): + def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True): + super(DecoderModel, self).__init__() + self.embedding = nn.Embedding(vocab_size, emb_size) + decoder_layers = nn.TransformerDecoderLayer(emb_size, num_heads, hidden_dim_size, batch_first=batch_first) + self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers) + self.linear = nn.Linear(emb_size, vocab_size) + + def forward(self, tokens, attention_mask, encoder_hidden_state): + tgt = self.embedding(tokens) # N, T, E + out = self.transformer_decoder(tgt, encoder_hidden_state, memory_mask=attention_mask) + logits = self.linear(out) + return logits + + +def main(): + vocabulary_size = 10000 + embedding_size = 8 + hidden_dim_size = 16 + num_heads = 1 + num_layers = 1 + + encoder = EncoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers) + decoder = DecoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers) + + # Omit training - just export randomly initialized network + + tokens = torch.LongTensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.LongTensor([[1, 1, 1, 1, 1]]) + + torch.onnx.export(encoder, + (tokens, attention_mask), + "random_encoder.onnx", + input_names=["input_ids", "attention_mask"], + output_names=["last_hidden_state"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "tokens"}, + "attention_mask": {0: "batch", 1: "tokens"}, + "last_hidden_state": {0: "batch", 1: "tokens"}, + }, + opset_version=12) + + last_hidden_state = encoder.forward(tokens, attention_mask) + tokens = torch.LongTensor([[0]]) #1, 2]]) + + torch.onnx.export(decoder, + (tokens, attention_mask.float(), last_hidden_state), + "random_decoder.onnx", + input_names=["input_ids", "encoder_attention_mask", "encoder_hidden_states"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "target_tokens"}, + "encoder_attention_mask": {0: "batch", 1: "source_tokens"}, + "encoder_hidden_states": {0: "batch", 1: "source_tokens"}, + "logits": {0: "batch", 1: "target_tokens"}, + }, + opset_version=12) + + +if __name__ == "__main__": + main() + + |