aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-11 14:04:32 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-11 14:04:32 +0200
commita2b8ee9591ab36ccbe64c2dc31bfd84fa4caffb3 (patch)
tree269a9f0525cc2f027f18a49cd45faafd32f1a901
parent20a9ae9b98f15cdbc24253aa6e9aa585b2759a3a (diff)
Use 'model' config type for LLM models
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java12
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-local-client.def7
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java120
3 files changed, 68 insertions, 71 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
index 3b99e5f0a09..1e204d29a19 100644
--- a/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
+++ b/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
@@ -44,7 +44,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel {
// Only used if GPU is not used
var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
- var modelFile = selectModelFile(config);
+ var modelFile = config.model().toFile().getAbsolutePath();
var modelParams = new ModelParameters()
.setModelFilePath(modelFile)
.setContinuousBatching(true)
@@ -69,16 +69,6 @@ public class LocalLLM extends AbstractComponent implements LanguageModel {
new ThreadPoolExecutor.AbortPolicy());
}
- private String selectModelFile(LlmLocalClientConfig config) {
- if ( ! config.localLlmFile().isEmpty()) { // primarily for testing
- return config.localLlmFile();
- } else if (config.modelUrl().exists()) {
- return config.modelUrl().getAbsolutePath();
- }
- throw new IllegalArgumentException("Local LLM model not set. " +
- "Either set 'localLlmFile' or 'modelUrl' in 'llm-local-client' config.");
- }
-
@Override
public void deconstruct() {
logger.info("Closing LLM model...");
diff --git a/container-search/src/main/resources/configdefinitions/llm-local-client.def b/container-search/src/main/resources/configdefinitions/llm-local-client.def
index 08eab19f0f8..c06c24b33e5 100755
--- a/container-search/src/main/resources/configdefinitions/llm-local-client.def
+++ b/container-search/src/main/resources/configdefinitions/llm-local-client.def
@@ -1,11 +1,8 @@
# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package=ai.vespa.llm.clients
-# Url to the model to use
-modelUrl url default=""
-
-# Local file path to the model to use - will have precedence over model_url if set - mostly for testing
-localLlmFile string default=""
+# The LLM model to use
+model model
# Maximum number of requests to handle in parallel pr container node
parallelRequests int default=10
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
index 5ad2dff8ee1..c7de66e4c81 100644
--- a/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
+++ b/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
@@ -5,6 +5,7 @@ import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.config.ModelReference;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
@@ -36,10 +37,13 @@ public class LocalLLMTest {
var prompt = StringPrompt.from("Why are ducks better than cats? Be concise, " +
"but use the word 'spoon' somewhere in your answer.");
var llm = createLLM(model);
- var result = llm.complete(prompt, defaultOptions());
- assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
- assertTrue(result.get(0).text().contains("spoon"));
- llm.deconstruct();
+ try {
+ var result = llm.complete(prompt, defaultOptions());
+ assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
+ assertTrue(result.get(0).text().contains("spoon"));
+ } finally {
+ llm.deconstruct();
+ }
}
@Test
@@ -52,23 +56,25 @@ public class LocalLLMTest {
.useGpu(true)
.parallelRequests(1)
.contextSize(1024)
- .localLlmFile(model);
+ .model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
- var future = llm.completeAsync(prompt, defaultOptions(), completion -> {
- sb.append(completion.text());
- System.out.print(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
-
- assertFalse(future.isDone());
- var reason = future.join();
- assertTrue(future.isDone());
- assertNotEquals(reason, Completion.FinishReason.error);
+ try {
+ var future = llm.completeAsync(prompt, defaultOptions(), completion -> {
+ sb.append(completion.text());
+ System.out.print(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
- System.out.println(prompt.asString());
- System.out.println(sb);
+ assertFalse(future.isDone());
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
- llm.deconstruct();
+ System.out.println(prompt.asString());
+ System.out.println(sb);
+ } finally {
+ llm.deconstruct();
+ }
}
@Test
@@ -84,26 +90,28 @@ public class LocalLLMTest {
var config = new LlmLocalClientConfig.Builder()
.useGpu(true)
.parallelRequests(parallelRequests)
- .localLlmFile(model);
+ .model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
- var start = System.currentTimeMillis();
- for (int i = 0; i < promptsToUse; i++) {
- final var seq = i;
-
- completions.set(seq, new StringBuilder());
- futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
- completions.get(seq).append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error));
+ try {
+ var start = System.currentTimeMillis();
+ for (int i = 0; i < promptsToUse; i++) {
+ final var seq = i;
+
+ completions.set(seq, new StringBuilder());
+ futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
+ completions.get(seq).append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error));
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ var reason = futures.get(i).join();
+ assertNotEquals(reason, Completion.FinishReason.error);
+ System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i));
+ }
+ System.out.println("Time: " + (System.currentTimeMillis() - start) / 1000.0 + "s");
+ } finally {
+ llm.deconstruct();
}
- for (int i = 0; i < promptsToUse; i++) {
- var reason = futures.get(i).join();
- assertNotEquals(reason, Completion.FinishReason.error);
- System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i));
- }
- System.out.println("Time: " + (System.currentTimeMillis() - start) / 1000.0 + "s");
-
- llm.deconstruct();
}
@Test
@@ -122,29 +130,31 @@ public class LocalLLMTest {
.useGpu(true)
.parallelRequests(parallelRequests)
.maxQueueSize(additionalQueue)
- .localLlmFile(model);
+ .model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
- final AtomicInteger rejected = new AtomicInteger(0);
- for (int i = 0; i < promptsToUse; i++) {
- final var seq = i;
-
- completions.set(seq, new StringBuilder());
- var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
- completions.get(seq).append(completion.text());
- if (completion.finishReason() == Completion.FinishReason.error) {
- rejected.incrementAndGet();
- }
- }).exceptionally(exception -> Completion.FinishReason.error);
- futures.set(seq, future);
- }
- for (int i = 0; i < promptsToUse; i++) {
- futures.get(i).join();
- System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i));
+ try {
+ final AtomicInteger rejected = new AtomicInteger(0);
+ for (int i = 0; i < promptsToUse; i++) {
+ final var seq = i;
+
+ completions.set(seq, new StringBuilder());
+ var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
+ completions.get(seq).append(completion.text());
+ if (completion.finishReason() == Completion.FinishReason.error) {
+ rejected.incrementAndGet();
+ }
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ futures.set(seq, future);
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ futures.get(i).join();
+ System.out.println("\n\n***\n" + prompts.get(i) + ":\n***\n" + completions.get(i));
+ }
+ assertEquals(9, rejected.get());
+ } finally {
+ llm.deconstruct();
}
-
- assertEquals(9, rejected.get());
- llm.deconstruct();
}
private static InferenceParameters defaultOptions() {
@@ -225,7 +235,7 @@ public class LocalLLMTest {
}
private static LocalLLM createLLM(String modelPath) {
- var config = new LlmLocalClientConfig.Builder().localLlmFile(modelPath).build();
+ var config = new LlmLocalClientConfig.Builder().model(ModelReference.valueOf(modelPath)).build();
return new LocalLLM(config);
}