diff options
Diffstat (limited to 'container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java')
-rwxr-xr-x | container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java | 254 |
1 files changed, 254 insertions, 0 deletions
diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java new file mode 100755 index 00000000000..d4f1dbc00a4 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java @@ -0,0 +1,254 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.clients.ConfigurableLanguageModelTest; +import ai.vespa.llm.clients.MockLLMClient; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.component.ComponentId; +import com.yahoo.component.chain.Chain; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.searchchain.Execution; +import org.junit.jupiter.api.Test; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class LLMSearcherTest { + + @Test + public void testLLMSelection() { + var llm1 = createLLMClient("mock1"); + var llm2 = createLLMClient("mock2"); + var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); + var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); + assertEquals(1, result.getHitCount()); + assertEquals("My id is mock2", getCompletion(result)); + } + + @Test + public void testGeneration() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of("prompt", "why are ducks better than cats"); + assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testPrompting() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + + // Prompt with prefix + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("llm.prompt", "why are ducks better than cats")))); + + // Prompt without prefix + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("prompt", "why are ducks better than cats")))); + + // Fallback to query if not given + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("query", "why are ducks better than cats")))); + } + + @Test + public void testPromptEvent() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of( + "prompt", "why are ducks better than cats", + "traceLevel", "1"); + var result = runMockSearch(searcher, params); + var events = ((EventStream) result.hits().get(0)).incoming().drain(); + assertEquals(2, events.size()); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("why are ducks better than cats", promptEvent.toString()); + + var completionEvent = (EventStream.Event) events.get(1); + assertEquals("completion", completionEvent.type()); + assertEquals("Ducks have adorable waddling walks.", completionEvent.toString()); + } + + @Test + public void testParameters() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of( + "llm.prompt", "why are ducks better than cats", + "llm.temperature", "1.0", + "llm.maxTokens", "5" + ); + assertEquals("Random text about ducks vs", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testParameterPrefix() { + var prefix = "foo"; + var params = Map.of( + "foo.prompt", "what is your opinion on cats", + "foo.maxTokens", "5" + ); + var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testApiKeyFromHeader() { + var properties = Map.of("prompt", "why are ducks better than cats"); + var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var sb = new StringBuilder(); + try { + var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock").build(); // config says don't stream... + var params = Map.of( + "llm.stream", "true", // ... but inference parameters says do it anyway + "llm.prompt", "why are ducks better than cats?" + ); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + Result result = runMockSearch(searcher, params); + + assertEquals(1, result.getHitCount()); + assertTrue(result.hits().get(0) instanceof EventStream); + EventStream eventStream = (EventStream) result.hits().get(0); + + var incoming = eventStream.incoming(); + incoming.addNewDataListener(() -> { + incoming.drain().forEach(event -> sb.append(event.toString())); + }, executor); + + incoming.completedFuture().join(); + assertTrue(incoming.isComplete()); + + // Ensure incoming has been fully drained to avoid race condition in this test + incoming.drain().forEach(event -> sb.append(event.toString())); + + } finally { + executor.shutdownNow(); + } + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + private static String getCompletion(Result result) { + assertTrue(result.hits().size() >= 1); + return ((EventStream) result.hits().get(0)).incoming().drain().get(0).toString(); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters) { + return runMockSearch(searcher, parameters, null, ""); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { + Chain<Searcher> chain = new Chain<>(searcher); + Execution execution = new Execution(chain, Execution.Context.createContextStub()); + Query query = new Query("?" + toUrlParams(parameters)); + if (apiKey != null) { + String headerKey = "X-LLM-API-KEY"; + if (prefix != null && ! prefix.isEmpty()) { + headerKey = prefix + "." + headerKey; + } + query.getHttpRequest().getJDiscRequest().headers().add(headerKey, apiKey); + } + return execution.search(query); + } + + public static String toUrlParams(Map<String, String> parameters) { + return parameters.entrySet().stream().map( + e -> e.getKey() + "=" + URLEncoder.encode(e.getValue(), StandardCharsets.UTF_8) + ).collect(Collectors.joining("&")); + } + + private static BiFunction<Prompt, InferenceParameters, String> createIdGenerator(String id) { + return (prompt, options) -> { + if (id == null || id.isEmpty()) + return "I have no ID"; + return "My id is " + id; + }; + } + + private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { + return ConfigurableLanguageModelTest.createGenerator(); + } + + static MockLLMClient createLLMClient() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, null); + } + + static MockLLMClient createLLMClient(String id) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createIdGenerator(id); + return new MockLLMClient(config, secretStore, generator, null); + } + + static MockLLMClient createLLMClient(ExecutorService executor) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, executor); + } + + static MockLLMClient createLLMClientWithoutSecretStore() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = new SecretStoreProvider(); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore.get(), generator, null); + } + + private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { + var config = new LlmSearcherConfig.Builder().stream(false).build(); + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcherImpl(config, models); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcherImpl(config, models); + } + + public static class LLMSearcherImpl extends LLMSearcher { + + public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + super(config, languageModels); + } + + @Override + public Result search(Query query, Execution execution) { + return complete(query, StringPrompt.from(getPrompt(query))); + } + } + +} |