blob: 8c9b961f4a8718c023a28e32bedf4a81f3af7461 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
package ai.vespa.llm.generation;
import ai.vespa.llm.generation.Generator;
import ai.vespa.llm.generation.GeneratorOptions;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.llm.GeneratorConfig;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assume.assumeTrue;
public class GeneratorTest {
@Test
public void testGenerator() {
String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));
GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
builder.tokenizerModel(ModelReference.valueOf(vocabPath));
builder.encoderModel(ModelReference.valueOf(encoderModelPath));
builder.decoderModel(ModelReference.valueOf(decoderModelPath));
Generator generator = newGenerator(builder.build());
GeneratorOptions options = new GeneratorOptions();
options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
options.setMaxLength(10);
String prompt = "generate some random text";
String result = generator.generate(prompt, options);
assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
}
private static Generator newGenerator(GeneratorConfig cfg) {
return new Generator(new OnnxRuntime(), cfg);
}
}
|