aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-12 12:19:35 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-12 12:19:35 +0200
commite11438c6335038f6d99ea50eef086511eb204d43 (patch)
treeafa8afdc747abd3a3a75cf873d95bad52b8d3b2e /container-search/src/test/java/ai
parent446333e197bfb274ae48c173f25c4ad7e8d76a0f (diff)
Clean up test code after move from container-search to model-integration
Diffstat (limited to 'container-search/src/test/java/ai')
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java141
-rw-r--r--container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java82
2 files changed, 87 insertions, 136 deletions
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
index f5971aa55ff..3baa9715c34 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
@@ -3,14 +3,11 @@ package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.clients.LlmClientConfig;
+import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.ComponentId;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.provider.ComponentRegistry;
-import com.yahoo.container.di.componentgraph.Provider;
-import com.yahoo.container.jdisc.SecretStoreProvider;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
@@ -21,10 +18,13 @@ import org.junit.jupiter.api.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import java.util.List;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -37,10 +37,10 @@ public class LLMSearcherTest {
@Test
public void testLLMSelection() {
- var llm1 = createLLMClient("mock1");
- var llm2 = createLLMClient("mock2");
+ var client1 = createLLMClient("mock1");
+ var client2 = createLLMClient("mock2");
var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
- var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
+ var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2));
var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
assertEquals(1, result.getHitCount());
assertEquals("My id is mock2", getCompletion(result));
@@ -49,14 +49,15 @@ public class LLMSearcherTest {
@Test
public void testGeneration() {
var client = createLLMClient();
- var searcher = createLLMSearcher(Map.of("mock", client));
+ var searcher = createLLMSearcher(client);
var params = Map.of("prompt", "why are ducks better than cats");
assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testPrompting() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
// Prompt with prefix
assertEquals("Ducks have adorable waddling walks.",
@@ -73,7 +74,8 @@ public class LLMSearcherTest {
@Test
public void testPromptEvent() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of(
"prompt", "why are ducks better than cats",
"traceLevel", "1");
@@ -92,7 +94,8 @@ public class LLMSearcherTest {
@Test
public void testParameters() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of(
"llm.prompt", "why are ducks better than cats",
"llm.temperature", "1.0",
@@ -109,16 +112,18 @@ public class LLMSearcherTest {
"foo.maxTokens", "5"
);
var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
- var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(config, client);
assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testApiKeyFromHeader() {
var properties = Map.of("prompt", "why are ducks better than cats");
- var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
- assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
- assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
+ var client = createLLMClient(createApiKeyGenerator("a_valid_key"));
+ var searcher = createLLMSearcher(client);
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key"));
}
@Test
@@ -131,7 +136,8 @@ public class LLMSearcherTest {
"llm.stream", "true", // ... but inference parameters says do it anyway
"llm.prompt", "why are ducks better than cats?"
);
- var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
+ var client = createLLMClient(executor);
+ var searcher = createLLMSearcher(config, client);
Result result = runMockSearch(searcher, params);
assertEquals(1, result.getHitCount());
@@ -164,6 +170,10 @@ public class LLMSearcherTest {
return runMockSearch(searcher, parameters, null, "");
}
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) {
+ return runMockSearch(searcher, parameters, apiKey, "llm");
+ }
+
static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
Chain<Searcher> chain = new Chain<>(searcher);
Execution execution = new Execution(chain, Execution.Context.createContextStub());
@@ -210,59 +220,42 @@ public class LLMSearcherTest {
};
}
- public static SecretStore createSecretStore(Map<String, String> secrets) {
- Provider<SecretStore> secretStore = new Provider<>() {
- public SecretStore get() {
- return new SecretStore() {
- public String getSecret(String key) {
- return secrets.get(key);
- }
- public String getSecret(String key, int version) {
- return secrets.get(key);
- }
- };
- }
- public void deconstruct() {
+ private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) {
+ return (prompt, options) -> {
+ if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) {
+ throw new IllegalArgumentException("Invalid API key");
}
+ return "Ok";
};
- return secretStore.get();
}
+ static MockLLM createLLMClient() {
+ return new MockLLM(createGenerator(), null);
+ }
- static MockLLMClient createLLMClient() {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore, generator, null);
+ static MockLLM createLLMClient(String id) {
+ return new MockLLM(createIdGenerator(id), null);
}
- static MockLLMClient createLLMClient(String id) {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createIdGenerator(id);
- return new MockLLMClient(config, secretStore, generator, null);
+ static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) {
+ return new MockLLM(generator, null);
}
- static MockLLMClient createLLMClient(ExecutorService executor) {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore, generator, executor);
+ static MockLLM createLLMClient(ExecutorService executor) {
+ return new MockLLM(createGenerator(), executor);
}
- static MockLLMClient createLLMClientWithoutSecretStore() {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = new SecretStoreProvider();
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore.get(), generator, null);
+ private static Searcher createLLMSearcher(LanguageModel llm) {
+ return createLLMSearcher(Map.of("mock", llm));
}
private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
- ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
- llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
- models.freeze();
- return new LLMSearcher(config, models);
+ return createLLMSearcher(config, llms);
+ }
+
+ private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) {
+ return createLLMSearcher(config, Map.of("mock", llm));
}
private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
@@ -272,4 +265,44 @@ public class LLMSearcherTest {
return new LLMSearcher(config, models);
}
+ private static class MockLLM implements LanguageModel {
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) {
+ this.executor = executor;
+ this.generator = generator;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i = 0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none));
+ Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+ return completionFuture;
+ }
+
+ }
+
}
diff --git a/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java b/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java
deleted file mode 100644
index 4411e0cab70..00000000000
--- a/container-search/src/test/java/ai/vespa/search/llm/MockLLMClient.java
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.search.llm;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.clients.ConfigurableLanguageModel;
-import ai.vespa.llm.clients.LlmClientConfig;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.function.BiFunction;
-import java.util.function.Consumer;
-
-public class MockLLMClient extends ConfigurableLanguageModel {
-
- public final static String ACCEPTED_API_KEY = "sesame";
-
- private final ExecutorService executor;
- private final BiFunction<Prompt, InferenceParameters, String> generator;
-
- private Prompt lastPrompt;
-
- public MockLLMClient(LlmClientConfig config,
- SecretStore secretStore,
- BiFunction<Prompt, InferenceParameters, String> generator,
- ExecutorService executor) {
- super(config, secretStore);
- this.generator = generator;
- this.executor = executor;
- }
-
- private void checkApiKey(InferenceParameters options) {
- var apiKey = getApiKey(options);
- if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) {
- throw new IllegalArgumentException("Invalid API key");
- }
- }
-
- private void setPrompt(Prompt prompt) {
- this.lastPrompt = prompt;
- }
-
- public Prompt getPrompt() {
- return this.lastPrompt;
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters params) {
- setApiKey(params);
- checkApiKey(params);
- setPrompt(prompt);
- return List.of(Completion.from(this.generator.apply(prompt, params)));
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters params,
- Consumer<Completion> consumer) {
- setPrompt(prompt);
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
-
- long sleep = 1;
- executor.submit(() -> {
- try {
- for (int i=0; i < completions.length; ++i) {
- String completion = (i > 0 ? " " : "") + completions[i];
- consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep);
- }
- completionFuture.complete(Completion.FinishReason.stop);
- } catch (InterruptedException e) {
- // Do nothing
- }
- });
-
- return completionFuture;
- }
-
-}