summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2023-02-18 11:48:09 +0100
committerLester Solbakken <lesters@oath.com>2023-02-18 11:48:09 +0100
commit3c630c78c43cd3ad91cda1836ee47eff0fb4d0bf (patch)
treed9aebd5a4b852366e93470b518dd938d1dad184b /model-integration/src/test
parent654fc45d0ae303bd7d171d5754cc3e813f8a3b34 (diff)
Add initial text generator component
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java36
-rw-r--r--model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.modelbin0 -> 400869 bytes
-rw-r--r--model-integration/src/test/models/onnx/llm/random_decoder.onnxbin0 -> 735210 bytes
-rw-r--r--model-integration/src/test/models/onnx/llm/random_encoder.onnxbin0 -> 345336 bytes
-rw-r--r--model-integration/src/test/models/onnx/llm/random_llm.py82
6 files changed, 118 insertions, 2 deletions
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index 73359736536..b06a54d68bb 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,9 +1,7 @@
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import com.yahoo.config.FileReference;
import com.yahoo.config.ModelReference;
-import com.yahoo.config.UrlReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
new file mode 100644
index 00000000000..733430aa10d
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
@@ -0,0 +1,36 @@
+package ai.vespa.llm;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.config.ModelReference;
+import com.yahoo.llm.GeneratorConfig;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+
+public class GeneratorTest {
+
+ @Test
+ public void testGenerator() {
+ String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
+ String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
+ String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath));
+
+ GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
+ builder.tokenizerModel(ModelReference.valueOf(vocabPath));
+ builder.encoderModel(ModelReference.valueOf(encoderModelPath));
+ builder.decoderModel(ModelReference.valueOf(decoderModelPath));
+ Generator generator = new Generator(builder.build());
+
+ GeneratorOptions options = new GeneratorOptions();
+ options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
+ options.setMaxLength(10);
+
+ String prompt = "generate some random text";
+ String result = generator.generate(prompt, options);
+
+ assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model
new file mode 100644
index 00000000000..89f93ef3517
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model
Binary files differ
diff --git a/model-integration/src/test/models/onnx/llm/random_decoder.onnx b/model-integration/src/test/models/onnx/llm/random_decoder.onnx
new file mode 100644
index 00000000000..a8c5f18ddf2
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/random_decoder.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/llm/random_encoder.onnx b/model-integration/src/test/models/onnx/llm/random_encoder.onnx
new file mode 100644
index 00000000000..a9100fcd6af
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/random_encoder.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/llm/random_llm.py b/model-integration/src/test/models/onnx/llm/random_llm.py
new file mode 100644
index 00000000000..722906fc48b
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/random_llm.py
@@ -0,0 +1,82 @@
+import torch
+import torch.onnx
+import torch.nn as nn
+from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoder, TransformerDecoderLayer
+
+
+class EncoderModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
+ super(EncoderModel, self).__init__()
+ self.embedding = nn.Embedding(vocab_size, emb_size)
+ encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout, batch_first=batch_first)
+ self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
+
+ def forward(self, tokens, attention_mask):
+ src = self.embedding(tokens * attention_mask) # N, S, E
+ output = self.transformer_encoder(src)
+ return output
+
+
+class DecoderModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
+ super(DecoderModel, self).__init__()
+ self.embedding = nn.Embedding(vocab_size, emb_size)
+ decoder_layers = nn.TransformerDecoderLayer(emb_size, num_heads, hidden_dim_size, batch_first=batch_first)
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers)
+ self.linear = nn.Linear(emb_size, vocab_size)
+
+ def forward(self, tokens, attention_mask, encoder_hidden_state):
+ tgt = self.embedding(tokens) # N, T, E
+ out = self.transformer_decoder(tgt, encoder_hidden_state, memory_mask=attention_mask)
+ logits = self.linear(out)
+ return logits
+
+
+def main():
+ vocabulary_size = 10000
+ embedding_size = 8
+ hidden_dim_size = 16
+ num_heads = 1
+ num_layers = 1
+
+ encoder = EncoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)
+ decoder = DecoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)
+
+ # Omit training - just export randomly initialized network
+
+ tokens = torch.LongTensor([[1, 2, 3, 4, 5]])
+ attention_mask = torch.LongTensor([[1, 1, 1, 1, 1]])
+
+ torch.onnx.export(encoder,
+ (tokens, attention_mask),
+ "random_encoder.onnx",
+ input_names=["input_ids", "attention_mask"],
+ output_names=["last_hidden_state"],
+ dynamic_axes={
+ "input_ids": {0: "batch", 1: "tokens"},
+ "attention_mask": {0: "batch", 1: "tokens"},
+ "last_hidden_state": {0: "batch", 1: "tokens"},
+ },
+ opset_version=12)
+
+ last_hidden_state = encoder.forward(tokens, attention_mask)
+ tokens = torch.LongTensor([[0]]) #1, 2]])
+
+ torch.onnx.export(decoder,
+ (tokens, attention_mask.float(), last_hidden_state),
+ "random_decoder.onnx",
+ input_names=["input_ids", "encoder_attention_mask", "encoder_hidden_states"],
+ output_names=["logits"],
+ dynamic_axes={
+ "input_ids": {0: "batch", 1: "target_tokens"},
+ "encoder_attention_mask": {0: "batch", 1: "source_tokens"},
+ "encoder_hidden_states": {0: "batch", 1: "source_tokens"},
+ "logits": {0: "batch", 1: "target_tokens"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+