summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-22 11:35:09 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-22 11:35:09 +0200
commit5eb8a49dd81cb79654bc1d113bc81fec9d38745c (patch)
tree76598ba460cbba4d54a0a18daa75bc95947e96bd
parent67b83227941ad3327207b5dbbdc9ebbf72f684f6 (diff)
Specifically set number of threads to use in llama unit test
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java9
1 files changed, 4 insertions, 5 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
index a3b260f3fb5..95bcfb985bd 100644
--- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
@@ -6,7 +6,6 @@ import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.config.ModelReference;
-import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@@ -33,10 +32,10 @@ public class LocalLLMTest {
private static Prompt prompt = StringPrompt.from("A random prompt");
@Test
- @Disabled
public void testGeneration() {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(1)
+ .threads(1)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
@@ -50,12 +49,12 @@ public class LocalLLMTest {
}
@Test
- @Disabled
public void testAsyncGeneration() {
var sb = new StringBuilder();
var tokenCount = new AtomicInteger(0);
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(1)
+ .threads(1)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
@@ -78,7 +77,6 @@ public class LocalLLMTest {
}
@Test
- @Disabled
public void testParallelGeneration() {
var prompts = testPrompts();
var promptsToUse = prompts.size();
@@ -90,6 +88,7 @@ public class LocalLLMTest {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(parallelRequests)
+ .threads(1)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
@@ -117,7 +116,6 @@ public class LocalLLMTest {
}
@Test
- @Disabled
public void testRejection() {
var prompts = testPrompts();
var promptsToUse = prompts.size();
@@ -130,6 +128,7 @@ public class LocalLLMTest {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(parallelRequests)
+ .threads(1)
.maxQueueSize(additionalQueue)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());