diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2024-04-15 14:51:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-15 14:51:27 +0200 |
commit | 7518d93961ac7c5c5da1cd41717d42f600dae647 (patch) | |
tree | 63e2811a56e6bf6b2bed5e65e15c98458cfb357f /model-integration/src/main/java | |
parent | f7fd3dd205912c0100786e86d78b6de93d667bfa (diff) |
Revert "Lesters/add local llms"
Diffstat (limited to 'model-integration/src/main/java')
4 files changed, 0 insertions, 256 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java deleted file mode 100644 index 761fdf0af93..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.logging.Logger; - - -/** - * Base class for language models that can be configured with config definitions. - * - * @author lesters - */ -@Beta -public abstract class ConfigurableLanguageModel implements LanguageModel { - - private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); - - private final String apiKey; - private final String endpoint; - - public ConfigurableLanguageModel() { - this.apiKey = null; - this.endpoint = null; - } - - @Inject - public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { - this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); - this.endpoint = config.endpoint(); - } - - private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { - String apiKey = ""; - if (property != null && ! property.isEmpty()) { - try { - apiKey = secretStore.getSecret(property); - } catch (UnsupportedOperationException e) { - // Secret store is not available - silently ignore this - } catch (Exception e) { - log.warning("Secret store look up failed: " + e.getMessage() + "\n" + - "Will expect API key in request header"); - } - } - return apiKey; - } - - protected String getApiKey(InferenceParameters params) { - return params.getApiKey().orElse(null); - } - - /** - * Set the API key as retrieved from secret store if it is not already set - */ - protected void setApiKey(InferenceParameters params) { - if (params.getApiKey().isEmpty() && apiKey != null) { - params.setApiKey(apiKey); - } - } - - protected String getEndpoint() { - return endpoint; - } - - protected void setEndpoint(InferenceParameters params) { - if (endpoint != null && ! endpoint.isEmpty()) { - params.setEndpoint(endpoint); - } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java deleted file mode 100644 index fd1b8b700c8..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.component.AbstractComponent; -import com.yahoo.component.annotation.Inject; -import de.kherud.llama.LlamaModel; -import de.kherud.llama.ModelParameters; -import de.kherud.llama.args.LogFormat; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; -import java.util.logging.Logger; - -/** - * A language model running locally on the container node. - * - * @author lesters - */ -public class LocalLLM extends AbstractComponent implements LanguageModel { - - private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); - private final LlamaModel model; - private final ThreadPoolExecutor executor; - private final int contextSize; - private final int maxTokens; - - @Inject - public LocalLLM(LlmLocalClientConfig config) { - executor = createExecutor(config); - - // Maximum number of tokens to generate - need this since some models can just generate infinitely - maxTokens = config.maxTokens(); - - // Only used if GPU is not used - var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2; - - var modelFile = config.model().toFile().getAbsolutePath(); - var modelParams = new ModelParameters() - .setModelFilePath(modelFile) - .setContinuousBatching(true) - .setNParallel(config.parallelRequests()) - .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads()) - .setNCtx(config.contextSize()) - .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0); - - long startLoad = System.nanoTime(); - model = new LlamaModel(modelParams); - long loadTime = System.nanoTime() - startLoad; - logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000))); - - // Todo: handle prompt context size - such as give a warning when prompt exceeds context size - contextSize = config.contextSize(); - } - - private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) { - return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(), - 0L, TimeUnit.MILLISECONDS, - config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(), - new ThreadPoolExecutor.AbortPolicy()); - } - - @Override - public void deconstruct() { - logger.info("Closing LLM model..."); - model.close(); - executor.shutdownNow(); - } - - @Override - public List<Completion> complete(Prompt prompt, InferenceParameters options) { - StringBuilder result = new StringBuilder(); - var future = completeAsync(prompt, options, completion -> { - result.append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); - var reason = future.join(); - - List<Completion> completions = new ArrayList<>(); - completions.add(new Completion(result.toString(), reason)); - return completions; - } - - @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) { - var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading()); - - // We always set this to some value to avoid infinite token generation - inferParams.setNPredict(maxTokens); - - options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v))); - options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v))); - options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v))); - options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v))); - options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); - // Todo: more options? - - var completionFuture = new CompletableFuture<Completion.FinishReason>(); - try { - executor.submit(() -> { - for (LlamaModel.Output output : model.generate(inferParams)) { - consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); - } - completionFuture.complete(Completion.FinishReason.stop); - }); - } catch (RejectedExecutionException e) { - // If we have too many requests (active + any waiting in queue), we reject the completion - int activeCount = executor.getActiveCount(); - int queueSize = executor.getQueue().size(); - String error = String.format("Rejected completion due to too many requests, " + - "%d active, %d in queue", activeCount, queueSize); - throw new RejectedExecutionException(error); - } - return completionFuture; - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java deleted file mode 100644 index 82e19d47c92..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.client.openai.OpenAiClient; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -/** - * A configurable OpenAI client. - * - * @author lesters - */ -@Beta -public class OpenAI extends ConfigurableLanguageModel { - - private final OpenAiClient client; - - @Inject - public OpenAI(LlmClientConfig config, SecretStore secretStore) { - super(config, secretStore); - client = new OpenAiClient(); - } - - @Override - public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { - setApiKey(parameters); - setEndpoint(parameters); - return client.complete(prompt, parameters); - } - - @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, - InferenceParameters parameters, - Consumer<Completion> consumer) { - setApiKey(parameters); - setEndpoint(parameters); - return client.completeAsync(prompt, parameters, consumer); - } -} - diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java deleted file mode 100644 index c360245901c..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.clients; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; |