diff options
author | Lester Solbakken <lester.solbakken@gmail.com> | 2024-04-11 14:04:32 +0200 |
---|---|---|
committer | Lester Solbakken <lester.solbakken@gmail.com> | 2024-04-11 14:04:32 +0200 |
commit | a2b8ee9591ab36ccbe64c2dc31bfd84fa4caffb3 (patch) | |
tree | 269a9f0525cc2f027f18a49cd45faafd32f1a901 /container-search/src/test | |
parent | 20a9ae9b98f15cdbc24253aa6e9aa585b2759a3a (diff) |
Use 'model' config type for LLM models
Diffstat (limited to 'container-search/src/test')
-rw-r--r-- | container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java | 120 |
1 files changed, 65 insertions, 55 deletions
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java index 5ad2dff8ee1..c7de66e4c81 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java +++ b/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -5,6 +5,7 @@ import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.config.ModelReference; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -36,10 +37,13 @@ public class LocalLLMTest { var prompt = StringPrompt.from("Why are ducks better than cats? Be concise, " + "but use the word 'spoon' somewhere in your answer."); var llm = createLLM(model); - var result = llm.complete(prompt, defaultOptions()); - assertEquals(Completion.FinishReason.stop, result.get(0).finishReason()); - assertTrue(result.get(0).text().contains("spoon")); - llm.deconstruct(); + try { + var result = llm.complete(prompt, defaultOptions()); + assertEquals(Completion.FinishReason.stop, result.get(0).finishReason()); + assertTrue(result.get(0).text().contains("spoon")); + } finally { + llm.deconstruct(); + } } @Test @@ -52,23 +56,25 @@ public class LocalLLMTest { .useGpu(true) .parallelRequests(1) .contextSize(1024) - .localLlmFile(model); + .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); - var future = llm.completeAsync(prompt, defaultOptions(), completion -> { - sb.append(completion.text()); - System.out.print(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); - - assertFalse(future.isDone()); - var reason = future.join(); - assertTrue(future.isDone()); - assertNotEquals(reason, Completion.FinishReason.error); + try { + var future = llm.completeAsync(prompt, defaultOptions(), completion -> { + sb.append(completion.text()); + System.out.print(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); - System.out.println(prompt.asString()); - System.out.println(sb); + assertFalse(future.isDone()); + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); - llm.deconstruct(); + System.out.println(prompt.asString()); + System.out.println(sb); + } finally { + llm.deconstruct(); + } } @Test @@ -84,26 +90,28 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .useGpu(true) .parallelRequests(parallelRequests) - .localLlmFile(model); + .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); - var start = System.currentTimeMillis(); - for (int i = 0; i < promptsToUse; i++) { - final var seq = i; - - completions.set(seq, new StringBuilder()); - futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { - completions.get(seq).append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error)); + try { + var start = System.currentTimeMillis(); + for (int i = 0; i < promptsToUse; i++) { + final var seq = i; + + completions.set(seq, new StringBuilder()); + futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { + completions.get(seq).append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error)); + } + for (int i = 0; i < promptsToUse; i++) { + var reason = futures.get(i).join(); + assertNotEquals(reason, Completion.FinishReason.error); + System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i)); + } + System.out.println("Time: " + (System.currentTimeMillis() - start) / 1000.0 + "s"); + } finally { + llm.deconstruct(); } - for (int i = 0; i < promptsToUse; i++) { - var reason = futures.get(i).join(); - assertNotEquals(reason, Completion.FinishReason.error); - System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i)); - } - System.out.println("Time: " + (System.currentTimeMillis() - start) / 1000.0 + "s"); - - llm.deconstruct(); } @Test @@ -122,29 +130,31 @@ public class LocalLLMTest { .useGpu(true) .parallelRequests(parallelRequests) .maxQueueSize(additionalQueue) - .localLlmFile(model); + .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); - final AtomicInteger rejected = new AtomicInteger(0); - for (int i = 0; i < promptsToUse; i++) { - final var seq = i; - - completions.set(seq, new StringBuilder()); - var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { - completions.get(seq).append(completion.text()); - if (completion.finishReason() == Completion.FinishReason.error) { - rejected.incrementAndGet(); - } - }).exceptionally(exception -> Completion.FinishReason.error); - futures.set(seq, future); - } - for (int i = 0; i < promptsToUse; i++) { - futures.get(i).join(); - System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i)); + try { + final AtomicInteger rejected = new AtomicInteger(0); + for (int i = 0; i < promptsToUse; i++) { + final var seq = i; + + completions.set(seq, new StringBuilder()); + var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { + completions.get(seq).append(completion.text()); + if (completion.finishReason() == Completion.FinishReason.error) { + rejected.incrementAndGet(); + } + }).exceptionally(exception -> Completion.FinishReason.error); + futures.set(seq, future); + } + for (int i = 0; i < promptsToUse; i++) { + futures.get(i).join(); + System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i)); + } + assertEquals(9, rejected.get()); + } finally { + llm.deconstruct(); } - - assertEquals(9, rejected.get()); - llm.deconstruct(); } private static InferenceParameters defaultOptions() { @@ -225,7 +235,7 @@ public class LocalLLMTest { } private static LocalLLM createLLM(String modelPath) { - var config = new LlmLocalClientConfig.Builder().localLlmFile(modelPath).build(); + var config = new LlmLocalClientConfig.Builder().model(ModelReference.valueOf(modelPath)).build(); return new LocalLLM(config); } |