aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-03-26 15:52:29 +0100
committerLester Solbakken <lester.solbakken@gmail.com>2024-03-26 15:52:29 +0100
commitbf13cde9e2c1e5243e7cee1b9f2a9f5d915a96ac (patch)
tree53a593f052cf2cad0c01707d8b024749519061c5 /container-search
parent32ad1a732962c9ea9d7b0693dfa880ef74e9ddd0 (diff)
Add RAG searcher
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json55
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java72
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java46
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java166
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java76
-rw-r--r--container-search/src/main/java/ai/vespa/llm/search/package-info.java7
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def8
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-searcher.def11
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java175
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java80
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java35
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java254
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java127
14 files changed, 1119 insertions, 0 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index bdb6cd9e7a5..257dd364000 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -9148,5 +9148,60 @@
"public int getTo()"
],
"fields" : [ ]
+ },
+ "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected java.lang.String getEndpoint()",
+ "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.OpenAI" : {
+ "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.search.LLMSearcher" : {
+ "superClass" : "com.yahoo.search.Searcher",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.search.RAGSearcher" : {
+ "superClass" : "ai.vespa.llm.search.LLMSearcher",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)",
+ "public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)",
+ "protected ai.vespa.llm.completion.Prompt buildPrompt(com.yahoo.search.Query, com.yahoo.search.Result)"
+ ],
+ "fields" : [ ]
}
} \ No newline at end of file
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
new file mode 100644
index 00000000000..c4bba632127
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
@@ -0,0 +1,72 @@
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmClientConfig;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.logging.Logger;
+
+
+/**
+ * Base class for language models that can be configured with config definitions.
+ *
+ * @author lesters
+ */
+public abstract class ConfigurableLanguageModel implements LanguageModel {
+
+ private static Logger log = Logger.getLogger(ai.vespa.llm.clients.ConfigurableLanguageModel.class.getName());
+
+ private final String apiKey;
+ private final String endpoint;
+
+ public ConfigurableLanguageModel() {
+ this.apiKey = null;
+ this.endpoint = null;
+ }
+
+ @Inject
+ public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
+ this.apiKey = findApiKeyInSecretStore(config.apiKey(), secretStore); // is this implicitly assuming external store?
+ this.endpoint = config.endpoint();
+ }
+
+ private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
+ String apiKey = "";
+ if (property != null && ! property.isEmpty()) {
+ try {
+ apiKey = secretStore.getSecret(property);
+ } catch (UnsupportedOperationException e) {
+ // Secret store is not available - silently ignore this
+ } catch (Exception e) {
+ log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
+ "Will expect API key in request header");
+ }
+ }
+ return apiKey;
+ }
+
+ protected String getApiKey(InferenceParameters params) {
+ return params.getApiKey().orElse(null);
+ }
+
+ /**
+ * Set the API key as retrieved from secret store if it is not already set
+ */
+ protected void setApiKey(InferenceParameters params) {
+ if (params.getApiKey().isEmpty() && apiKey != null) {
+ params.setApiKey(apiKey);
+ }
+ }
+
+ protected String getEndpoint() {
+ return endpoint;
+ }
+
+ protected void setEndpoint(InferenceParameters params) {
+ params.setEndpoint(endpoint);
+ }
+
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
new file mode 100644
index 00000000000..0414fdd2e1b
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
@@ -0,0 +1,46 @@
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.client.openai.OpenAiClient;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+/**
+ * A configurable OpenAI client.
+ *
+ * @author lesters
+ */
+public class OpenAI extends ConfigurableLanguageModel {
+
+ private final OpenAiClient client;
+
+ @Inject
+ public OpenAI(LlmClientConfig config, SecretStore secretStore) {
+ super(config, secretStore);
+ client = new OpenAiClient();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.complete(prompt, parameters);
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters parameters,
+ Consumer<Completion> consumer) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.completeAsync(prompt, parameters, consumer);
+ }
+}
+
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
new file mode 100644
index 00000000000..c360245901c
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.clients;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java
new file mode 100755
index 00000000000..6ff40401a8f
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java
@@ -0,0 +1,166 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LanguageModelException;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.result.HitGroup;
+
+import java.util.List;
+import java.util.function.Function;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Base class for LLM searchers. Provides utilities for calling LLMs and handling properties.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class LLMSearcher extends Searcher {
+
+ private static Logger log = Logger.getLogger(LLMSearcher.class.getName());
+
+ private static final String API_KEY_HEADER = "X-LLM-API-KEY";
+ private static final String STREAM_PROPERTY = "stream";
+ private static final String PROMPT_PROPERTY = "prompt";
+
+ private final String propertyPrefix;
+ private final boolean stream;
+ private final LanguageModel languageModel;
+ private final String languageModelId;
+
+ @Inject
+ LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ this.stream = config.stream();
+ this.languageModelId = config.providerId();
+ this.languageModel = findLanguageModel(languageModelId, languageModels);
+ this.propertyPrefix = config.propertyPrefix();
+ }
+
+ private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
+ throws IllegalArgumentException
+ {
+ if (languageModels.allComponents().isEmpty()) {
+ throw new IllegalArgumentException("No language models were found");
+ }
+ if (providerId == null || providerId.isEmpty()) {
+ var entry = languageModels.allComponentsById().entrySet().stream().findFirst();
+ if (entry.isEmpty()) {
+ throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above
+ }
+ log.info("Language model provider was not found in config. " +
+ "Fallback to using first available language model: " + entry.get().getKey());
+ return entry.get().getValue();
+ }
+ final LanguageModel languageModel = languageModels.getComponent(providerId);
+ if (languageModel == null) {
+ throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " +
+ "Available LLM components are: " + languageModels.allComponentsById().keySet().stream()
+ .map(ComponentId::toString).collect(Collectors.joining(",")));
+ }
+ return languageModel;
+ }
+
+ Result complete(Query query, Prompt prompt) {
+ var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
+ var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ }
+
+ private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
+ EventStream eventStream = new EventStream();
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
+
+ languageModel.completeAsync(prompt, options, token -> {
+ eventStream.add(token.text());
+ }).exceptionally(exception -> {
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ eventStream.markComplete();
+ return Completion.FinishReason.error;
+ }).thenAccept(finishReason -> {
+ eventStream.markComplete();
+ });
+
+ HitGroup hitGroup = new HitGroup("token_stream");
+ hitGroup.add(eventStream);
+ return new Result(query, hitGroup);
+ }
+
+ private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
+ EventStream eventStream = new EventStream();
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
+
+ List<Completion> completions = languageModel.complete(prompt, options);
+ eventStream.add(completions.get(0).text(), "completion");
+ eventStream.markComplete();
+
+ HitGroup hitGroup = new HitGroup("completion");
+ hitGroup.add(eventStream);
+ return new Result(query, hitGroup);
+ }
+
+ String getPrompt(Query query) {
+ // Look for prompt with or without prefix
+ String prompt = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString(p));
+ if (prompt != null)
+ return prompt;
+
+ // If not found, use query directly
+ prompt = query.getModel().getQueryString();
+ if (prompt != null)
+ return prompt;
+
+ // If not, throw exception
+ throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for " +
+ "'" + propertyPrefix + "." + PROMPT_PROPERTY + "', '" + PROMPT_PROPERTY + "' or '@query'.");
+ }
+
+ String getPropertyPrefix() {
+ return this.propertyPrefix;
+ }
+
+ String lookupProperty(String property, Query query) {
+ String propertyWithPrefix = this.propertyPrefix + "." + property;
+ return query.properties().getString(propertyWithPrefix, null);
+ }
+
+ Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) {
+ String propertyWithPrefix = this.propertyPrefix + "." + property;
+ return query.properties().getBoolean(propertyWithPrefix, defaultValue);
+ }
+
+ String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) {
+ String value = lookup.apply(getPropertyPrefix() + "." + property);
+ if (value != null)
+ return value;
+ return lookup.apply(property);
+ }
+
+ String getApiKeyHeader(Query query) {
+ return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java
new file mode 100755
index 00000000000..e297359a6a6
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java
@@ -0,0 +1,76 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.searchchain.Execution;
+
+import java.util.logging.Logger;
+
+/**
+ * An LLM searcher that uses the RAG (Retrieval-Augmented Generation) model to generate completions.
+ * Prompts are generated based on the search result context.
+ * By default, the context is a concatenation of the fields of the search result hits.
+ *
+ * @author lesters
+ */
+@Beta
+public class RAGSearcher extends LLMSearcher {
+
+ private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
+
+ private static final String CONTEXT_PROPERTY = "context";
+
+ @Inject
+ public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ super(config, languageModels);
+ log.info("Starting " + RAGSearcher.class.getName() + " with language model " + config.providerId());
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Result result = execution.search(query);
+ execution.fill(result);
+ return complete(query, buildPrompt(query, result));
+ }
+
+ protected Prompt buildPrompt(Query query, Result result) {
+ String prompt = getPrompt(query);
+
+ // Replace @query with the actual query
+ if (prompt.contains("@query")) {
+ prompt = prompt.replace("@query", query.getModel().getQueryString());
+ }
+
+ String context = lookupProperty(CONTEXT_PROPERTY, query);
+ if (context == null || !context.equals("skip")) {
+ if ( !prompt.contains("{context}")) {
+ prompt = "{context}\n" + prompt;
+ }
+ prompt = prompt.replace("{context}", buildContext(result));
+ }
+ return StringPrompt.from(prompt);
+ }
+
+ private String buildContext(Result result) {
+ StringBuilder sb = new StringBuilder();
+ var hits = result.hits();
+ hits.forEach(hit -> {
+ hit.fields().forEach((key, value) -> {
+ sb.append(key).append(": ").append(value).append("\n");
+ });
+ sb.append("\n");
+ });
+ var context = sb.toString();
+ return context;
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/search/package-info.java b/container-search/src/main/java/ai/vespa/llm/search/package-info.java
new file mode 100644
index 00000000000..6a8975fd2fa
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.search;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
new file mode 100755
index 00000000000..009c5253082
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-client.def
@@ -0,0 +1,8 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm
+
+# The name of the secret containing the api key
+apiKey string default=""
+
+# Endpoint for LLM client - if not set reverts to default for client
+endpoint string default=""
diff --git a/container-search/src/main/resources/configdefinitions/llm-searcher.def b/container-search/src/main/resources/configdefinitions/llm-searcher.def
new file mode 100755
index 00000000000..918a6e6e8b1
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-searcher.def
@@ -0,0 +1,11 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm
+
+# Query propertry prefix for options
+propertyPrefix string default="llm"
+
+# Should the searcher stream tokens or wait for the entire thing?
+stream bool default=true
+
+# The external LLM provider - the id of a LanguageModel component
+providerId string default=""
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
new file mode 100644
index 00000000000..9c7cd6ad064
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
@@ -0,0 +1,175 @@
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ConfigurableLanguageModelTest {
+
+ @Test
+ public void testSyncGeneration() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
+ assertEquals(1, result.size());
+ assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var sb = new StringBuilder();
+ try {
+ var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
+ sb.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+ } finally {
+ executor.shutdownNow();
+ }
+
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ @Test
+ public void testInferenceParameters() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
+ var result = createLLM().complete(prompt, params);
+ assertEquals("Random text about ducks", result.get(0).text());
+ }
+
+ @Test
+ public void testNoApiKey() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key", null);
+ var secrets = createSecretStore(Map.of());
+ assertThrows(IllegalArgumentException.class, () -> {
+ createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
+ });
+ }
+
+ @Test
+ public void testApiKeyFromSecretStore() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key-in-secret-store", null);
+ var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
+ assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
+ }
+
+ private static String lookupParameter(String parameter, Map<String, String> params) {
+ return params.get(parameter);
+ }
+
+ private static InferenceParameters inferenceParams() {
+ return new InferenceParameters(s -> lookupParameter(s, Collections.emptyMap()));
+ }
+
+ private static InferenceParameters inferenceParams(Map<String, String> params) {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
+ }
+
+ private static InferenceParameters inferenceParamsWithDefaultKey() {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Collections.emptyMap()));
+ }
+
+ private LlmClientConfig modelParams(String apiKey, String endpoint) {
+ var config = new LlmClientConfig.Builder();
+ if (apiKey != null) {
+ config.apiKey(apiKey);
+ }
+ if (endpoint != null) {
+ config.endpoint(endpoint);
+ }
+ return config.build();
+ }
+
+ public static SecretStore createSecretStore(Map<String, String> secrets) {
+ Provider<SecretStore> secretStore = new Provider<>() {
+ public SecretStore get() {
+ return new SecretStore() {
+ public String getSecret(String key) {
+ return secrets.get(key);
+ }
+ public String getSecret(String key, int version) {
+ return secrets.get(key);
+ }
+ };
+ }
+ public void deconstruct() {
+ }
+ };
+ return secretStore.get();
+ }
+
+ public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
+ }
+
+ private static MockLLMClient createLLM() {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, null);
+ }
+
+ private static MockLLMClient createLLM(ExecutorService executor) {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
+ var generator = createGenerator();
+ var secretStore = new SecretStoreProvider(); // throws exception on use
+ return createLLM(config, generator, secretStore.get(), executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore) {
+ return createLLM(config, generator, secretStore, null);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore,
+ ExecutorService executor) {
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
new file mode 100644
index 00000000000..f6132f58cbb
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
@@ -0,0 +1,80 @@
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+
+public class MockLLMClient extends ConfigurableLanguageModel {
+
+ public final static String ACCEPTED_API_KEY = "sesame";
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ private Prompt lastPrompt;
+
+ public MockLLMClient(LlmClientConfig config,
+ SecretStore secretStore,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ ExecutorService executor) {
+ super(config, secretStore);
+ this.generator = generator;
+ this.executor = executor;
+ }
+
+ private void checkApiKey(InferenceParameters options) {
+ var apiKey = getApiKey(options);
+ if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) {
+ throw new IllegalArgumentException("Invalid API key");
+ }
+ }
+
+ private void setPrompt(Prompt prompt) {
+ this.lastPrompt = prompt;
+ }
+
+ public Prompt getPrompt() {
+ return this.lastPrompt;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ setApiKey(params);
+ checkApiKey(params);
+ setPrompt(prompt);
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ setPrompt(prompt);
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i=0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+
+ return completionFuture;
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
new file mode 100644
index 00000000000..9207047425b
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
@@ -0,0 +1,35 @@
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+public class OpenAITest {
+
+ private static final String apiKey = "<your-api-key>";
+
+ @Test
+ @Disabled
+ public void testOpenAIGeneration() {
+ var config = new LlmClientConfig.Builder().build();
+ var openai = new OpenAI(config, new SecretStoreProvider().get());
+ var options = Map.of(
+ "maxTokens", "10"
+ );
+
+ var prompt = StringPrompt.from("why are ducks better than cats?");
+ var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
+ System.out.print(completion.text());
+ }).exceptionally(exception -> {
+ System.out.println("Error: " + exception);
+ return null;
+ });
+ future.join();
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java
new file mode 100755
index 00000000000..ec5617891e6
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java
@@ -0,0 +1,254 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
+import ai.vespa.llm.clients.MockLLMClient;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.chain.Chain;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.searchchain.Execution;
+import org.junit.jupiter.api.Test;
+
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+
+public class LLMSearcherTest {
+
+ @Test
+ public void testLLMSelection() {
+ var llm1 = createLLMClient("mock1");
+ var llm2 = createLLMClient("mock2");
+ var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
+ var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
+ var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
+ assertEquals(1, result.getHitCount());
+ assertEquals("My id is mock2", getCompletion(result));
+ }
+
+ @Test
+ public void testGeneration() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of("prompt", "why are ducks better than cats");
+ assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testPrompting() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+
+ // Prompt with prefix
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("llm.prompt", "why are ducks better than cats"))));
+
+ // Prompt without prefix
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("prompt", "why are ducks better than cats"))));
+
+ // Fallback to query if not given
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("query", "why are ducks better than cats"))));
+ }
+
+ @Test
+ public void testPromptEvent() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of(
+ "prompt", "why are ducks better than cats",
+ "traceLevel", "1");
+ var result = runMockSearch(searcher, params);
+ var events = ((EventStream) result.hits().get(0)).incoming().drain();
+ assertEquals(2, events.size());
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("why are ducks better than cats", promptEvent.toString());
+
+ var completionEvent = (EventStream.Event) events.get(1);
+ assertEquals("completion", completionEvent.type());
+ assertEquals("Ducks have adorable waddling walks.", completionEvent.toString());
+ }
+
+ @Test
+ public void testParameters() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of(
+ "llm.prompt", "why are ducks better than cats",
+ "llm.temperature", "1.0",
+ "llm.maxTokens", "5"
+ );
+ assertEquals("Random text about ducks vs", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testParameterPrefix() {
+ var prefix = "foo";
+ var params = Map.of(
+ "foo.prompt", "what is your opinion on cats",
+ "foo.maxTokens", "5"
+ );
+ var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
+ assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testApiKeyFromHeader() {
+ var properties = Map.of("prompt", "why are ducks better than cats");
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var sb = new StringBuilder();
+ try {
+ var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock").build(); // config says don't stream...
+ var params = Map.of(
+ "llm.stream", "true", // ... but inference parameters says do it anyway
+ "llm.prompt", "why are ducks better than cats?"
+ );
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
+ Result result = runMockSearch(searcher, params);
+
+ assertEquals(1, result.getHitCount());
+ assertTrue(result.hits().get(0) instanceof EventStream);
+ EventStream eventStream = (EventStream) result.hits().get(0);
+
+ var incoming = eventStream.incoming();
+ incoming.addNewDataListener(() -> {
+ incoming.drain().forEach(event -> sb.append(event.toString()));
+ }, executor);
+
+ incoming.completedFuture().join();
+ assertTrue(incoming.isComplete());
+
+ // Ensure incoming has been fully drained to avoid race condition in this test
+ incoming.drain().forEach(event -> sb.append(event.toString()));
+
+ } finally {
+ executor.shutdownNow();
+ }
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ private static String getCompletion(Result result) {
+ assertTrue(result.hits().size() >= 1);
+ return ((EventStream) result.hits().get(0)).incoming().drain().get(0).toString();
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
+ return runMockSearch(searcher, parameters, null, "");
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
+ Chain<Searcher> chain = new Chain<>(searcher);
+ Execution execution = new Execution(chain, Execution.Context.createContextStub());
+ Query query = new Query("?" + toUrlParams(parameters));
+ if (apiKey != null) {
+ String headerKey = "X-LLM-API-KEY";
+ if (prefix != null && ! prefix.isEmpty()) {
+ headerKey = prefix + "." + headerKey;
+ }
+ query.getHttpRequest().getJDiscRequest().headers().add(headerKey, apiKey);
+ }
+ return execution.search(query);
+ }
+
+ public static String toUrlParams(Map<String, String> parameters) {
+ return parameters.entrySet().stream().map(
+ e -> e.getKey() + "=" + URLEncoder.encode(e.getValue(), StandardCharsets.UTF_8)
+ ).collect(Collectors.joining("&"));
+ }
+
+ private static BiFunction<Prompt, InferenceParameters, String> createIdGenerator(String id) {
+ return (prompt, options) -> {
+ if (id == null || id.isEmpty())
+ return "I have no ID";
+ return "My id is " + id;
+ };
+ }
+
+ private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return ConfigurableLanguageModelTest.createGenerator();
+ }
+
+ static MockLLMClient createLLMClient() {
+ var config = new LlmClientConfig.Builder().apiKey("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, null);
+ }
+
+ static MockLLMClient createLLMClient(String id) {
+ var config = new LlmClientConfig.Builder().apiKey("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createIdGenerator(id);
+ return new MockLLMClient(config, secretStore, generator, null);
+ }
+
+ static MockLLMClient createLLMClient(ExecutorService executor) {
+ var config = new LlmClientConfig.Builder().apiKey("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+ static MockLLMClient createLLMClientWithoutSecretStore() {
+ var config = new LlmClientConfig.Builder().apiKey("api-key").build();
+ var secretStore = new SecretStoreProvider();
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore.get(), generator, null);
+ }
+
+ private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
+ var config = new LlmSearcherConfig.Builder().stream(false).build();
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcherImpl(config, models);
+ }
+
+ private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcherImpl(config, models);
+ }
+
+ public static class LLMSearcherImpl extends LLMSearcher {
+
+ public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ super(config, languageModels);
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ return complete(query, StringPrompt.from(getPrompt(query)));
+ }
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java
new file mode 100755
index 00000000000..ccf9a4a6401
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java
@@ -0,0 +1,127 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmSearcherConfig;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.chain.Chain;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.searchchain.Execution;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+
+public class RAGSearcherTest {
+
+ private static final String DOC1_TITLE = "Exploring the Delightful Qualities of Ducks";
+ private static final String DOC1_CONTENT = "Ducks, with their gentle quacks and adorable waddling walks, possess a unique " +
+ "charm that sets them apart as extraordinary pets.";
+ private static final String DOC2_TITLE = "Why Cats Reign Supreme";
+ private static final String DOC2_CONTENT = "Cats bring an enchanting allure to households with their independent " +
+ "companionship, playful nature, natural hunting abilities, low-maintenance grooming, and the " +
+ "emotional support they offer.";
+
+ @Test
+ public void testRAGGeneration() {
+ var eventStream = runRAGQuery(Map.of(
+ "prompt", "why are ducks better than cats?",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+ assertEquals(2, events.size());
+
+ // Generated prompt
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("title: " + DOC1_TITLE + "\n" +
+ "content: " + DOC1_CONTENT + "\n\n" +
+ "title: " + DOC2_TITLE + "\n" +
+ "content: " + DOC2_CONTENT + "\n\n\n" +
+ "why are ducks better than cats?", promptEvent.toString());
+
+ // Generated completion
+ var completionEvent = (EventStream.Event) events.get(1);
+ assertEquals("completion", completionEvent.type());
+ assertEquals("Ducks have adorable waddling walks.", completionEvent.toString());
+ }
+
+ @Test
+ public void testPromptGeneration() {
+ var eventStream = runRAGQuery(Map.of(
+ "query", "why are ducks better than cats?",
+ "prompt", "{context}\nGiven these documents, answer this query as concisely as possible: @query",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("title: " + DOC1_TITLE + "\n" +
+ "content: " + DOC1_CONTENT + "\n\n" +
+ "title: " + DOC2_TITLE + "\n" +
+ "content: " + DOC2_CONTENT + "\n\n\n" +
+ "Given these documents, answer this query as concisely as possible: " +
+ "why are ducks better than cats?", promptEvent.toString());
+ }
+
+ @Test
+ public void testSkipContextInPrompt() {
+ var eventStream = runRAGQuery(Map.of(
+ "query", "why are ducks better than cats?",
+ "llm.context", "skip",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("why are ducks better than cats?", promptEvent.toString());
+ }
+
+ public static class MockSearchResults extends Searcher {
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Hit hit1 = new Hit("1");
+ hit1.setField("title", DOC1_TITLE);
+ hit1.setField("content", DOC1_CONTENT);
+
+ Hit hit2 = new Hit("2");
+ hit2.setField("title", DOC2_TITLE);
+ hit2.setField("content", DOC2_CONTENT);
+
+ Result r = new Result(query);
+ r.hits().add(hit1);
+ r.hits().add(hit2);
+ return r;
+ }
+ }
+
+ private EventStream runRAGQuery(Map<String, String> params) {
+ var llm = LLMSearcherTest.createLLMClient();
+ var searcher = createRAGSearcher(Map.of("mock", llm));
+ var result = runMockSearch(searcher, params);
+ return (EventStream) result.hits().get(0);
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
+ Chain<Searcher> chain = new Chain<>(searcher, new MockSearchResults());
+ Execution execution = new Execution(chain, Execution.Context.createContextStub());
+ Query query = new Query("?" + LLMSearcherTest.toUrlParams(parameters));
+ return execution.search(query);
+ }
+
+ private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
+ var config = new LlmSearcherConfig.Builder().stream(false).build();
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new RAGSearcher(config, models);
+ }
+
+}