summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/llm/CompletionTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/CompletionTest.java38
1 files changed, 38 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
new file mode 100644
index 00000000000..acf7d14c438
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
@@ -0,0 +1,38 @@
+package ai.vespa.llm;
+
+import ai.vespa.llm.testing.MockLanguageModel;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Tests completion with a mock completer.
+ *
+ * @author bratseth
+ */
+public class CompletionTest {
+
+ @Test
+ public void testCompletion() {
+ Function<Prompt, List<Completion>> completer = in ->
+ switch (in.asString()) {
+ case "Complete this: " -> List.of(Completion.from("The completion"));
+ default -> throw new RuntimeException("Cannot complete '" + in + "'");
+ };
+ var llm = new MockLanguageModel.Builder().completer(completer).build();
+
+ String input = "Complete this: ";
+ StringPrompt prompt = StringPrompt.from(input);
+ for (int i = 0; i < 10; i++) {
+ var completion = llm.complete(prompt).get(0);
+ prompt = prompt.append(completion);
+ if (completion.finishReason() == Completion.FinishReason.stop) break;
+ }
+ assertEquals("Complete this: The completion", prompt.asString());
+ }
+
+}