aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java42
1 files changed, 42 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java
new file mode 100644
index 00000000000..8c9b961f4a8
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java
@@ -0,0 +1,42 @@
+package ai.vespa.llm.generation;
+
+import ai.vespa.llm.generation.Generator;
+import ai.vespa.llm.generation.GeneratorOptions;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
+import com.yahoo.config.ModelReference;
+import com.yahoo.llm.GeneratorConfig;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assume.assumeTrue;
+
+public class GeneratorTest {
+
+ @Test
+ public void testGenerator() {
+ String vocabPath = "src/test/models/onnx/llm/en.wiki.bpe.vs10000.model";
+ String encoderModelPath = "src/test/models/onnx/llm/random_encoder.onnx";
+ String decoderModelPath = "src/test/models/onnx/llm/random_decoder.onnx";
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(encoderModelPath));
+
+ GeneratorConfig.Builder builder = new GeneratorConfig.Builder();
+ builder.tokenizerModel(ModelReference.valueOf(vocabPath));
+ builder.encoderModel(ModelReference.valueOf(encoderModelPath));
+ builder.decoderModel(ModelReference.valueOf(decoderModelPath));
+ Generator generator = newGenerator(builder.build());
+
+ GeneratorOptions options = new GeneratorOptions();
+ options.setSearchMethod(GeneratorOptions.SearchMethod.GREEDY);
+ options.setMaxLength(10);
+
+ String prompt = "generate some random text";
+ String result = generator.generate(prompt, options);
+
+ assertEquals("<unk> linear recruit latest sack annually institutions cert solid references", result);
+ }
+
+ private static Generator newGenerator(GeneratorConfig cfg) {
+ return new Generator(new OnnxRuntime(), cfg);
+ }
+
+}