diff options
author | Lester Solbakken <lesters@oath.com> | 2023-02-18 11:48:09 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2023-02-18 11:48:09 +0100 |
commit | 3c630c78c43cd3ad91cda1836ee47eff0fb4d0bf (patch) | |
tree | d9aebd5a4b852366e93470b518dd938d1dad184b | |
parent | 654fc45d0ae303bd7d171d5754cc3e813f8a3b34 (diff) |
Add initial text generator component
-rw-r--r-- | model-integration/src/main/java/ai/vespa/llm/Generator.java | 226 | ||||
-rw-r--r-- | model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java | 34 | ||||
-rw-r--r-- | model-integration/src/main/resources/configdefinitions/llm.generator.def | 32 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java | 2 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java | 36 | ||||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model | bin | 0 -> 400869 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/random_decoder.onnx | bin | 0 -> 735210 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/random_encoder.onnx | bin | 0 -> 345336 bytes | |||
-rw-r--r-- | model-integration/src/test/models/onnx/llm/random_llm.py | 82 |
9 files changed, 410 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java new file mode 100644 index 00000000000..ed231a5e94c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java @@ -0,0 +1,226 @@ +package ai.vespa.llm; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import com.yahoo.component.annotation.Inject; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.sentencepiece.SentencePieceEmbedder; +import com.yahoo.llm.GeneratorConfig; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.PartialAddress; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** +* A text generator based on language models (LLMs). By configuring a + * sentencepience tokenizer and models for encoding and decoding, this + * component generates text based on the given prompt. + * + * See llm.generator.def for configurable parameters. + * + * @author lesters + */ +public class Generator { + + private final static int TOKEN_EOS = 1; // end of sequence + + private final static String BATCH_DIMENSION = "d0"; + private final static String SEQUENCE_DIMENSION = "d1"; + + private final int tokenizerMaxTokens; + private final String encoderInputIdsName; + private final String encoderAttentionMaskName; + private final String encoderOutputName; + private final String decoderInputIdsName; + private final String decoderAttentionMaskName; + private final String decoderEncoderHiddenStateName; + private final String decoderOutputName; + + private final SentencePieceEmbedder tokenizer; + private final OnnxEvaluator encoder; + private final OnnxEvaluator decoder; + + @Inject + public Generator(GeneratorConfig config) { + // Set up tokenizer + tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build(); + tokenizerMaxTokens = config.tokenizerMaxTokens(); + + // Set up encoder + encoderInputIdsName = config.encoderModelInputIdsName(); + encoderAttentionMaskName = config.encoderModelAttentionMaskName(); + encoderOutputName = config.encoderModelOutputName(); + + OnnxEvaluatorOptions encoderOptions = new OnnxEvaluatorOptions(); + encoderOptions.setExecutionMode(config.encoderOnnxExecutionMode().toString()); + encoderOptions.setInterOpThreads(modifyThreadCount(config.encoderOnnxInterOpThreads())); + encoderOptions.setIntraOpThreads(modifyThreadCount(config.encoderOnnxIntraOpThreads())); + + encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions); + + // Set up decoder + decoderInputIdsName = config.decoderModelInputIdsName(); + decoderAttentionMaskName = config.decoderModelAttentionMaskName(); + decoderEncoderHiddenStateName = config.decoderModelEncoderHiddenStateName(); + decoderOutputName = config.decoderModelOutputName(); + + OnnxEvaluatorOptions decoderOptions = new OnnxEvaluatorOptions(); + decoderOptions.setExecutionMode(config.decoderOnnxExecutionMode().toString()); + decoderOptions.setInterOpThreads(modifyThreadCount(config.decoderOnnxInterOpThreads())); + decoderOptions.setIntraOpThreads(modifyThreadCount(config.decoderOnnxIntraOpThreads())); + + decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions); + + validateModels(); + } + + /** + * Generates text by evaluating an encoder model to encode the prompt, and + * repeatedly evaluating a decoding model to generate tokens until some + * stopping criteria has been met. + * + * @param prompt the prompt to generate text from + * @param options options for text generation + * @return a text generated from the prompt + */ + public String generate(String prompt, GeneratorOptions options) { + return switch (options.getSearchMethod()) { + case GREEDY -> generateGreedy(prompt, options); + default -> generateNotImplemented(options); + }; + } + + public String generate(String prompt) { + return generate(prompt, new GeneratorOptions()); + } + + private String generateNotImplemented(GeneratorOptions options) { + throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented"); + } + + private String generateGreedy(String prompt, GeneratorOptions options) { + var generatedTokens = new ArrayList<Integer>(); + generatedTokens.add(0); // Or target tokens + + // Tokenize + var inputTokens = tokenize(prompt); // Or source tokens + + // Evaluate encoder + var encoderInput = createTensorRepresentation(inputTokens, SEQUENCE_DIMENSION); + var encoderMask = createAttentionMask(encoderInput).expand(BATCH_DIMENSION); + var encoderOutput = evaluateEncoder(encoderInput.expand(BATCH_DIMENSION), encoderMask); + + // Greedy search just grabs the next most probable token + while (generatedTokens.size() < options.getMaxLength()) { // Todo: add stopping criteria + var decoderInput = createTensorRepresentation(generatedTokens, SEQUENCE_DIMENSION).expand(BATCH_DIMENSION); + var logits = evaluateDecoder(decoderInput, encoderMask, encoderOutput); + var nextToken = findMostProbableToken(logits, generatedTokens.size()-1, BATCH_DIMENSION, SEQUENCE_DIMENSION); + generatedTokens.add(nextToken); + } + + return detokenize(generatedTokens); + } + + private Tensor evaluateEncoder(Tensor input, Tensor mask) { + var encoderInputs = Map.of(encoderInputIdsName, input, + encoderAttentionMaskName, mask); + return encoder.evaluate(encoderInputs, encoderOutputName); + } + + private IndexedTensor evaluateDecoder(Tensor input, Tensor encoderMask, Tensor encoderOutput) { + var inputs = Map.of(decoderInputIdsName, input, + decoderAttentionMaskName, encoderMask, // yes, encoder's attention mask + decoderEncoderHiddenStateName, encoderOutput); + var output = decoder.evaluate(inputs, decoderOutputName); + if ( ! (output instanceof IndexedTensor indexedTensor)) { + throw new IllegalArgumentException("Output of decoder model is not an 'IndexedTensor'"); + } + return indexedTensor; + } + + /** + * Given a tensor 'logits' with 3 dimensions: batch, sequence, and vocabulary + * find the value in the vocabulary dimension with highest score for the given + * token in the sequence + */ + private static int findMostProbableToken(IndexedTensor logits, int seqIndex, String batchDim, String seqDim) { + if (logits.type().rank() != 3) { + throw new IllegalArgumentException("Expected a tensor with rank 3: batch, sequence, and vocabulary size. " + + "Got: " + logits.type()); + } + var iterator = logits.cellIterator(new PartialAddress.Builder(2). + add(batchDim, 0). + add(seqDim, seqIndex).build(), + DimensionSizes.of(logits.type())); + var maxVal = iterator.next().getValue(); + int maxIndex = 0; + for (int i = 1; iterator.hasNext(); ++i) { + var val = iterator.next().getValue(); + if (val >= maxVal && i != TOKEN_EOS) { + maxVal = val; + maxIndex = i; + } + } + return maxIndex; + } + + private List<Integer> tokenize(String text) { + var tokens = tokenizer.embed(text, new Embedder.Context("tokenizer")); + tokens = tokens.size() >= tokenizerMaxTokens ? tokens.subList(0,tokenizerMaxTokens-1): tokens; + tokens.add(TOKEN_EOS); + return tokens; + } + + private String detokenize(List<Integer> tokens) { + return tokenizer.decode(tokens, new Embedder.Context("tokenizer"), true); + } + + private static Tensor createTensorRepresentation(List<Integer> tokens, String dimension) { + var size = tokens.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; ++i) { + builder.cell(tokens.get(i), i); + } + return builder.build(); + } + + private static Tensor createAttentionMask(Tensor d) { + return d.map((x) -> x > 0 ? 1:0); + } + + private void validateModels() { + Map<String, TensorType> inputs = encoder.getInputInfo(); + validateName(inputs, encoderInputIdsName, "input"); + validateName(inputs, encoderAttentionMaskName, "input"); + + Map<String, TensorType> outputs = encoder.getOutputInfo(); + validateName(outputs, encoderOutputName, "output"); + + inputs = decoder.getInputInfo(); + validateName(inputs, decoderInputIdsName, "input"); + validateName(inputs, decoderAttentionMaskName, "input"); + validateName(inputs, decoderEncoderHiddenStateName, "input"); + + outputs = decoder.getOutputInfo(); + validateName(outputs, decoderOutputName, "output"); + } + + private void validateName(Map<String, TensorType> types, String name, String type) { + if ( ! types.containsKey(name)) { + throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + + "Model contains: " + String.join(",", types.keySet())); + } + } + + private int modifyThreadCount(int numThreads) { + if (numThreads >= 0) + return numThreads; + return Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / (-1 * numThreads))); + } +} diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java new file mode 100644 index 00000000000..743bb7c2f27 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java @@ -0,0 +1,34 @@ +package ai.vespa.llm; + +public class GeneratorOptions { + + public enum SearchMethod { + GREEDY, + CONTRASTIVE, + BEAM, + SAMPLE, + } + + private SearchMethod searchMethod = SearchMethod.GREEDY; + private int maxLength = 20; + + public SearchMethod getSearchMethod() { + return searchMethod; + } + + public GeneratorOptions setSearchMethod(SearchMethod searchMethod) { + this.searchMethod = searchMethod; + return this; + } + + public int getMaxLength() { + return maxLength; + } + + public GeneratorOptions setMaxLength(int maxLength) { + this.maxLength = maxLength; + return this; + } + + +} diff --git a/model-integration/src/main/resources/configdefinitions/llm.generator.def b/model-integration/src/main/resources/configdefinitions/llm.generator.def new file mode 100644 index 00000000000..478daad6ede --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/llm.generator.def @@ -0,0 +1,32 @@ +namespace=llm + +# SentencePiece tokenizer +tokenizerModel model +tokenizerMaxTokens int default=1000 + +# +# The encoder model +# +encoderModel model +encoderModelInputIdsName string default=input_ids +encoderModelAttentionMaskName string default=attention_mask +encoderModelOutputName string default=last_hidden_state + +encoderOnnxExecutionMode enum { parallel, sequential } default=sequential +encoderOnnxInterOpThreads int default=1 +encoderOnnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n +# enable GPU? + +# +# The decoder model +# +decoderModel model +decoderModelInputIdsName string default=input_ids +decoderModelAttentionMaskName string default=encoder_attention_mask +decoderModelEncoderHiddenStateName string default=encoder_hidden_states +decoderModelOutputName string default=logits + +decoderOnnxExecutionMode enum { parallel, sequential } default=sequential +decoderOnnxInterOpThreads int default=1 +decoderOnnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n +# enable GPU? diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index 73359736536..b06a54d68bb 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -1,9 +1,7 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; -import com.yahoo.config.FileReference; import com.yahoo.config.ModelReference; -import com.yahoo.config.UrlReference; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java new file mode 100644 index 00000000000..733430aa10d --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java @@ -0,0 +1,36 @@ +package ai.vespa.llm; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import com.yahoo.config.ModelReference; +import com.yahoo.llm.GeneratorConfig; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; + +public class GeneratorTest { + + @Test + public void testGenerator() { + String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model"; + String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx"; + String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx"; + assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath)); + + GeneratorConfig.Builder builder = new GeneratorConfig.Builder(); + builder.tokenizerModel(ModelReference.valueOf(vocabPath)); + builder.encoderModel(ModelReference.valueOf(encoderModelPath)); + builder.decoderModel(ModelReference.valueOf(decoderModelPath)); + Generator generator = new Generator(builder.build()); + + GeneratorOptions options = new GeneratorOptions(); + options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY); + options.setMaxLength(10); + + String prompt = "generate some random text"; + String result = generator.generate(prompt, options); + + assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result); + } + +} diff --git a/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model Binary files differnew file mode 100644 index 00000000000..89f93ef3517 --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model diff --git a/model-integration/src/test/models/onnx/llm/random_decoder.onnx b/model-integration/src/test/models/onnx/llm/random_decoder.onnx Binary files differnew file mode 100644 index 00000000000..a8c5f18ddf2 --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/random_decoder.onnx diff --git a/model-integration/src/test/models/onnx/llm/random_encoder.onnx b/model-integration/src/test/models/onnx/llm/random_encoder.onnx Binary files differnew file mode 100644 index 00000000000..a9100fcd6af --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/random_encoder.onnx diff --git a/model-integration/src/test/models/onnx/llm/random_llm.py b/model-integration/src/test/models/onnx/llm/random_llm.py new file mode 100644 index 00000000000..722906fc48b --- /dev/null +++ b/model-integration/src/test/models/onnx/llm/random_llm.py @@ -0,0 +1,82 @@ +import torch +import torch.onnx +import torch.nn as nn +from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoder, TransformerDecoderLayer + + +class EncoderModel(nn.Module): + def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True): + super(EncoderModel, self).__init__() + self.embedding = nn.Embedding(vocab_size, emb_size) + encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout, batch_first=batch_first) + self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers) + + def forward(self, tokens, attention_mask): + src = self.embedding(tokens * attention_mask) # N, S, E + output = self.transformer_encoder(src) + return output + + +class DecoderModel(nn.Module): + def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True): + super(DecoderModel, self).__init__() + self.embedding = nn.Embedding(vocab_size, emb_size) + decoder_layers = nn.TransformerDecoderLayer(emb_size, num_heads, hidden_dim_size, batch_first=batch_first) + self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers) + self.linear = nn.Linear(emb_size, vocab_size) + + def forward(self, tokens, attention_mask, encoder_hidden_state): + tgt = self.embedding(tokens) # N, T, E + out = self.transformer_decoder(tgt, encoder_hidden_state, memory_mask=attention_mask) + logits = self.linear(out) + return logits + + +def main(): + vocabulary_size = 10000 + embedding_size = 8 + hidden_dim_size = 16 + num_heads = 1 + num_layers = 1 + + encoder = EncoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers) + decoder = DecoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers) + + # Omit training - just export randomly initialized network + + tokens = torch.LongTensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.LongTensor([[1, 1, 1, 1, 1]]) + + torch.onnx.export(encoder, + (tokens, attention_mask), + "random_encoder.onnx", + input_names=["input_ids", "attention_mask"], + output_names=["last_hidden_state"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "tokens"}, + "attention_mask": {0: "batch", 1: "tokens"}, + "last_hidden_state": {0: "batch", 1: "tokens"}, + }, + opset_version=12) + + last_hidden_state = encoder.forward(tokens, attention_mask) + tokens = torch.LongTensor([[0]]) #1, 2]]) + + torch.onnx.export(decoder, + (tokens, attention_mask.float(), last_hidden_state), + "random_decoder.onnx", + input_names=["input_ids", "encoder_attention_mask", "encoder_hidden_states"], + output_names=["logits"], + dynamic_axes={ + "input_ids": {0: "batch", 1: "target_tokens"}, + "encoder_attention_mask": {0: "batch", 1: "source_tokens"}, + "encoder_hidden_states": {0: "batch", 1: "source_tokens"}, + "logits": {0: "batch", 1: "target_tokens"}, + }, + opset_version=12) + + +if __name__ == "__main__": + main() + + |