diff options
Diffstat (limited to 'model-integration/src/main')
6 files changed, 293 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..761fdf0af93 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + if (endpoint != null && ! endpoint.isEmpty()) { + params.setEndpoint(endpoint); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java new file mode 100644 index 00000000000..fd1b8b700c8 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -0,0 +1,126 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import de.kherud.llama.LlamaModel; +import de.kherud.llama.ModelParameters; +import de.kherud.llama.args.LogFormat; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.logging.Logger; + +/** + * A language model running locally on the container node. + * + * @author lesters + */ +public class LocalLLM extends AbstractComponent implements LanguageModel { + + private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); + private final LlamaModel model; + private final ThreadPoolExecutor executor; + private final int contextSize; + private final int maxTokens; + + @Inject + public LocalLLM(LlmLocalClientConfig config) { + executor = createExecutor(config); + + // Maximum number of tokens to generate - need this since some models can just generate infinitely + maxTokens = config.maxTokens(); + + // Only used if GPU is not used + var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2; + + var modelFile = config.model().toFile().getAbsolutePath(); + var modelParams = new ModelParameters() + .setModelFilePath(modelFile) + .setContinuousBatching(true) + .setNParallel(config.parallelRequests()) + .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads()) + .setNCtx(config.contextSize()) + .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0); + + long startLoad = System.nanoTime(); + model = new LlamaModel(modelParams); + long loadTime = System.nanoTime() - startLoad; + logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000))); + + // Todo: handle prompt context size - such as give a warning when prompt exceeds context size + contextSize = config.contextSize(); + } + + private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) { + return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(), + 0L, TimeUnit.MILLISECONDS, + config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(), + new ThreadPoolExecutor.AbortPolicy()); + } + + @Override + public void deconstruct() { + logger.info("Closing LLM model..."); + model.close(); + executor.shutdownNow(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters options) { + StringBuilder result = new StringBuilder(); + var future = completeAsync(prompt, options, completion -> { + result.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + var reason = future.join(); + + List<Completion> completions = new ArrayList<>(); + completions.add(new Completion(result.toString(), reason)); + return completions; + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) { + var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading()); + + // We always set this to some value to avoid infinite token generation + inferParams.setNPredict(maxTokens); + + options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v))); + options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v))); + options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v))); + options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v))); + options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); + // Todo: more options? + + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + try { + executor.submit(() -> { + for (LlamaModel.Output output : model.generate(inferParams)) { + consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); + } + completionFuture.complete(Completion.FinishReason.stop); + }); + } catch (RejectedExecutionException e) { + // If we have too many requests (active + any waiting in queue), we reject the completion + int activeCount = executor.getActiveCount(); + int queueSize = executor.getQueue().size(); + String error = String.format("Rejected completion due to too many requests, " + + "%d active, %d in queue", activeCount, queueSize); + throw new RejectedExecutionException(error); + } + return completionFuture; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..82e19d47c92 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,48 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer<Completion> consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/model-integration/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def new file mode 100755 index 00000000000..0866459166a --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/llm-client.def @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The name of the secret containing the api key +apiKeySecretName string default="" + +# Endpoint for LLM client - if not set reverts to default for client +endpoint string default="" diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def new file mode 100755 index 00000000000..c06c24b33e5 --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def @@ -0,0 +1,29 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The LLM model to use +model model + +# Maximum number of requests to handle in parallel pr container node +parallelRequests int default=10 + +# Additional number of requests to put in queue for processing before starting to reject new requests +maxQueueSize int default=10 + +# Use GPU +useGpu bool default=false + +# Maximum number of model layers to run on GPU +gpuLayers int default=1000000 + +# Number of threads to use for CPU processing - -1 means use all available cores +# Not used for GPU processing +threads int default=-1 + +# Context size for the model +# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context +contextSize int default=512 + +# Maximum number of tokens to process in one request - overriden by inference parameters +maxTokens int default=512 + |