diff options
Diffstat (limited to 'model-integration/src')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java | 3 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java | 9 |
2 files changed, 5 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java index fd1b8b700c8..aa7c071b93a 100644 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -9,7 +9,6 @@ import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import de.kherud.llama.LlamaModel; import de.kherud.llama.ModelParameters; -import de.kherud.llama.args.LogFormat; import java.util.ArrayList; import java.util.List; @@ -43,7 +42,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel { maxTokens = config.maxTokens(); // Only used if GPU is not used - var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2; + var defaultThreadCount = Math.max(Runtime.getRuntime().availableProcessors() - 2, 1); var modelFile = config.model().toFile().getAbsolutePath(); var modelParams = new ModelParameters() diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java index a3b260f3fb5..95bcfb985bd 100644 --- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -6,7 +6,6 @@ import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import ai.vespa.llm.completion.StringPrompt; import com.yahoo.config.ModelReference; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import java.util.ArrayList; @@ -33,10 +32,10 @@ public class LocalLLMTest { private static Prompt prompt = StringPrompt.from("A random prompt"); @Test - @Disabled public void testGeneration() { var config = new LlmLocalClientConfig.Builder() .parallelRequests(1) + .threads(1) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); @@ -50,12 +49,12 @@ public class LocalLLMTest { } @Test - @Disabled public void testAsyncGeneration() { var sb = new StringBuilder(); var tokenCount = new AtomicInteger(0); var config = new LlmLocalClientConfig.Builder() .parallelRequests(1) + .threads(1) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); @@ -78,7 +77,6 @@ public class LocalLLMTest { } @Test - @Disabled public void testParallelGeneration() { var prompts = testPrompts(); var promptsToUse = prompts.size(); @@ -90,6 +88,7 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .parallelRequests(parallelRequests) + .threads(1) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); @@ -117,7 +116,6 @@ public class LocalLLMTest { } @Test - @Disabled public void testRejection() { var prompts = testPrompts(); var promptsToUse = prompts.size(); @@ -130,6 +128,7 @@ public class LocalLLMTest { var config = new LlmLocalClientConfig.Builder() .parallelRequests(parallelRequests) + .threads(1) .maxQueueSize(additionalQueue) .model(ModelReference.valueOf(model)); var llm = new LocalLLM(config.build()); |