From 18568733247f2e4d0f603416061e1002a83e9317 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 6 May 2024 13:10:21 +0200 Subject: Add timeout for requests waiting to start local llm inference --- .../main/java/ai/vespa/llm/clients/LocalLLM.java | 38 +++++++++++++++++++--- .../configdefinitions/llm-local-client.def | 7 ++-- 2 files changed, 38 insertions(+), 7 deletions(-) (limited to 'model-integration/src/main') diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index aa7c071b93a..b6409b5466d 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -3,6 +3,7 @@ package ai.vespa.llm.clients; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LanguageModelException; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.AbstractComponent; @@ -14,10 +15,14 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.logging.Logger; @@ -29,14 +34,19 @@ import java.util.logging.Logger; public class LocalLLM extends AbstractComponent implements LanguageModel { private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); + + private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); + private final LlamaModel model; private final ThreadPoolExecutor executor; + private final long queueTimeoutMilliseconds; private final int contextSize; private final int maxTokens; @Inject public LocalLLM(LlmLocalClientConfig config) { executor = createExecutor(config); + queueTimeoutMilliseconds = config.maxQueueWait(); // Maximum number of tokens to generate - need this since some models can just generate infinitely maxTokens = config.maxTokens(); @@ -74,6 +84,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel { logger.info("Closing LLM model..."); model.close(); executor.shutdownNow(); + scheduler.shutdownNow(); } @Override @@ -104,22 +115,39 @@ public class LocalLLM extends AbstractComponent implements LanguageModel { // Todo: more options? var completionFuture = new CompletableFuture(); + var hasStarted = new AtomicBoolean(false); try { - executor.submit(() -> { + Future future = executor.submit(() -> { + hasStarted.set(true); for (LlamaModel.Output output : model.generate(inferParams)) { consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); } completionFuture.complete(Completion.FinishReason.stop); }); + + if (queueTimeoutMilliseconds > 0) { + scheduler.schedule(() -> { + if ( ! hasStarted.get()) { + future.cancel(false); + String error = rejectedExecutionReason("Rejected completion due to timeout waiting to start"); + completionFuture.completeExceptionally(new LanguageModelException(504, error)); + } + }, queueTimeoutMilliseconds, TimeUnit.MILLISECONDS); + } + } catch (RejectedExecutionException e) { // If we have too many requests (active + any waiting in queue), we reject the completion - int activeCount = executor.getActiveCount(); - int queueSize = executor.getQueue().size(); - String error = String.format("Rejected completion due to too many requests, " + - "%d active, %d in queue", activeCount, queueSize); + String error = rejectedExecutionReason("Rejected completion due to too many requests"); throw new RejectedExecutionException(error); } return completionFuture; } + private String rejectedExecutionReason(String prepend) { + int activeCount = executor.getActiveCount(); + int queueSize = executor.getQueue().size(); + return String.format("%s, %d active, %d in queue", prepend, activeCount, queueSize); + } + + } diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def index 4823a53ec46..6b83ffd0751 100755 --- a/model-integration/src/main/resources/configdefinitions/llm-local-client.def +++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def @@ -8,7 +8,10 @@ model model parallelRequests int default=1 # Additional number of requests to put in queue for processing before starting to reject new requests -maxQueueSize int default=10 +maxQueueSize int default=100 + +# Max number of milliseoncds to wait in the queue before rejecting a request +maxQueueWait int default=10000 # Use GPU useGpu bool default=true @@ -24,6 +27,6 @@ threads int default=-1 # Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context contextSize int default=4096 -# Maximum number of tokens to process in one request - overriden by inference parameters +# Maximum number of tokens to process in one request - overridden by inference parameters maxTokens int default=512 -- cgit v1.2.3