aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/ai/vespa/llm/search
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/test/java/ai/vespa/llm/search')
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java254
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java127
2 files changed, 381 insertions, 0 deletions
diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java
new file mode 100755
index 00000000000..d4f1dbc00a4
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java
@@ -0,0 +1,254 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
+import ai.vespa.llm.clients.MockLLMClient;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.chain.Chain;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.searchchain.Execution;
+import org.junit.jupiter.api.Test;
+
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+
+public class LLMSearcherTest {
+
+ @Test
+ public void testLLMSelection() {
+ var llm1 = createLLMClient("mock1");
+ var llm2 = createLLMClient("mock2");
+ var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
+ var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
+ var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
+ assertEquals(1, result.getHitCount());
+ assertEquals("My id is mock2", getCompletion(result));
+ }
+
+ @Test
+ public void testGeneration() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of("prompt", "why are ducks better than cats");
+ assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testPrompting() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+
+ // Prompt with prefix
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("llm.prompt", "why are ducks better than cats"))));
+
+ // Prompt without prefix
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("prompt", "why are ducks better than cats"))));
+
+ // Fallback to query if not given
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("query", "why are ducks better than cats"))));
+ }
+
+ @Test
+ public void testPromptEvent() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of(
+ "prompt", "why are ducks better than cats",
+ "traceLevel", "1");
+ var result = runMockSearch(searcher, params);
+ var events = ((EventStream) result.hits().get(0)).incoming().drain();
+ assertEquals(2, events.size());
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("why are ducks better than cats", promptEvent.toString());
+
+ var completionEvent = (EventStream.Event) events.get(1);
+ assertEquals("completion", completionEvent.type());
+ assertEquals("Ducks have adorable waddling walks.", completionEvent.toString());
+ }
+
+ @Test
+ public void testParameters() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of(
+ "llm.prompt", "why are ducks better than cats",
+ "llm.temperature", "1.0",
+ "llm.maxTokens", "5"
+ );
+ assertEquals("Random text about ducks vs", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testParameterPrefix() {
+ var prefix = "foo";
+ var params = Map.of(
+ "foo.prompt", "what is your opinion on cats",
+ "foo.maxTokens", "5"
+ );
+ var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
+ assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testApiKeyFromHeader() {
+ var properties = Map.of("prompt", "why are ducks better than cats");
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var sb = new StringBuilder();
+ try {
+ var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock").build(); // config says don't stream...
+ var params = Map.of(
+ "llm.stream", "true", // ... but inference parameters says do it anyway
+ "llm.prompt", "why are ducks better than cats?"
+ );
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
+ Result result = runMockSearch(searcher, params);
+
+ assertEquals(1, result.getHitCount());
+ assertTrue(result.hits().get(0) instanceof EventStream);
+ EventStream eventStream = (EventStream) result.hits().get(0);
+
+ var incoming = eventStream.incoming();
+ incoming.addNewDataListener(() -> {
+ incoming.drain().forEach(event -> sb.append(event.toString()));
+ }, executor);
+
+ incoming.completedFuture().join();
+ assertTrue(incoming.isComplete());
+
+ // Ensure incoming has been fully drained to avoid race condition in this test
+ incoming.drain().forEach(event -> sb.append(event.toString()));
+
+ } finally {
+ executor.shutdownNow();
+ }
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ private static String getCompletion(Result result) {
+ assertTrue(result.hits().size() >= 1);
+ return ((EventStream) result.hits().get(0)).incoming().drain().get(0).toString();
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
+ return runMockSearch(searcher, parameters, null, "");
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
+ Chain<Searcher> chain = new Chain<>(searcher);
+ Execution execution = new Execution(chain, Execution.Context.createContextStub());
+ Query query = new Query("?" + toUrlParams(parameters));
+ if (apiKey != null) {
+ String headerKey = "X-LLM-API-KEY";
+ if (prefix != null && ! prefix.isEmpty()) {
+ headerKey = prefix + "." + headerKey;
+ }
+ query.getHttpRequest().getJDiscRequest().headers().add(headerKey, apiKey);
+ }
+ return execution.search(query);
+ }
+
+ public static String toUrlParams(Map<String, String> parameters) {
+ return parameters.entrySet().stream().map(
+ e -> e.getKey() + "=" + URLEncoder.encode(e.getValue(), StandardCharsets.UTF_8)
+ ).collect(Collectors.joining("&"));
+ }
+
+ private static BiFunction<Prompt, InferenceParameters, String> createIdGenerator(String id) {
+ return (prompt, options) -> {
+ if (id == null || id.isEmpty())
+ return "I have no ID";
+ return "My id is " + id;
+ };
+ }
+
+ private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return ConfigurableLanguageModelTest.createGenerator();
+ }
+
+ static MockLLMClient createLLMClient() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, null);
+ }
+
+ static MockLLMClient createLLMClient(String id) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createIdGenerator(id);
+ return new MockLLMClient(config, secretStore, generator, null);
+ }
+
+ static MockLLMClient createLLMClient(ExecutorService executor) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+ static MockLLMClient createLLMClientWithoutSecretStore() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = new SecretStoreProvider();
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore.get(), generator, null);
+ }
+
+ private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
+ var config = new LlmSearcherConfig.Builder().stream(false).build();
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcherImpl(config, models);
+ }
+
+ private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcherImpl(config, models);
+ }
+
+ public static class LLMSearcherImpl extends LLMSearcher {
+
+ public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ super(config, languageModels);
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ return complete(query, StringPrompt.from(getPrompt(query)));
+ }
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java
new file mode 100755
index 00000000000..ccf9a4a6401
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java
@@ -0,0 +1,127 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmSearcherConfig;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.chain.Chain;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.searchchain.Execution;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+
+public class RAGSearcherTest {
+
+ private static final String DOC1_TITLE = "Exploring the Delightful Qualities of Ducks";
+ private static final String DOC1_CONTENT = "Ducks, with their gentle quacks and adorable waddling walks, possess a unique " +
+ "charm that sets them apart as extraordinary pets.";
+ private static final String DOC2_TITLE = "Why Cats Reign Supreme";
+ private static final String DOC2_CONTENT = "Cats bring an enchanting allure to households with their independent " +
+ "companionship, playful nature, natural hunting abilities, low-maintenance grooming, and the " +
+ "emotional support they offer.";
+
+ @Test
+ public void testRAGGeneration() {
+ var eventStream = runRAGQuery(Map.of(
+ "prompt", "why are ducks better than cats?",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+ assertEquals(2, events.size());
+
+ // Generated prompt
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("title: " + DOC1_TITLE + "\n" +
+ "content: " + DOC1_CONTENT + "\n\n" +
+ "title: " + DOC2_TITLE + "\n" +
+ "content: " + DOC2_CONTENT + "\n\n\n" +
+ "why are ducks better than cats?", promptEvent.toString());
+
+ // Generated completion
+ var completionEvent = (EventStream.Event) events.get(1);
+ assertEquals("completion", completionEvent.type());
+ assertEquals("Ducks have adorable waddling walks.", completionEvent.toString());
+ }
+
+ @Test
+ public void testPromptGeneration() {
+ var eventStream = runRAGQuery(Map.of(
+ "query", "why are ducks better than cats?",
+ "prompt", "{context}\nGiven these documents, answer this query as concisely as possible: @query",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("title: " + DOC1_TITLE + "\n" +
+ "content: " + DOC1_CONTENT + "\n\n" +
+ "title: " + DOC2_TITLE + "\n" +
+ "content: " + DOC2_CONTENT + "\n\n\n" +
+ "Given these documents, answer this query as concisely as possible: " +
+ "why are ducks better than cats?", promptEvent.toString());
+ }
+
+ @Test
+ public void testSkipContextInPrompt() {
+ var eventStream = runRAGQuery(Map.of(
+ "query", "why are ducks better than cats?",
+ "llm.context", "skip",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("why are ducks better than cats?", promptEvent.toString());
+ }
+
+ public static class MockSearchResults extends Searcher {
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Hit hit1 = new Hit("1");
+ hit1.setField("title", DOC1_TITLE);
+ hit1.setField("content", DOC1_CONTENT);
+
+ Hit hit2 = new Hit("2");
+ hit2.setField("title", DOC2_TITLE);
+ hit2.setField("content", DOC2_CONTENT);
+
+ Result r = new Result(query);
+ r.hits().add(hit1);
+ r.hits().add(hit2);
+ return r;
+ }
+ }
+
+ private EventStream runRAGQuery(Map<String, String> params) {
+ var llm = LLMSearcherTest.createLLMClient();
+ var searcher = createRAGSearcher(Map.of("mock", llm));
+ var result = runMockSearch(searcher, params);
+ return (EventStream) result.hits().get(0);
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
+ Chain<Searcher> chain = new Chain<>(searcher, new MockSearchResults());
+ Execution execution = new Execution(chain, Execution.Context.createContextStub());
+ Query query = new Query("?" + LLMSearcherTest.toUrlParams(parameters));
+ return execution.search(query);
+ }
+
+ private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
+ var config = new LlmSearcherConfig.Builder().stream(false).build();
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new RAGSearcher(config, models);
+ }
+
+}