aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main')
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java74
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java49
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java166
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java75
-rw-r--r--container-search/src/main/java/ai/vespa/llm/search/package-info.java7
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def8
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-searcher.def11
8 files changed, 397 insertions, 0 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
new file mode 100644
index 00000000000..662d73d4e01
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
@@ -0,0 +1,74 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmClientConfig;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.logging.Logger;
+
+
+/**
+ * Base class for language models that can be configured with config definitions.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class ConfigurableLanguageModel implements LanguageModel {
+
+ private static Logger log = Logger.getLogger(ai.vespa.llm.clients.ConfigurableLanguageModel.class.getName());
+
+ private final String apiKey;
+ private final String endpoint;
+
+ public ConfigurableLanguageModel() {
+ this.apiKey = null;
+ this.endpoint = null;
+ }
+
+ @Inject
+ public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
+ this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
+ this.endpoint = config.endpoint();
+ }
+
+ private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
+ String apiKey = "";
+ if (property != null && ! property.isEmpty()) {
+ try {
+ apiKey = secretStore.getSecret(property);
+ } catch (UnsupportedOperationException e) {
+ // Secret store is not available - silently ignore this
+ } catch (Exception e) {
+ log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
+ "Will expect API key in request header");
+ }
+ }
+ return apiKey;
+ }
+
+ protected String getApiKey(InferenceParameters params) {
+ return params.getApiKey().orElse(null);
+ }
+
+ /**
+ * Set the API key as retrieved from secret store if it is not already set
+ */
+ protected void setApiKey(InferenceParameters params) {
+ if (params.getApiKey().isEmpty() && apiKey != null) {
+ params.setApiKey(apiKey);
+ }
+ }
+
+ protected String getEndpoint() {
+ return endpoint;
+ }
+
+ protected void setEndpoint(InferenceParameters params) {
+ params.setEndpoint(endpoint);
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
new file mode 100644
index 00000000000..f6092f51948
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
@@ -0,0 +1,49 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.client.openai.OpenAiClient;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+/**
+ * A configurable OpenAI client.
+ *
+ * @author lesters
+ */
+@Beta
+public class OpenAI extends ConfigurableLanguageModel {
+
+ private final OpenAiClient client;
+
+ @Inject
+ public OpenAI(LlmClientConfig config, SecretStore secretStore) {
+ super(config, secretStore);
+ client = new OpenAiClient();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.complete(prompt, parameters);
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters parameters,
+ Consumer<Completion> consumer) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.completeAsync(prompt, parameters, consumer);
+ }
+}
+
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
new file mode 100644
index 00000000000..c360245901c
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.clients;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java
new file mode 100755
index 00000000000..6ff40401a8f
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java
@@ -0,0 +1,166 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LanguageModelException;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.result.HitGroup;
+
+import java.util.List;
+import java.util.function.Function;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Base class for LLM searchers. Provides utilities for calling LLMs and handling properties.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class LLMSearcher extends Searcher {
+
+ private static Logger log = Logger.getLogger(LLMSearcher.class.getName());
+
+ private static final String API_KEY_HEADER = "X-LLM-API-KEY";
+ private static final String STREAM_PROPERTY = "stream";
+ private static final String PROMPT_PROPERTY = "prompt";
+
+ private final String propertyPrefix;
+ private final boolean stream;
+ private final LanguageModel languageModel;
+ private final String languageModelId;
+
+ @Inject
+ LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ this.stream = config.stream();
+ this.languageModelId = config.providerId();
+ this.languageModel = findLanguageModel(languageModelId, languageModels);
+ this.propertyPrefix = config.propertyPrefix();
+ }
+
+ private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
+ throws IllegalArgumentException
+ {
+ if (languageModels.allComponents().isEmpty()) {
+ throw new IllegalArgumentException("No language models were found");
+ }
+ if (providerId == null || providerId.isEmpty()) {
+ var entry = languageModels.allComponentsById().entrySet().stream().findFirst();
+ if (entry.isEmpty()) {
+ throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above
+ }
+ log.info("Language model provider was not found in config. " +
+ "Fallback to using first available language model: " + entry.get().getKey());
+ return entry.get().getValue();
+ }
+ final LanguageModel languageModel = languageModels.getComponent(providerId);
+ if (languageModel == null) {
+ throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " +
+ "Available LLM components are: " + languageModels.allComponentsById().keySet().stream()
+ .map(ComponentId::toString).collect(Collectors.joining(",")));
+ }
+ return languageModel;
+ }
+
+ Result complete(Query query, Prompt prompt) {
+ var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
+ var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ }
+
+ private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
+ EventStream eventStream = new EventStream();
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
+
+ languageModel.completeAsync(prompt, options, token -> {
+ eventStream.add(token.text());
+ }).exceptionally(exception -> {
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ eventStream.markComplete();
+ return Completion.FinishReason.error;
+ }).thenAccept(finishReason -> {
+ eventStream.markComplete();
+ });
+
+ HitGroup hitGroup = new HitGroup("token_stream");
+ hitGroup.add(eventStream);
+ return new Result(query, hitGroup);
+ }
+
+ private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
+ EventStream eventStream = new EventStream();
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
+
+ List<Completion> completions = languageModel.complete(prompt, options);
+ eventStream.add(completions.get(0).text(), "completion");
+ eventStream.markComplete();
+
+ HitGroup hitGroup = new HitGroup("completion");
+ hitGroup.add(eventStream);
+ return new Result(query, hitGroup);
+ }
+
+ String getPrompt(Query query) {
+ // Look for prompt with or without prefix
+ String prompt = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString(p));
+ if (prompt != null)
+ return prompt;
+
+ // If not found, use query directly
+ prompt = query.getModel().getQueryString();
+ if (prompt != null)
+ return prompt;
+
+ // If not, throw exception
+ throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for " +
+ "'" + propertyPrefix + "." + PROMPT_PROPERTY + "', '" + PROMPT_PROPERTY + "' or '@query'.");
+ }
+
+ String getPropertyPrefix() {
+ return this.propertyPrefix;
+ }
+
+ String lookupProperty(String property, Query query) {
+ String propertyWithPrefix = this.propertyPrefix + "." + property;
+ return query.properties().getString(propertyWithPrefix, null);
+ }
+
+ Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) {
+ String propertyWithPrefix = this.propertyPrefix + "." + property;
+ return query.properties().getBoolean(propertyWithPrefix, defaultValue);
+ }
+
+ String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) {
+ String value = lookup.apply(getPropertyPrefix() + "." + property);
+ if (value != null)
+ return value;
+ return lookup.apply(property);
+ }
+
+ String getApiKeyHeader(Query query) {
+ return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java
new file mode 100755
index 00000000000..b8e33778ced
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java
@@ -0,0 +1,75 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.searchchain.Execution;
+
+import java.util.logging.Logger;
+
+/**
+ * An LLM searcher that uses the RAG (Retrieval-Augmented Generation) model to generate completions.
+ * Prompts are generated based on the search result context.
+ * By default, the context is a concatenation of the fields of the search result hits.
+ *
+ * @author lesters
+ */
+@Beta
+public class RAGSearcher extends LLMSearcher {
+
+ private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
+
+ private static final String CONTEXT_PROPERTY = "context";
+
+ @Inject
+ public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ super(config, languageModels);
+ log.info("Starting " + RAGSearcher.class.getName() + " with language model " + config.providerId());
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Result result = execution.search(query);
+ execution.fill(result);
+ return complete(query, buildPrompt(query, result));
+ }
+
+ protected Prompt buildPrompt(Query query, Result result) {
+ String prompt = getPrompt(query);
+
+ // Replace @query with the actual query
+ if (prompt.contains("@query")) {
+ prompt = prompt.replace("@query", query.getModel().getQueryString());
+ }
+
+ String context = lookupProperty(CONTEXT_PROPERTY, query);
+ if (context == null || !context.equals("skip")) {
+ if ( !prompt.contains("{context}")) {
+ prompt = "{context}\n" + prompt;
+ }
+ prompt = prompt.replace("{context}", buildContext(result));
+ }
+ return StringPrompt.from(prompt);
+ }
+
+ private String buildContext(Result result) {
+ StringBuilder sb = new StringBuilder();
+ var hits = result.hits();
+ hits.forEach(hit -> {
+ hit.fields().forEach((key, value) -> {
+ sb.append(key).append(": ").append(value).append("\n");
+ });
+ sb.append("\n");
+ });
+ var context = sb.toString();
+ return context;
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/search/package-info.java b/container-search/src/main/java/ai/vespa/llm/search/package-info.java
new file mode 100644
index 00000000000..6a8975fd2fa
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.search;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
new file mode 100755
index 00000000000..6bfd95c3cf2
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-client.def
@@ -0,0 +1,8 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm
+
+# The name of the secret containing the api key
+apiKeySecretName string default=""
+
+# Endpoint for LLM client - if not set reverts to default for client
+endpoint string default=""
diff --git a/container-search/src/main/resources/configdefinitions/llm-searcher.def b/container-search/src/main/resources/configdefinitions/llm-searcher.def
new file mode 100755
index 00000000000..918a6e6e8b1
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-searcher.def
@@ -0,0 +1,11 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm
+
+# Query propertry prefix for options
+propertyPrefix string default="llm"
+
+# Should the searcher stream tokens or wait for the entire thing?
+stream bool default=true
+
+# The external LLM provider - the id of a LanguageModel component
+providerId string default=""