diff options
author | Lester Solbakken <lester.solbakken@gmail.com> | 2024-04-22 11:35:09 +0200 |
---|---|---|
committer | Lester Solbakken <lester.solbakken@gmail.com> | 2024-04-22 11:35:09 +0200 |
commit | 5eb8a49dd81cb79654bc1d113bc81fec9d38745c (patch) | |
tree | 76598ba460cbba4d54a0a18daa75bc95947e96bd | |
parent | 67b83227941ad3327207b5dbbdc9ebbf72f684f6 (diff) |
Specifically set number of threads to use in llama unit test
-rw-r--r-- | model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java | 9 |
1 files changed, 4 insertions, 5 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java index a3b260f3fb5..95bcfb985bd 100644 --- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -6,7 +6,6 @@ import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import ai.vespa.llm.completion.StringPrompt; import com.yahoo.config.ModelReference; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -33,10 +32,10 @@ public class LocalLLMTest { private static Prompt prompt = StringPrompt.from("A random prompt"); @Test - @Disabled public void testGeneration() { var config = new LlmLocalClientConfig.Builder() .parallelRequests(1) + .threads(1) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); @@ -50,12 +49,12 @@ public class LocalLLMTest { } @Test - @Disabled public void testAsyncGeneration() { var sb = new StringBuilder(); var tokenCount = new AtomicInteger(0); var config = new LlmLocalClientConfig.Builder() .parallelRequests(1) + .threads(1) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); @@ -78,7 +77,6 @@ public class LocalLLMTest { } @Test - @Disabled public void testParallelGeneration() { var prompts = testPrompts(); var promptsToUse = prompts.size(); @@ -90,6 +88,7 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .parallelRequests(parallelRequests) + .threads(1) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); @@ -117,7 +116,6 @@ public class LocalLLMTest { } @Test - @Disabled public void testRejection() { var prompts = testPrompts(); var promptsToUse = prompts.size(); @@ -130,6 +128,7 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .parallelRequests(parallelRequests) + .threads(1) .maxQueueSize(additionalQueue) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); |