aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-16 13:07:31 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-16 13:07:31 +0200
commit780bc7cbe8fb67ae712fcf278f8900c8f32e14a6 (patch)
treed37ae6302fb72f5b6f5e33a8a45d968ffc505fdf /model-integration
parentfca990d5ed32c408df42bbe178b174711fa54a08 (diff)
Reapply "Lesters/add local llms 2"
This reverts commit ed62b750494822cc67a328390178754512baf032.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/abi-spec.json182
-rw-r--r--model-integration/pom.xml12
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java75
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java126
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xmodel-integration/src/main/resources/configdefinitions/llm-client.def8
-rwxr-xr-xmodel-integration/src/main/resources/configdefinitions/llm-local-client.def29
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java174
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java181
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java80
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java35
-rw-r--r--model-integration/src/test/models/llm/tinyllm.ggufbin0 -> 1185376 bytes
13 files changed, 957 insertions, 0 deletions
diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json
index d3c472778e6..e7130d9c777 100644
--- a/model-integration/abi-spec.json
+++ b/model-integration/abi-spec.json
@@ -1,4 +1,186 @@
{
+ "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected java.lang.String getEndpoint()",
+ "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
+ "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
+ "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public ai.vespa.llm.clients.LlmClientConfig build()"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes" : [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods" : [
+ "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig" : {
+ "superClass" : "com.yahoo.config.ConfigInstance",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
+ "public java.lang.String apiKeySecretName()",
+ "public java.lang.String endpoint()"
+ ],
+ "fields" : [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
+ "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder model(com.yahoo.config.ModelReference)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig build()"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes" : [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods" : [
+ "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmLocalClientConfig" : {
+ "superClass" : "com.yahoo.config.ConfigInstance",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)",
+ "public java.nio.file.Path model()",
+ "public int parallelRequests()",
+ "public int maxQueueSize()",
+ "public boolean useGpu()",
+ "public int gpuLayers()",
+ "public int threads()",
+ "public int contextSize()",
+ "public int maxTokens()"
+ ],
+ "fields" : [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
+ "ai.vespa.llm.clients.LocalLLM" : {
+ "superClass" : "com.yahoo.component.AbstractComponent",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
+ "public void deconstruct()",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.OpenAI" : {
+ "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
"ai.vespa.llm.generation.Generator" : {
"superClass" : "com.yahoo.component.AbstractComponent",
"interfaces" : [ ],
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 0bab30e1453..d92fa319251 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -40,6 +40,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>container-disc</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>searchcore</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
@@ -76,6 +82,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>container-llama</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>component</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
new file mode 100644
index 00000000000..761fdf0af93
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
@@ -0,0 +1,75 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.logging.Logger;
+
+
+/**
+ * Base class for language models that can be configured with config definitions.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class ConfigurableLanguageModel implements LanguageModel {
+
+ private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName());
+
+ private final String apiKey;
+ private final String endpoint;
+
+ public ConfigurableLanguageModel() {
+ this.apiKey = null;
+ this.endpoint = null;
+ }
+
+ @Inject
+ public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
+ this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
+ this.endpoint = config.endpoint();
+ }
+
+ private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
+ String apiKey = "";
+ if (property != null && ! property.isEmpty()) {
+ try {
+ apiKey = secretStore.getSecret(property);
+ } catch (UnsupportedOperationException e) {
+ // Secret store is not available - silently ignore this
+ } catch (Exception e) {
+ log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
+ "Will expect API key in request header");
+ }
+ }
+ return apiKey;
+ }
+
+ protected String getApiKey(InferenceParameters params) {
+ return params.getApiKey().orElse(null);
+ }
+
+ /**
+ * Set the API key as retrieved from secret store if it is not already set
+ */
+ protected void setApiKey(InferenceParameters params) {
+ if (params.getApiKey().isEmpty() && apiKey != null) {
+ params.setApiKey(apiKey);
+ }
+ }
+
+ protected String getEndpoint() {
+ return endpoint;
+ }
+
+ protected void setEndpoint(InferenceParameters params) {
+ if (endpoint != null && ! endpoint.isEmpty()) {
+ params.setEndpoint(endpoint);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
new file mode 100644
index 00000000000..fd1b8b700c8
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
@@ -0,0 +1,126 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import de.kherud.llama.LlamaModel;
+import de.kherud.llama.ModelParameters;
+import de.kherud.llama.args.LogFormat;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.logging.Logger;
+
+/**
+ * A language model running locally on the container node.
+ *
+ * @author lesters
+ */
+public class LocalLLM extends AbstractComponent implements LanguageModel {
+
+ private final static Logger logger = Logger.getLogger(LocalLLM.class.getName());
+ private final LlamaModel model;
+ private final ThreadPoolExecutor executor;
+ private final int contextSize;
+ private final int maxTokens;
+
+ @Inject
+ public LocalLLM(LlmLocalClientConfig config) {
+ executor = createExecutor(config);
+
+ // Maximum number of tokens to generate - need this since some models can just generate infinitely
+ maxTokens = config.maxTokens();
+
+ // Only used if GPU is not used
+ var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
+
+ var modelFile = config.model().toFile().getAbsolutePath();
+ var modelParams = new ModelParameters()
+ .setModelFilePath(modelFile)
+ .setContinuousBatching(true)
+ .setNParallel(config.parallelRequests())
+ .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads())
+ .setNCtx(config.contextSize())
+ .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0);
+
+ long startLoad = System.nanoTime();
+ model = new LlamaModel(modelParams);
+ long loadTime = System.nanoTime() - startLoad;
+ logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000)));
+
+ // Todo: handle prompt context size - such as give a warning when prompt exceeds context size
+ contextSize = config.contextSize();
+ }
+
+ private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) {
+ return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
+ 0L, TimeUnit.MILLISECONDS,
+ config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
+ new ThreadPoolExecutor.AbortPolicy());
+ }
+
+ @Override
+ public void deconstruct() {
+ logger.info("Closing LLM model...");
+ model.close();
+ executor.shutdownNow();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
+ StringBuilder result = new StringBuilder();
+ var future = completeAsync(prompt, options, completion -> {
+ result.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ var reason = future.join();
+
+ List<Completion> completions = new ArrayList<>();
+ completions.add(new Completion(result.toString(), reason));
+ return completions;
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
+ var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading());
+
+ // We always set this to some value to avoid infinite token generation
+ inferParams.setNPredict(maxTokens);
+
+ options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
+ options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
+ options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
+ options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
+ options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
+ // Todo: more options?
+
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ try {
+ executor.submit(() -> {
+ for (LlamaModel.Output output : model.generate(inferParams)) {
+ consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ });
+ } catch (RejectedExecutionException e) {
+ // If we have too many requests (active + any waiting in queue), we reject the completion
+ int activeCount = executor.getActiveCount();
+ int queueSize = executor.getQueue().size();
+ String error = String.format("Rejected completion due to too many requests, " +
+ "%d active, %d in queue", activeCount, queueSize);
+ throw new RejectedExecutionException(error);
+ }
+ return completionFuture;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
new file mode 100644
index 00000000000..82e19d47c92
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
@@ -0,0 +1,48 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.client.openai.OpenAiClient;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+/**
+ * A configurable OpenAI client.
+ *
+ * @author lesters
+ */
+@Beta
+public class OpenAI extends ConfigurableLanguageModel {
+
+ private final OpenAiClient client;
+
+ @Inject
+ public OpenAI(LlmClientConfig config, SecretStore secretStore) {
+ super(config, secretStore);
+ client = new OpenAiClient();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.complete(prompt, parameters);
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters parameters,
+ Consumer<Completion> consumer) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.completeAsync(prompt, parameters, consumer);
+ }
+}
+
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
new file mode 100644
index 00000000000..c360245901c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.clients;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/model-integration/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def
new file mode 100755
index 00000000000..0866459166a
--- /dev/null
+++ b/model-integration/src/main/resources/configdefinitions/llm-client.def
@@ -0,0 +1,8 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm.clients
+
+# The name of the secret containing the api key
+apiKeySecretName string default=""
+
+# Endpoint for LLM client - if not set reverts to default for client
+endpoint string default=""
diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def
new file mode 100755
index 00000000000..c06c24b33e5
--- /dev/null
+++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def
@@ -0,0 +1,29 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm.clients
+
+# The LLM model to use
+model model
+
+# Maximum number of requests to handle in parallel pr container node
+parallelRequests int default=10
+
+# Additional number of requests to put in queue for processing before starting to reject new requests
+maxQueueSize int default=10
+
+# Use GPU
+useGpu bool default=false
+
+# Maximum number of model layers to run on GPU
+gpuLayers int default=1000000
+
+# Number of threads to use for CPU processing - -1 means use all available cores
+# Not used for GPU processing
+threads int default=-1
+
+# Context size for the model
+# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context
+contextSize int default=512
+
+# Maximum number of tokens to process in one request - overriden by inference parameters
+maxTokens int default=512
+
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
new file mode 100644
index 00000000000..35d5cfd3855
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
@@ -0,0 +1,174 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ConfigurableLanguageModelTest {
+
+ @Test
+ public void testSyncGeneration() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
+ assertEquals(1, result.size());
+ assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var sb = new StringBuilder();
+ try {
+ var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
+ sb.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+ } finally {
+ executor.shutdownNow();
+ }
+
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ @Test
+ public void testInferenceParameters() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
+ var result = createLLM().complete(prompt, params);
+ assertEquals("Random text about ducks", result.get(0).text());
+ }
+
+ @Test
+ public void testNoApiKey() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key", null);
+ var secrets = createSecretStore(Map.of());
+ assertThrows(IllegalArgumentException.class, () -> {
+ createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
+ });
+ }
+
+ @Test
+ public void testApiKeyFromSecretStore() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key-in-secret-store", null);
+ var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
+ assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
+ }
+
+ private static String lookupParameter(String parameter, Map<String, String> params) {
+ return params.get(parameter);
+ }
+
+ private static InferenceParameters inferenceParams() {
+ return new InferenceParameters(s -> lookupParameter(s, Map.of()));
+ }
+
+ private static InferenceParameters inferenceParams(Map<String, String> params) {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
+ }
+
+ private static InferenceParameters inferenceParamsWithDefaultKey() {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of()));
+ }
+
+ private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) {
+ var config = new LlmClientConfig.Builder();
+ if (apiKeySecretName != null) {
+ config.apiKeySecretName(apiKeySecretName);
+ }
+ if (endpoint != null) {
+ config.endpoint(endpoint);
+ }
+ return config.build();
+ }
+
+ public static SecretStore createSecretStore(Map<String, String> secrets) {
+ Provider<SecretStore> secretStore = new Provider<>() {
+ public SecretStore get() {
+ return new SecretStore() {
+ public String getSecret(String key) {
+ return secrets.get(key);
+ }
+ public String getSecret(String key, int version) {
+ return secrets.get(key);
+ }
+ };
+ }
+ public void deconstruct() {
+ }
+ };
+ return secretStore.get();
+ }
+
+ public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
+ }
+
+ private static MockLLMClient createLLM() {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, null);
+ }
+
+ private static MockLLMClient createLLM(ExecutorService executor) {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
+ var generator = createGenerator();
+ var secretStore = new SecretStoreProvider(); // throws exception on use
+ return createLLM(config, generator, secretStore.get(), executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore) {
+ return createLLM(config, generator, secretStore, null);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore,
+ ExecutorService executor) {
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
new file mode 100644
index 00000000000..e85e397b7ff
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
@@ -0,0 +1,181 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.config.ModelReference;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * Tests for LocalLLM.
+ *
+ * @author lesters
+ */
+public class LocalLLMTest {
+
+ private static String model = "src/test/models/llm/tinyllm.gguf";
+ private static Prompt prompt = StringPrompt.from("A random prompt");
+
+ @Test
+ public void testGeneration() {
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(1)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ try {
+ var result = llm.complete(prompt, defaultOptions());
+ assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
+ assertTrue(result.get(0).text().length() > 10);
+ } finally {
+ llm.deconstruct();
+ }
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var sb = new StringBuilder();
+ var tokenCount = new AtomicInteger(0);
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(1)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ try {
+ var future = llm.completeAsync(prompt, defaultOptions(), completion -> {
+ sb.append(completion.text());
+ tokenCount.incrementAndGet();
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ assertFalse(future.isDone());
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+
+ } finally {
+ llm.deconstruct();
+ }
+ assertTrue(tokenCount.get() > 0);
+// System.out.println(sb);
+ }
+
+ @Test
+ public void testParallelGeneration() {
+ var prompts = testPrompts();
+ var promptsToUse = prompts.size();
+ var parallelRequests = 10;
+
+ var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
+ var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
+ var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0));
+
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(parallelRequests)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ try {
+ for (int i = 0; i < promptsToUse; i++) {
+ final var seq = i;
+
+ completions.set(seq, new StringBuilder());
+ futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
+ completions.get(seq).append(completion.text());
+ tokenCounts.set(seq, tokenCounts.get(seq) + 1);
+ }).exceptionally(exception -> Completion.FinishReason.error));
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ var reason = futures.get(i).join();
+ assertNotEquals(reason, Completion.FinishReason.error);
+ }
+ } finally {
+ llm.deconstruct();
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ assertFalse(completions.get(i).isEmpty());
+ assertTrue(tokenCounts.get(i) > 0);
+ }
+ }
+
+ @Test
+ public void testRejection() {
+ var prompts = testPrompts();
+ var promptsToUse = prompts.size();
+ var parallelRequests = 2;
+ var additionalQueue = 1;
+ // 7 should be rejected
+
+ var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
+ var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
+
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(parallelRequests)
+ .maxQueueSize(additionalQueue)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ var rejected = new AtomicInteger(0);
+ try {
+ for (int i = 0; i < promptsToUse; i++) {
+ final var seq = i;
+
+ completions.set(seq, new StringBuilder());
+ try {
+ var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
+ completions.get(seq).append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ futures.set(seq, future);
+ } catch (RejectedExecutionException e) {
+ rejected.incrementAndGet();
+ }
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ if (futures.get(i) != null) {
+ assertNotEquals(futures.get(i).join(), Completion.FinishReason.error);
+ }
+ }
+ } finally {
+ llm.deconstruct();
+ }
+ assertEquals(7, rejected.get());
+ }
+
+ private static InferenceParameters defaultOptions() {
+ final Map<String, String> options = Map.of(
+ "temperature", "0.1",
+ "npredict", "100"
+ );
+ return new InferenceParameters(options::get);
+ }
+
+ private List<String> testPrompts() {
+ List<String> prompts = new ArrayList<>();
+ prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries.");
+ prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms.");
+ prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe.");
+ prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure.");
+ prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective.");
+ prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'");
+ prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`.");
+ prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease.");
+ prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level.");
+ prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century.");
+ return prompts;
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
new file mode 100644
index 00000000000..4d0073f1cbe
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
@@ -0,0 +1,80 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+
+public class MockLLMClient extends ConfigurableLanguageModel {
+
+ public final static String ACCEPTED_API_KEY = "sesame";
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ private Prompt lastPrompt;
+
+ public MockLLMClient(LlmClientConfig config,
+ SecretStore secretStore,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ ExecutorService executor) {
+ super(config, secretStore);
+ this.generator = generator;
+ this.executor = executor;
+ }
+
+ private void checkApiKey(InferenceParameters options) {
+ var apiKey = getApiKey(options);
+ if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) {
+ throw new IllegalArgumentException("Invalid API key");
+ }
+ }
+
+ private void setPrompt(Prompt prompt) {
+ this.lastPrompt = prompt;
+ }
+
+ public Prompt getPrompt() {
+ return this.lastPrompt;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ setApiKey(params);
+ checkApiKey(params);
+ setPrompt(prompt);
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ setPrompt(prompt);
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i=0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+
+ return completionFuture;
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java
new file mode 100644
index 00000000000..57339f6ad49
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java
@@ -0,0 +1,35 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+public class OpenAITest {
+
+ private static final String apiKey = "<your-api-key>";
+
+ @Test
+ @Disabled
+ public void testOpenAIGeneration() {
+ var config = new LlmClientConfig.Builder().build();
+ var openai = new OpenAI(config, new SecretStoreProvider().get());
+ var options = Map.of(
+ "maxTokens", "10"
+ );
+
+ var prompt = StringPrompt.from("why are ducks better than cats?");
+ var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
+ System.out.print(completion.text());
+ }).exceptionally(exception -> {
+ System.out.println("Error: " + exception);
+ return null;
+ });
+ future.join();
+ }
+
+}
diff --git a/model-integration/src/test/models/llm/tinyllm.gguf b/model-integration/src/test/models/llm/tinyllm.gguf
new file mode 100644
index 00000000000..34367b6b57b
--- /dev/null
+++ b/model-integration/src/test/models/llm/tinyllm.gguf
Binary files differ