diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java | 23 |
1 files changed, 17 insertions, 6 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java index a3b260f3fb5..4db1140d171 100644 --- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -2,6 +2,7 @@ package ai.vespa.llm.clients; import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModelException; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import ai.vespa.llm.completion.StringPrompt; @@ -96,7 +97,6 @@ public class LocalLLMTest { try { for (int i = 0; i < promptsToUse; i++) { final var seq = i; - completions.set(seq, new StringBuilder()); futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { completions.get(seq).append(completion.text()); @@ -122,8 +122,9 @@ public class LocalLLMTest { var prompts = testPrompts(); var promptsToUse = prompts.size(); var parallelRequests = 2; - var additionalQueue = 1; - // 7 should be rejected + var additionalQueue = 100; + var queueWaitTime = 10; + // 8 should be rejected due to queue wait time var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null)); var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null)); @@ -131,10 +132,12 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .parallelRequests(parallelRequests) .maxQueueSize(additionalQueue) + .maxQueueWait(queueWaitTime) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); var rejected = new AtomicInteger(0); + var timedOut = new AtomicInteger(0); try { for (int i = 0; i < promptsToUse; i++) { final var seq = i; @@ -143,7 +146,14 @@ public class LocalLLMTest { try { var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { completions.get(seq).append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); + }).exceptionally(exception -> { + if (exception instanceof LanguageModelException lme) { + if (lme.code() == 504) { + timedOut.incrementAndGet(); + } + } + return Completion.FinishReason.error; + }); futures.set(seq, future); } catch (RejectedExecutionException e) { rejected.incrementAndGet(); @@ -151,13 +161,14 @@ public class LocalLLMTest { } for (int i = 0; i < promptsToUse; i++) { if (futures.get(i) != null) { - assertNotEquals(futures.get(i).join(), Completion.FinishReason.error); + futures.get(i).join(); } } } finally { llm.deconstruct(); } - assertEquals(7, rejected.get()); + assertEquals(0, rejected.get()); + assertEquals(8, timedOut.get()); } private static InferenceParameters defaultOptions() { |