diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2024-04-15 14:51:27 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-15 14:51:27 +0200 |
commit | 7518d93961ac7c5c5da1cd41717d42f600dae647 (patch) | |
tree | 63e2811a56e6bf6b2bed5e65e15c98458cfb357f /container-search | |
parent | f7fd3dd205912c0100786e86d78b6de93d667bfa (diff) |
Revert "Lesters/add local llms"
Diffstat (limited to 'container-search')
12 files changed, 587 insertions, 243 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 07f0449e61a..e74fe22c588 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -7842,21 +7842,6 @@ "public static final int emptyDocsumsCode" ] }, - "com.yahoo.search.result.EventStream$ErrorEvent" : { - "superClass" : "com.yahoo.search.result.EventStream$Event", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)", - "public java.lang.String source()", - "public int code()", - "public java.lang.String message()", - "public com.yahoo.search.result.Hit asHit()" - ], - "fields" : [ ] - }, "com.yahoo.search.result.EventStream$Event" : { "superClass" : "com.yahoo.component.provider.ListenableFreezableClass", "interfaces" : [ @@ -9164,6 +9149,99 @@ ], "fields" : [ ] }, + "ai.vespa.llm.clients.ConfigurableLanguageModel" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public", + "abstract" + ], + "methods" : [ + "public void <init>()", + "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", + "protected void setApiKey(ai.vespa.llm.InferenceParameters)", + "protected java.lang.String getEndpoint()", + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void <init>()", + "public void <init>(ai.vespa.llm.clients.LlmClientConfig)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public ai.vespa.llm.clients.LlmClientConfig build()" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Producer" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig" : { + "superClass" : "com.yahoo.config.ConfigInstance", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)", + "public java.lang.String apiKeySecretName()", + "public java.lang.String endpoint()" + ], + "fields" : [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, + "ai.vespa.llm.clients.OpenAI" : { + "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, "ai.vespa.search.llm.LLMSearcher" : { "superClass" : "com.yahoo.search.Searcher", "interfaces" : [ ], diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..761fdf0af93 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + if (endpoint != null && ! endpoint.isEmpty()) { + params.setEndpoint(endpoint); + } + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..82e19d47c92 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,48 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer<Completion> consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index f565315b775..860fc69af91 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -20,7 +20,6 @@ import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.List; -import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -84,41 +83,27 @@ public class LLMSearcher extends Searcher { protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config - try { - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); - } catch (RejectedExecutionException e) { - return new Result(query, new ErrorMessage(429, e.getMessage())); - } - } - - private boolean shouldAddPrompt(Query query) { - return query.getTrace().getLevel() >= 1; - } - - private boolean shouldAddTokenStats(Query query) { - return query.getTrace().getLevel() >= 1; + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - final EventStream eventStream = new EventStream(); + EventStream eventStream = new EventStream(); - if (shouldAddPrompt(query)) { + if (query.getTrace().getLevel() >= 1) { eventStream.add(prompt.asString(), "prompt"); } - final TokenStats tokenStats = new TokenStats(); - languageModel.completeAsync(prompt, options, completion -> { - tokenStats.onToken(); - handleCompletion(eventStream, completion); + languageModel.completeAsync(prompt, options, token -> { + eventStream.add(token.text()); }).exceptionally(exception -> { - handleException(eventStream, exception); + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { - tokenStats.onCompletion(); - if (shouldAddTokenStats(query)) { - eventStream.add(tokenStats.report(), "stats"); - } eventStream.markComplete(); }); @@ -127,26 +112,10 @@ public class LLMSearcher extends Searcher { return new Result(query, hitGroup); } - private void handleCompletion(EventStream eventStream, Completion completion) { - if (completion.finishReason() == Completion.FinishReason.error) { - eventStream.add(completion.text(), "error"); - } else { - eventStream.add(completion.text()); - } - } - - private void handleException(EventStream eventStream, Throwable exception) { - int errorCode = 400; - if (exception instanceof LanguageModelException languageModelException) { - errorCode = languageModelException.code(); - } - eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); - } - private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { EventStream eventStream = new EventStream(); - if (shouldAddPrompt(query)) { + if (query.getTrace().getLevel() >= 1) { eventStream.add(prompt.asString(), "prompt"); } @@ -200,35 +169,4 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } - private static class TokenStats { - - private long start; - private long timeToFirstToken; - private long timeToLastToken; - private long tokens = 0; - - TokenStats() { - start = System.currentTimeMillis(); - } - - void onToken() { - if (tokens == 0) { - timeToFirstToken = System.currentTimeMillis() - start; - } - tokens++; - } - - void onCompletion() { - timeToLastToken = System.currentTimeMillis() - start; - } - - String report() { - return "Time to first token: " + timeToFirstToken + " ms, " + - "Generation time: " + timeToLastToken + " ms, " + - "Generated tokens: " + tokens + " " + - String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0)); - } - - } - } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 88a1e6c1485..83ae349f5a0 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -64,17 +64,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { @Override public void data(Data data) throws IOException { - if (data instanceof EventStream.ErrorEvent error) { - generator.writeRaw("event: error\n"); - generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField("source", error.source()); - generator.writeNumberField("error", error.code()); - generator.writeStringField("message", error.message()); - generator.writeEndObject(); - generator.writeRaw("\n\n"); - generator.flush(); - } else if (data instanceof EventStream.Event event) { + if (data instanceof EventStream.Event event) { if (RENDER_EVENT_HEADER) { generator.writeRaw("event: " + event.type() + "\n"); } @@ -85,6 +75,19 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { generator.writeRaw("\n\n"); generator.flush(); } + else if (data instanceof ErrorHit) { + for (ErrorMessage error : ((ErrorHit) data).errors()) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.getSource()); + generator.writeNumberField("error", error.getCode()); + generator.writeStringField("message", error.getMessage()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + } // Todo: support other types of data such as search results (hits), timing and trace } diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index 8e6f7977d55..b393a91e6d0 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> { } public void error(String source, ErrorMessage message) { - incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message)); + incoming().add(new DefaultErrorHit(source, message)); } public void markComplete() { @@ -117,38 +117,4 @@ public class EventStream extends Hit implements DataList<Data> { } - public static class ErrorEvent extends Event { - - private final String source; - private final ErrorMessage message; - - public ErrorEvent(int eventNumber, String source, ErrorMessage message) { - super(eventNumber, message.getMessage(), "error"); - this.source = source; - this.message = message; - } - - public String source() { - return source; - } - - public int code() { - return message.getCode(); - } - - public String message() { - return message.getMessage(); - } - - @Override - public Hit asHit() { - Hit hit = super.asHit(); - hit.setField("source", source); - hit.setField("code", message.getCode()); - return hit; - } - - - } - } diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def new file mode 100755 index 00000000000..0866459166a --- /dev/null +++ b/container-search/src/main/resources/configdefinitions/llm-client.def @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The name of the secret containing the api key +apiKeySecretName string default="" + +# Endpoint for LLM client - if not set reverts to default for client +endpoint string default="" diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java new file mode 100644 index 00000000000..35d5cfd3855 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java @@ -0,0 +1,174 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.container.jdisc.secretstore.SecretStore; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigurableLanguageModelTest { + + @Test + public void testSyncGeneration() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); + assertEquals(1, result.size()); + assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var sb = new StringBuilder(); + try { + var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { + sb.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + } finally { + executor.shutdownNow(); + } + + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + @Test + public void testInferenceParameters() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); + var result = createLLM().complete(prompt, params); + assertEquals("Random text about ducks", result.get(0).text()); + } + + @Test + public void testNoApiKey() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key", null); + var secrets = createSecretStore(Map.of()); + assertThrows(IllegalArgumentException.class, () -> { + createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); + }); + } + + @Test + public void testApiKeyFromSecretStore() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key-in-secret-store", null); + var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); + assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); + } + + private static String lookupParameter(String parameter, Map<String, String> params) { + return params.get(parameter); + } + + private static InferenceParameters inferenceParams() { + return new InferenceParameters(s -> lookupParameter(s, Map.of())); + } + + private static InferenceParameters inferenceParams(Map<String, String> params) { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); + } + + private static InferenceParameters inferenceParamsWithDefaultKey() { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of())); + } + + private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { + var config = new LlmClientConfig.Builder(); + if (apiKeySecretName != null) { + config.apiKeySecretName(apiKeySecretName); + } + if (endpoint != null) { + config.endpoint(endpoint); + } + return config.build(); + } + + public static SecretStore createSecretStore(Map<String, String> secrets) { + Provider<SecretStore> secretStore = new Provider<>() { + public SecretStore get() { + return new SecretStore() { + public String getSecret(String key) { + return secrets.get(key); + } + public String getSecret(String key, int version) { + return secrets.get(key); + } + }; + } + public void deconstruct() { + } + }; + return secretStore.get(); + } + + public static BiFunction<Prompt, InferenceParameters, String> createGenerator() { + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; + } + + private static MockLLMClient createLLM() { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, null); + } + + private static MockLLMClient createLLM(ExecutorService executor) { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { + var generator = createGenerator(); + var secretStore = new SecretStoreProvider(); // throws exception on use + return createLLM(config, generator, secretStore.get(), executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction<Prompt, InferenceParameters, String> generator, + SecretStore secretStore) { + return createLLM(config, generator, secretStore, null); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction<Prompt, InferenceParameters, String> generator, + SecretStore secretStore, + ExecutorService executor) { + return new MockLLMClient(config, secretStore, generator, executor); + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java new file mode 100644 index 00000000000..4d0073f1cbe --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java @@ -0,0 +1,80 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class MockLLMClient extends ConfigurableLanguageModel { + + public final static String ACCEPTED_API_KEY = "sesame"; + + private final ExecutorService executor; + private final BiFunction<Prompt, InferenceParameters, String> generator; + + private Prompt lastPrompt; + + public MockLLMClient(LlmClientConfig config, + SecretStore secretStore, + BiFunction<Prompt, InferenceParameters, String> generator, + ExecutorService executor) { + super(config, secretStore); + this.generator = generator; + this.executor = executor; + } + + private void checkApiKey(InferenceParameters options) { + var apiKey = getApiKey(options); + if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { + throw new IllegalArgumentException("Invalid API key"); + } + } + + private void setPrompt(Prompt prompt) { + this.lastPrompt = prompt; + } + + public Prompt getPrompt() { + return this.lastPrompt; + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters params) { + setApiKey(params); + checkApiKey(params); + setPrompt(prompt); + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters params, + Consumer<Completion> consumer) { + setPrompt(prompt); + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i=0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + + return completionFuture; + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java new file mode 100644 index 00000000000..57339f6ad49 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -0,0 +1,35 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.jdisc.SecretStoreProvider; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +public class OpenAITest { + + private static final String apiKey = "<your-api-key>"; + + @Test + @Disabled + public void testOpenAIGeneration() { + var config = new LlmClientConfig.Builder().build(); + var openai = new OpenAI(config, new SecretStoreProvider().get()); + var options = Map.of( + "maxTokens", "10" + ); + + var prompt = StringPrompt.from("why are ducks better than cats?"); + var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { + System.out.print(completion.text()); + }).exceptionally(exception -> { + System.out.println("Error: " + exception); + return null; + }); + future.join(); + } + +} diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java index 3baa9715c34..1efcf1c736a 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java @@ -3,11 +3,14 @@ package ai.vespa.search.llm; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.clients.ConfigurableLanguageModelTest; +import ai.vespa.llm.clients.LlmClientConfig; +import ai.vespa.llm.clients.MockLLMClient; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; @@ -17,14 +20,10 @@ import org.junit.jupiter.api.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BiFunction; -import java.util.function.Consumer; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -37,10 +36,10 @@ public class LLMSearcherTest { @Test public void testLLMSelection() { - var client1 = createLLMClient("mock1"); - var client2 = createLLMClient("mock2"); + var llm1 = createLLMClient("mock1"); + var llm2 = createLLMClient("mock2"); var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); - var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2)); + var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); assertEquals(1, result.getHitCount()); assertEquals("My id is mock2", getCompletion(result)); @@ -48,16 +47,14 @@ public class LLMSearcherTest { @Test public void testGeneration() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); var params = Map.of("prompt", "why are ducks better than cats"); assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); } @Test public void testPrompting() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); // Prompt with prefix assertEquals("Ducks have adorable waddling walks.", @@ -74,8 +71,7 @@ public class LLMSearcherTest { @Test public void testPromptEvent() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); var params = Map.of( "prompt", "why are ducks better than cats", "traceLevel", "1"); @@ -94,8 +90,7 @@ public class LLMSearcherTest { @Test public void testParameters() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); var params = Map.of( "llm.prompt", "why are ducks better than cats", "llm.temperature", "1.0", @@ -112,18 +107,16 @@ public class LLMSearcherTest { "foo.maxTokens", "5" ); var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); - var client = createLLMClient(); - var searcher = createLLMSearcher(config, client); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); } @Test public void testApiKeyFromHeader() { var properties = Map.of("prompt", "why are ducks better than cats"); - var client = createLLMClient(createApiKeyGenerator("a_valid_key")); - var searcher = createLLMSearcher(client); - assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key")); - assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key")); + var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); } @Test @@ -136,8 +129,7 @@ public class LLMSearcherTest { "llm.stream", "true", // ... but inference parameters says do it anyway "llm.prompt", "why are ducks better than cats?" ); - var client = createLLMClient(executor); - var searcher = createLLMSearcher(config, client); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); Result result = runMockSearch(searcher, params); assertEquals(1, result.getHitCount()); @@ -170,10 +162,6 @@ public class LLMSearcherTest { return runMockSearch(searcher, parameters, null, ""); } - static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) { - return runMockSearch(searcher, parameters, apiKey, "llm"); - } - static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { Chain<Searcher> chain = new Chain<>(searcher); Execution execution = new Execution(chain, Execution.Context.createContextStub()); @@ -203,59 +191,43 @@ public class LLMSearcherTest { } private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { - return (prompt, options) -> { - String answer = "I have no opinion on the matter"; - if (prompt.asString().contains("ducks")) { - answer = "Ducks have adorable waddling walks."; - var temperature = options.getDouble("temperature"); - if (temperature.isPresent() && temperature.get() > 0.5) { - answer = "Random text about ducks vs cats that makes no sense whatsoever."; - } - } - var maxTokens = options.getInt("maxTokens"); - if (maxTokens.isPresent()) { - return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); - } - return answer; - }; + return ConfigurableLanguageModelTest.createGenerator(); } - private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) { - return (prompt, options) -> { - if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) { - throw new IllegalArgumentException("Invalid API key"); - } - return "Ok"; - }; - } - - static MockLLM createLLMClient() { - return new MockLLM(createGenerator(), null); + static MockLLMClient createLLMClient() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, null); } - static MockLLM createLLMClient(String id) { - return new MockLLM(createIdGenerator(id), null); + static MockLLMClient createLLMClient(String id) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createIdGenerator(id); + return new MockLLMClient(config, secretStore, generator, null); } - static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) { - return new MockLLM(generator, null); + static MockLLMClient createLLMClient(ExecutorService executor) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, executor); } - static MockLLM createLLMClient(ExecutorService executor) { - return new MockLLM(createGenerator(), executor); - } - - private static Searcher createLLMSearcher(LanguageModel llm) { - return createLLMSearcher(Map.of("mock", llm)); + static MockLLMClient createLLMClientWithoutSecretStore() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = new SecretStoreProvider(); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore.get(), generator, null); } private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); - return createLLMSearcher(config, llms); - } - - private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) { - return createLLMSearcher(config, Map.of("mock", llm)); + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcher(config, models); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { @@ -265,44 +237,4 @@ public class LLMSearcherTest { return new LLMSearcher(config, models); } - private static class MockLLM implements LanguageModel { - - private final ExecutorService executor; - private final BiFunction<Prompt, InferenceParameters, String> generator; - - public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) { - this.executor = executor; - this.generator = generator; - } - - @Override - public List<Completion> complete(Prompt prompt, InferenceParameters params) { - return List.of(Completion.from(this.generator.apply(prompt, params))); - } - - @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, - InferenceParameters params, - Consumer<Completion> consumer) { - var completionFuture = new CompletableFuture<Completion.FinishReason>(); - var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization - - long sleep = 1; - executor.submit(() -> { - try { - for (int i = 0; i < completions.length; ++i) { - String completion = (i > 0 ? " " : "") + completions[i]; - consumer.accept(Completion.from(completion, Completion.FinishReason.none)); - Thread.sleep(sleep); - } - completionFuture.complete(Completion.FinishReason.stop); - } catch (InterruptedException e) { - // Do nothing - } - }); - return completionFuture; - } - - } - } |