aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-04-16 13:07:31 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-04-16 13:07:31 +0200
commit780bc7cbe8fb67ae712fcf278f8900c8f32e14a6 (patch)
treed37ae6302fb72f5b6f5e33a8a45d968ffc505fdf /container-search/src/main/java
parentfca990d5ed32c408df42bbe178b174711fa54a08 (diff)
Reapply "Lesters/add local llms 2"
This reverts commit ed62b750494822cc67a328390178754512baf032.
Diffstat (limited to 'container-search/src/main/java')
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java75
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java48
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java84
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java36
6 files changed, 119 insertions, 156 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
deleted file mode 100644
index 761fdf0af93..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.logging.Logger;
-
-
-/**
- * Base class for language models that can be configured with config definitions.
- *
- * @author lesters
- */
-@Beta
-public abstract class ConfigurableLanguageModel implements LanguageModel {
-
- private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName());
-
- private final String apiKey;
- private final String endpoint;
-
- public ConfigurableLanguageModel() {
- this.apiKey = null;
- this.endpoint = null;
- }
-
- @Inject
- public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
- this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
- this.endpoint = config.endpoint();
- }
-
- private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
- String apiKey = "";
- if (property != null && ! property.isEmpty()) {
- try {
- apiKey = secretStore.getSecret(property);
- } catch (UnsupportedOperationException e) {
- // Secret store is not available - silently ignore this
- } catch (Exception e) {
- log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
- "Will expect API key in request header");
- }
- }
- return apiKey;
- }
-
- protected String getApiKey(InferenceParameters params) {
- return params.getApiKey().orElse(null);
- }
-
- /**
- * Set the API key as retrieved from secret store if it is not already set
- */
- protected void setApiKey(InferenceParameters params) {
- if (params.getApiKey().isEmpty() && apiKey != null) {
- params.setApiKey(apiKey);
- }
- }
-
- protected String getEndpoint() {
- return endpoint;
- }
-
- protected void setEndpoint(InferenceParameters params) {
- if (endpoint != null && ! endpoint.isEmpty()) {
- params.setEndpoint(endpoint);
- }
- }
-
-}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
deleted file mode 100644
index 82e19d47c92..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.client.openai.OpenAiClient;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.function.Consumer;
-
-/**
- * A configurable OpenAI client.
- *
- * @author lesters
- */
-@Beta
-public class OpenAI extends ConfigurableLanguageModel {
-
- private final OpenAiClient client;
-
- @Inject
- public OpenAI(LlmClientConfig config, SecretStore secretStore) {
- super(config, secretStore);
- client = new OpenAiClient();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.complete(prompt, parameters);
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters parameters,
- Consumer<Completion> consumer) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.completeAsync(prompt, parameters, consumer);
- }
-}
-
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
deleted file mode 100644
index c360245901c..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
+++ /dev/null
@@ -1,7 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-@ExportPackage
-@PublicApi
-package ai.vespa.llm.clients;
-
-import com.yahoo.api.annotations.PublicApi;
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 860fc69af91..f565315b775 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -20,6 +20,7 @@ import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.List;
+import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -83,27 +84,41 @@ public class LLMSearcher extends Searcher {
protected Result complete(Query query, Prompt prompt) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ try {
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ } catch (RejectedExecutionException e) {
+ return new Result(query, new ErrorMessage(429, e.getMessage()));
+ }
+ }
+
+ private boolean shouldAddPrompt(Query query) {
+ return query.getTrace().getLevel() >= 1;
+ }
+
+ private boolean shouldAddTokenStats(Query query) {
+ return query.getTrace().getLevel() >= 1;
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- EventStream eventStream = new EventStream();
+ final EventStream eventStream = new EventStream();
- if (query.getTrace().getLevel() >= 1) {
+ if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
- languageModel.completeAsync(prompt, options, token -> {
- eventStream.add(token.text());
+ final TokenStats tokenStats = new TokenStats();
+ languageModel.completeAsync(prompt, options, completion -> {
+ tokenStats.onToken();
+ handleCompletion(eventStream, completion);
}).exceptionally(exception -> {
- int errorCode = 400;
- if (exception instanceof LanguageModelException languageModelException) {
- errorCode = languageModelException.code();
- }
- eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ handleException(eventStream, exception);
eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
+ tokenStats.onCompletion();
+ if (shouldAddTokenStats(query)) {
+ eventStream.add(tokenStats.report(), "stats");
+ }
eventStream.markComplete();
});
@@ -112,10 +127,26 @@ public class LLMSearcher extends Searcher {
return new Result(query, hitGroup);
}
+ private void handleCompletion(EventStream eventStream, Completion completion) {
+ if (completion.finishReason() == Completion.FinishReason.error) {
+ eventStream.add(completion.text(), "error");
+ } else {
+ eventStream.add(completion.text());
+ }
+ }
+
+ private void handleException(EventStream eventStream, Throwable exception) {
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ }
+
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
EventStream eventStream = new EventStream();
- if (query.getTrace().getLevel() >= 1) {
+ if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
@@ -169,4 +200,35 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
+ private static class TokenStats {
+
+ private long start;
+ private long timeToFirstToken;
+ private long timeToLastToken;
+ private long tokens = 0;
+
+ TokenStats() {
+ start = System.currentTimeMillis();
+ }
+
+ void onToken() {
+ if (tokens == 0) {
+ timeToFirstToken = System.currentTimeMillis() - start;
+ }
+ tokens++;
+ }
+
+ void onCompletion() {
+ timeToLastToken = System.currentTimeMillis() - start;
+ }
+
+ String report() {
+ return "Time to first token: " + timeToFirstToken + " ms, " +
+ "Generation time: " + timeToLastToken + " ms, " +
+ "Generated tokens: " + tokens + " " +
+ String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0));
+ }
+
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 83ae349f5a0..88a1e6c1485 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -64,7 +64,17 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- if (data instanceof EventStream.Event event) {
+ if (data instanceof EventStream.ErrorEvent error) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.source());
+ generator.writeNumberField("error", error.code());
+ generator.writeStringField("message", error.message());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ } else if (data instanceof EventStream.Event event) {
if (RENDER_EVENT_HEADER) {
generator.writeRaw("event: " + event.type() + "\n");
}
@@ -75,19 +85,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.writeRaw("\n\n");
generator.flush();
}
- else if (data instanceof ErrorHit) {
- for (ErrorMessage error : ((ErrorHit) data).errors()) {
- generator.writeRaw("event: error\n");
- generator.writeRaw("data: ");
- generator.writeStartObject();
- generator.writeStringField("source", error.getSource());
- generator.writeNumberField("error", error.getCode());
- generator.writeStringField("message", error.getMessage());
- generator.writeEndObject();
- generator.writeRaw("\n\n");
- generator.flush();
- }
- }
// Todo: support other types of data such as search results (hits), timing and trace
}
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
index b393a91e6d0..8e6f7977d55 100644
--- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> {
}
public void error(String source, ErrorMessage message) {
- incoming().add(new DefaultErrorHit(source, message));
+ incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message));
}
public void markComplete() {
@@ -117,4 +117,38 @@ public class EventStream extends Hit implements DataList<Data> {
}
+ public static class ErrorEvent extends Event {
+
+ private final String source;
+ private final ErrorMessage message;
+
+ public ErrorEvent(int eventNumber, String source, ErrorMessage message) {
+ super(eventNumber, message.getMessage(), "error");
+ this.source = source;
+ this.message = message;
+ }
+
+ public String source() {
+ return source;
+ }
+
+ public int code() {
+ return message.getCode();
+ }
+
+ public String message() {
+ return message.getMessage();
+ }
+
+ @Override
+ public Hit asHit() {
+ Hit hit = super.asHit();
+ hit.setField("source", source);
+ hit.setField("code", message.getCode());
+ return hit;
+ }
+
+
+ }
+
}