aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test
diff options
context:
space:
mode:
authorHarald Musum <musum@vespa.ai>2024-04-15 20:44:10 +0200
committerGitHub <noreply@github.com>2024-04-15 20:44:10 +0200
commited62b750494822cc67a328390178754512baf032 (patch)
tree3f94e22d38d63c456a3bb202b4f3787ecff5ded6 /container-search/src/test
parent44a866c0d648543c04567503990c03c36403d86d (diff)
Revert "Lesters/add local llms 2"
Diffstat (limited to 'container-search/src/test')
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java174
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java80
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java35
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java150
4 files changed, 330 insertions, 109 deletions
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
new file mode 100644
index 00000000000..35d5cfd3855
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
@@ -0,0 +1,174 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ConfigurableLanguageModelTest {
+
+ @Test
+ public void testSyncGeneration() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
+ assertEquals(1, result.size());
+ assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var sb = new StringBuilder();
+ try {
+ var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
+ sb.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+ } finally {
+ executor.shutdownNow();
+ }
+
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ @Test
+ public void testInferenceParameters() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
+ var result = createLLM().complete(prompt, params);
+ assertEquals("Random text about ducks", result.get(0).text());
+ }
+
+ @Test
+ public void testNoApiKey() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key", null);
+ var secrets = createSecretStore(Map.of());
+ assertThrows(IllegalArgumentException.class, () -> {
+ createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
+ });
+ }
+
+ @Test
+ public void testApiKeyFromSecretStore() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key-in-secret-store", null);
+ var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
+ assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
+ }
+
+ private static String lookupParameter(String parameter, Map<String, String> params) {
+ return params.get(parameter);
+ }
+
+ private static InferenceParameters inferenceParams() {
+ return new InferenceParameters(s -> lookupParameter(s, Map.of()));
+ }
+
+ private static InferenceParameters inferenceParams(Map<String, String> params) {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
+ }
+
+ private static InferenceParameters inferenceParamsWithDefaultKey() {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of()));
+ }
+
+ private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) {
+ var config = new LlmClientConfig.Builder();
+ if (apiKeySecretName != null) {
+ config.apiKeySecretName(apiKeySecretName);
+ }
+ if (endpoint != null) {
+ config.endpoint(endpoint);
+ }
+ return config.build();
+ }
+
+ public static SecretStore createSecretStore(Map<String, String> secrets) {
+ Provider<SecretStore> secretStore = new Provider<>() {
+ public SecretStore get() {
+ return new SecretStore() {
+ public String getSecret(String key) {
+ return secrets.get(key);
+ }
+ public String getSecret(String key, int version) {
+ return secrets.get(key);
+ }
+ };
+ }
+ public void deconstruct() {
+ }
+ };
+ return secretStore.get();
+ }
+
+ public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
+ }
+
+ private static MockLLMClient createLLM() {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, null);
+ }
+
+ private static MockLLMClient createLLM(ExecutorService executor) {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
+ var generator = createGenerator();
+ var secretStore = new SecretStoreProvider(); // throws exception on use
+ return createLLM(config, generator, secretStore.get(), executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore) {
+ return createLLM(config, generator, secretStore, null);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore,
+ ExecutorService executor) {
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
new file mode 100644
index 00000000000..4d0073f1cbe
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
@@ -0,0 +1,80 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+
+public class MockLLMClient extends ConfigurableLanguageModel {
+
+ public final static String ACCEPTED_API_KEY = "sesame";
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ private Prompt lastPrompt;
+
+ public MockLLMClient(LlmClientConfig config,
+ SecretStore secretStore,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ ExecutorService executor) {
+ super(config, secretStore);
+ this.generator = generator;
+ this.executor = executor;
+ }
+
+ private void checkApiKey(InferenceParameters options) {
+ var apiKey = getApiKey(options);
+ if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) {
+ throw new IllegalArgumentException("Invalid API key");
+ }
+ }
+
+ private void setPrompt(Prompt prompt) {
+ this.lastPrompt = prompt;
+ }
+
+ public Prompt getPrompt() {
+ return this.lastPrompt;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ setApiKey(params);
+ checkApiKey(params);
+ setPrompt(prompt);
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ setPrompt(prompt);
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i=0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+
+ return completionFuture;
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
new file mode 100644
index 00000000000..57339f6ad49
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
@@ -0,0 +1,35 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+public class OpenAITest {
+
+ private static final String apiKey = "<your-api-key>";
+
+ @Test
+ @Disabled
+ public void testOpenAIGeneration() {
+ var config = new LlmClientConfig.Builder().build();
+ var openai = new OpenAI(config, new SecretStoreProvider().get());
+ var options = Map.of(
+ "maxTokens", "10"
+ );
+
+ var prompt = StringPrompt.from("why are ducks better than cats?");
+ var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
+ System.out.print(completion.text());
+ }).exceptionally(exception -> {
+ System.out.println("Error: " + exception);
+ return null;
+ });
+ future.join();
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
index 3baa9715c34..1efcf1c736a 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
@@ -3,11 +3,14 @@ package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
+import ai.vespa.llm.clients.LlmClientConfig;
+import ai.vespa.llm.clients.MockLLMClient;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.ComponentId;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.container.jdisc.SecretStoreProvider;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
@@ -17,14 +20,10 @@ import org.junit.jupiter.api.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.List;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
-import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -37,10 +36,10 @@ public class LLMSearcherTest {
@Test
public void testLLMSelection() {
- var client1 = createLLMClient("mock1");
- var client2 = createLLMClient("mock2");
+ var llm1 = createLLMClient("mock1");
+ var llm2 = createLLMClient("mock2");
var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
- var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2));
+ var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
assertEquals(1, result.getHitCount());
assertEquals("My id is mock2", getCompletion(result));
@@ -48,16 +47,14 @@ public class LLMSearcherTest {
@Test
public void testGeneration() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
var params = Map.of("prompt", "why are ducks better than cats");
assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testPrompting() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
// Prompt with prefix
assertEquals("Ducks have adorable waddling walks.",
@@ -74,8 +71,7 @@ public class LLMSearcherTest {
@Test
public void testPromptEvent() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
var params = Map.of(
"prompt", "why are ducks better than cats",
"traceLevel", "1");
@@ -94,8 +90,7 @@ public class LLMSearcherTest {
@Test
public void testParameters() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
var params = Map.of(
"llm.prompt", "why are ducks better than cats",
"llm.temperature", "1.0",
@@ -112,18 +107,16 @@ public class LLMSearcherTest {
"foo.maxTokens", "5"
);
var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
- var client = createLLMClient();
- var searcher = createLLMSearcher(config, client);
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testApiKeyFromHeader() {
var properties = Map.of("prompt", "why are ducks better than cats");
- var client = createLLMClient(createApiKeyGenerator("a_valid_key"));
- var searcher = createLLMSearcher(client);
- assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key"));
- assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key"));
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
}
@Test
@@ -136,8 +129,7 @@ public class LLMSearcherTest {
"llm.stream", "true", // ... but inference parameters says do it anyway
"llm.prompt", "why are ducks better than cats?"
);
- var client = createLLMClient(executor);
- var searcher = createLLMSearcher(config, client);
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
Result result = runMockSearch(searcher, params);
assertEquals(1, result.getHitCount());
@@ -170,10 +162,6 @@ public class LLMSearcherTest {
return runMockSearch(searcher, parameters, null, "");
}
- static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) {
- return runMockSearch(searcher, parameters, apiKey, "llm");
- }
-
static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
Chain<Searcher> chain = new Chain<>(searcher);
Execution execution = new Execution(chain, Execution.Context.createContextStub());
@@ -203,59 +191,43 @@ public class LLMSearcherTest {
}
private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return (prompt, options) -> {
- String answer = "I have no opinion on the matter";
- if (prompt.asString().contains("ducks")) {
- answer = "Ducks have adorable waddling walks.";
- var temperature = options.getDouble("temperature");
- if (temperature.isPresent() && temperature.get() > 0.5) {
- answer = "Random text about ducks vs cats that makes no sense whatsoever.";
- }
- }
- var maxTokens = options.getInt("maxTokens");
- if (maxTokens.isPresent()) {
- return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
- }
- return answer;
- };
+ return ConfigurableLanguageModelTest.createGenerator();
}
- private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) {
- return (prompt, options) -> {
- if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) {
- throw new IllegalArgumentException("Invalid API key");
- }
- return "Ok";
- };
- }
-
- static MockLLM createLLMClient() {
- return new MockLLM(createGenerator(), null);
+ static MockLLMClient createLLMClient() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, null);
}
- static MockLLM createLLMClient(String id) {
- return new MockLLM(createIdGenerator(id), null);
+ static MockLLMClient createLLMClient(String id) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createIdGenerator(id);
+ return new MockLLMClient(config, secretStore, generator, null);
}
- static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) {
- return new MockLLM(generator, null);
+ static MockLLMClient createLLMClient(ExecutorService executor) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, executor);
}
- static MockLLM createLLMClient(ExecutorService executor) {
- return new MockLLM(createGenerator(), executor);
- }
-
- private static Searcher createLLMSearcher(LanguageModel llm) {
- return createLLMSearcher(Map.of("mock", llm));
+ static MockLLMClient createLLMClientWithoutSecretStore() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = new SecretStoreProvider();
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore.get(), generator, null);
}
private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
- return createLLMSearcher(config, llms);
- }
-
- private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) {
- return createLLMSearcher(config, Map.of("mock", llm));
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcher(config, models);
}
private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
@@ -265,44 +237,4 @@ public class LLMSearcherTest {
return new LLMSearcher(config, models);
}
- private static class MockLLM implements LanguageModel {
-
- private final ExecutorService executor;
- private final BiFunction<Prompt, InferenceParameters, String> generator;
-
- public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) {
- this.executor = executor;
- this.generator = generator;
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters params) {
- return List.of(Completion.from(this.generator.apply(prompt, params)));
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters params,
- Consumer<Completion> consumer) {
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
-
- long sleep = 1;
- executor.submit(() -> {
- try {
- for (int i = 0; i < completions.length; ++i) {
- String completion = (i > 0 ? " " : "") + completions[i];
- consumer.accept(Completion.from(completion, Completion.FinishReason.none));
- Thread.sleep(sleep);
- }
- completionFuture.complete(Completion.FinishReason.stop);
- } catch (InterruptedException e) {
- // Do nothing
- }
- });
- return completionFuture;
- }
-
- }
-
}