aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-12 11:34:47 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-12 11:34:47 +0200
commit446333e197bfb274ae48c173f25c4ad7e8d76a0f (patch)
tree7c487fcb53ce956e4454604825bc17b89db97a60 /container-search
parenta11f45f8f3e39f7bd3595abec02eee385514b6a3 (diff)
Move LLM client stuff from container-search to model-integration
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json199
-rw-r--r--container-search/pom.xml6
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java75
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java125
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java48
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def8
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-local-client.def29
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java175
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java181
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java35
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java49
-rw-r--r--container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java (renamed from container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java)4
-rw-r--r--container-search/src/test/resources/llms/tinyllm.ggufbin1185376 -> 0 bytes
14 files changed, 60 insertions, 881 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 8170d6bd9a8..07f0449e61a 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -7842,6 +7842,21 @@
"public static final int emptyDocsumsCode"
]
},
+ "com.yahoo.search.result.EventStream$ErrorEvent" : {
+ "superClass" : "com.yahoo.search.result.EventStream$Event",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)",
+ "public java.lang.String source()",
+ "public int code()",
+ "public java.lang.String message()",
+ "public com.yahoo.search.result.Hit asHit()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.search.result.EventStream$Event" : {
"superClass" : "com.yahoo.component.provider.ListenableFreezableClass",
"interfaces" : [
@@ -9149,190 +9164,6 @@
],
"fields" : [ ]
},
- "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public",
- "abstract"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
- "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
- "protected java.lang.String getEndpoint()",
- "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Builder"
- ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
- "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
- "public final java.lang.String getDefMd5()",
- "public final java.lang.String getDefName()",
- "public final java.lang.String getDefNamespace()",
- "public final boolean getApplyOnRestart()",
- "public final void setApplyOnRestart(boolean)",
- "public ai.vespa.llm.clients.LlmClientConfig build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Producer"
- ],
- "attributes" : [
- "public",
- "interface",
- "abstract"
- ],
- "methods" : [
- "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig" : {
- "superClass" : "com.yahoo.config.ConfigInstance",
- "interfaces" : [ ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public static java.lang.String getDefMd5()",
- "public static java.lang.String getDefName()",
- "public static java.lang.String getDefNamespace()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
- "public java.lang.String apiKeySecretName()",
- "public java.lang.String endpoint()"
- ],
- "fields" : [
- "public static final java.lang.String CONFIG_DEF_MD5",
- "public static final java.lang.String CONFIG_DEF_NAME",
- "public static final java.lang.String CONFIG_DEF_NAMESPACE",
- "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
- ]
- },
- "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Builder"
- ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder modelUrl(com.yahoo.config.UrlReference)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder localLlmFile(java.lang.String)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)",
- "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
- "public final java.lang.String getDefMd5()",
- "public final java.lang.String getDefName()",
- "public final java.lang.String getDefNamespace()",
- "public final boolean getApplyOnRestart()",
- "public final void setApplyOnRestart(boolean)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Producer"
- ],
- "attributes" : [
- "public",
- "interface",
- "abstract"
- ],
- "methods" : [
- "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmLocalClientConfig" : {
- "superClass" : "com.yahoo.config.ConfigInstance",
- "interfaces" : [ ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public static java.lang.String getDefMd5()",
- "public static java.lang.String getDefName()",
- "public static java.lang.String getDefNamespace()",
- "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)",
- "public java.io.File modelUrl()",
- "public java.lang.String localLlmFile()",
- "public int parallelRequests()",
- "public int maxQueueSize()",
- "public boolean useGpu()",
- "public int gpuLayers()",
- "public int threads()",
- "public int contextSize()",
- "public int maxTokens()"
- ],
- "fields" : [
- "public static final java.lang.String CONFIG_DEF_MD5",
- "public static final java.lang.String CONFIG_DEF_NAME",
- "public static final java.lang.String CONFIG_DEF_NAMESPACE",
- "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
- ]
- },
- "ai.vespa.llm.clients.LocalLLM" : {
- "superClass" : "com.yahoo.component.AbstractComponent",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
- "public void deconstruct()",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.OpenAI" : {
- "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
- "interfaces" : [ ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
- ],
- "fields" : [ ]
- },
"ai.vespa.search.llm.LLMSearcher" : {
"superClass" : "com.yahoo.search.Searcher",
"interfaces" : [ ],
diff --git a/container-search/pom.xml b/container-search/pom.xml
index 38a4cb0ac2d..5e7c60d49c3 100644
--- a/container-search/pom.xml
+++ b/container-search/pom.xml
@@ -87,12 +87,6 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>container-llama</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
<dependency>
<groupId>xerces</groupId>
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
deleted file mode 100644
index 761fdf0af93..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.logging.Logger;
-
-
-/**
- * Base class for language models that can be configured with config definitions.
- *
- * @author lesters
- */
-@Beta
-public abstract class ConfigurableLanguageModel implements LanguageModel {
-
- private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName());
-
- private final String apiKey;
- private final String endpoint;
-
- public ConfigurableLanguageModel() {
- this.apiKey = null;
- this.endpoint = null;
- }
-
- @Inject
- public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
- this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
- this.endpoint = config.endpoint();
- }
-
- private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
- String apiKey = "";
- if (property != null && ! property.isEmpty()) {
- try {
- apiKey = secretStore.getSecret(property);
- } catch (UnsupportedOperationException e) {
- // Secret store is not available - silently ignore this
- } catch (Exception e) {
- log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
- "Will expect API key in request header");
- }
- }
- return apiKey;
- }
-
- protected String getApiKey(InferenceParameters params) {
- return params.getApiKey().orElse(null);
- }
-
- /**
- * Set the API key as retrieved from secret store if it is not already set
- */
- protected void setApiKey(InferenceParameters params) {
- if (params.getApiKey().isEmpty() && apiKey != null) {
- params.setApiKey(apiKey);
- }
- }
-
- protected String getEndpoint() {
- return endpoint;
- }
-
- protected void setEndpoint(InferenceParameters params) {
- if (endpoint != null && ! endpoint.isEmpty()) {
- params.setEndpoint(endpoint);
- }
- }
-
-}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
deleted file mode 100644
index 1e204d29a19..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/LocalLLM.java
+++ /dev/null
@@ -1,125 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.component.AbstractComponent;
-import com.yahoo.component.annotation.Inject;
-import de.kherud.llama.LlamaModel;
-import de.kherud.llama.ModelParameters;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.SynchronousQueue;
-import java.util.concurrent.ThreadPoolExecutor;
-import java.util.concurrent.TimeUnit;
-import java.util.function.Consumer;
-import java.util.logging.Logger;
-
-/**
- * A language model running locally on the container node.
- *
- * @author lesters
- */
-public class LocalLLM extends AbstractComponent implements LanguageModel {
-
- private final static Logger logger = Logger.getLogger(LocalLLM.class.getName());
- private final LlamaModel model;
- private final ThreadPoolExecutor executor;
- private final int contextSize;
- private final int maxTokens;
-
- @Inject
- public LocalLLM(LlmLocalClientConfig config) {
- executor = createExecutor(config);
-
- // Maximum number of tokens to generate - need this since some models can just generate infinitely
- maxTokens = config.maxTokens();
-
- // Only used if GPU is not used
- var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
-
- var modelFile = config.model().toFile().getAbsolutePath();
- var modelParams = new ModelParameters()
- .setModelFilePath(modelFile)
- .setContinuousBatching(true)
- .setNParallel(config.parallelRequests())
- .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads())
- .setNCtx(config.contextSize())
- .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0);
-
- long startLoad = System.nanoTime();
- model = new LlamaModel(modelParams);
- long loadTime = System.nanoTime() - startLoad;
- logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000)));
-
- // Todo: handle prompt context size - such as give a warning when prompt exceeds context size
- contextSize = config.contextSize();
- }
-
- private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) {
- return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
- 0L, TimeUnit.MILLISECONDS,
- config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
- new ThreadPoolExecutor.AbortPolicy());
- }
-
- @Override
- public void deconstruct() {
- logger.info("Closing LLM model...");
- model.close();
- executor.shutdownNow();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters options) {
- StringBuilder result = new StringBuilder();
- var future = completeAsync(prompt, options, completion -> {
- result.append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
- var reason = future.join();
-
- List<Completion> completions = new ArrayList<>();
- completions.add(new Completion(result.toString(), reason));
- return completions;
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
- var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading());
-
- // We always set this to some value to avoid infinite token generation
- inferParams.setNPredict(maxTokens);
-
- options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
- options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
- options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
- options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
- options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
- // Todo: more options?
-
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- try {
- executor.submit(() -> {
- for (LlamaModel.Output output : model.generate(inferParams)) {
- consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
- }
- completionFuture.complete(Completion.FinishReason.stop);
- });
- } catch (RejectedExecutionException e) {
- // If we have too many requests (active + any waiting in queue), we reject the completion
- int activeCount = executor.getActiveCount();
- int queueSize = executor.getQueue().size();
- String error = String.format("Rejected completion due to too many requests, " +
- "%d active, %d in queue", activeCount, queueSize);
- throw new RejectedExecutionException(error);
- }
- return completionFuture;
- }
-
-}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
deleted file mode 100644
index 82e19d47c92..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.client.openai.OpenAiClient;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.function.Consumer;
-
-/**
- * A configurable OpenAI client.
- *
- * @author lesters
- */
-@Beta
-public class OpenAI extends ConfigurableLanguageModel {
-
- private final OpenAiClient client;
-
- @Inject
- public OpenAI(LlmClientConfig config, SecretStore secretStore) {
- super(config, secretStore);
- client = new OpenAiClient();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.complete(prompt, parameters);
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters parameters,
- Consumer<Completion> consumer) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.completeAsync(prompt, parameters, consumer);
- }
-}
-
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
deleted file mode 100644
index c360245901c..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
+++ /dev/null
@@ -1,7 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-@ExportPackage
-@PublicApi
-package ai.vespa.llm.clients;
-
-import com.yahoo.api.annotations.PublicApi;
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
deleted file mode 100755
index 0866459166a..00000000000
--- a/container-search/src/main/resources/configdefinitions/llm-client.def
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package=ai.vespa.llm.clients
-
-# The name of the secret containing the api key
-apiKeySecretName string default=""
-
-# Endpoint for LLM client - if not set reverts to default for client
-endpoint string default=""
diff --git a/container-search/src/main/resources/configdefinitions/llm-local-client.def b/container-search/src/main/resources/configdefinitions/llm-local-client.def
deleted file mode 100755
index c06c24b33e5..00000000000
--- a/container-search/src/main/resources/configdefinitions/llm-local-client.def
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package=ai.vespa.llm.clients
-
-# The LLM model to use
-model model
-
-# Maximum number of requests to handle in parallel pr container node
-parallelRequests int default=10
-
-# Additional number of requests to put in queue for processing before starting to reject new requests
-maxQueueSize int default=10
-
-# Use GPU
-useGpu bool default=false
-
-# Maximum number of model layers to run on GPU
-gpuLayers int default=1000000
-
-# Number of threads to use for CPU processing - -1 means use all available cores
-# Not used for GPU processing
-threads int default=-1
-
-# Context size for the model
-# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context
-contextSize int default=512
-
-# Maximum number of tokens to process in one request - overriden by inference parameters
-maxTokens int default=512
-
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
deleted file mode 100644
index a9f4c3dfac5..00000000000
--- a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
+++ /dev/null
@@ -1,175 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import ai.vespa.llm.completion.StringPrompt;
-import com.yahoo.container.di.componentgraph.Provider;
-import com.yahoo.container.jdisc.SecretStoreProvider;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-import org.junit.jupiter.api.Test;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Map;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.function.BiFunction;
-import java.util.stream.Collectors;
-
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class ConfigurableLanguageModelTest {
-
- @Test
- public void testSyncGeneration() {
- var prompt = StringPrompt.from("Why are ducks better than cats?");
- var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
- assertEquals(1, result.size());
- assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
- }
-
- @Test
- public void testAsyncGeneration() {
- var executor = Executors.newFixedThreadPool(1);
- var prompt = StringPrompt.from("Why are ducks better than cats?");
- var sb = new StringBuilder();
- try {
- var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
- sb.append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
-
- var reason = future.join();
- assertTrue(future.isDone());
- assertNotEquals(reason, Completion.FinishReason.error);
- } finally {
- executor.shutdownNow();
- }
-
- assertEquals("Ducks have adorable waddling walks.", sb.toString());
- }
-
- @Test
- public void testInferenceParameters() {
- var prompt = StringPrompt.from("Why are ducks better than cats?");
- var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
- var result = createLLM().complete(prompt, params);
- assertEquals("Random text about ducks", result.get(0).text());
- }
-
- @Test
- public void testNoApiKey() {
- var prompt = StringPrompt.from("");
- var config = modelParams("api-key", null);
- var secrets = createSecretStore(Map.of());
- assertThrows(IllegalArgumentException.class, () -> {
- createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
- });
- }
-
- @Test
- public void testApiKeyFromSecretStore() {
- var prompt = StringPrompt.from("");
- var config = modelParams("api-key-in-secret-store", null);
- var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
- assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
- }
-
- private static String lookupParameter(String parameter, Map<String, String> params) {
- return params.get(parameter);
- }
-
- private static InferenceParameters inferenceParams() {
- return new InferenceParameters(s -> lookupParameter(s, Collections.emptyMap()));
- }
-
- private static InferenceParameters inferenceParams(Map<String, String> params) {
- return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
- }
-
- private static InferenceParameters inferenceParamsWithDefaultKey() {
- return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Collections.emptyMap()));
- }
-
- private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) {
- var config = new LlmClientConfig.Builder();
- if (apiKeySecretName != null) {
- config.apiKeySecretName(apiKeySecretName);
- }
- if (endpoint != null) {
- config.endpoint(endpoint);
- }
- return config.build();
- }
-
- public static SecretStore createSecretStore(Map<String, String> secrets) {
- Provider<SecretStore> secretStore = new Provider<>() {
- public SecretStore get() {
- return new SecretStore() {
- public String getSecret(String key) {
- return secrets.get(key);
- }
- public String getSecret(String key, int version) {
- return secrets.get(key);
- }
- };
- }
- public void deconstruct() {
- }
- };
- return secretStore.get();
- }
-
- public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return (prompt, options) -> {
- String answer = "I have no opinion on the matter";
- if (prompt.asString().contains("ducks")) {
- answer = "Ducks have adorable waddling walks.";
- var temperature = options.getDouble("temperature");
- if (temperature.isPresent() && temperature.get() > 0.5) {
- answer = "Random text about ducks vs cats that makes no sense whatsoever.";
- }
- }
- var maxTokens = options.getInt("maxTokens");
- if (maxTokens.isPresent()) {
- return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
- }
- return answer;
- };
- }
-
- private static MockLLMClient createLLM() {
- LlmClientConfig config = new LlmClientConfig.Builder().build();
- return createLLM(config, null);
- }
-
- private static MockLLMClient createLLM(ExecutorService executor) {
- LlmClientConfig config = new LlmClientConfig.Builder().build();
- return createLLM(config, executor);
- }
-
- private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
- var generator = createGenerator();
- var secretStore = new SecretStoreProvider(); // throws exception on use
- return createLLM(config, generator, secretStore.get(), executor);
- }
-
- private static MockLLMClient createLLM(LlmClientConfig config,
- BiFunction<Prompt, InferenceParameters, String> generator,
- SecretStore secretStore) {
- return createLLM(config, generator, secretStore, null);
- }
-
- private static MockLLMClient createLLM(LlmClientConfig config,
- BiFunction<Prompt, InferenceParameters, String> generator,
- SecretStore secretStore,
- ExecutorService executor) {
- return new MockLLMClient(config, secretStore, generator, executor);
- }
-
-}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
deleted file mode 100644
index 72b64cc0a0c..00000000000
--- a/container-search/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
+++ /dev/null
@@ -1,181 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import ai.vespa.llm.completion.StringPrompt;
-import com.yahoo.config.ModelReference;
-import org.junit.jupiter.api.Test;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.atomic.AtomicInteger;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertNotEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-/**
- * Tests for LocalLLM.
- *
- * @author lesters
- */
-public class LocalLLMTest {
-
- private static String model = "src/test/resources/llms/tinyllm.gguf";
- private static Prompt prompt = StringPrompt.from("A random prompt");
-
- @Test
- public void testGeneration() {
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(1)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- try {
- var result = llm.complete(prompt, defaultOptions());
- assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
- assertTrue(result.get(0).text().length() > 10);
- } finally {
- llm.deconstruct();
- }
- }
-
- @Test
- public void testAsyncGeneration() {
- var sb = new StringBuilder();
- var tokenCount = new AtomicInteger(0);
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(1)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- try {
- var future = llm.completeAsync(prompt, defaultOptions(), completion -> {
- sb.append(completion.text());
- tokenCount.incrementAndGet();
- }).exceptionally(exception -> Completion.FinishReason.error);
-
- assertFalse(future.isDone());
- var reason = future.join();
- assertTrue(future.isDone());
- assertNotEquals(reason, Completion.FinishReason.error);
-
- } finally {
- llm.deconstruct();
- }
- assertTrue(tokenCount.get() > 0);
-// System.out.println(sb);
- }
-
- @Test
- public void testParallelGeneration() {
- var prompts = testPrompts();
- var promptsToUse = prompts.size();
- var parallelRequests = 10;
-
- var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
- var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
- var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0));
-
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(parallelRequests)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- try {
- for (int i = 0; i < promptsToUse; i++) {
- final var seq = i;
-
- completions.set(seq, new StringBuilder());
- futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
- completions.get(seq).append(completion.text());
- tokenCounts.set(seq, tokenCounts.get(seq) + 1);
- }).exceptionally(exception -> Completion.FinishReason.error));
- }
- for (int i = 0; i < promptsToUse; i++) {
- var reason = futures.get(i).join();
- assertNotEquals(reason, Completion.FinishReason.error);
- }
- } finally {
- llm.deconstruct();
- }
- for (int i = 0; i < promptsToUse; i++) {
- assertFalse(completions.get(i).isEmpty());
- assertTrue(tokenCounts.get(i) > 0);
- }
- }
-
- @Test
- public void testRejection() {
- var prompts = testPrompts();
- var promptsToUse = prompts.size();
- var parallelRequests = 2;
- var additionalQueue = 1;
- // 7 should be rejected
-
- var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
- var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
-
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(parallelRequests)
- .maxQueueSize(additionalQueue)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- var rejected = new AtomicInteger(0);
- try {
- for (int i = 0; i < promptsToUse; i++) {
- final var seq = i;
-
- completions.set(seq, new StringBuilder());
- try {
- var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
- completions.get(seq).append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
- futures.set(seq, future);
- } catch (RejectedExecutionException e) {
- rejected.incrementAndGet();
- }
- }
- for (int i = 0; i < promptsToUse; i++) {
- if (futures.get(i) != null) {
- assertNotEquals(futures.get(i).join(), Completion.FinishReason.error);
- }
- }
- } finally {
- llm.deconstruct();
- }
- assertEquals(7, rejected.get());
- }
-
- private static InferenceParameters defaultOptions() {
- final Map<String, String> options = Map.of(
- "temperature", "0.1",
- "npredict", "100"
- );
- return new InferenceParameters(options::get);
- }
-
- private List<String> testPrompts() {
- List<String> prompts = new ArrayList<>();
- prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries.");
- prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms.");
- prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe.");
- prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure.");
- prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective.");
- prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'");
- prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`.");
- prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease.");
- prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level.");
- prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century.");
- return prompts;
- }
-
-}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
deleted file mode 100644
index 57339f6ad49..00000000000
--- a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.StringPrompt;
-import com.yahoo.container.jdisc.SecretStoreProvider;
-import org.junit.jupiter.api.Disabled;
-import org.junit.jupiter.api.Test;
-
-import java.util.Map;
-
-public class OpenAITest {
-
- private static final String apiKey = "<your-api-key>";
-
- @Test
- @Disabled
- public void testOpenAIGeneration() {
- var config = new LlmClientConfig.Builder().build();
- var openai = new OpenAI(config, new SecretStoreProvider().get());
- var options = Map.of(
- "maxTokens", "10"
- );
-
- var prompt = StringPrompt.from("why are ducks better than cats?");
- var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
- System.out.print(completion.text());
- }).exceptionally(exception -> {
- System.out.println("Error: " + exception);
- return null;
- });
- future.join();
- }
-
-}
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
index 1efcf1c736a..f5971aa55ff 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
@@ -3,14 +3,14 @@ package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
import ai.vespa.llm.clients.LlmClientConfig;
-import ai.vespa.llm.clients.MockLLMClient;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.ComponentId;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.container.di.componentgraph.Provider;
import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
@@ -20,6 +20,7 @@ import org.junit.jupiter.api.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
@@ -47,7 +48,8 @@ public class LLMSearcherTest {
@Test
public void testGeneration() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(Map.of("mock", client));
var params = Map.of("prompt", "why are ducks better than cats");
assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
}
@@ -191,26 +193,59 @@ public class LLMSearcherTest {
}
private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return ConfigurableLanguageModelTest.createGenerator();
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
}
+ public static SecretStore createSecretStore(Map<String, String> secrets) {
+ Provider<SecretStore> secretStore = new Provider<>() {
+ public SecretStore get() {
+ return new SecretStore() {
+ public String getSecret(String key) {
+ return secrets.get(key);
+ }
+ public String getSecret(String key, int version) {
+ return secrets.get(key);
+ }
+ };
+ }
+ public void deconstruct() {
+ }
+ };
+ return secretStore.get();
+ }
+
+
static MockLLMClient createLLMClient() {
var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
var generator = createGenerator();
return new MockLLMClient(config, secretStore, generator, null);
}
static MockLLMClient createLLMClient(String id) {
var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
var generator = createIdGenerator(id);
return new MockLLMClient(config, secretStore, generator, null);
}
static MockLLMClient createLLMClient(ExecutorService executor) {
var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
var generator = createGenerator();
return new MockLLMClient(config, secretStore, generator, executor);
}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java
index 4d0073f1cbe..4411e0cab70 100644
--- a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java
@@ -1,7 +1,9 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
+package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.clients.ConfigurableLanguageModel;
+import ai.vespa.llm.clients.LlmClientConfig;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.container.jdisc.secretstore.SecretStore;
diff --git a/container-search/src/test/resources/llms/tinyllm.gguf b/container-search/src/test/resources/llms/tinyllm.gguf
deleted file mode 100644
index 34367b6b57b..00000000000
--- a/container-search/src/test/resources/llms/tinyllm.gguf
+++ /dev/null
Binary files differ