diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java new file mode 100644 index 00000000000..8c9b961f4a8 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java @@ -0,0 +1,42 @@ +package ai.vespa.llm.generation; + +import ai.vespa.llm.generation.Generator; +import ai.vespa.llm.generation.GeneratorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.config.ModelReference; +import com.yahoo.llm.GeneratorConfig; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeTrue; + +public class GeneratorTest { + + @Test + public void testGenerator() { + String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model"; + String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx"; + String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath)); + + GeneratorConfig.Builder builder = new GeneratorConfig.Builder(); + builder.tokenizerModel(ModelReference.valueOf(vocabPath)); + builder.encoderModel(ModelReference.valueOf(encoderModelPath)); + builder.decoderModel(ModelReference.valueOf(decoderModelPath)); + Generator generator = newGenerator(builder.build()); + + GeneratorOptions options = new GeneratorOptions(); + options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY); + options.setMaxLength(10); + + String prompt = "generate some random text"; + String result = generator.generate(prompt, options); + + assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result); + } + + private static Generator newGenerator(GeneratorConfig cfg) { + return new Generator(new OnnxRuntime(), cfg); + } + +} |