aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2024-04-15 14:51:27 +0200
committerGitHub <noreply@github.com>2024-04-15 14:51:27 +0200
commit7518d93961ac7c5c5da1cd41717d42f600dae647 (patch)
tree63e2811a56e6bf6b2bed5e65e15c98458cfb357f /model-integration/src/main/java
parentf7fd3dd205912c0100786e86d78b6de93d667bfa (diff)
Revert "Lesters/add local llms"
Diffstat (limited to 'model-integration/src/main/java')
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java75
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java126
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/package-info.java7
4 files changed, 0 insertions, 256 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
deleted file mode 100644
index 761fdf0af93..00000000000
--- a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.logging.Logger;
-
-
-/**
- * Base class for language models that can be configured with config definitions.
- *
- * @author lesters
- */
-@Beta
-public abstract class ConfigurableLanguageModel implements LanguageModel {
-
- private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName());
-
- private final String apiKey;
- private final String endpoint;
-
- public ConfigurableLanguageModel() {
- this.apiKey = null;
- this.endpoint = null;
- }
-
- @Inject
- public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
- this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
- this.endpoint = config.endpoint();
- }
-
- private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
- String apiKey = "";
- if (property != null && ! property.isEmpty()) {
- try {
- apiKey = secretStore.getSecret(property);
- } catch (UnsupportedOperationException e) {
- // Secret store is not available - silently ignore this
- } catch (Exception e) {
- log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
- "Will expect API key in request header");
- }
- }
- return apiKey;
- }
-
- protected String getApiKey(InferenceParameters params) {
- return params.getApiKey().orElse(null);
- }
-
- /**
- * Set the API key as retrieved from secret store if it is not already set
- */
- protected void setApiKey(InferenceParameters params) {
- if (params.getApiKey().isEmpty() && apiKey != null) {
- params.setApiKey(apiKey);
- }
- }
-
- protected String getEndpoint() {
- return endpoint;
- }
-
- protected void setEndpoint(InferenceParameters params) {
- if (endpoint != null && ! endpoint.isEmpty()) {
- params.setEndpoint(endpoint);
- }
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
deleted file mode 100644
index fd1b8b700c8..00000000000
--- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.component.AbstractComponent;
-import com.yahoo.component.annotation.Inject;
-import de.kherud.llama.LlamaModel;
-import de.kherud.llama.ModelParameters;
-import de.kherud.llama.args.LogFormat;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.SynchronousQueue;
-import java.util.concurrent.ThreadPoolExecutor;
-import java.util.concurrent.TimeUnit;
-import java.util.function.Consumer;
-import java.util.logging.Logger;
-
-/**
- * A language model running locally on the container node.
- *
- * @author lesters
- */
-public class LocalLLM extends AbstractComponent implements LanguageModel {
-
- private final static Logger logger = Logger.getLogger(LocalLLM.class.getName());
- private final LlamaModel model;
- private final ThreadPoolExecutor executor;
- private final int contextSize;
- private final int maxTokens;
-
- @Inject
- public LocalLLM(LlmLocalClientConfig config) {
- executor = createExecutor(config);
-
- // Maximum number of tokens to generate - need this since some models can just generate infinitely
- maxTokens = config.maxTokens();
-
- // Only used if GPU is not used
- var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
-
- var modelFile = config.model().toFile().getAbsolutePath();
- var modelParams = new ModelParameters()
- .setModelFilePath(modelFile)
- .setContinuousBatching(true)
- .setNParallel(config.parallelRequests())
- .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads())
- .setNCtx(config.contextSize())
- .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0);
-
- long startLoad = System.nanoTime();
- model = new LlamaModel(modelParams);
- long loadTime = System.nanoTime() - startLoad;
- logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000)));
-
- // Todo: handle prompt context size - such as give a warning when prompt exceeds context size
- contextSize = config.contextSize();
- }
-
- private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) {
- return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
- 0L, TimeUnit.MILLISECONDS,
- config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
- new ThreadPoolExecutor.AbortPolicy());
- }
-
- @Override
- public void deconstruct() {
- logger.info("Closing LLM model...");
- model.close();
- executor.shutdownNow();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters options) {
- StringBuilder result = new StringBuilder();
- var future = completeAsync(prompt, options, completion -> {
- result.append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
- var reason = future.join();
-
- List<Completion> completions = new ArrayList<>();
- completions.add(new Completion(result.toString(), reason));
- return completions;
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
- var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading());
-
- // We always set this to some value to avoid infinite token generation
- inferParams.setNPredict(maxTokens);
-
- options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
- options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
- options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
- options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
- options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
- // Todo: more options?
-
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- try {
- executor.submit(() -> {
- for (LlamaModel.Output output : model.generate(inferParams)) {
- consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
- }
- completionFuture.complete(Completion.FinishReason.stop);
- });
- } catch (RejectedExecutionException e) {
- // If we have too many requests (active + any waiting in queue), we reject the completion
- int activeCount = executor.getActiveCount();
- int queueSize = executor.getQueue().size();
- String error = String.format("Rejected completion due to too many requests, " +
- "%d active, %d in queue", activeCount, queueSize);
- throw new RejectedExecutionException(error);
- }
- return completionFuture;
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
deleted file mode 100644
index 82e19d47c92..00000000000
--- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.client.openai.OpenAiClient;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.function.Consumer;
-
-/**
- * A configurable OpenAI client.
- *
- * @author lesters
- */
-@Beta
-public class OpenAI extends ConfigurableLanguageModel {
-
- private final OpenAiClient client;
-
- @Inject
- public OpenAI(LlmClientConfig config, SecretStore secretStore) {
- super(config, secretStore);
- client = new OpenAiClient();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.complete(prompt, parameters);
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters parameters,
- Consumer<Completion> consumer) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.completeAsync(prompt, parameters, consumer);
- }
-}
-
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
deleted file mode 100644
index c360245901c..00000000000
--- a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
+++ /dev/null
@@ -1,7 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-@ExportPackage
-@PublicApi
-package ai.vespa.llm.clients;
-
-import com.yahoo.api.annotations.PublicApi;
-import com.yahoo.osgi.annotation.ExportPackage;