aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/llm/generation/Generator.java
blob: 64dafee646f420d1b1c3c8ef1f5400ac0ed2fa4e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
package ai.vespa.llm.generation;

import ai.vespa.modelintegration.evaluator.OnnxEvaluator;
import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.language.process.Embedder;
import com.yahoo.language.sentencepiece.SentencePieceEmbedder;
import com.yahoo.llm.GeneratorConfig;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.api.annotations.Beta;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
* A text generator based on language models (LLMs). By configuring a
 * sentencepience tokenizer and models for encoding and decoding, this
 * component generates text based on the given prompt.
 *
 * See llm.generator.def for configurable parameters.
 *
 * @author lesters
 */
@Beta
public class Generator extends AbstractComponent {

    private final static int TOKEN_EOS = 1;  // end of sequence

    private final static String BATCH_DIMENSION = "d0";
    private final static String SEQUENCE_DIMENSION = "d1";

    private final int tokenizerMaxTokens;
    private final String encoderInputIdsName;
    private final String encoderAttentionMaskName;
    private final String encoderOutputName;
    private final String decoderInputIdsName;
    private final String decoderAttentionMaskName;
    private final String decoderEncoderHiddenStateName;
    private final String decoderOutputName;

    private final SentencePieceEmbedder tokenizer;
    private final OnnxEvaluator encoder;
    private final OnnxEvaluator decoder;

    @Inject
    public Generator(OnnxRuntime onnx, GeneratorConfig config) {
        // Set up tokenizer
        tokenizer = new SentencePieceEmbedder.Builder(config.tokenizerModel().toString()).build();
        tokenizerMaxTokens = config.tokenizerMaxTokens();

        // Set up encoder
        encoderInputIdsName = config.encoderModelInputIdsName();
        encoderAttentionMaskName = config.encoderModelAttentionMaskName();
        encoderOutputName = config.encoderModelOutputName();

        OnnxEvaluatorOptions encoderOptions = new OnnxEvaluatorOptions();
        encoderOptions.setExecutionMode(config.encoderOnnxExecutionMode().toString());
        encoderOptions.setThreads(config.encoderOnnxInterOpThreads(), config.encoderOnnxIntraOpThreads());

        encoder = onnx.evaluatorOf(config.encoderModel().toString(), encoderOptions);

        // Set up decoder
        decoderInputIdsName = config.decoderModelInputIdsName();
        decoderAttentionMaskName = config.decoderModelAttentionMaskName();
        decoderEncoderHiddenStateName = config.decoderModelEncoderHiddenStateName();
        decoderOutputName = config.decoderModelOutputName();

        OnnxEvaluatorOptions decoderOptions = new OnnxEvaluatorOptions();
        decoderOptions.setExecutionMode(config.decoderOnnxExecutionMode().toString());
        decoderOptions.setThreads(config.decoderOnnxInterOpThreads(), config.decoderOnnxIntraOpThreads());

        decoder = onnx.evaluatorOf(config.decoderModel().toString(), decoderOptions);

        validateModels();
    }

    /**
     * Generates text by evaluating an encoder model to encode the prompt, and
     * repeatedly evaluating a decoding model to generate tokens until some
     * stopping criteria has been met.
     *
     * @param prompt the prompt to generate text from
     * @param options options for text generation
     * @return a text generated from the prompt
     */
    public String generate(String prompt, GeneratorOptions options) {
        return switch (options.getSearchMethod()) {
            case GREEDY -> generateGreedy(prompt, options);
            default -> generateNotImplemented(options);
        };
    }

    public String generate(String prompt) {
        return generate(prompt, new GeneratorOptions());
    }

    @Override public void deconstruct() { encoder.close(); decoder.close(); }

    private String generateNotImplemented(GeneratorOptions options) {
        throw new UnsupportedOperationException("Search method '" + options.getSearchMethod() + "' is currently not implemented");
    }

    private String generateGreedy(String prompt, GeneratorOptions options) {
        var generatedTokens = new ArrayList<Integer>();
        generatedTokens.add(0);  // Or target tokens

        // Tokenize
        var inputTokens = tokenize(prompt);  // Or source tokens

        // Evaluate encoder
        var encoderInput  = createTensorRepresentation(inputTokens, SEQUENCE_DIMENSION);
        var encoderMask   = createAttentionMask(encoderInput).expand(BATCH_DIMENSION);
        var encoderOutput = evaluateEncoder(encoderInput.expand(BATCH_DIMENSION), encoderMask);

        // Greedy search just grabs the next most probable token
        while (generatedTokens.size() < options.getMaxLength()) {  // Todo: add stopping criteria
            var decoderInput = createTensorRepresentation(generatedTokens, SEQUENCE_DIMENSION).expand(BATCH_DIMENSION);
            var logits       = evaluateDecoder(decoderInput, encoderMask, encoderOutput);
            var nextToken    = findMostProbableToken(logits, generatedTokens.size()-1, BATCH_DIMENSION, SEQUENCE_DIMENSION);
            generatedTokens.add(nextToken);
        }

        return detokenize(generatedTokens);
    }

    private Tensor evaluateEncoder(Tensor input, Tensor mask) {
        var encoderInputs = Map.of(encoderInputIdsName, input,
                                   encoderAttentionMaskName, mask);
        return encoder.evaluate(encoderInputs, encoderOutputName);
    }

    private IndexedTensor evaluateDecoder(Tensor input, Tensor encoderMask, Tensor encoderOutput) {
        var inputs = Map.of(decoderInputIdsName, input,
                            decoderAttentionMaskName, encoderMask,  // yes, encoder's attention mask
                            decoderEncoderHiddenStateName, encoderOutput);
        var output  = decoder.evaluate(inputs, decoderOutputName);
        if ( ! (output instanceof IndexedTensor indexedTensor)) {
            throw new IllegalArgumentException("Output of decoder model is not an 'IndexedTensor'");
        }
        return indexedTensor;
    }

    /**
     * Given a tensor 'logits' with 3 dimensions: batch, sequence, and vocabulary
     * find the value in the vocabulary dimension with highest score for the given
     * token in the sequence
     */
    private static int findMostProbableToken(IndexedTensor logits, int seqIndex, String batchDim, String seqDim) {
        if (logits.type().rank() != 3) {
            throw new IllegalArgumentException("Expected a tensor with rank 3: batch, sequence, and vocabulary size. " +
                                               "Got: " + logits.type());
        }
        var iterator = logits.cellIterator(new PartialAddress.Builder(2).
                                                add(batchDim, 0).
                                                add(seqDim, seqIndex).build(),
                                           DimensionSizes.of(logits.type()));
        var maxVal = iterator.next().getValue();
        int maxIndex = 0;
        for (int i = 1; iterator.hasNext(); ++i) {
            var val = iterator.next().getValue();
            if (val >= maxVal && i != TOKEN_EOS) {
                maxVal = val;
                maxIndex = i;
            }
        }
        return maxIndex;
    }

    private List<Integer> tokenize(String text) {
        var tokens = tokenizer.embed(text, new Embedder.Context("tokenizer"));
        tokens = tokens.size() >= tokenizerMaxTokens ? tokens.subList(0,tokenizerMaxTokens-1): tokens;
        tokens.add(TOKEN_EOS);
        return tokens;
    }

    private String detokenize(List<Integer> tokens) {
        return tokenizer.decode(tokens, new Embedder.Context("tokenizer"), true);
    }

    private static Tensor createTensorRepresentation(List<Integer> tokens, String dimension) {
        var size = tokens.size();
        TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
        IndexedTensor.Builder builder = IndexedTensor.Builder.of(type);
        for (int i = 0; i < size; ++i) {
            builder.cell(tokens.get(i), i);
        }
        return builder.build();
    }

    private static Tensor createAttentionMask(Tensor d)  {
        return d.map((x) -> x > 0 ? 1:0);
    }

    private void validateModels() {
        Map<String, TensorType> inputs = encoder.getInputInfo();
        validateName(inputs, encoderInputIdsName, "input");
        validateName(inputs, encoderAttentionMaskName, "input");

        Map<String, TensorType> outputs = encoder.getOutputInfo();
        validateName(outputs, encoderOutputName, "output");

        inputs = decoder.getInputInfo();
        validateName(inputs, decoderInputIdsName, "input");
        validateName(inputs, decoderAttentionMaskName, "input");
        validateName(inputs, decoderEncoderHiddenStateName, "input");

        outputs = decoder.getOutputInfo();
        validateName(outputs, decoderOutputName, "output");
    }

    private void validateName(Map<String, TensorType> types, String name, String type) {
        if ( ! types.containsKey(name)) {
            throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
                    "Model contains: " + String.join(",", types.keySet()));
        }
    }

}