From e11438c6335038f6d99ea50eef086511eb204d43 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 12 Apr 2024 12:19:35 +0200 Subject: Clean up test code after move from container-search to model-integration --- .../java/ai/vespa/search/llm/LLMSearcherTest.java | 141 +++++++++++++-------- .../java/ai/vespa/search/llm/MockLLMClient.java | 82 ------------ 2 files changed, 87 insertions(+), 136 deletions(-) delete mode 100644 container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java (limited to 'container-search/src/test/java/ai') diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java index f5971aa55ff..3baa9715c34 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java @@ -3,14 +3,11 @@ package ai.vespa.search.llm; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.clients.LlmClientConfig; +import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.container.di.componentgraph.Provider; -import com.yahoo.container.jdisc.SecretStoreProvider; -import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; @@ -21,10 +18,13 @@ import org.junit.jupiter.api.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -37,10 +37,10 @@ public class LLMSearcherTest { @Test public void testLLMSelection() { - var llm1 = createLLMClient("mock1"); - var llm2 = createLLMClient("mock2"); + var client1 = createLLMClient("mock1"); + var client2 = createLLMClient("mock2"); var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); - var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2)); var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); assertEquals(1, result.getHitCount()); assertEquals("My id is mock2", getCompletion(result)); @@ -49,14 +49,15 @@ public class LLMSearcherTest { @Test public void testGeneration() { var client = createLLMClient(); - var searcher = createLLMSearcher(Map.of("mock", client)); + var searcher = createLLMSearcher(client); var params = Map.of("prompt", "why are ducks better than cats"); assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); } @Test public void testPrompting() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); // Prompt with prefix assertEquals("Ducks have adorable waddling walks.", @@ -73,7 +74,8 @@ public class LLMSearcherTest { @Test public void testPromptEvent() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "prompt", "why are ducks better than cats", "traceLevel", "1"); @@ -92,7 +94,8 @@ public class LLMSearcherTest { @Test public void testParameters() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "llm.prompt", "why are ducks better than cats", "llm.temperature", "1.0", @@ -109,16 +112,18 @@ public class LLMSearcherTest { "foo.maxTokens", "5" ); var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(config, client); assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); } @Test public void testApiKeyFromHeader() { var properties = Map.of("prompt", "why are ducks better than cats"); - var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); - assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); - assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + var client = createLLMClient(createApiKeyGenerator("a_valid_key")); + var searcher = createLLMSearcher(client); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key")); } @Test @@ -131,7 +136,8 @@ public class LLMSearcherTest { "llm.stream", "true", // ... but inference parameters says do it anyway "llm.prompt", "why are ducks better than cats?" ); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + var client = createLLMClient(executor); + var searcher = createLLMSearcher(config, client); Result result = runMockSearch(searcher, params); assertEquals(1, result.getHitCount()); @@ -164,6 +170,10 @@ public class LLMSearcherTest { return runMockSearch(searcher, parameters, null, ""); } + static Result runMockSearch(Searcher searcher, Map parameters, String apiKey) { + return runMockSearch(searcher, parameters, apiKey, "llm"); + } + static Result runMockSearch(Searcher searcher, Map parameters, String apiKey, String prefix) { Chain chain = new Chain<>(searcher); Execution execution = new Execution(chain, Execution.Context.createContextStub()); @@ -210,59 +220,42 @@ public class LLMSearcherTest { }; } - public static SecretStore createSecretStore(Map secrets) { - Provider secretStore = new Provider<>() { - public SecretStore get() { - return new SecretStore() { - public String getSecret(String key) { - return secrets.get(key); - } - public String getSecret(String key, int version) { - return secrets.get(key); - } - }; - } - public void deconstruct() { + private static BiFunction createApiKeyGenerator(String validApiKey) { + return (prompt, options) -> { + if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) { + throw new IllegalArgumentException("Invalid API key"); } + return "Ok"; }; - return secretStore.get(); } + static MockLLM createLLMClient() { + return new MockLLM(createGenerator(), null); + } - static MockLLMClient createLLMClient() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, null); + static MockLLM createLLMClient(String id) { + return new MockLLM(createIdGenerator(id), null); } - static MockLLMClient createLLMClient(String id) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createIdGenerator(id); - return new MockLLMClient(config, secretStore, generator, null); + static MockLLM createLLMClient(BiFunction generator) { + return new MockLLM(generator, null); } - static MockLLMClient createLLMClient(ExecutorService executor) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, executor); + static MockLLM createLLMClient(ExecutorService executor) { + return new MockLLM(createGenerator(), executor); } - static MockLLMClient createLLMClientWithoutSecretStore() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = new SecretStoreProvider(); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore.get(), generator, null); + private static Searcher createLLMSearcher(LanguageModel llm) { + return createLLMSearcher(Map.of("mock", llm)); } private static Searcher createLLMSearcher(Map llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); - ComponentRegistry models = new ComponentRegistry<>(); - llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); - models.freeze(); - return new LLMSearcher(config, models); + return createLLMSearcher(config, llms); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) { + return createLLMSearcher(config, Map.of("mock", llm)); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map llms) { @@ -272,4 +265,44 @@ public class LLMSearcherTest { return new LLMSearcher(config, models); } + private static class MockLLM implements LanguageModel { + + private final ExecutorService executor; + private final BiFunction generator; + + public MockLLM(BiFunction generator, ExecutorService executor) { + this.executor = executor; + this.generator = generator; + } + + @Override + public List complete(Prompt prompt, InferenceParameters params) { + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters params, + Consumer consumer) { + var completionFuture = new CompletableFuture(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i = 0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); + Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + return completionFuture; + } + + } + } diff --git a/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java b/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java deleted file mode 100644 index 4411e0cab70..00000000000 --- a/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.search.llm; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.clients.ConfigurableLanguageModel; -import ai.vespa.llm.clients.LlmClientConfig; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.function.BiFunction; -import java.util.function.Consumer; - -public class MockLLMClient extends ConfigurableLanguageModel { - - public final static String ACCEPTED_API_KEY = "sesame"; - - private final ExecutorService executor; - private final BiFunction generator; - - private Prompt lastPrompt; - - public MockLLMClient(LlmClientConfig config, - SecretStore secretStore, - BiFunction generator, - ExecutorService executor) { - super(config, secretStore); - this.generator = generator; - this.executor = executor; - } - - private void checkApiKey(InferenceParameters options) { - var apiKey = getApiKey(options); - if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { - throw new IllegalArgumentException("Invalid API key"); - } - } - - private void setPrompt(Prompt prompt) { - this.lastPrompt = prompt; - } - - public Prompt getPrompt() { - return this.lastPrompt; - } - - @Override - public List complete(Prompt prompt, InferenceParameters params) { - setApiKey(params); - checkApiKey(params); - setPrompt(prompt); - return List.of(Completion.from(this.generator.apply(prompt, params))); - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, - InferenceParameters params, - Consumer consumer) { - setPrompt(prompt); - var completionFuture = new CompletableFuture(); - var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization - - long sleep = 1; - executor.submit(() -> { - try { - for (int i=0; i < completions.length; ++i) { - String completion = (i > 0 ? " " : "") + completions[i]; - consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); - } - completionFuture.complete(Completion.FinishReason.stop); - } catch (InterruptedException e) { - // Do nothing - } - }); - - return completionFuture; - } - -} -- cgit v1.2.3