aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2023-02-18 11:48:09 +0100
committerLester Solbakken <lesters@oath.com>2023-02-18 11:48:09 +0100
commit3c630c78c43cd3ad91cda1836ee47eff0fb4d0bf (patch)
treed9aebd5a4b852366e93470b518dd938d1dad184b
parent654fc45d0ae303bd7d171d5754cc3e813f8a3b34 (diff)
Add initial text generator component
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Generator.java226
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java34
-rw-r--r--model-integration/src/main/resources/configdefinitions/llm.generator.def32
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java36
-rw-r--r--model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.modelbin0 -> 400869 bytes
-rw-r--r--model-integration/src/test/models/onnx/llm/random_decoder.onnxbin0 -> 735210 bytes
-rw-r--r--model-integration/src/test/models/onnx/llm/random_encoder.onnxbin0 -> 345336 bytes
-rw-r--r--model-integration/src/test/models/onnx/llm/random_llm.py82
9 files changed, 410 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java
new file mode 100644
index 00000000000..ed231a5e94c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java
@@ -0,0 +1,226 @@
+package ai.vespa.llm;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
+import com.yahoo.llm.GeneratorConfig;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.PartialAddress;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+/**
+* A text generator based on language models (LLMs). By configuring a
+ * sentencepience tokenizer and models for encoding and decoding, this
+ * component generates text based on the given prompt.
+ *
+ * See llm.generator.def for configurable parameters.
+ *
+ * @author lesters
+ */
+public class Generator {
+
+ private final static int TOKEN_EOS = 1; // end of sequence
+
+ private final static String BATCH_DIMENSION = "d0";
+ private final static String SEQUENCE_DIMENSION = "d1";
+
+ private final int tokenizerMaxTokens;
+ private final String encoderInputIdsName;
+ private final String encoderAttentionMaskName;
+ private final String encoderOutputName;
+ private final String decoderInputIdsName;
+ private final String decoderAttentionMaskName;
+ private final String decoderEncoderHiddenStateName;
+ private final String decoderOutputName;
+
+ private final SentencePieceEmbedder tokenizer;
+ private final OnnxEvaluator encoder;
+ private final OnnxEvaluator decoder;
+
+ @Inject
+ public Generator(GeneratorConfig config) {
+ // Set up tokenizer
+ tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build();
+ tokenizerMaxTokens = config.tokenizerMaxTokens();
+
+ // Set up encoder
+ encoderInputIdsName = config.encoderModelInputIdsName();
+ encoderAttentionMaskName = config.encoderModelAttentionMaskName();
+ encoderOutputName = config.encoderModelOutputName();
+
+ OnnxEvaluatorOptions encoderOptions = new OnnxEvaluatorOptions();
+ encoderOptions.setExecutionMode(config.encoderOnnxExecutionMode().toString());
+ encoderOptions.setInterOpThreads(modifyThreadCount(config.encoderOnnxInterOpThreads()));
+ encoderOptions.setIntraOpThreads(modifyThreadCount(config.encoderOnnxIntraOpThreads()));
+
+ encoder = new OnnxEvaluator(config.encoderModel().toString(), encoderOptions);
+
+ // Set up decoder
+ decoderInputIdsName = config.decoderModelInputIdsName();
+ decoderAttentionMaskName = config.decoderModelAttentionMaskName();
+ decoderEncoderHiddenStateName = config.decoderModelEncoderHiddenStateName();
+ decoderOutputName = config.decoderModelOutputName();
+
+ OnnxEvaluatorOptions decoderOptions = new OnnxEvaluatorOptions();
+ decoderOptions.setExecutionMode(config.decoderOnnxExecutionMode().toString());
+ decoderOptions.setInterOpThreads(modifyThreadCount(config.decoderOnnxInterOpThreads()));
+ decoderOptions.setIntraOpThreads(modifyThreadCount(config.decoderOnnxIntraOpThreads()));
+
+ decoder = new OnnxEvaluator(config.decoderModel().toString(), decoderOptions);
+
+ validateModels();
+ }
+
+ /**
+ * Generates text by evaluating an encoder model to encode the prompt, and
+ * repeatedly evaluating a decoding model to generate tokens until some
+ * stopping criteria has been met.
+ *
+ * @param prompt the prompt to generate text from
+ * @param options options for text generation
+ * @return a text generated from the prompt
+ */
+ public String generate(String prompt, GeneratorOptions options) {
+ return switch (options.getSearchMethod()) {
+ case GREEDY -> generateGreedy(prompt, options);
+ default -> generateNotImplemented(options);
+ };
+ }
+
+ public String generate(String prompt) {
+ return generate(prompt, new GeneratorOptions());
+ }
+
+ private String generateNotImplemented(GeneratorOptions options) {
+ throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented");
+ }
+
+ private String generateGreedy(String prompt, GeneratorOptions options) {
+ var generatedTokens = new ArrayList<Integer>();
+ generatedTokens.add(0); // Or target tokens
+
+ // Tokenize
+ var inputTokens = tokenize(prompt); // Or source tokens
+
+ // Evaluate encoder
+ var encoderInput = createTensorRepresentation(inputTokens, SEQUENCE_DIMENSION);
+ var encoderMask = createAttentionMask(encoderInput).expand(BATCH_DIMENSION);
+ var encoderOutput = evaluateEncoder(encoderInput.expand(BATCH_DIMENSION), encoderMask);
+
+ // Greedy search just grabs the next most probable token
+ while (generatedTokens.size() < options.getMaxLength()) { // Todo: add stopping criteria
+ var decoderInput = createTensorRepresentation(generatedTokens, SEQUENCE_DIMENSION).expand(BATCH_DIMENSION);
+ var logits = evaluateDecoder(decoderInput, encoderMask, encoderOutput);
+ var nextToken = findMostProbableToken(logits, generatedTokens.size()-1, BATCH_DIMENSION, SEQUENCE_DIMENSION);
+ generatedTokens.add(nextToken);
+ }
+
+ return detokenize(generatedTokens);
+ }
+
+ private Tensor evaluateEncoder(Tensor input, Tensor mask) {
+ var encoderInputs = Map.of(encoderInputIdsName, input,
+ encoderAttentionMaskName, mask);
+ return encoder.evaluate(encoderInputs, encoderOutputName);
+ }
+
+ private IndexedTensor evaluateDecoder(Tensor input, Tensor encoderMask, Tensor encoderOutput) {
+ var inputs = Map.of(decoderInputIdsName, input,
+ decoderAttentionMaskName, encoderMask, // yes, encoder's attention mask
+ decoderEncoderHiddenStateName, encoderOutput);
+ var output = decoder.evaluate(inputs, decoderOutputName);
+ if ( ! (output instanceof IndexedTensor indexedTensor)) {
+ throw new IllegalArgumentException("Output of decoder model is not an 'IndexedTensor'");
+ }
+ return indexedTensor;
+ }
+
+ /**
+ * Given a tensor 'logits' with 3 dimensions: batch, sequence, and vocabulary
+ * find the value in the vocabulary dimension with highest score for the given
+ * token in the sequence
+ */
+ private static int findMostProbableToken(IndexedTensor logits, int seqIndex, String batchDim, String seqDim) {
+ if (logits.type().rank() != 3) {
+ throw new IllegalArgumentException("Expected a tensor with rank 3: batch, sequence, and vocabulary size. " +
+ "Got: " + logits.type());
+ }
+ var iterator = logits.cellIterator(new PartialAddress.Builder(2).
+ add(batchDim, 0).
+ add(seqDim, seqIndex).build(),
+ DimensionSizes.of(logits.type()));
+ var maxVal = iterator.next().getValue();
+ int maxIndex = 0;
+ for (int i = 1; iterator.hasNext(); ++i) {
+ var val = iterator.next().getValue();
+ if (val >= maxVal && i != TOKEN_EOS) {
+ maxVal = val;
+ maxIndex = i;
+ }
+ }
+ return maxIndex;
+ }
+
+ private List<Integer> tokenize(String text) {
+ var tokens = tokenizer.embed(text, new Embedder.Context("tokenizer"));
+ tokens = tokens.size() >= tokenizerMaxTokens ? tokens.subList(0,tokenizerMaxTokens-1): tokens;
+ tokens.add(TOKEN_EOS);
+ return tokens;
+ }
+
+ private String detokenize(List<Integer> tokens) {
+ return tokenizer.decode(tokens, new Embedder.Context("tokenizer"), true);
+ }
+
+ private static Tensor createTensorRepresentation(List<Integer> tokens, String dimension) {
+ var size = tokens.size();
+ TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
+ for (int i = 0; i < size; ++i) {
+ builder.cell(tokens.get(i), i);
+ }
+ return builder.build();
+ }
+
+ private static Tensor createAttentionMask(Tensor d) {
+ return d.map((x) -> x > 0 ? 1:0);
+ }
+
+ private void validateModels() {
+ Map<String, TensorType> inputs = encoder.getInputInfo();
+ validateName(inputs, encoderInputIdsName, "input");
+ validateName(inputs, encoderAttentionMaskName, "input");
+
+ Map<String, TensorType> outputs = encoder.getOutputInfo();
+ validateName(outputs, encoderOutputName, "output");
+
+ inputs = decoder.getInputInfo();
+ validateName(inputs, decoderInputIdsName, "input");
+ validateName(inputs, decoderAttentionMaskName, "input");
+ validateName(inputs, decoderEncoderHiddenStateName, "input");
+
+ outputs = decoder.getOutputInfo();
+ validateName(outputs, decoderOutputName, "output");
+ }
+
+ private void validateName(Map<String, TensorType> types, String name, String type) {
+ if ( ! types.containsKey(name)) {
+ throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
+ "Model contains: " + String.join(",", types.keySet()));
+ }
+ }
+
+ private int modifyThreadCount(int numThreads) {
+ if (numThreads >= 0)
+ return numThreads;
+ return Math.max(1, (int) Math.ceil(((double) Runtime.getRuntime().availableProcessors()) / (-1 * numThreads)));
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
new file mode 100644
index 00000000000..743bb7c2f27
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
@@ -0,0 +1,34 @@
+package ai.vespa.llm;
+
+public class GeneratorOptions {
+
+ public enum SearchMethod {
+ GREEDY,
+ CONTRASTIVE,
+ BEAM,
+ SAMPLE,
+ }
+
+ private SearchMethod searchMethod = SearchMethod.GREEDY;
+ private int maxLength = 20;
+
+ public SearchMethod getSearchMethod() {
+ return searchMethod;
+ }
+
+ public GeneratorOptions setSearchMethod(SearchMethod searchMethod) {
+ this.searchMethod = searchMethod;
+ return this;
+ }
+
+ public int getMaxLength() {
+ return maxLength;
+ }
+
+ public GeneratorOptions setMaxLength(int maxLength) {
+ this.maxLength = maxLength;
+ return this;
+ }
+
+
+}
diff --git a/model-integration/src/main/resources/configdefinitions/llm.generator.def b/model-integration/src/main/resources/configdefinitions/llm.generator.def
new file mode 100644
index 00000000000..478daad6ede
--- /dev/null
+++ b/model-integration/src/main/resources/configdefinitions/llm.generator.def
@@ -0,0 +1,32 @@
+namespace=llm
+
+# SentencePiece tokenizer
+tokenizerModel model
+tokenizerMaxTokens int default=1000
+
+#
+# The encoder model
+#
+encoderModel model
+encoderModelInputIdsName string default=input_ids
+encoderModelAttentionMaskName string default=attention_mask
+encoderModelOutputName string default=last_hidden_state
+
+encoderOnnxExecutionMode enum { parallel, sequential } default=sequential
+encoderOnnxInterOpThreads int default=1
+encoderOnnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
+# enable GPU?
+
+#
+# The decoder model
+#
+decoderModel model
+decoderModelInputIdsName string default=input_ids
+decoderModelAttentionMaskName string default=encoder_attention_mask
+decoderModelEncoderHiddenStateName string default=encoder_hidden_states
+decoderModelOutputName string default=logits
+
+decoderOnnxExecutionMode enum { parallel, sequential } default=sequential
+decoderOnnxInterOpThreads int default=1
+decoderOnnxIntraOpThreads int default=-4 # n=number of threads -> n<0: CPUs/(-n), n==0: CPUs, n>0: n
+# enable GPU?
diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
index 73359736536..b06a54d68bb 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java
@@ -1,9 +1,7 @@
package ai.vespa.embedding;
import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
-import com.yahoo.config.FileReference;
import com.yahoo.config.ModelReference;
-import com.yahoo.config.UrlReference;
import com.yahoo.embedding.BertBaseEmbedderConfig;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
new file mode 100644
index 00000000000..733430aa10d
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java
@@ -0,0 +1,36 @@
+package ai.vespa.llm;
+
+import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
+import com.yahoo.config.ModelReference;
+import com.yahoo.llm.GeneratorConfig;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+
+public class GeneratorTest {
+
+ @Test
+ public void testGenerator() {
+ String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
+ String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
+ String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
+ assumeTrue(OnnxEvaluator.isRuntimeAvailable(encoderModelPath));
+
+ GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
+ builder.tokenizerModel(ModelReference.valueOf(vocabPath));
+ builder.encoderModel(ModelReference.valueOf(encoderModelPath));
+ builder.decoderModel(ModelReference.valueOf(decoderModelPath));
+ Generator generator = new Generator(builder.build());
+
+ GeneratorOptions options = new GeneratorOptions();
+ options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
+ options.setMaxLength(10);
+
+ String prompt = "generate some random text";
+ String result = generator.generate(prompt, options);
+
+ assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model
new file mode 100644
index 00000000000..89f93ef3517
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/en.wiki.bpe.vs10000.model
Binary files differ
diff --git a/model-integration/src/test/models/onnx/llm/random_decoder.onnx b/model-integration/src/test/models/onnx/llm/random_decoder.onnx
new file mode 100644
index 00000000000..a8c5f18ddf2
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/random_decoder.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/llm/random_encoder.onnx b/model-integration/src/test/models/onnx/llm/random_encoder.onnx
new file mode 100644
index 00000000000..a9100fcd6af
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/random_encoder.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/llm/random_llm.py b/model-integration/src/test/models/onnx/llm/random_llm.py
new file mode 100644
index 00000000000..722906fc48b
--- /dev/null
+++ b/model-integration/src/test/models/onnx/llm/random_llm.py
@@ -0,0 +1,82 @@
+import torch
+import torch.onnx
+import torch.nn as nn
+from torch.nn import TransformerEncoderLayer, TransformerEncoder, TransformerDecoder, TransformerDecoderLayer
+
+
+class EncoderModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
+ super(EncoderModel, self).__init__()
+ self.embedding = nn.Embedding(vocab_size, emb_size)
+ encoder_layers = TransformerEncoderLayer(emb_size, num_heads, hidden_dim_size, dropout, batch_first=batch_first)
+ self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
+
+ def forward(self, tokens, attention_mask):
+ src = self.embedding(tokens * attention_mask) # N, S, E
+ output = self.transformer_encoder(src)
+ return output
+
+
+class DecoderModel(nn.Module):
+ def __init__(self, vocab_size, emb_size, hidden_dim_size, num_heads, num_layers, dropout=0.2, batch_first=True):
+ super(DecoderModel, self).__init__()
+ self.embedding = nn.Embedding(vocab_size, emb_size)
+ decoder_layers = nn.TransformerDecoderLayer(emb_size, num_heads, hidden_dim_size, batch_first=batch_first)
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layers, num_layers)
+ self.linear = nn.Linear(emb_size, vocab_size)
+
+ def forward(self, tokens, attention_mask, encoder_hidden_state):
+ tgt = self.embedding(tokens) # N, T, E
+ out = self.transformer_decoder(tgt, encoder_hidden_state, memory_mask=attention_mask)
+ logits = self.linear(out)
+ return logits
+
+
+def main():
+ vocabulary_size = 10000
+ embedding_size = 8
+ hidden_dim_size = 16
+ num_heads = 1
+ num_layers = 1
+
+ encoder = EncoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)
+ decoder = DecoderModel(vocabulary_size, embedding_size, hidden_dim_size, num_heads, num_layers)
+
+ # Omit training - just export randomly initialized network
+
+ tokens = torch.LongTensor([[1, 2, 3, 4, 5]])
+ attention_mask = torch.LongTensor([[1, 1, 1, 1, 1]])
+
+ torch.onnx.export(encoder,
+ (tokens, attention_mask),
+ "random_encoder.onnx",
+ input_names=["input_ids", "attention_mask"],
+ output_names=["last_hidden_state"],
+ dynamic_axes={
+ "input_ids": {0: "batch", 1: "tokens"},
+ "attention_mask": {0: "batch", 1: "tokens"},
+ "last_hidden_state": {0: "batch", 1: "tokens"},
+ },
+ opset_version=12)
+
+ last_hidden_state = encoder.forward(tokens, attention_mask)
+ tokens = torch.LongTensor([[0]]) #1, 2]])
+
+ torch.onnx.export(decoder,
+ (tokens, attention_mask.float(), last_hidden_state),
+ "random_decoder.onnx",
+ input_names=["input_ids", "encoder_attention_mask", "encoder_hidden_states"],
+ output_names=["logits"],
+ dynamic_axes={
+ "input_ids": {0: "batch", 1: "target_tokens"},
+ "encoder_attention_mask": {0: "batch", 1: "source_tokens"},
+ "encoder_hidden_states": {0: "batch", 1: "source_tokens"},
+ "logits": {0: "batch", 1: "target_tokens"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+