aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java174
1 files changed, 174 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
new file mode 100644
index 00000000000..35d5cfd3855
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
@@ -0,0 +1,174 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ConfigurableLanguageModelTest {
+
+ @Test
+ public void testSyncGeneration() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
+ assertEquals(1, result.size());
+ assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var sb = new StringBuilder();
+ try {
+ var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
+ sb.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+ } finally {
+ executor.shutdownNow();
+ }
+
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ @Test
+ public void testInferenceParameters() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
+ var result = createLLM().complete(prompt, params);
+ assertEquals("Random text about ducks", result.get(0).text());
+ }
+
+ @Test
+ public void testNoApiKey() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key", null);
+ var secrets = createSecretStore(Map.of());
+ assertThrows(IllegalArgumentException.class, () -> {
+ createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
+ });
+ }
+
+ @Test
+ public void testApiKeyFromSecretStore() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key-in-secret-store", null);
+ var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
+ assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
+ }
+
+ private static String lookupParameter(String parameter, Map<String, String> params) {
+ return params.get(parameter);
+ }
+
+ private static InferenceParameters inferenceParams() {
+ return new InferenceParameters(s -> lookupParameter(s, Map.of()));
+ }
+
+ private static InferenceParameters inferenceParams(Map<String, String> params) {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
+ }
+
+ private static InferenceParameters inferenceParamsWithDefaultKey() {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of()));
+ }
+
+ private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) {
+ var config = new LlmClientConfig.Builder();
+ if (apiKeySecretName != null) {
+ config.apiKeySecretName(apiKeySecretName);
+ }
+ if (endpoint != null) {
+ config.endpoint(endpoint);
+ }
+ return config.build();
+ }
+
+ public static SecretStore createSecretStore(Map<String, String> secrets) {
+ Provider<SecretStore> secretStore = new Provider<>() {
+ public SecretStore get() {
+ return new SecretStore() {
+ public String getSecret(String key) {
+ return secrets.get(key);
+ }
+ public String getSecret(String key, int version) {
+ return secrets.get(key);
+ }
+ };
+ }
+ public void deconstruct() {
+ }
+ };
+ return secretStore.get();
+ }
+
+ public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
+ }
+
+ private static MockLLMClient createLLM() {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, null);
+ }
+
+ private static MockLLMClient createLLM(ExecutorService executor) {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
+ var generator = createGenerator();
+ var secretStore = new SecretStoreProvider(); // throws exception on use
+ return createLLM(config, generator, secretStore.get(), executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore) {
+ return createLLM(config, generator, secretStore, null);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore,
+ ExecutorService executor) {
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+}