diff options
Diffstat (limited to 'container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java')
-rwxr-xr-x | container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java | 150 |
1 files changed, 109 insertions, 41 deletions
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java index 1efcf1c736a..3baa9715c34 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java @@ -3,14 +3,11 @@ package ai.vespa.search.llm; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.clients.ConfigurableLanguageModelTest; -import ai.vespa.llm.clients.LlmClientConfig; -import ai.vespa.llm.clients.MockLLMClient; +import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; @@ -20,10 +17,14 @@ import org.junit.jupiter.api.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -36,10 +37,10 @@ public class LLMSearcherTest { @Test public void testLLMSelection() { - var llm1 = createLLMClient("mock1"); - var llm2 = createLLMClient("mock2"); + var client1 = createLLMClient("mock1"); + var client2 = createLLMClient("mock2"); var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); - var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2)); var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); assertEquals(1, result.getHitCount()); assertEquals("My id is mock2", getCompletion(result)); @@ -47,14 +48,16 @@ public class LLMSearcherTest { @Test public void testGeneration() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of("prompt", "why are ducks better than cats"); assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); } @Test public void testPrompting() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); // Prompt with prefix assertEquals("Ducks have adorable waddling walks.", @@ -71,7 +74,8 @@ public class LLMSearcherTest { @Test public void testPromptEvent() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "prompt", "why are ducks better than cats", "traceLevel", "1"); @@ -90,7 +94,8 @@ public class LLMSearcherTest { @Test public void testParameters() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "llm.prompt", "why are ducks better than cats", "llm.temperature", "1.0", @@ -107,16 +112,18 @@ public class LLMSearcherTest { "foo.maxTokens", "5" ); var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(config, client); assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); } @Test public void testApiKeyFromHeader() { var properties = Map.of("prompt", "why are ducks better than cats"); - var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); - assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); - assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + var client = createLLMClient(createApiKeyGenerator("a_valid_key")); + var searcher = createLLMSearcher(client); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key")); } @Test @@ -129,7 +136,8 @@ public class LLMSearcherTest { "llm.stream", "true", // ... but inference parameters says do it anyway "llm.prompt", "why are ducks better than cats?" ); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + var client = createLLMClient(executor); + var searcher = createLLMSearcher(config, client); Result result = runMockSearch(searcher, params); assertEquals(1, result.getHitCount()); @@ -162,6 +170,10 @@ public class LLMSearcherTest { return runMockSearch(searcher, parameters, null, ""); } + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) { + return runMockSearch(searcher, parameters, apiKey, "llm"); + } + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { Chain<Searcher> chain = new Chain<>(searcher); Execution execution = new Execution(chain, Execution.Context.createContextStub()); @@ -191,43 +203,59 @@ public class LLMSearcherTest { } private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { - return ConfigurableLanguageModelTest.createGenerator(); + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; } - static MockLLMClient createLLMClient() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, null); + private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) { + return (prompt, options) -> { + if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) { + throw new IllegalArgumentException("Invalid API key"); + } + return "Ok"; + }; + } + + static MockLLM createLLMClient() { + return new MockLLM(createGenerator(), null); } - static MockLLMClient createLLMClient(String id) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createIdGenerator(id); - return new MockLLMClient(config, secretStore, generator, null); + static MockLLM createLLMClient(String id) { + return new MockLLM(createIdGenerator(id), null); } - static MockLLMClient createLLMClient(ExecutorService executor) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, executor); + static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) { + return new MockLLM(generator, null); } - static MockLLMClient createLLMClientWithoutSecretStore() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = new SecretStoreProvider(); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore.get(), generator, null); + static MockLLM createLLMClient(ExecutorService executor) { + return new MockLLM(createGenerator(), executor); + } + + private static Searcher createLLMSearcher(LanguageModel llm) { + return createLLMSearcher(Map.of("mock", llm)); } private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); - ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); - llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); - models.freeze(); - return new LLMSearcher(config, models); + return createLLMSearcher(config, llms); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) { + return createLLMSearcher(config, Map.of("mock", llm)); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { @@ -237,4 +265,44 @@ public class LLMSearcherTest { return new LLMSearcher(config, models); } + private static class MockLLM implements LanguageModel { + + private final ExecutorService executor; + private final BiFunction<Prompt, InferenceParameters, String> generator; + + public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) { + this.executor = executor; + this.generator = generator; + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters params) { + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters params, + Consumer<Completion> consumer) { + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i = 0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); + Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + return completionFuture; + } + + } + } |