aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main
diff options
context:
space:
mode:
authorHarald Musum <musum@vespa.ai>2024-04-15 20:44:10 +0200
committerGitHub <noreply@github.com>2024-04-15 20:44:10 +0200
commited62b750494822cc67a328390178754512baf032 (patch)
tree3f94e22d38d63c456a3bb202b4f3787ecff5ded6 /container-search/src/main
parent44a866c0d648543c04567503990c03c36403d86d (diff)
Revert "Lesters/add local llms 2"
Diffstat (limited to 'container-search/src/main')
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java75
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java48
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java84
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java36
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def8
7 files changed, 164 insertions, 119 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
new file mode 100644
index 00000000000..761fdf0af93
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
@@ -0,0 +1,75 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.logging.Logger;
+
+
+/**
+ * Base class for language models that can be configured with config definitions.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class ConfigurableLanguageModel implements LanguageModel {
+
+ private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName());
+
+ private final String apiKey;
+ private final String endpoint;
+
+ public ConfigurableLanguageModel() {
+ this.apiKey = null;
+ this.endpoint = null;
+ }
+
+ @Inject
+ public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
+ this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
+ this.endpoint = config.endpoint();
+ }
+
+ private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
+ String apiKey = "";
+ if (property != null && ! property.isEmpty()) {
+ try {
+ apiKey = secretStore.getSecret(property);
+ } catch (UnsupportedOperationException e) {
+ // Secret store is not available - silently ignore this
+ } catch (Exception e) {
+ log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
+ "Will expect API key in request header");
+ }
+ }
+ return apiKey;
+ }
+
+ protected String getApiKey(InferenceParameters params) {
+ return params.getApiKey().orElse(null);
+ }
+
+ /**
+ * Set the API key as retrieved from secret store if it is not already set
+ */
+ protected void setApiKey(InferenceParameters params) {
+ if (params.getApiKey().isEmpty() && apiKey != null) {
+ params.setApiKey(apiKey);
+ }
+ }
+
+ protected String getEndpoint() {
+ return endpoint;
+ }
+
+ protected void setEndpoint(InferenceParameters params) {
+ if (endpoint != null && ! endpoint.isEmpty()) {
+ params.setEndpoint(endpoint);
+ }
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
new file mode 100644
index 00000000000..82e19d47c92
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
@@ -0,0 +1,48 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.client.openai.OpenAiClient;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+/**
+ * A configurable OpenAI client.
+ *
+ * @author lesters
+ */
+@Beta
+public class OpenAI extends ConfigurableLanguageModel {
+
+ private final OpenAiClient client;
+
+ @Inject
+ public OpenAI(LlmClientConfig config, SecretStore secretStore) {
+ super(config, secretStore);
+ client = new OpenAiClient();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.complete(prompt, parameters);
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters parameters,
+ Consumer<Completion> consumer) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.completeAsync(prompt, parameters, consumer);
+ }
+}
+
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
new file mode 100644
index 00000000000..c360245901c
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.clients;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index f565315b775..860fc69af91 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -20,7 +20,6 @@ import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.List;
-import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -84,41 +83,27 @@ public class LLMSearcher extends Searcher {
protected Result complete(Query query, Prompt prompt) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
- try {
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
- } catch (RejectedExecutionException e) {
- return new Result(query, new ErrorMessage(429, e.getMessage()));
- }
- }
-
- private boolean shouldAddPrompt(Query query) {
- return query.getTrace().getLevel() >= 1;
- }
-
- private boolean shouldAddTokenStats(Query query) {
- return query.getTrace().getLevel() >= 1;
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- final EventStream eventStream = new EventStream();
+ EventStream eventStream = new EventStream();
- if (shouldAddPrompt(query)) {
+ if (query.getTrace().getLevel() >= 1) {
eventStream.add(prompt.asString(), "prompt");
}
- final TokenStats tokenStats = new TokenStats();
- languageModel.completeAsync(prompt, options, completion -> {
- tokenStats.onToken();
- handleCompletion(eventStream, completion);
+ languageModel.completeAsync(prompt, options, token -> {
+ eventStream.add(token.text());
}).exceptionally(exception -> {
- handleException(eventStream, exception);
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
- tokenStats.onCompletion();
- if (shouldAddTokenStats(query)) {
- eventStream.add(tokenStats.report(), "stats");
- }
eventStream.markComplete();
});
@@ -127,26 +112,10 @@ public class LLMSearcher extends Searcher {
return new Result(query, hitGroup);
}
- private void handleCompletion(EventStream eventStream, Completion completion) {
- if (completion.finishReason() == Completion.FinishReason.error) {
- eventStream.add(completion.text(), "error");
- } else {
- eventStream.add(completion.text());
- }
- }
-
- private void handleException(EventStream eventStream, Throwable exception) {
- int errorCode = 400;
- if (exception instanceof LanguageModelException languageModelException) {
- errorCode = languageModelException.code();
- }
- eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
- }
-
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
EventStream eventStream = new EventStream();
- if (shouldAddPrompt(query)) {
+ if (query.getTrace().getLevel() >= 1) {
eventStream.add(prompt.asString(), "prompt");
}
@@ -200,35 +169,4 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
- private static class TokenStats {
-
- private long start;
- private long timeToFirstToken;
- private long timeToLastToken;
- private long tokens = 0;
-
- TokenStats() {
- start = System.currentTimeMillis();
- }
-
- void onToken() {
- if (tokens == 0) {
- timeToFirstToken = System.currentTimeMillis() - start;
- }
- tokens++;
- }
-
- void onCompletion() {
- timeToLastToken = System.currentTimeMillis() - start;
- }
-
- String report() {
- return "Time to first token: " + timeToFirstToken + " ms, " +
- "Generation time: " + timeToLastToken + " ms, " +
- "Generated tokens: " + tokens + " " +
- String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0));
- }
-
- }
-
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 88a1e6c1485..83ae349f5a0 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -64,17 +64,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- if (data instanceof EventStream.ErrorEvent error) {
- generator.writeRaw("event: error\n");
- generator.writeRaw("data: ");
- generator.writeStartObject();
- generator.writeStringField("source", error.source());
- generator.writeNumberField("error", error.code());
- generator.writeStringField("message", error.message());
- generator.writeEndObject();
- generator.writeRaw("\n\n");
- generator.flush();
- } else if (data instanceof EventStream.Event event) {
+ if (data instanceof EventStream.Event event) {
if (RENDER_EVENT_HEADER) {
generator.writeRaw("event: " + event.type() + "\n");
}
@@ -85,6 +75,19 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.writeRaw("\n\n");
generator.flush();
}
+ else if (data instanceof ErrorHit) {
+ for (ErrorMessage error : ((ErrorHit) data).errors()) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.getSource());
+ generator.writeNumberField("error", error.getCode());
+ generator.writeStringField("message", error.getMessage());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ }
+ }
// Todo: support other types of data such as search results (hits), timing and trace
}
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
index 8e6f7977d55..b393a91e6d0 100644
--- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> {
}
public void error(String source, ErrorMessage message) {
- incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message));
+ incoming().add(new DefaultErrorHit(source, message));
}
public void markComplete() {
@@ -117,38 +117,4 @@ public class EventStream extends Hit implements DataList<Data> {
}
- public static class ErrorEvent extends Event {
-
- private final String source;
- private final ErrorMessage message;
-
- public ErrorEvent(int eventNumber, String source, ErrorMessage message) {
- super(eventNumber, message.getMessage(), "error");
- this.source = source;
- this.message = message;
- }
-
- public String source() {
- return source;
- }
-
- public int code() {
- return message.getCode();
- }
-
- public String message() {
- return message.getMessage();
- }
-
- @Override
- public Hit asHit() {
- Hit hit = super.asHit();
- hit.setField("source", source);
- hit.setField("code", message.getCode());
- return hit;
- }
-
-
- }
-
}
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
new file mode 100755
index 00000000000..0866459166a
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-client.def
@@ -0,0 +1,8 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm.clients
+
+# The name of the secret containing the api key
+apiKeySecretName string default=""
+
+# Endpoint for LLM client - if not set reverts to default for client
+endpoint string default=""