diff options
author | Harald Musum <musum@vespa.ai> | 2024-04-15 20:44:10 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-15 20:44:10 +0200 |
commit | ed62b750494822cc67a328390178754512baf032 (patch) | |
tree | 3f94e22d38d63c456a3bb202b4f3787ecff5ded6 /container-search/src/main | |
parent | 44a866c0d648543c04567503990c03c36403d86d (diff) |
Revert "Lesters/add local llms 2"
Diffstat (limited to 'container-search/src/main')
7 files changed, 164 insertions, 119 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..761fdf0af93 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + if (endpoint != null && ! endpoint.isEmpty()) { + params.setEndpoint(endpoint); + } + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..82e19d47c92 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,48 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer<Completion> consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index f565315b775..860fc69af91 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -20,7 +20,6 @@ import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.List; -import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -84,41 +83,27 @@ public class LLMSearcher extends Searcher { protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config - try { - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); - } catch (RejectedExecutionException e) { - return new Result(query, new ErrorMessage(429, e.getMessage())); - } - } - - private boolean shouldAddPrompt(Query query) { - return query.getTrace().getLevel() >= 1; - } - - private boolean shouldAddTokenStats(Query query) { - return query.getTrace().getLevel() >= 1; + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - final EventStream eventStream = new EventStream(); + EventStream eventStream = new EventStream(); - if (shouldAddPrompt(query)) { + if (query.getTrace().getLevel() >= 1) { eventStream.add(prompt.asString(), "prompt"); } - final TokenStats tokenStats = new TokenStats(); - languageModel.completeAsync(prompt, options, completion -> { - tokenStats.onToken(); - handleCompletion(eventStream, completion); + languageModel.completeAsync(prompt, options, token -> { + eventStream.add(token.text()); }).exceptionally(exception -> { - handleException(eventStream, exception); + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { - tokenStats.onCompletion(); - if (shouldAddTokenStats(query)) { - eventStream.add(tokenStats.report(), "stats"); - } eventStream.markComplete(); }); @@ -127,26 +112,10 @@ public class LLMSearcher extends Searcher { return new Result(query, hitGroup); } - private void handleCompletion(EventStream eventStream, Completion completion) { - if (completion.finishReason() == Completion.FinishReason.error) { - eventStream.add(completion.text(), "error"); - } else { - eventStream.add(completion.text()); - } - } - - private void handleException(EventStream eventStream, Throwable exception) { - int errorCode = 400; - if (exception instanceof LanguageModelException languageModelException) { - errorCode = languageModelException.code(); - } - eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); - } - private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { EventStream eventStream = new EventStream(); - if (shouldAddPrompt(query)) { + if (query.getTrace().getLevel() >= 1) { eventStream.add(prompt.asString(), "prompt"); } @@ -200,35 +169,4 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } - private static class TokenStats { - - private long start; - private long timeToFirstToken; - private long timeToLastToken; - private long tokens = 0; - - TokenStats() { - start = System.currentTimeMillis(); - } - - void onToken() { - if (tokens == 0) { - timeToFirstToken = System.currentTimeMillis() - start; - } - tokens++; - } - - void onCompletion() { - timeToLastToken = System.currentTimeMillis() - start; - } - - String report() { - return "Time to first token: " + timeToFirstToken + " ms, " + - "Generation time: " + timeToLastToken + " ms, " + - "Generated tokens: " + tokens + " " + - String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0)); - } - - } - } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 88a1e6c1485..83ae349f5a0 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -64,17 +64,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { @Override public void data(Data data) throws IOException { - if (data instanceof EventStream.ErrorEvent error) { - generator.writeRaw("event: error\n"); - generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField("source", error.source()); - generator.writeNumberField("error", error.code()); - generator.writeStringField("message", error.message()); - generator.writeEndObject(); - generator.writeRaw("\n\n"); - generator.flush(); - } else if (data instanceof EventStream.Event event) { + if (data instanceof EventStream.Event event) { if (RENDER_EVENT_HEADER) { generator.writeRaw("event: " + event.type() + "\n"); } @@ -85,6 +75,19 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { generator.writeRaw("\n\n"); generator.flush(); } + else if (data instanceof ErrorHit) { + for (ErrorMessage error : ((ErrorHit) data).errors()) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.getSource()); + generator.writeNumberField("error", error.getCode()); + generator.writeStringField("message", error.getMessage()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + } // Todo: support other types of data such as search results (hits), timing and trace } diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index 8e6f7977d55..b393a91e6d0 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> { } public void error(String source, ErrorMessage message) { - incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message)); + incoming().add(new DefaultErrorHit(source, message)); } public void markComplete() { @@ -117,38 +117,4 @@ public class EventStream extends Hit implements DataList<Data> { } - public static class ErrorEvent extends Event { - - private final String source; - private final ErrorMessage message; - - public ErrorEvent(int eventNumber, String source, ErrorMessage message) { - super(eventNumber, message.getMessage(), "error"); - this.source = source; - this.message = message; - } - - public String source() { - return source; - } - - public int code() { - return message.getCode(); - } - - public String message() { - return message.getMessage(); - } - - @Override - public Hit asHit() { - Hit hit = super.asHit(); - hit.setField("source", source); - hit.setField("code", message.getCode()); - return hit; - } - - - } - } diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def new file mode 100755 index 00000000000..0866459166a --- /dev/null +++ b/container-search/src/main/resources/configdefinitions/llm-client.def @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The name of the secret containing the api key +apiKeySecretName string default="" + +# Endpoint for LLM client - if not set reverts to default for client +endpoint string default="" |