blob: 05b623a00c0eb2b4232ed6416fd307c3354b68ad (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.llm.generation;
import ai.vespa.llm.generation.Generator;
import ai.vespa.llm.generation.GeneratorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.llm.GeneratorConfig;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assume.assumeTrue;
public class GeneratorTest {
@Test
public void testGenerator() {
String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));
GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
builder.tokenizerModel(ModelReference.valueOf(vocabPath));
builder.encoderModel(ModelReference.valueOf(encoderModelPath));
builder.decoderModel(ModelReference.valueOf(decoderModelPath));
Generator generator = newGenerator(builder.build());
GeneratorOptions options = new GeneratorOptions();
options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
options.setMaxLength(10);
String prompt = "generate some random text";
String result = generator.generate(prompt, options);
assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
}
private static Generator newGenerator(GeneratorConfig cfg) {
return new Generator(new OnnxRuntime(), cfg);
}
}
|