aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java23
1 files changed, 17 insertions, 6 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
index a3b260f3fb5..4db1140d171 100644
--- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
@@ -2,6 +2,7 @@
package ai.vespa.llm.clients;
import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
@@ -96,7 +97,6 @@ public class LocalLLMTest {
try {
for (int i = 0; i < promptsToUse; i++) {
final var seq = i;
-
completions.set(seq, new StringBuilder());
futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
completions.get(seq).append(completion.text());
@@ -122,8 +122,9 @@ public class LocalLLMTest {
var prompts = testPrompts();
var promptsToUse = prompts.size();
var parallelRequests = 2;
- var additionalQueue = 1;
- // 7 should be rejected
+ var additionalQueue = 100;
+ var queueWaitTime = 10;
+ // 8 should be rejected due to queue wait time
var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
@@ -131,10 +132,12 @@ public class LocalLLMTest {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(parallelRequests)
.maxQueueSize(additionalQueue)
+ .maxQueueWait(queueWaitTime)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
var rejected = new AtomicInteger(0);
+ var timedOut = new AtomicInteger(0);
try {
for (int i = 0; i < promptsToUse; i++) {
final var seq = i;
@@ -143,7 +146,14 @@ public class LocalLLMTest {
try {
var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
completions.get(seq).append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
+ }).exceptionally(exception -> {
+ if (exception instanceof LanguageModelException lme) {
+ if (lme.code() == 504) {
+ timedOut.incrementAndGet();
+ }
+ }
+ return Completion.FinishReason.error;
+ });
futures.set(seq, future);
} catch (RejectedExecutionException e) {
rejected.incrementAndGet();
@@ -151,13 +161,14 @@ public class LocalLLMTest {
}
for (int i = 0; i < promptsToUse; i++) {
if (futures.get(i) != null) {
- assertNotEquals(futures.get(i).join(), Completion.FinishReason.error);
+ futures.get(i).join();
}
}
} finally {
llm.deconstruct();
}
- assertEquals(7, rejected.get());
+ assertEquals(0, rejected.get());
+ assertEquals(8, timedOut.get());
}
private static InferenceParameters defaultOptions() {