diff options
Diffstat (limited to 'container-search/src/main/java')
6 files changed, 119 insertions, 156 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java deleted file mode 100644 index 761fdf0af93..00000000000 --- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.logging.Logger; - - -/** - * Base class for language models that can be configured with config definitions. - * - * @author lesters - */ -@Beta -public abstract class ConfigurableLanguageModel implements LanguageModel { - - private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); - - private final String apiKey; - private final String endpoint; - - public ConfigurableLanguageModel() { - this.apiKey = null; - this.endpoint = null; - } - - @Inject - public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { - this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); - this.endpoint = config.endpoint(); - } - - private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { - String apiKey = ""; - if (property != null && ! property.isEmpty()) { - try { - apiKey = secretStore.getSecret(property); - } catch (UnsupportedOperationException e) { - // Secret store is not available - silently ignore this - } catch (Exception e) { - log.warning("Secret store look up failed: " + e.getMessage() + "\n" + - "Will expect API key in request header"); - } - } - return apiKey; - } - - protected String getApiKey(InferenceParameters params) { - return params.getApiKey().orElse(null); - } - - /** - * Set the API key as retrieved from secret store if it is not already set - */ - protected void setApiKey(InferenceParameters params) { - if (params.getApiKey().isEmpty() && apiKey != null) { - params.setApiKey(apiKey); - } - } - - protected String getEndpoint() { - return endpoint; - } - - protected void setEndpoint(InferenceParameters params) { - if (endpoint != null && ! endpoint.isEmpty()) { - params.setEndpoint(endpoint); - } - } - -} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java deleted file mode 100644 index 82e19d47c92..00000000000 --- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.client.openai.OpenAiClient; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -/** - * A configurable OpenAI client. - * - * @author lesters - */ -@Beta -public class OpenAI extends ConfigurableLanguageModel { - - private final OpenAiClient client; - - @Inject - public OpenAI(LlmClientConfig config, SecretStore secretStore) { - super(config, secretStore); - client = new OpenAiClient(); - } - - @Override - public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { - setApiKey(parameters); - setEndpoint(parameters); - return client.complete(prompt, parameters); - } - - @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, - InferenceParameters parameters, - Consumer<Completion> consumer) { - setApiKey(parameters); - setEndpoint(parameters); - return client.completeAsync(prompt, parameters, consumer); - } -} - diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java deleted file mode 100644 index c360245901c..00000000000 --- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.clients; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 860fc69af91..f565315b775 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -20,6 +20,7 @@ import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.List; +import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -83,27 +84,41 @@ public class LLMSearcher extends Searcher { protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + try { + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + } catch (RejectedExecutionException e) { + return new Result(query, new ErrorMessage(429, e.getMessage())); + } + } + + private boolean shouldAddPrompt(Query query) { + return query.getTrace().getLevel() >= 1; + } + + private boolean shouldAddTokenStats(Query query) { + return query.getTrace().getLevel() >= 1; } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - EventStream eventStream = new EventStream(); + final EventStream eventStream = new EventStream(); - if (query.getTrace().getLevel() >= 1) { + if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } - languageModel.completeAsync(prompt, options, token -> { - eventStream.add(token.text()); + final TokenStats tokenStats = new TokenStats(); + languageModel.completeAsync(prompt, options, completion -> { + tokenStats.onToken(); + handleCompletion(eventStream, completion); }).exceptionally(exception -> { - int errorCode = 400; - if (exception instanceof LanguageModelException languageModelException) { - errorCode = languageModelException.code(); - } - eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + handleException(eventStream, exception); eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { + tokenStats.onCompletion(); + if (shouldAddTokenStats(query)) { + eventStream.add(tokenStats.report(), "stats"); + } eventStream.markComplete(); }); @@ -112,10 +127,26 @@ public class LLMSearcher extends Searcher { return new Result(query, hitGroup); } + private void handleCompletion(EventStream eventStream, Completion completion) { + if (completion.finishReason() == Completion.FinishReason.error) { + eventStream.add(completion.text(), "error"); + } else { + eventStream.add(completion.text()); + } + } + + private void handleException(EventStream eventStream, Throwable exception) { + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + } + private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { EventStream eventStream = new EventStream(); - if (query.getTrace().getLevel() >= 1) { + if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } @@ -169,4 +200,35 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } + private static class TokenStats { + + private long start; + private long timeToFirstToken; + private long timeToLastToken; + private long tokens = 0; + + TokenStats() { + start = System.currentTimeMillis(); + } + + void onToken() { + if (tokens == 0) { + timeToFirstToken = System.currentTimeMillis() - start; + } + tokens++; + } + + void onCompletion() { + timeToLastToken = System.currentTimeMillis() - start; + } + + String report() { + return "Time to first token: " + timeToFirstToken + " ms, " + + "Generation time: " + timeToLastToken + " ms, " + + "Generated tokens: " + tokens + " " + + String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0)); + } + + } + } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 83ae349f5a0..88a1e6c1485 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -64,7 +64,17 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { @Override public void data(Data data) throws IOException { - if (data instanceof EventStream.Event event) { + if (data instanceof EventStream.ErrorEvent error) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.source()); + generator.writeNumberField("error", error.code()); + generator.writeStringField("message", error.message()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } else if (data instanceof EventStream.Event event) { if (RENDER_EVENT_HEADER) { generator.writeRaw("event: " + event.type() + "\n"); } @@ -75,19 +85,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { generator.writeRaw("\n\n"); generator.flush(); } - else if (data instanceof ErrorHit) { - for (ErrorMessage error : ((ErrorHit) data).errors()) { - generator.writeRaw("event: error\n"); - generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField("source", error.getSource()); - generator.writeNumberField("error", error.getCode()); - generator.writeStringField("message", error.getMessage()); - generator.writeEndObject(); - generator.writeRaw("\n\n"); - generator.flush(); - } - } // Todo: support other types of data such as search results (hits), timing and trace } diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index b393a91e6d0..8e6f7977d55 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> { } public void error(String source, ErrorMessage message) { - incoming().add(new DefaultErrorHit(source, message)); + incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message)); } public void markComplete() { @@ -117,4 +117,38 @@ public class EventStream extends Hit implements DataList<Data> { } + public static class ErrorEvent extends Event { + + private final String source; + private final ErrorMessage message; + + public ErrorEvent(int eventNumber, String source, ErrorMessage message) { + super(eventNumber, message.getMessage(), "error"); + this.source = source; + this.message = message; + } + + public String source() { + return source; + } + + public int code() { + return message.getCode(); + } + + public String message() { + return message.getMessage(); + } + + @Override + public Hit asHit() { + Hit hit = super.asHit(); + hit.setField("source", source); + hit.setField("code", message.getCode()); + return hit; + } + + + } + } |