diff options
Diffstat (limited to 'container-search/src/test/java/ai/vespa')
5 files changed, 674 insertions, 0 deletions
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java new file mode 100644 index 00000000000..1f2a12322a1 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java @@ -0,0 +1,176 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.container.jdisc.secretstore.SecretStore; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigurableLanguageModelTest { + + @Test + public void testSyncGeneration() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); + assertEquals(1, result.size()); + assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var sb = new StringBuilder(); + try { + var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { + sb.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + } finally { + executor.shutdownNow(); + } + + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + @Test + public void testInferenceParameters() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); + var result = createLLM().complete(prompt, params); + assertEquals("Random text about ducks", result.get(0).text()); + } + + @Test + public void testNoApiKey() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key", null); + var secrets = createSecretStore(Map.of()); + assertThrows(IllegalArgumentException.class, () -> { + createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); + }); + } + + @Test + public void testApiKeyFromSecretStore() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key-in-secret-store", null); + var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); + assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); + } + + private static String lookupParameter(String parameter, Map<String, String> params) { + return params.get(parameter); + } + + private static InferenceParameters inferenceParams() { + return new InferenceParameters(s -> lookupParameter(s, Collections.emptyMap())); + } + + private static InferenceParameters inferenceParams(Map<String, String> params) { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); + } + + private static InferenceParameters inferenceParamsWithDefaultKey() { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Collections.emptyMap())); + } + + private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { + var config = new LlmClientConfig.Builder(); + if (apiKeySecretName != null) { + config.apiKeySecretName(apiKeySecretName); + } + if (endpoint != null) { + config.endpoint(endpoint); + } + return config.build(); + } + + public static SecretStore createSecretStore(Map<String, String> secrets) { + Provider<SecretStore> secretStore = new Provider<>() { + public SecretStore get() { + return new SecretStore() { + public String getSecret(String key) { + return secrets.get(key); + } + public String getSecret(String key, int version) { + return secrets.get(key); + } + }; + } + public void deconstruct() { + } + }; + return secretStore.get(); + } + + public static BiFunction<Prompt, InferenceParameters, String> createGenerator() { + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; + } + + private static MockLLMClient createLLM() { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, null); + } + + private static MockLLMClient createLLM(ExecutorService executor) { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { + var generator = createGenerator(); + var secretStore = new SecretStoreProvider(); // throws exception on use + return createLLM(config, generator, secretStore.get(), executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction<Prompt, InferenceParameters, String> generator, + SecretStore secretStore) { + return createLLM(config, generator, secretStore, null); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction<Prompt, InferenceParameters, String> generator, + SecretStore secretStore, + ExecutorService executor) { + return new MockLLMClient(config, secretStore, generator, executor); + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java new file mode 100644 index 00000000000..cfb6a43984f --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java @@ -0,0 +1,81 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class MockLLMClient extends ConfigurableLanguageModel { + + public final static String ACCEPTED_API_KEY = "sesame"; + + private final ExecutorService executor; + private final BiFunction<Prompt, InferenceParameters, String> generator; + + private Prompt lastPrompt; + + public MockLLMClient(LlmClientConfig config, + SecretStore secretStore, + BiFunction<Prompt, InferenceParameters, String> generator, + ExecutorService executor) { + super(config, secretStore); + this.generator = generator; + this.executor = executor; + } + + private void checkApiKey(InferenceParameters options) { + var apiKey = getApiKey(options); + if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { + throw new IllegalArgumentException("Invalid API key"); + } + } + + private void setPrompt(Prompt prompt) { + this.lastPrompt = prompt; + } + + public Prompt getPrompt() { + return this.lastPrompt; + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters params) { + setApiKey(params); + checkApiKey(params); + setPrompt(prompt); + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters params, + Consumer<Completion> consumer) { + setPrompt(prompt); + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i=0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + + return completionFuture; + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java new file mode 100644 index 00000000000..1111a9824f5 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -0,0 +1,36 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.jdisc.SecretStoreProvider; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +public class OpenAITest { + + private static final String apiKey = "<your-api-key>"; + + @Test + @Disabled + public void testOpenAIGeneration() { + var config = new LlmClientConfig.Builder().build(); + var openai = new OpenAI(config, new SecretStoreProvider().get()); + var options = Map.of( + "maxTokens", "10" + ); + + var prompt = StringPrompt.from("why are ducks better than cats?"); + var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { + System.out.print(completion.text()); + }).exceptionally(exception -> { + System.out.println("Error: " + exception); + return null; + }); + future.join(); + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java new file mode 100755 index 00000000000..d4f1dbc00a4 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java @@ -0,0 +1,254 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.clients.ConfigurableLanguageModelTest; +import ai.vespa.llm.clients.MockLLMClient; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.component.ComponentId; +import com.yahoo.component.chain.Chain; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.searchchain.Execution; +import org.junit.jupiter.api.Test; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class LLMSearcherTest { + + @Test + public void testLLMSelection() { + var llm1 = createLLMClient("mock1"); + var llm2 = createLLMClient("mock2"); + var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); + var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); + assertEquals(1, result.getHitCount()); + assertEquals("My id is mock2", getCompletion(result)); + } + + @Test + public void testGeneration() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of("prompt", "why are ducks better than cats"); + assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testPrompting() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + + // Prompt with prefix + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("llm.prompt", "why are ducks better than cats")))); + + // Prompt without prefix + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("prompt", "why are ducks better than cats")))); + + // Fallback to query if not given + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("query", "why are ducks better than cats")))); + } + + @Test + public void testPromptEvent() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of( + "prompt", "why are ducks better than cats", + "traceLevel", "1"); + var result = runMockSearch(searcher, params); + var events = ((EventStream) result.hits().get(0)).incoming().drain(); + assertEquals(2, events.size()); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("why are ducks better than cats", promptEvent.toString()); + + var completionEvent = (EventStream.Event) events.get(1); + assertEquals("completion", completionEvent.type()); + assertEquals("Ducks have adorable waddling walks.", completionEvent.toString()); + } + + @Test + public void testParameters() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of( + "llm.prompt", "why are ducks better than cats", + "llm.temperature", "1.0", + "llm.maxTokens", "5" + ); + assertEquals("Random text about ducks vs", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testParameterPrefix() { + var prefix = "foo"; + var params = Map.of( + "foo.prompt", "what is your opinion on cats", + "foo.maxTokens", "5" + ); + var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testApiKeyFromHeader() { + var properties = Map.of("prompt", "why are ducks better than cats"); + var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var sb = new StringBuilder(); + try { + var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock").build(); // config says don't stream... + var params = Map.of( + "llm.stream", "true", // ... but inference parameters says do it anyway + "llm.prompt", "why are ducks better than cats?" + ); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + Result result = runMockSearch(searcher, params); + + assertEquals(1, result.getHitCount()); + assertTrue(result.hits().get(0) instanceof EventStream); + EventStream eventStream = (EventStream) result.hits().get(0); + + var incoming = eventStream.incoming(); + incoming.addNewDataListener(() -> { + incoming.drain().forEach(event -> sb.append(event.toString())); + }, executor); + + incoming.completedFuture().join(); + assertTrue(incoming.isComplete()); + + // Ensure incoming has been fully drained to avoid race condition in this test + incoming.drain().forEach(event -> sb.append(event.toString())); + + } finally { + executor.shutdownNow(); + } + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + private static String getCompletion(Result result) { + assertTrue(result.hits().size() >= 1); + return ((EventStream) result.hits().get(0)).incoming().drain().get(0).toString(); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters) { + return runMockSearch(searcher, parameters, null, ""); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { + Chain<Searcher> chain = new Chain<>(searcher); + Execution execution = new Execution(chain, Execution.Context.createContextStub()); + Query query = new Query("?" + toUrlParams(parameters)); + if (apiKey != null) { + String headerKey = "X-LLM-API-KEY"; + if (prefix != null && ! prefix.isEmpty()) { + headerKey = prefix + "." + headerKey; + } + query.getHttpRequest().getJDiscRequest().headers().add(headerKey, apiKey); + } + return execution.search(query); + } + + public static String toUrlParams(Map<String, String> parameters) { + return parameters.entrySet().stream().map( + e -> e.getKey() + "=" + URLEncoder.encode(e.getValue(), StandardCharsets.UTF_8) + ).collect(Collectors.joining("&")); + } + + private static BiFunction<Prompt, InferenceParameters, String> createIdGenerator(String id) { + return (prompt, options) -> { + if (id == null || id.isEmpty()) + return "I have no ID"; + return "My id is " + id; + }; + } + + private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { + return ConfigurableLanguageModelTest.createGenerator(); + } + + static MockLLMClient createLLMClient() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, null); + } + + static MockLLMClient createLLMClient(String id) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createIdGenerator(id); + return new MockLLMClient(config, secretStore, generator, null); + } + + static MockLLMClient createLLMClient(ExecutorService executor) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, executor); + } + + static MockLLMClient createLLMClientWithoutSecretStore() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = new SecretStoreProvider(); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore.get(), generator, null); + } + + private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { + var config = new LlmSearcherConfig.Builder().stream(false).build(); + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcherImpl(config, models); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcherImpl(config, models); + } + + public static class LLMSearcherImpl extends LLMSearcher { + + public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + super(config, languageModels); + } + + @Override + public Result search(Query query, Execution execution) { + return complete(query, StringPrompt.from(getPrompt(query))); + } + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java new file mode 100755 index 00000000000..ccf9a4a6401 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java @@ -0,0 +1,127 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmSearcherConfig; +import com.yahoo.component.ComponentId; +import com.yahoo.component.chain.Chain; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.result.Hit; +import com.yahoo.search.searchchain.Execution; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +public class RAGSearcherTest { + + private static final String DOC1_TITLE = "Exploring the Delightful Qualities of Ducks"; + private static final String DOC1_CONTENT = "Ducks, with their gentle quacks and adorable waddling walks, possess a unique " + + "charm that sets them apart as extraordinary pets."; + private static final String DOC2_TITLE = "Why Cats Reign Supreme"; + private static final String DOC2_CONTENT = "Cats bring an enchanting allure to households with their independent " + + "companionship, playful nature, natural hunting abilities, low-maintenance grooming, and the " + + "emotional support they offer."; + + @Test + public void testRAGGeneration() { + var eventStream = runRAGQuery(Map.of( + "prompt", "why are ducks better than cats?", + "traceLevel", "1")); + var events = eventStream.incoming().drain(); + assertEquals(2, events.size()); + + // Generated prompt + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("title: " + DOC1_TITLE + "\n" + + "content: " + DOC1_CONTENT + "\n\n" + + "title: " + DOC2_TITLE + "\n" + + "content: " + DOC2_CONTENT + "\n\n\n" + + "why are ducks better than cats?", promptEvent.toString()); + + // Generated completion + var completionEvent = (EventStream.Event) events.get(1); + assertEquals("completion", completionEvent.type()); + assertEquals("Ducks have adorable waddling walks.", completionEvent.toString()); + } + + @Test + public void testPromptGeneration() { + var eventStream = runRAGQuery(Map.of( + "query", "why are ducks better than cats?", + "prompt", "{context}\nGiven these documents, answer this query as concisely as possible: @query", + "traceLevel", "1")); + var events = eventStream.incoming().drain(); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("title: " + DOC1_TITLE + "\n" + + "content: " + DOC1_CONTENT + "\n\n" + + "title: " + DOC2_TITLE + "\n" + + "content: " + DOC2_CONTENT + "\n\n\n" + + "Given these documents, answer this query as concisely as possible: " + + "why are ducks better than cats?", promptEvent.toString()); + } + + @Test + public void testSkipContextInPrompt() { + var eventStream = runRAGQuery(Map.of( + "query", "why are ducks better than cats?", + "llm.context", "skip", + "traceLevel", "1")); + var events = eventStream.incoming().drain(); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("why are ducks better than cats?", promptEvent.toString()); + } + + public static class MockSearchResults extends Searcher { + + @Override + public Result search(Query query, Execution execution) { + Hit hit1 = new Hit("1"); + hit1.setField("title", DOC1_TITLE); + hit1.setField("content", DOC1_CONTENT); + + Hit hit2 = new Hit("2"); + hit2.setField("title", DOC2_TITLE); + hit2.setField("content", DOC2_CONTENT); + + Result r = new Result(query); + r.hits().add(hit1); + r.hits().add(hit2); + return r; + } + } + + private EventStream runRAGQuery(Map<String, String> params) { + var llm = LLMSearcherTest.createLLMClient(); + var searcher = createRAGSearcher(Map.of("mock", llm)); + var result = runMockSearch(searcher, params); + return (EventStream) result.hits().get(0); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters) { + Chain<Searcher> chain = new Chain<>(searcher, new MockSearchResults()); + Execution execution = new Execution(chain, Execution.Context.createContextStub()); + Query query = new Query("?" + LLMSearcherTest.toUrlParams(parameters)); + return execution.search(query); + } + + private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) { + var config = new LlmSearcherConfig.Builder().stream(false).build(); + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new RAGSearcher(config, models); + } + +} |