aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java
blob: 8c9b961f4a8718c023a28e32bedf4a81f3af7461 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
package ai.vespa.llm.generation;

import ai.vespa.llm.generation.Generator;
import ai.vespa.llm.generation.GeneratorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.llm.GeneratorConfig;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assume.assumeTrue;

public class GeneratorTest {

    @Test
    public void testGenerator() {
        String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
        String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
        String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
        assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));

        GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
        builder.tokenizerModel(ModelReference.valueOf(vocabPath));
        builder.encoderModel(ModelReference.valueOf(encoderModelPath));
        builder.decoderModel(ModelReference.valueOf(decoderModelPath));
        Generator generator = newGenerator(builder.build());

        GeneratorOptions options = new GeneratorOptions();
        options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
        options.setMaxLength(10);

        String prompt = "generate some random text";
        String result = generator.generate(prompt, options);

        assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
    }

    private static Generator newGenerator(GeneratorConfig cfg) {
        return new Generator(new OnnxRuntime(), cfg);
    }

}