From 780bc7cbe8fb67ae712fcf278f8900c8f32e14a6 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 16 Apr 2024 13:07:31 +0200 Subject: Reapply "Lesters/add local llms 2" This reverts commit ed62b750494822cc67a328390178754512baf032. --- model-integration/abi-spec.json | 182 +++++++++++++++++++++ model-integration/pom.xml | 12 ++ .../llm/clients/ConfigurableLanguageModel.java | 75 +++++++++ .../main/java/ai/vespa/llm/clients/LocalLLM.java | 126 ++++++++++++++ .../src/main/java/ai/vespa/llm/clients/OpenAI.java | 48 ++++++ .../java/ai/vespa/llm/clients/package-info.java | 7 + .../resources/configdefinitions/llm-client.def | 8 + .../configdefinitions/llm-local-client.def | 29 ++++ .../llm/clients/ConfigurableLanguageModelTest.java | 174 ++++++++++++++++++++ .../java/ai/vespa/llm/clients/LocalLLMTest.java | 181 ++++++++++++++++++++ .../java/ai/vespa/llm/clients/MockLLMClient.java | 80 +++++++++ .../test/java/ai/vespa/llm/clients/OpenAITest.java | 35 ++++ model-integration/src/test/models/llm/tinyllm.gguf | Bin 0 -> 1185376 bytes 13 files changed, 957 insertions(+) create mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java create mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java create mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java create mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/package-info.java create mode 100755 model-integration/src/main/resources/configdefinitions/llm-client.def create mode 100755 model-integration/src/main/resources/configdefinitions/llm-local-client.def create mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java create mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java create mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java create mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java create mode 100644 model-integration/src/test/models/llm/tinyllm.gguf (limited to 'model-integration') diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index d3c472778e6..e7130d9c777 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -1,4 +1,186 @@ { + "ai.vespa.llm.clients.ConfigurableLanguageModel" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public", + "abstract" + ], + "methods" : [ + "public void ()", + "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", + "protected void setApiKey(ai.vespa.llm.InferenceParameters)", + "protected java.lang.String getEndpoint()", + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void ()", + "public void (ai.vespa.llm.clients.LlmClientConfig)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public ai.vespa.llm.clients.LlmClientConfig build()" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Producer" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig" : { + "superClass" : "com.yahoo.config.ConfigInstance", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public void (ai.vespa.llm.clients.LlmClientConfig$Builder)", + "public java.lang.String apiKeySecretName()", + "public java.lang.String endpoint()" + ], + "fields" : [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, + "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void ()", + "public void (ai.vespa.llm.clients.LlmLocalClientConfig)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder model(com.yahoo.config.ModelReference)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public ai.vespa.llm.clients.LlmLocalClientConfig build()" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmLocalClientConfig" : { + "superClass" : "com.yahoo.config.ConfigInstance", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public void (ai.vespa.llm.clients.LlmLocalClientConfig$Builder)", + "public java.nio.file.Path model()", + "public int parallelRequests()", + "public int maxQueueSize()", + "public boolean useGpu()", + "public int gpuLayers()", + "public int threads()", + "public int contextSize()", + "public int maxTokens()" + ], + "fields" : [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, + "ai.vespa.llm.clients.LocalLLM" : { + "superClass" : "com.yahoo.component.AbstractComponent", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void (ai.vespa.llm.clients.LlmLocalClientConfig)", + "public void deconstruct()", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.OpenAI" : { + "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, "ai.vespa.llm.generation.Generator" : { "superClass" : "com.yahoo.component.AbstractComponent", "interfaces" : [ ], diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 0bab30e1453..d92fa319251 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -38,6 +38,12 @@ ${project.version} provided + + com.yahoo.vespa + container-disc + ${project.version} + provided + com.yahoo.vespa searchcore @@ -74,6 +80,12 @@ ${project.version} provided + + com.yahoo.vespa + container-llama + ${project.version} + provided + com.yahoo.vespa component diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..761fdf0af93 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + if (endpoint != null && ! endpoint.isEmpty()) { + params.setEndpoint(endpoint); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java new file mode 100644 index 00000000000..fd1b8b700c8 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -0,0 +1,126 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import de.kherud.llama.LlamaModel; +import de.kherud.llama.ModelParameters; +import de.kherud.llama.args.LogFormat; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.logging.Logger; + +/** + * A language model running locally on the container node. + * + * @author lesters + */ +public class LocalLLM extends AbstractComponent implements LanguageModel { + + private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); + private final LlamaModel model; + private final ThreadPoolExecutor executor; + private final int contextSize; + private final int maxTokens; + + @Inject + public LocalLLM(LlmLocalClientConfig config) { + executor = createExecutor(config); + + // Maximum number of tokens to generate - need this since some models can just generate infinitely + maxTokens = config.maxTokens(); + + // Only used if GPU is not used + var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2; + + var modelFile = config.model().toFile().getAbsolutePath(); + var modelParams = new ModelParameters() + .setModelFilePath(modelFile) + .setContinuousBatching(true) + .setNParallel(config.parallelRequests()) + .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads()) + .setNCtx(config.contextSize()) + .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0); + + long startLoad = System.nanoTime(); + model = new LlamaModel(modelParams); + long loadTime = System.nanoTime() - startLoad; + logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000))); + + // Todo: handle prompt context size - such as give a warning when prompt exceeds context size + contextSize = config.contextSize(); + } + + private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) { + return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(), + 0L, TimeUnit.MILLISECONDS, + config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(), + new ThreadPoolExecutor.AbortPolicy()); + } + + @Override + public void deconstruct() { + logger.info("Closing LLM model..."); + model.close(); + executor.shutdownNow(); + } + + @Override + public List complete(Prompt prompt, InferenceParameters options) { + StringBuilder result = new StringBuilder(); + var future = completeAsync(prompt, options, completion -> { + result.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + var reason = future.join(); + + List completions = new ArrayList<>(); + completions.add(new Completion(result.toString(), reason)); + return completions; + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, InferenceParameters options, Consumer consumer) { + var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading()); + + // We always set this to some value to avoid infinite token generation + inferParams.setNPredict(maxTokens); + + options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v))); + options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v))); + options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v))); + options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v))); + options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); + // Todo: more options? + + var completionFuture = new CompletableFuture(); + try { + executor.submit(() -> { + for (LlamaModel.Output output : model.generate(inferParams)) { + consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); + } + completionFuture.complete(Completion.FinishReason.stop); + }); + } catch (RejectedExecutionException e) { + // If we have too many requests (active + any waiting in queue), we reject the completion + int activeCount = executor.getActiveCount(); + int queueSize = executor.getQueue().size(); + String error = String.format("Rejected completion due to too many requests, " + + "%d active, %d in queue", activeCount, queueSize); + throw new RejectedExecutionException(error); + } + return completionFuture; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..82e19d47c92 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,48 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/model-integration/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def new file mode 100755 index 00000000000..0866459166a --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/llm-client.def @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The name of the secret containing the api key +apiKeySecretName string default="" + +# Endpoint for LLM client - if not set reverts to default for client +endpoint string default="" diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def new file mode 100755 index 00000000000..c06c24b33e5 --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def @@ -0,0 +1,29 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The LLM model to use +model model + +# Maximum number of requests to handle in parallel pr container node +parallelRequests int default=10 + +# Additional number of requests to put in queue for processing before starting to reject new requests +maxQueueSize int default=10 + +# Use GPU +useGpu bool default=false + +# Maximum number of model layers to run on GPU +gpuLayers int default=1000000 + +# Number of threads to use for CPU processing - -1 means use all available cores +# Not used for GPU processing +threads int default=-1 + +# Context size for the model +# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context +contextSize int default=512 + +# Maximum number of tokens to process in one request - overriden by inference parameters +maxTokens int default=512 + diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java new file mode 100644 index 00000000000..35d5cfd3855 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java @@ -0,0 +1,174 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.container.jdisc.secretstore.SecretStore; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigurableLanguageModelTest { + + @Test + public void testSyncGeneration() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); + assertEquals(1, result.size()); + assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var sb = new StringBuilder(); + try { + var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { + sb.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + } finally { + executor.shutdownNow(); + } + + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + @Test + public void testInferenceParameters() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); + var result = createLLM().complete(prompt, params); + assertEquals("Random text about ducks", result.get(0).text()); + } + + @Test + public void testNoApiKey() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key", null); + var secrets = createSecretStore(Map.of()); + assertThrows(IllegalArgumentException.class, () -> { + createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); + }); + } + + @Test + public void testApiKeyFromSecretStore() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key-in-secret-store", null); + var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); + assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); + } + + private static String lookupParameter(String parameter, Map params) { + return params.get(parameter); + } + + private static InferenceParameters inferenceParams() { + return new InferenceParameters(s -> lookupParameter(s, Map.of())); + } + + private static InferenceParameters inferenceParams(Map params) { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); + } + + private static InferenceParameters inferenceParamsWithDefaultKey() { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of())); + } + + private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { + var config = new LlmClientConfig.Builder(); + if (apiKeySecretName != null) { + config.apiKeySecretName(apiKeySecretName); + } + if (endpoint != null) { + config.endpoint(endpoint); + } + return config.build(); + } + + public static SecretStore createSecretStore(Map secrets) { + Provider secretStore = new Provider<>() { + public SecretStore get() { + return new SecretStore() { + public String getSecret(String key) { + return secrets.get(key); + } + public String getSecret(String key, int version) { + return secrets.get(key); + } + }; + } + public void deconstruct() { + } + }; + return secretStore.get(); + } + + public static BiFunction createGenerator() { + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; + } + + private static MockLLMClient createLLM() { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, null); + } + + private static MockLLMClient createLLM(ExecutorService executor) { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { + var generator = createGenerator(); + var secretStore = new SecretStoreProvider(); // throws exception on use + return createLLM(config, generator, secretStore.get(), executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction generator, + SecretStore secretStore) { + return createLLM(config, generator, secretStore, null); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction generator, + SecretStore secretStore, + ExecutorService executor) { + return new MockLLMClient(config, secretStore, generator, executor); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java new file mode 100644 index 00000000000..e85e397b7ff --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -0,0 +1,181 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.config.ModelReference; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for LocalLLM. + * + * @author lesters + */ +public class LocalLLMTest { + + private static String model = "src/test/models/llm/tinyllm.gguf"; + private static Prompt prompt = StringPrompt.from("A random prompt"); + + @Test + public void testGeneration() { + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(1) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + try { + var result = llm.complete(prompt, defaultOptions()); + assertEquals(Completion.FinishReason.stop, result.get(0).finishReason()); + assertTrue(result.get(0).text().length() > 10); + } finally { + llm.deconstruct(); + } + } + + @Test + public void testAsyncGeneration() { + var sb = new StringBuilder(); + var tokenCount = new AtomicInteger(0); + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(1) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + try { + var future = llm.completeAsync(prompt, defaultOptions(), completion -> { + sb.append(completion.text()); + tokenCount.incrementAndGet(); + }).exceptionally(exception -> Completion.FinishReason.error); + + assertFalse(future.isDone()); + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + + } finally { + llm.deconstruct(); + } + assertTrue(tokenCount.get() > 0); +// System.out.println(sb); + } + + @Test + public void testParallelGeneration() { + var prompts = testPrompts(); + var promptsToUse = prompts.size(); + var parallelRequests = 10; + + var futures = new ArrayList>(Collections.nCopies(promptsToUse, null)); + var completions = new ArrayList(Collections.nCopies(promptsToUse, null)); + var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0)); + + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(parallelRequests) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + try { + for (int i = 0; i < promptsToUse; i++) { + final var seq = i; + + completions.set(seq, new StringBuilder()); + futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { + completions.get(seq).append(completion.text()); + tokenCounts.set(seq, tokenCounts.get(seq) + 1); + }).exceptionally(exception -> Completion.FinishReason.error)); + } + for (int i = 0; i < promptsToUse; i++) { + var reason = futures.get(i).join(); + assertNotEquals(reason, Completion.FinishReason.error); + } + } finally { + llm.deconstruct(); + } + for (int i = 0; i < promptsToUse; i++) { + assertFalse(completions.get(i).isEmpty()); + assertTrue(tokenCounts.get(i) > 0); + } + } + + @Test + public void testRejection() { + var prompts = testPrompts(); + var promptsToUse = prompts.size(); + var parallelRequests = 2; + var additionalQueue = 1; + // 7 should be rejected + + var futures = new ArrayList>(Collections.nCopies(promptsToUse, null)); + var completions = new ArrayList(Collections.nCopies(promptsToUse, null)); + + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(parallelRequests) + .maxQueueSize(additionalQueue) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + var rejected = new AtomicInteger(0); + try { + for (int i = 0; i < promptsToUse; i++) { + final var seq = i; + + completions.set(seq, new StringBuilder()); + try { + var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { + completions.get(seq).append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + futures.set(seq, future); + } catch (RejectedExecutionException e) { + rejected.incrementAndGet(); + } + } + for (int i = 0; i < promptsToUse; i++) { + if (futures.get(i) != null) { + assertNotEquals(futures.get(i).join(), Completion.FinishReason.error); + } + } + } finally { + llm.deconstruct(); + } + assertEquals(7, rejected.get()); + } + + private static InferenceParameters defaultOptions() { + final Map options = Map.of( + "temperature", "0.1", + "npredict", "100" + ); + return new InferenceParameters(options::get); + } + + private List testPrompts() { + List prompts = new ArrayList<>(); + prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries."); + prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms."); + prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe."); + prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure."); + prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective."); + prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'"); + prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`."); + prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease."); + prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level."); + prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century."); + return prompts; + } + +} diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java new file mode 100644 index 00000000000..4d0073f1cbe --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java @@ -0,0 +1,80 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class MockLLMClient extends ConfigurableLanguageModel { + + public final static String ACCEPTED_API_KEY = "sesame"; + + private final ExecutorService executor; + private final BiFunction generator; + + private Prompt lastPrompt; + + public MockLLMClient(LlmClientConfig config, + SecretStore secretStore, + BiFunction generator, + ExecutorService executor) { + super(config, secretStore); + this.generator = generator; + this.executor = executor; + } + + private void checkApiKey(InferenceParameters options) { + var apiKey = getApiKey(options); + if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { + throw new IllegalArgumentException("Invalid API key"); + } + } + + private void setPrompt(Prompt prompt) { + this.lastPrompt = prompt; + } + + public Prompt getPrompt() { + return this.lastPrompt; + } + + @Override + public List complete(Prompt prompt, InferenceParameters params) { + setApiKey(params); + checkApiKey(params); + setPrompt(prompt); + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters params, + Consumer consumer) { + setPrompt(prompt); + var completionFuture = new CompletableFuture(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i=0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + + return completionFuture; + } + +} diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java new file mode 100644 index 00000000000..57339f6ad49 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -0,0 +1,35 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.jdisc.SecretStoreProvider; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +public class OpenAITest { + + private static final String apiKey = ""; + + @Test + @Disabled + public void testOpenAIGeneration() { + var config = new LlmClientConfig.Builder().build(); + var openai = new OpenAI(config, new SecretStoreProvider().get()); + var options = Map.of( + "maxTokens", "10" + ); + + var prompt = StringPrompt.from("why are ducks better than cats?"); + var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { + System.out.print(completion.text()); + }).exceptionally(exception -> { + System.out.println("Error: " + exception); + return null; + }); + future.join(); + } + +} diff --git a/model-integration/src/test/models/llm/tinyllm.gguf b/model-integration/src/test/models/llm/tinyllm.gguf new file mode 100644 index 00000000000..34367b6b57b Binary files /dev/null and b/model-integration/src/test/models/llm/tinyllm.gguf differ -- cgit v1.2.3