diff options
Diffstat (limited to 'model-integration/src')
4 files changed, 60 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index aa7c071b93a..bbb82db7139 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -3,6 +3,7 @@ package ai.vespa.llm.clients; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LanguageModelException; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.AbstractComponent; @@ -14,10 +15,14 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.SynchronousQueue; import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.logging.Logger; @@ -29,14 +34,19 @@ import java.util.logging.Logger; public class LocalLLM extends AbstractComponent implements LanguageModel { private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); + + private final ScheduledExecutorService scheduler = Executors.newSingleThreadScheduledExecutor(); + private final LlamaModel model; private final ThreadPoolExecutor executor; + private final long queueTimeoutMilliseconds; private final int contextSize; private final int maxTokens; @Inject public LocalLLM(LlmLocalClientConfig config) { executor = createExecutor(config); + queueTimeoutMilliseconds = config.maxQueueWait(); // Maximum number of tokens to generate - need this since some models can just generate infinitely maxTokens = config.maxTokens(); @@ -74,6 +84,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel { logger.info("Closing LLM model..."); model.close(); executor.shutdownNow(); + scheduler.shutdownNow(); } @Override @@ -103,23 +114,42 @@ public class LocalLLM extends AbstractComponent implements LanguageModel { options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); // Todo: more options? + inferParams.setUseChatTemplate(true); + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var hasStarted = new AtomicBoolean(false); try { - executor.submit(() -> { - for (LlamaModel.Output output : model.generate(inferParams)) { + Future<?> future = executor.submit(() -> { + hasStarted.set(true); + for (var output : model.generate(inferParams)) { consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); } completionFuture.complete(Completion.FinishReason.stop); }); + + if (queueTimeoutMilliseconds > 0) { + scheduler.schedule(() -> { + if ( ! hasStarted.get()) { + future.cancel(false); + String error = rejectedExecutionReason("Rejected completion due to timeout waiting to start"); + completionFuture.completeExceptionally(new LanguageModelException(504, error)); + } + }, queueTimeoutMilliseconds, TimeUnit.MILLISECONDS); + } + } catch (RejectedExecutionException e) { // If we have too many requests (active + any waiting in queue), we reject the completion - int activeCount = executor.getActiveCount(); - int queueSize = executor.getQueue().size(); - String error = String.format("Rejected completion due to too many requests, " + - "%d active, %d in queue", activeCount, queueSize); + String error = rejectedExecutionReason("Rejected completion due to too many requests"); throw new RejectedExecutionException(error); } return completionFuture; } + private String rejectedExecutionReason(String prepend) { + int activeCount = executor.getActiveCount(); + int queueSize = executor.getQueue().size(); + return String.format("%s, %d active, %d in queue", prepend, activeCount, queueSize); + } + + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java index e1d2f8802a6..6a1e2f2562a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMImporter.java @@ -34,9 +34,9 @@ public class LightGBMImporter extends ModelImporter { private boolean probe(File modelFile) { try (JsonParser parser = Jackson.mapper().createParser(modelFile)) { while (parser.nextToken() != null) { - JsonToken token = parser.getCurrentToken(); + JsonToken token = parser.currentToken(); if (token == JsonToken.FIELD_NAME) { - if ("tree_info".equals(parser.getCurrentName())) return true; + if ("tree_info".equals(parser.currentName())) return true; } } return false; diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def index 4823a53ec46..6b83ffd0751 100755 --- a/model-integration/src/main/resources/configdefinitions/llm-local-client.def +++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def @@ -8,7 +8,10 @@ model model parallelRequests int default=1 # Additional number of requests to put in queue for processing before starting to reject new requests -maxQueueSize int default=10 +maxQueueSize int default=100 + +# Max number of milliseoncds to wait in the queue before rejecting a request +maxQueueWait int default=10000 # Use GPU useGpu bool default=true @@ -24,6 +27,6 @@ threads int default=-1 # Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context contextSize int default=4096 -# Maximum number of tokens to process in one request - overriden by inference parameters +# Maximum number of tokens to process in one request - overridden by inference parameters maxTokens int default=512 diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java index a3b260f3fb5..4db1140d171 100644 --- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -2,6 +2,7 @@ package ai.vespa.llm.clients; import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModelException; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import ai.vespa.llm.completion.StringPrompt; @@ -96,7 +97,6 @@ public class LocalLLMTest { try { for (int i = 0; i < promptsToUse; i++) { final var seq = i; - completions.set(seq, new StringBuilder()); futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { completions.get(seq).append(completion.text()); @@ -122,8 +122,9 @@ public class LocalLLMTest { var prompts = testPrompts(); var promptsToUse = prompts.size(); var parallelRequests = 2; - var additionalQueue = 1; - // 7 should be rejected + var additionalQueue = 100; + var queueWaitTime = 10; + // 8 should be rejected due to queue wait time var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null)); var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null)); @@ -131,10 +132,12 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .parallelRequests(parallelRequests) .maxQueueSize(additionalQueue) + .maxQueueWait(queueWaitTime) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); var rejected = new AtomicInteger(0); + var timedOut = new AtomicInteger(0); try { for (int i = 0; i < promptsToUse; i++) { final var seq = i; @@ -143,7 +146,14 @@ public class LocalLLMTest { try { var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { completions.get(seq).append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); + }).exceptionally(exception -> { + if (exception instanceof LanguageModelException lme) { + if (lme.code() == 504) { + timedOut.incrementAndGet(); + } + } + return Completion.FinishReason.error; + }); futures.set(seq, future); } catch (RejectedExecutionException e) { rejected.incrementAndGet(); @@ -151,13 +161,14 @@ public class LocalLLMTest { } for (int i = 0; i < promptsToUse; i++) { if (futures.get(i) != null) { - assertNotEquals(futures.get(i).join(), Completion.FinishReason.error); + futures.get(i).join(); } } } finally { llm.deconstruct(); } - assertEquals(7, rejected.get()); + assertEquals(0, rejected.get()); + assertEquals(8, timedOut.get()); } private static InferenceParameters defaultOptions() { |