summaryrefslogtreecommitdiffstats
path: root/model-integration/src
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src')
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java3
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java9
2 files changed, 5 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
index fd1b8b700c8..aa7c071b93a 100644
--- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
@@ -9,7 +9,6 @@ import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import de.kherud.llama.LlamaModel;
import de.kherud.llama.ModelParameters;
-import de.kherud.llama.args.LogFormat;
import java.util.ArrayList;
import java.util.List;
@@ -43,7 +42,7 @@ public class LocalLLM extends AbstractComponent implements LanguageModel {
maxTokens = config.maxTokens();
// Only used if GPU is not used
- var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
+ var defaultThreadCount = Math.max(Runtime.getRuntime().availableProcessors() - 2, 1);
var modelFile = config.model().toFile().getAbsolutePath();
var modelParams = new ModelParameters()
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
index a3b260f3fb5..95bcfb985bd 100644
--- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
@@ -6,7 +6,6 @@ import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import com.yahoo.config.ModelReference;
-import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
@@ -33,10 +32,10 @@ public class LocalLLMTest {
private static Prompt prompt = StringPrompt.from("A random prompt");
@Test
- @Disabled
public void testGeneration() {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(1)
+ .threads(1)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
@@ -50,12 +49,12 @@ public class LocalLLMTest {
}
@Test
- @Disabled
public void testAsyncGeneration() {
var sb = new StringBuilder();
var tokenCount = new AtomicInteger(0);
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(1)
+ .threads(1)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
@@ -78,7 +77,6 @@ public class LocalLLMTest {
}
@Test
- @Disabled
public void testParallelGeneration() {
var prompts = testPrompts();
var promptsToUse = prompts.size();
@@ -90,6 +88,7 @@ public class LocalLLMTest {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(parallelRequests)
+ .threads(1)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());
@@ -117,7 +116,6 @@ public class LocalLLMTest {
}
@Test
- @Disabled
public void testRejection() {
var prompts = testPrompts();
var promptsToUse = prompts.size();
@@ -130,6 +128,7 @@ public class LocalLLMTest {
var config = new LlmLocalClientConfig.Builder()
.parallelRequests(parallelRequests)
+ .threads(1)
.maxQueueSize(additionalQueue)
.model(ModelReference.valueOf(model));
var llm = new LocalLLM(config.build());