aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json108
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java75
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java48
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java84
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java36
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def8
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java174
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java80
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java35
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java150
12 files changed, 243 insertions, 587 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index e74fe22c588..07f0449e61a 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -7842,6 +7842,21 @@
"public static final int emptyDocsumsCode"
]
},
+ "com.yahoo.search.result.EventStream$ErrorEvent" : {
+ "superClass" : "com.yahoo.search.result.EventStream$Event",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)",
+ "public java.lang.String source()",
+ "public int code()",
+ "public java.lang.String message()",
+ "public com.yahoo.search.result.Hit asHit()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.search.result.EventStream$Event" : {
"superClass" : "com.yahoo.component.provider.ListenableFreezableClass",
"interfaces" : [
@@ -9149,99 +9164,6 @@
],
"fields" : [ ]
},
- "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public",
- "abstract"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
- "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
- "protected java.lang.String getEndpoint()",
- "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Builder"
- ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
- "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
- "public final java.lang.String getDefMd5()",
- "public final java.lang.String getDefName()",
- "public final java.lang.String getDefNamespace()",
- "public final boolean getApplyOnRestart()",
- "public final void setApplyOnRestart(boolean)",
- "public ai.vespa.llm.clients.LlmClientConfig build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Producer"
- ],
- "attributes" : [
- "public",
- "interface",
- "abstract"
- ],
- "methods" : [
- "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig" : {
- "superClass" : "com.yahoo.config.ConfigInstance",
- "interfaces" : [ ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public static java.lang.String getDefMd5()",
- "public static java.lang.String getDefName()",
- "public static java.lang.String getDefNamespace()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
- "public java.lang.String apiKeySecretName()",
- "public java.lang.String endpoint()"
- ],
- "fields" : [
- "public static final java.lang.String CONFIG_DEF_MD5",
- "public static final java.lang.String CONFIG_DEF_NAME",
- "public static final java.lang.String CONFIG_DEF_NAMESPACE",
- "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
- ]
- },
- "ai.vespa.llm.clients.OpenAI" : {
- "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
- "interfaces" : [ ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
- ],
- "fields" : [ ]
- },
"ai.vespa.search.llm.LLMSearcher" : {
"superClass" : "com.yahoo.search.Searcher",
"interfaces" : [ ],
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
deleted file mode 100644
index 761fdf0af93..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.logging.Logger;
-
-
-/**
- * Base class for language models that can be configured with config definitions.
- *
- * @author lesters
- */
-@Beta
-public abstract class ConfigurableLanguageModel implements LanguageModel {
-
- private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName());
-
- private final String apiKey;
- private final String endpoint;
-
- public ConfigurableLanguageModel() {
- this.apiKey = null;
- this.endpoint = null;
- }
-
- @Inject
- public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
- this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
- this.endpoint = config.endpoint();
- }
-
- private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
- String apiKey = "";
- if (property != null && ! property.isEmpty()) {
- try {
- apiKey = secretStore.getSecret(property);
- } catch (UnsupportedOperationException e) {
- // Secret store is not available - silently ignore this
- } catch (Exception e) {
- log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
- "Will expect API key in request header");
- }
- }
- return apiKey;
- }
-
- protected String getApiKey(InferenceParameters params) {
- return params.getApiKey().orElse(null);
- }
-
- /**
- * Set the API key as retrieved from secret store if it is not already set
- */
- protected void setApiKey(InferenceParameters params) {
- if (params.getApiKey().isEmpty() && apiKey != null) {
- params.setApiKey(apiKey);
- }
- }
-
- protected String getEndpoint() {
- return endpoint;
- }
-
- protected void setEndpoint(InferenceParameters params) {
- if (endpoint != null && ! endpoint.isEmpty()) {
- params.setEndpoint(endpoint);
- }
- }
-
-}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
deleted file mode 100644
index 82e19d47c92..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.client.openai.OpenAiClient;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.component.annotation.Inject;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.function.Consumer;
-
-/**
- * A configurable OpenAI client.
- *
- * @author lesters
- */
-@Beta
-public class OpenAI extends ConfigurableLanguageModel {
-
- private final OpenAiClient client;
-
- @Inject
- public OpenAI(LlmClientConfig config, SecretStore secretStore) {
- super(config, secretStore);
- client = new OpenAiClient();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.complete(prompt, parameters);
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters parameters,
- Consumer<Completion> consumer) {
- setApiKey(parameters);
- setEndpoint(parameters);
- return client.completeAsync(prompt, parameters, consumer);
- }
-}
-
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
deleted file mode 100644
index c360245901c..00000000000
--- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
+++ /dev/null
@@ -1,7 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-@ExportPackage
-@PublicApi
-package ai.vespa.llm.clients;
-
-import com.yahoo.api.annotations.PublicApi;
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 860fc69af91..f565315b775 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -20,6 +20,7 @@ import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.List;
+import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -83,27 +84,41 @@ public class LLMSearcher extends Searcher {
protected Result complete(Query query, Prompt prompt) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ try {
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ } catch (RejectedExecutionException e) {
+ return new Result(query, new ErrorMessage(429, e.getMessage()));
+ }
+ }
+
+ private boolean shouldAddPrompt(Query query) {
+ return query.getTrace().getLevel() >= 1;
+ }
+
+ private boolean shouldAddTokenStats(Query query) {
+ return query.getTrace().getLevel() >= 1;
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- EventStream eventStream = new EventStream();
+ final EventStream eventStream = new EventStream();
- if (query.getTrace().getLevel() >= 1) {
+ if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
- languageModel.completeAsync(prompt, options, token -> {
- eventStream.add(token.text());
+ final TokenStats tokenStats = new TokenStats();
+ languageModel.completeAsync(prompt, options, completion -> {
+ tokenStats.onToken();
+ handleCompletion(eventStream, completion);
}).exceptionally(exception -> {
- int errorCode = 400;
- if (exception instanceof LanguageModelException languageModelException) {
- errorCode = languageModelException.code();
- }
- eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ handleException(eventStream, exception);
eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
+ tokenStats.onCompletion();
+ if (shouldAddTokenStats(query)) {
+ eventStream.add(tokenStats.report(), "stats");
+ }
eventStream.markComplete();
});
@@ -112,10 +127,26 @@ public class LLMSearcher extends Searcher {
return new Result(query, hitGroup);
}
+ private void handleCompletion(EventStream eventStream, Completion completion) {
+ if (completion.finishReason() == Completion.FinishReason.error) {
+ eventStream.add(completion.text(), "error");
+ } else {
+ eventStream.add(completion.text());
+ }
+ }
+
+ private void handleException(EventStream eventStream, Throwable exception) {
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ }
+
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
EventStream eventStream = new EventStream();
- if (query.getTrace().getLevel() >= 1) {
+ if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
@@ -169,4 +200,35 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
+ private static class TokenStats {
+
+ private long start;
+ private long timeToFirstToken;
+ private long timeToLastToken;
+ private long tokens = 0;
+
+ TokenStats() {
+ start = System.currentTimeMillis();
+ }
+
+ void onToken() {
+ if (tokens == 0) {
+ timeToFirstToken = System.currentTimeMillis() - start;
+ }
+ tokens++;
+ }
+
+ void onCompletion() {
+ timeToLastToken = System.currentTimeMillis() - start;
+ }
+
+ String report() {
+ return "Time to first token: " + timeToFirstToken + " ms, " +
+ "Generation time: " + timeToLastToken + " ms, " +
+ "Generated tokens: " + tokens + " " +
+ String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0));
+ }
+
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 83ae349f5a0..88a1e6c1485 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -64,7 +64,17 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- if (data instanceof EventStream.Event event) {
+ if (data instanceof EventStream.ErrorEvent error) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.source());
+ generator.writeNumberField("error", error.code());
+ generator.writeStringField("message", error.message());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ } else if (data instanceof EventStream.Event event) {
if (RENDER_EVENT_HEADER) {
generator.writeRaw("event: " + event.type() + "\n");
}
@@ -75,19 +85,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.writeRaw("\n\n");
generator.flush();
}
- else if (data instanceof ErrorHit) {
- for (ErrorMessage error : ((ErrorHit) data).errors()) {
- generator.writeRaw("event: error\n");
- generator.writeRaw("data: ");
- generator.writeStartObject();
- generator.writeStringField("source", error.getSource());
- generator.writeNumberField("error", error.getCode());
- generator.writeStringField("message", error.getMessage());
- generator.writeEndObject();
- generator.writeRaw("\n\n");
- generator.flush();
- }
- }
// Todo: support other types of data such as search results (hits), timing and trace
}
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
index b393a91e6d0..8e6f7977d55 100644
--- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> {
}
public void error(String source, ErrorMessage message) {
- incoming().add(new DefaultErrorHit(source, message));
+ incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message));
}
public void markComplete() {
@@ -117,4 +117,38 @@ public class EventStream extends Hit implements DataList<Data> {
}
+ public static class ErrorEvent extends Event {
+
+ private final String source;
+ private final ErrorMessage message;
+
+ public ErrorEvent(int eventNumber, String source, ErrorMessage message) {
+ super(eventNumber, message.getMessage(), "error");
+ this.source = source;
+ this.message = message;
+ }
+
+ public String source() {
+ return source;
+ }
+
+ public int code() {
+ return message.getCode();
+ }
+
+ public String message() {
+ return message.getMessage();
+ }
+
+ @Override
+ public Hit asHit() {
+ Hit hit = super.asHit();
+ hit.setField("source", source);
+ hit.setField("code", message.getCode());
+ return hit;
+ }
+
+
+ }
+
}
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
deleted file mode 100755
index 0866459166a..00000000000
--- a/container-search/src/main/resources/configdefinitions/llm-client.def
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package=ai.vespa.llm.clients
-
-# The name of the secret containing the api key
-apiKeySecretName string default=""
-
-# Endpoint for LLM client - if not set reverts to default for client
-endpoint string default=""
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
deleted file mode 100644
index 35d5cfd3855..00000000000
--- a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
+++ /dev/null
@@ -1,174 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import ai.vespa.llm.completion.StringPrompt;
-import com.yahoo.container.di.componentgraph.Provider;
-import com.yahoo.container.jdisc.SecretStoreProvider;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-import org.junit.jupiter.api.Test;
-
-import java.util.Arrays;
-import java.util.Map;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Executors;
-import java.util.function.BiFunction;
-import java.util.stream.Collectors;
-
-import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertNotEquals;
-import static org.junit.jupiter.api.Assertions.assertThrows;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-public class ConfigurableLanguageModelTest {
-
- @Test
- public void testSyncGeneration() {
- var prompt = StringPrompt.from("Why are ducks better than cats?");
- var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
- assertEquals(1, result.size());
- assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
- }
-
- @Test
- public void testAsyncGeneration() {
- var executor = Executors.newFixedThreadPool(1);
- var prompt = StringPrompt.from("Why are ducks better than cats?");
- var sb = new StringBuilder();
- try {
- var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
- sb.append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
-
- var reason = future.join();
- assertTrue(future.isDone());
- assertNotEquals(reason, Completion.FinishReason.error);
- } finally {
- executor.shutdownNow();
- }
-
- assertEquals("Ducks have adorable waddling walks.", sb.toString());
- }
-
- @Test
- public void testInferenceParameters() {
- var prompt = StringPrompt.from("Why are ducks better than cats?");
- var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
- var result = createLLM().complete(prompt, params);
- assertEquals("Random text about ducks", result.get(0).text());
- }
-
- @Test
- public void testNoApiKey() {
- var prompt = StringPrompt.from("");
- var config = modelParams("api-key", null);
- var secrets = createSecretStore(Map.of());
- assertThrows(IllegalArgumentException.class, () -> {
- createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
- });
- }
-
- @Test
- public void testApiKeyFromSecretStore() {
- var prompt = StringPrompt.from("");
- var config = modelParams("api-key-in-secret-store", null);
- var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
- assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
- }
-
- private static String lookupParameter(String parameter, Map<String, String> params) {
- return params.get(parameter);
- }
-
- private static InferenceParameters inferenceParams() {
- return new InferenceParameters(s -> lookupParameter(s, Map.of()));
- }
-
- private static InferenceParameters inferenceParams(Map<String, String> params) {
- return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
- }
-
- private static InferenceParameters inferenceParamsWithDefaultKey() {
- return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of()));
- }
-
- private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) {
- var config = new LlmClientConfig.Builder();
- if (apiKeySecretName != null) {
- config.apiKeySecretName(apiKeySecretName);
- }
- if (endpoint != null) {
- config.endpoint(endpoint);
- }
- return config.build();
- }
-
- public static SecretStore createSecretStore(Map<String, String> secrets) {
- Provider<SecretStore> secretStore = new Provider<>() {
- public SecretStore get() {
- return new SecretStore() {
- public String getSecret(String key) {
- return secrets.get(key);
- }
- public String getSecret(String key, int version) {
- return secrets.get(key);
- }
- };
- }
- public void deconstruct() {
- }
- };
- return secretStore.get();
- }
-
- public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return (prompt, options) -> {
- String answer = "I have no opinion on the matter";
- if (prompt.asString().contains("ducks")) {
- answer = "Ducks have adorable waddling walks.";
- var temperature = options.getDouble("temperature");
- if (temperature.isPresent() && temperature.get() > 0.5) {
- answer = "Random text about ducks vs cats that makes no sense whatsoever.";
- }
- }
- var maxTokens = options.getInt("maxTokens");
- if (maxTokens.isPresent()) {
- return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
- }
- return answer;
- };
- }
-
- private static MockLLMClient createLLM() {
- LlmClientConfig config = new LlmClientConfig.Builder().build();
- return createLLM(config, null);
- }
-
- private static MockLLMClient createLLM(ExecutorService executor) {
- LlmClientConfig config = new LlmClientConfig.Builder().build();
- return createLLM(config, executor);
- }
-
- private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
- var generator = createGenerator();
- var secretStore = new SecretStoreProvider(); // throws exception on use
- return createLLM(config, generator, secretStore.get(), executor);
- }
-
- private static MockLLMClient createLLM(LlmClientConfig config,
- BiFunction<Prompt, InferenceParameters, String> generator,
- SecretStore secretStore) {
- return createLLM(config, generator, secretStore, null);
- }
-
- private static MockLLMClient createLLM(LlmClientConfig config,
- BiFunction<Prompt, InferenceParameters, String> generator,
- SecretStore secretStore,
- ExecutorService executor) {
- return new MockLLMClient(config, secretStore, generator, executor);
- }
-
-}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
deleted file mode 100644
index 4d0073f1cbe..00000000000
--- a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
+++ /dev/null
@@ -1,80 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.container.jdisc.secretstore.SecretStore;
-
-import java.util.List;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutorService;
-import java.util.function.BiFunction;
-import java.util.function.Consumer;
-
-public class MockLLMClient extends ConfigurableLanguageModel {
-
- public final static String ACCEPTED_API_KEY = "sesame";
-
- private final ExecutorService executor;
- private final BiFunction<Prompt, InferenceParameters, String> generator;
-
- private Prompt lastPrompt;
-
- public MockLLMClient(LlmClientConfig config,
- SecretStore secretStore,
- BiFunction<Prompt, InferenceParameters, String> generator,
- ExecutorService executor) {
- super(config, secretStore);
- this.generator = generator;
- this.executor = executor;
- }
-
- private void checkApiKey(InferenceParameters options) {
- var apiKey = getApiKey(options);
- if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) {
- throw new IllegalArgumentException("Invalid API key");
- }
- }
-
- private void setPrompt(Prompt prompt) {
- this.lastPrompt = prompt;
- }
-
- public Prompt getPrompt() {
- return this.lastPrompt;
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters params) {
- setApiKey(params);
- checkApiKey(params);
- setPrompt(prompt);
- return List.of(Completion.from(this.generator.apply(prompt, params)));
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters params,
- Consumer<Completion> consumer) {
- setPrompt(prompt);
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
-
- long sleep = 1;
- executor.submit(() -> {
- try {
- for (int i=0; i < completions.length; ++i) {
- String completion = (i > 0 ? " " : "") + completions[i];
- consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep);
- }
- completionFuture.complete(Completion.FinishReason.stop);
- } catch (InterruptedException e) {
- // Do nothing
- }
- });
-
- return completionFuture;
- }
-
-}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
deleted file mode 100644
index 57339f6ad49..00000000000
--- a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.StringPrompt;
-import com.yahoo.container.jdisc.SecretStoreProvider;
-import org.junit.jupiter.api.Disabled;
-import org.junit.jupiter.api.Test;
-
-import java.util.Map;
-
-public class OpenAITest {
-
- private static final String apiKey = "<your-api-key>";
-
- @Test
- @Disabled
- public void testOpenAIGeneration() {
- var config = new LlmClientConfig.Builder().build();
- var openai = new OpenAI(config, new SecretStoreProvider().get());
- var options = Map.of(
- "maxTokens", "10"
- );
-
- var prompt = StringPrompt.from("why are ducks better than cats?");
- var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
- System.out.print(completion.text());
- }).exceptionally(exception -> {
- System.out.println("Error: " + exception);
- return null;
- });
- future.join();
- }
-
-}
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
index 1efcf1c736a..3baa9715c34 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
@@ -3,14 +3,11 @@ package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
-import ai.vespa.llm.clients.LlmClientConfig;
-import ai.vespa.llm.clients.MockLLMClient;
+import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.ComponentId;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.provider.ComponentRegistry;
-import com.yahoo.container.jdisc.SecretStoreProvider;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
@@ -20,10 +17,14 @@ import org.junit.jupiter.api.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -36,10 +37,10 @@ public class LLMSearcherTest {
@Test
public void testLLMSelection() {
- var llm1 = createLLMClient("mock1");
- var llm2 = createLLMClient("mock2");
+ var client1 = createLLMClient("mock1");
+ var client2 = createLLMClient("mock2");
var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
- var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
+ var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2));
var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
assertEquals(1, result.getHitCount());
assertEquals("My id is mock2", getCompletion(result));
@@ -47,14 +48,16 @@ public class LLMSearcherTest {
@Test
public void testGeneration() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of("prompt", "why are ducks better than cats");
assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testPrompting() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
// Prompt with prefix
assertEquals("Ducks have adorable waddling walks.",
@@ -71,7 +74,8 @@ public class LLMSearcherTest {
@Test
public void testPromptEvent() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of(
"prompt", "why are ducks better than cats",
"traceLevel", "1");
@@ -90,7 +94,8 @@ public class LLMSearcherTest {
@Test
public void testParameters() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of(
"llm.prompt", "why are ducks better than cats",
"llm.temperature", "1.0",
@@ -107,16 +112,18 @@ public class LLMSearcherTest {
"foo.maxTokens", "5"
);
var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
- var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(config, client);
assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testApiKeyFromHeader() {
var properties = Map.of("prompt", "why are ducks better than cats");
- var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
- assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
- assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
+ var client = createLLMClient(createApiKeyGenerator("a_valid_key"));
+ var searcher = createLLMSearcher(client);
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key"));
}
@Test
@@ -129,7 +136,8 @@ public class LLMSearcherTest {
"llm.stream", "true", // ... but inference parameters says do it anyway
"llm.prompt", "why are ducks better than cats?"
);
- var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
+ var client = createLLMClient(executor);
+ var searcher = createLLMSearcher(config, client);
Result result = runMockSearch(searcher, params);
assertEquals(1, result.getHitCount());
@@ -162,6 +170,10 @@ public class LLMSearcherTest {
return runMockSearch(searcher, parameters, null, "");
}
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) {
+ return runMockSearch(searcher, parameters, apiKey, "llm");
+ }
+
static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
Chain<Searcher> chain = new Chain<>(searcher);
Execution execution = new Execution(chain, Execution.Context.createContextStub());
@@ -191,43 +203,59 @@ public class LLMSearcherTest {
}
private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return ConfigurableLanguageModelTest.createGenerator();
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
}
- static MockLLMClient createLLMClient() {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore, generator, null);
+ private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) {
+ return (prompt, options) -> {
+ if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) {
+ throw new IllegalArgumentException("Invalid API key");
+ }
+ return "Ok";
+ };
+ }
+
+ static MockLLM createLLMClient() {
+ return new MockLLM(createGenerator(), null);
}
- static MockLLMClient createLLMClient(String id) {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createIdGenerator(id);
- return new MockLLMClient(config, secretStore, generator, null);
+ static MockLLM createLLMClient(String id) {
+ return new MockLLM(createIdGenerator(id), null);
}
- static MockLLMClient createLLMClient(ExecutorService executor) {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore, generator, executor);
+ static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) {
+ return new MockLLM(generator, null);
}
- static MockLLMClient createLLMClientWithoutSecretStore() {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = new SecretStoreProvider();
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore.get(), generator, null);
+ static MockLLM createLLMClient(ExecutorService executor) {
+ return new MockLLM(createGenerator(), executor);
+ }
+
+ private static Searcher createLLMSearcher(LanguageModel llm) {
+ return createLLMSearcher(Map.of("mock", llm));
}
private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
- ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
- llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
- models.freeze();
- return new LLMSearcher(config, models);
+ return createLLMSearcher(config, llms);
+ }
+
+ private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) {
+ return createLLMSearcher(config, Map.of("mock", llm));
}
private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
@@ -237,4 +265,44 @@ public class LLMSearcherTest {
return new LLMSearcher(config, models);
}
+ private static class MockLLM implements LanguageModel {
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) {
+ this.executor = executor;
+ this.generator = generator;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i = 0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none));
+ Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+ return completionFuture;
+ }
+
+ }
+
}