diff options
Diffstat (limited to 'container-search/src/main/java/ai')
6 files changed, 378 insertions, 0 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..662d73d4e01 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,74 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmClientConfig; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ai.vespa.llm.clients.ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + params.setEndpoint(endpoint); + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..f6092f51948 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,49 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer<Completion> consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java new file mode 100755 index 00000000000..6ff40401a8f --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java @@ -0,0 +1,166 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LanguageModelException; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.ComponentId; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.result.HitGroup; + +import java.util.List; +import java.util.function.Function; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Base class for LLM searchers. Provides utilities for calling LLMs and handling properties. + * + * @author lesters + */ +@Beta +public abstract class LLMSearcher extends Searcher { + + private static Logger log = Logger.getLogger(LLMSearcher.class.getName()); + + private static final String API_KEY_HEADER = "X-LLM-API-KEY"; + private static final String STREAM_PROPERTY = "stream"; + private static final String PROMPT_PROPERTY = "prompt"; + + private final String propertyPrefix; + private final boolean stream; + private final LanguageModel languageModel; + private final String languageModelId; + + @Inject + LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + this.stream = config.stream(); + this.languageModelId = config.providerId(); + this.languageModel = findLanguageModel(languageModelId, languageModels); + this.propertyPrefix = config.propertyPrefix(); + } + + private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels) + throws IllegalArgumentException + { + if (languageModels.allComponents().isEmpty()) { + throw new IllegalArgumentException("No language models were found"); + } + if (providerId == null || providerId.isEmpty()) { + var entry = languageModels.allComponentsById().entrySet().stream().findFirst(); + if (entry.isEmpty()) { + throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above + } + log.info("Language model provider was not found in config. " + + "Fallback to using first available language model: " + entry.get().getKey()); + return entry.get().getValue(); + } + final LanguageModel languageModel = languageModels.getComponent(providerId); + if (languageModel == null) { + throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " + + "Available LLM components are: " + languageModels.allComponentsById().keySet().stream() + .map(ComponentId::toString).collect(Collectors.joining(","))); + } + return languageModel; + } + + Result complete(Query query, Prompt prompt) { + var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); + var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + } + + private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { + EventStream eventStream = new EventStream(); + + if (query.getTrace().getLevel() >= 1) { + eventStream.add(prompt.asString(), "prompt"); + } + + languageModel.completeAsync(prompt, options, token -> { + eventStream.add(token.text()); + }).exceptionally(exception -> { + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + eventStream.markComplete(); + return Completion.FinishReason.error; + }).thenAccept(finishReason -> { + eventStream.markComplete(); + }); + + HitGroup hitGroup = new HitGroup("token_stream"); + hitGroup.add(eventStream); + return new Result(query, hitGroup); + } + + private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { + EventStream eventStream = new EventStream(); + + if (query.getTrace().getLevel() >= 1) { + eventStream.add(prompt.asString(), "prompt"); + } + + List<Completion> completions = languageModel.complete(prompt, options); + eventStream.add(completions.get(0).text(), "completion"); + eventStream.markComplete(); + + HitGroup hitGroup = new HitGroup("completion"); + hitGroup.add(eventStream); + return new Result(query, hitGroup); + } + + String getPrompt(Query query) { + // Look for prompt with or without prefix + String prompt = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString(p)); + if (prompt != null) + return prompt; + + // If not found, use query directly + prompt = query.getModel().getQueryString(); + if (prompt != null) + return prompt; + + // If not, throw exception + throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for " + + "'" + propertyPrefix + "." + PROMPT_PROPERTY + "', '" + PROMPT_PROPERTY + "' or '@query'."); + } + + String getPropertyPrefix() { + return this.propertyPrefix; + } + + String lookupProperty(String property, Query query) { + String propertyWithPrefix = this.propertyPrefix + "." + property; + return query.properties().getString(propertyWithPrefix, null); + } + + Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) { + String propertyWithPrefix = this.propertyPrefix + "." + property; + return query.properties().getBoolean(propertyWithPrefix, defaultValue); + } + + String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) { + String value = lookup.apply(getPropertyPrefix() + "." + property); + if (value != null) + return value; + return lookup.apply(property); + } + + String getApiKeyHeader(Query query) { + return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java new file mode 100755 index 00000000000..b8e33778ced --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.searchchain.Execution; + +import java.util.logging.Logger; + +/** + * An LLM searcher that uses the RAG (Retrieval-Augmented Generation) model to generate completions. + * Prompts are generated based on the search result context. + * By default, the context is a concatenation of the fields of the search result hits. + * + * @author lesters + */ +@Beta +public class RAGSearcher extends LLMSearcher { + + private static Logger log = Logger.getLogger(RAGSearcher.class.getName()); + + private static final String CONTEXT_PROPERTY = "context"; + + @Inject + public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + super(config, languageModels); + log.info("Starting " + RAGSearcher.class.getName() + " with language model " + config.providerId()); + } + + @Override + public Result search(Query query, Execution execution) { + Result result = execution.search(query); + execution.fill(result); + return complete(query, buildPrompt(query, result)); + } + + protected Prompt buildPrompt(Query query, Result result) { + String prompt = getPrompt(query); + + // Replace @query with the actual query + if (prompt.contains("@query")) { + prompt = prompt.replace("@query", query.getModel().getQueryString()); + } + + String context = lookupProperty(CONTEXT_PROPERTY, query); + if (context == null || !context.equals("skip")) { + if ( !prompt.contains("{context}")) { + prompt = "{context}\n" + prompt; + } + prompt = prompt.replace("{context}", buildContext(result)); + } + return StringPrompt.from(prompt); + } + + private String buildContext(Result result) { + StringBuilder sb = new StringBuilder(); + var hits = result.hits(); + hits.forEach(hit -> { + hit.fields().forEach((key, value) -> { + sb.append(key).append(": ").append(value).append("\n"); + }); + sb.append("\n"); + }); + var context = sb.toString(); + return context; + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/search/package-info.java b/container-search/src/main/java/ai/vespa/llm/search/package-info.java new file mode 100644 index 00000000000..6a8975fd2fa --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/search/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.search; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; |