summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-10 17:24:44 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-10 17:24:44 +0200
commitb321d23f99f7ee87dd19044de5951d250c29ec27 (patch)
tree6ab7f59326b69d02cc5b773ca90f6e2362a9f339 /container-search
parentdedeea2ae252ea75bf6991dc2ef6cf228155825f (diff)
Non-functional changes
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java35
1 files changed, 20 insertions, 15 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
index e34f914729c..c550406ff92 100644
--- a/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
+++ b/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
@@ -36,20 +36,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel {
@Inject
public LocalLLM(LlmLocalClientConfig config) {
- this.executor = new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
- 0L, TimeUnit.MILLISECONDS,
- config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
- new ThreadPoolExecutor.AbortPolicy());
-
- String modelFile;
- if ( ! config.localLlmFile().isEmpty()) { // for testing
- modelFile = config.localLlmFile();
- } else if (config.modelUrl().exists()){
- modelFile = config.modelUrl().getAbsolutePath();
- } else {
- throw new IllegalArgumentException("Local LLM model not set. " +
- "Either set 'localLlmFile' or 'modelUrl' in 'llm-local-client' config.");
- }
+ executor = createExecutor(config);
// Maximum number of tokens to generate - need this since some models can just generate infinitely
maxTokens = config.maxTokens();
@@ -57,6 +44,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel {
// Only used if GPU is not used
var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
+ var modelFile = selectModelFile(config);
var modelParams = new ModelParameters()
.setModelFilePath(modelFile)
.setContinuousBatching(true)
@@ -71,7 +59,24 @@ public class LocalLLM extends AbstractComponent implements LanguageModel {
logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000)));
// Todo: handle prompt context size - such as give a warning when prompt exceeds context size
- this.contextSize = config.contextSize();
+ contextSize = config.contextSize();
+ }
+
+ private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) {
+ return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
+ 0L, TimeUnit.MILLISECONDS,
+ config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
+ new ThreadPoolExecutor.AbortPolicy());
+ }
+
+ private String selectModelFile(LlmLocalClientConfig config) {
+ if ( ! config.localLlmFile().isEmpty()) { // primarily for testing
+ return config.localLlmFile();
+ } else if (config.modelUrl().exists()) {
+ return config.modelUrl().getAbsolutePath();
+ }
+ throw new IllegalArgumentException("Local LLM model not set. " +
+ "Either set 'localLlmFile' or 'modelUrl' in 'llm-local-client' config.");
}
@Override