diff options
Diffstat (limited to 'container-search')
12 files changed, 243 insertions, 587 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index e74fe22c588..07f0449e61a 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -7842,6 +7842,21 @@ "public static final int emptyDocsumsCode" ] }, + "com.yahoo.search.result.EventStream$ErrorEvent" : { + "superClass" : "com.yahoo.search.result.EventStream$Event", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)", + "public java.lang.String source()", + "public int code()", + "public java.lang.String message()", + "public com.yahoo.search.result.Hit asHit()" + ], + "fields" : [ ] + }, "com.yahoo.search.result.EventStream$Event" : { "superClass" : "com.yahoo.component.provider.ListenableFreezableClass", "interfaces" : [ @@ -9149,99 +9164,6 @@ ], "fields" : [ ] }, - "ai.vespa.llm.clients.ConfigurableLanguageModel" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public", - "abstract" - ], - "methods" : [ - "public void <init>()", - "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", - "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", - "protected void setApiKey(ai.vespa.llm.InferenceParameters)", - "protected java.lang.String getEndpoint()", - "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Builder" - ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public void <init>()", - "public void <init>(ai.vespa.llm.clients.LlmClientConfig)", - "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", - "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", - "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", - "public final java.lang.String getDefMd5()", - "public final java.lang.String getDefName()", - "public final java.lang.String getDefNamespace()", - "public final boolean getApplyOnRestart()", - "public final void setApplyOnRestart(boolean)", - "public ai.vespa.llm.clients.LlmClientConfig build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig$Producer" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Producer" - ], - "attributes" : [ - "public", - "interface", - "abstract" - ], - "methods" : [ - "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig" : { - "superClass" : "com.yahoo.config.ConfigInstance", - "interfaces" : [ ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public static java.lang.String getDefMd5()", - "public static java.lang.String getDefName()", - "public static java.lang.String getDefNamespace()", - "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)", - "public java.lang.String apiKeySecretName()", - "public java.lang.String endpoint()" - ], - "fields" : [ - "public static final java.lang.String CONFIG_DEF_MD5", - "public static final java.lang.String CONFIG_DEF_NAME", - "public static final java.lang.String CONFIG_DEF_NAMESPACE", - "public static final java.lang.String[] CONFIG_DEF_SCHEMA" - ] - }, - "ai.vespa.llm.clients.OpenAI" : { - "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", - "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" - ], - "fields" : [ ] - }, "ai.vespa.search.llm.LLMSearcher" : { "superClass" : "com.yahoo.search.Searcher", "interfaces" : [ ], diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java deleted file mode 100644 index 761fdf0af93..00000000000 --- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.logging.Logger; - - -/** - * Base class for language models that can be configured with config definitions. - * - * @author lesters - */ -@Beta -public abstract class ConfigurableLanguageModel implements LanguageModel { - - private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); - - private final String apiKey; - private final String endpoint; - - public ConfigurableLanguageModel() { - this.apiKey = null; - this.endpoint = null; - } - - @Inject - public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { - this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); - this.endpoint = config.endpoint(); - } - - private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { - String apiKey = ""; - if (property != null && ! property.isEmpty()) { - try { - apiKey = secretStore.getSecret(property); - } catch (UnsupportedOperationException e) { - // Secret store is not available - silently ignore this - } catch (Exception e) { - log.warning("Secret store look up failed: " + e.getMessage() + "\n" + - "Will expect API key in request header"); - } - } - return apiKey; - } - - protected String getApiKey(InferenceParameters params) { - return params.getApiKey().orElse(null); - } - - /** - * Set the API key as retrieved from secret store if it is not already set - */ - protected void setApiKey(InferenceParameters params) { - if (params.getApiKey().isEmpty() && apiKey != null) { - params.setApiKey(apiKey); - } - } - - protected String getEndpoint() { - return endpoint; - } - - protected void setEndpoint(InferenceParameters params) { - if (endpoint != null && ! endpoint.isEmpty()) { - params.setEndpoint(endpoint); - } - } - -} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java deleted file mode 100644 index 82e19d47c92..00000000000 --- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.client.openai.OpenAiClient; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -/** - * A configurable OpenAI client. - * - * @author lesters - */ -@Beta -public class OpenAI extends ConfigurableLanguageModel { - - private final OpenAiClient client; - - @Inject - public OpenAI(LlmClientConfig config, SecretStore secretStore) { - super(config, secretStore); - client = new OpenAiClient(); - } - - @Override - public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { - setApiKey(parameters); - setEndpoint(parameters); - return client.complete(prompt, parameters); - } - - @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, - InferenceParameters parameters, - Consumer<Completion> consumer) { - setApiKey(parameters); - setEndpoint(parameters); - return client.completeAsync(prompt, parameters, consumer); - } -} - diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java deleted file mode 100644 index c360245901c..00000000000 --- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.clients; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 860fc69af91..f565315b775 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -20,6 +20,7 @@ import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.List; +import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -83,27 +84,41 @@ public class LLMSearcher extends Searcher { protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + try { + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + } catch (RejectedExecutionException e) { + return new Result(query, new ErrorMessage(429, e.getMessage())); + } + } + + private boolean shouldAddPrompt(Query query) { + return query.getTrace().getLevel() >= 1; + } + + private boolean shouldAddTokenStats(Query query) { + return query.getTrace().getLevel() >= 1; } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - EventStream eventStream = new EventStream(); + final EventStream eventStream = new EventStream(); - if (query.getTrace().getLevel() >= 1) { + if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } - languageModel.completeAsync(prompt, options, token -> { - eventStream.add(token.text()); + final TokenStats tokenStats = new TokenStats(); + languageModel.completeAsync(prompt, options, completion -> { + tokenStats.onToken(); + handleCompletion(eventStream, completion); }).exceptionally(exception -> { - int errorCode = 400; - if (exception instanceof LanguageModelException languageModelException) { - errorCode = languageModelException.code(); - } - eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + handleException(eventStream, exception); eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { + tokenStats.onCompletion(); + if (shouldAddTokenStats(query)) { + eventStream.add(tokenStats.report(), "stats"); + } eventStream.markComplete(); }); @@ -112,10 +127,26 @@ public class LLMSearcher extends Searcher { return new Result(query, hitGroup); } + private void handleCompletion(EventStream eventStream, Completion completion) { + if (completion.finishReason() == Completion.FinishReason.error) { + eventStream.add(completion.text(), "error"); + } else { + eventStream.add(completion.text()); + } + } + + private void handleException(EventStream eventStream, Throwable exception) { + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + } + private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { EventStream eventStream = new EventStream(); - if (query.getTrace().getLevel() >= 1) { + if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } @@ -169,4 +200,35 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } + private static class TokenStats { + + private long start; + private long timeToFirstToken; + private long timeToLastToken; + private long tokens = 0; + + TokenStats() { + start = System.currentTimeMillis(); + } + + void onToken() { + if (tokens == 0) { + timeToFirstToken = System.currentTimeMillis() - start; + } + tokens++; + } + + void onCompletion() { + timeToLastToken = System.currentTimeMillis() - start; + } + + String report() { + return "Time to first token: " + timeToFirstToken + " ms, " + + "Generation time: " + timeToLastToken + " ms, " + + "Generated tokens: " + tokens + " " + + String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0)); + } + + } + } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 83ae349f5a0..88a1e6c1485 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -64,7 +64,17 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { @Override public void data(Data data) throws IOException { - if (data instanceof EventStream.Event event) { + if (data instanceof EventStream.ErrorEvent error) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.source()); + generator.writeNumberField("error", error.code()); + generator.writeStringField("message", error.message()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } else if (data instanceof EventStream.Event event) { if (RENDER_EVENT_HEADER) { generator.writeRaw("event: " + event.type() + "\n"); } @@ -75,19 +85,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { generator.writeRaw("\n\n"); generator.flush(); } - else if (data instanceof ErrorHit) { - for (ErrorMessage error : ((ErrorHit) data).errors()) { - generator.writeRaw("event: error\n"); - generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField("source", error.getSource()); - generator.writeNumberField("error", error.getCode()); - generator.writeStringField("message", error.getMessage()); - generator.writeEndObject(); - generator.writeRaw("\n\n"); - generator.flush(); - } - } // Todo: support other types of data such as search results (hits), timing and trace } diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index b393a91e6d0..8e6f7977d55 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> { } public void error(String source, ErrorMessage message) { - incoming().add(new DefaultErrorHit(source, message)); + incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message)); } public void markComplete() { @@ -117,4 +117,38 @@ public class EventStream extends Hit implements DataList<Data> { } + public static class ErrorEvent extends Event { + + private final String source; + private final ErrorMessage message; + + public ErrorEvent(int eventNumber, String source, ErrorMessage message) { + super(eventNumber, message.getMessage(), "error"); + this.source = source; + this.message = message; + } + + public String source() { + return source; + } + + public int code() { + return message.getCode(); + } + + public String message() { + return message.getMessage(); + } + + @Override + public Hit asHit() { + Hit hit = super.asHit(); + hit.setField("source", source); + hit.setField("code", message.getCode()); + return hit; + } + + + } + } diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def deleted file mode 100755 index 0866459166a..00000000000 --- a/container-search/src/main/resources/configdefinitions/llm-client.def +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package=ai.vespa.llm.clients - -# The name of the secret containing the api key -apiKeySecretName string default="" - -# Endpoint for LLM client - if not set reverts to default for client -endpoint string default="" diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java deleted file mode 100644 index 35d5cfd3855..00000000000 --- a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.container.di.componentgraph.Provider; -import com.yahoo.container.jdisc.SecretStoreProvider; -import com.yahoo.container.jdisc.secretstore.SecretStore; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.function.BiFunction; -import java.util.stream.Collectors; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class ConfigurableLanguageModelTest { - - @Test - public void testSyncGeneration() { - var prompt = StringPrompt.from("Why are ducks better than cats?"); - var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); - assertEquals(1, result.size()); - assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); - } - - @Test - public void testAsyncGeneration() { - var executor = Executors.newFixedThreadPool(1); - var prompt = StringPrompt.from("Why are ducks better than cats?"); - var sb = new StringBuilder(); - try { - var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { - sb.append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); - - var reason = future.join(); - assertTrue(future.isDone()); - assertNotEquals(reason, Completion.FinishReason.error); - } finally { - executor.shutdownNow(); - } - - assertEquals("Ducks have adorable waddling walks.", sb.toString()); - } - - @Test - public void testInferenceParameters() { - var prompt = StringPrompt.from("Why are ducks better than cats?"); - var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); - var result = createLLM().complete(prompt, params); - assertEquals("Random text about ducks", result.get(0).text()); - } - - @Test - public void testNoApiKey() { - var prompt = StringPrompt.from(""); - var config = modelParams("api-key", null); - var secrets = createSecretStore(Map.of()); - assertThrows(IllegalArgumentException.class, () -> { - createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); - }); - } - - @Test - public void testApiKeyFromSecretStore() { - var prompt = StringPrompt.from(""); - var config = modelParams("api-key-in-secret-store", null); - var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); - assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); - } - - private static String lookupParameter(String parameter, Map<String, String> params) { - return params.get(parameter); - } - - private static InferenceParameters inferenceParams() { - return new InferenceParameters(s -> lookupParameter(s, Map.of())); - } - - private static InferenceParameters inferenceParams(Map<String, String> params) { - return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); - } - - private static InferenceParameters inferenceParamsWithDefaultKey() { - return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of())); - } - - private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { - var config = new LlmClientConfig.Builder(); - if (apiKeySecretName != null) { - config.apiKeySecretName(apiKeySecretName); - } - if (endpoint != null) { - config.endpoint(endpoint); - } - return config.build(); - } - - public static SecretStore createSecretStore(Map<String, String> secrets) { - Provider<SecretStore> secretStore = new Provider<>() { - public SecretStore get() { - return new SecretStore() { - public String getSecret(String key) { - return secrets.get(key); - } - public String getSecret(String key, int version) { - return secrets.get(key); - } - }; - } - public void deconstruct() { - } - }; - return secretStore.get(); - } - - public static BiFunction<Prompt, InferenceParameters, String> createGenerator() { - return (prompt, options) -> { - String answer = "I have no opinion on the matter"; - if (prompt.asString().contains("ducks")) { - answer = "Ducks have adorable waddling walks."; - var temperature = options.getDouble("temperature"); - if (temperature.isPresent() && temperature.get() > 0.5) { - answer = "Random text about ducks vs cats that makes no sense whatsoever."; - } - } - var maxTokens = options.getInt("maxTokens"); - if (maxTokens.isPresent()) { - return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); - } - return answer; - }; - } - - private static MockLLMClient createLLM() { - LlmClientConfig config = new LlmClientConfig.Builder().build(); - return createLLM(config, null); - } - - private static MockLLMClient createLLM(ExecutorService executor) { - LlmClientConfig config = new LlmClientConfig.Builder().build(); - return createLLM(config, executor); - } - - private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { - var generator = createGenerator(); - var secretStore = new SecretStoreProvider(); // throws exception on use - return createLLM(config, generator, secretStore.get(), executor); - } - - private static MockLLMClient createLLM(LlmClientConfig config, - BiFunction<Prompt, InferenceParameters, String> generator, - SecretStore secretStore) { - return createLLM(config, generator, secretStore, null); - } - - private static MockLLMClient createLLM(LlmClientConfig config, - BiFunction<Prompt, InferenceParameters, String> generator, - SecretStore secretStore, - ExecutorService executor) { - return new MockLLMClient(config, secretStore, generator, executor); - } - -} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java deleted file mode 100644 index 4d0073f1cbe..00000000000 --- a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.function.BiFunction; -import java.util.function.Consumer; - -public class MockLLMClient extends ConfigurableLanguageModel { - - public final static String ACCEPTED_API_KEY = "sesame"; - - private final ExecutorService executor; - private final BiFunction<Prompt, InferenceParameters, String> generator; - - private Prompt lastPrompt; - - public MockLLMClient(LlmClientConfig config, - SecretStore secretStore, - BiFunction<Prompt, InferenceParameters, String> generator, - ExecutorService executor) { - super(config, secretStore); - this.generator = generator; - this.executor = executor; - } - - private void checkApiKey(InferenceParameters options) { - var apiKey = getApiKey(options); - if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { - throw new IllegalArgumentException("Invalid API key"); - } - } - - private void setPrompt(Prompt prompt) { - this.lastPrompt = prompt; - } - - public Prompt getPrompt() { - return this.lastPrompt; - } - - @Override - public List<Completion> complete(Prompt prompt, InferenceParameters params) { - setApiKey(params); - checkApiKey(params); - setPrompt(prompt); - return List.of(Completion.from(this.generator.apply(prompt, params))); - } - - @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, - InferenceParameters params, - Consumer<Completion> consumer) { - setPrompt(prompt); - var completionFuture = new CompletableFuture<Completion.FinishReason>(); - var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization - - long sleep = 1; - executor.submit(() -> { - try { - for (int i=0; i < completions.length; ++i) { - String completion = (i > 0 ? " " : "") + completions[i]; - consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); - } - completionFuture.complete(Completion.FinishReason.stop); - } catch (InterruptedException e) { - // Do nothing - } - }); - - return completionFuture; - } - -} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java deleted file mode 100644 index 57339f6ad49..00000000000 --- a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.container.jdisc.SecretStoreProvider; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import java.util.Map; - -public class OpenAITest { - - private static final String apiKey = "<your-api-key>"; - - @Test - @Disabled - public void testOpenAIGeneration() { - var config = new LlmClientConfig.Builder().build(); - var openai = new OpenAI(config, new SecretStoreProvider().get()); - var options = Map.of( - "maxTokens", "10" - ); - - var prompt = StringPrompt.from("why are ducks better than cats?"); - var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { - System.out.print(completion.text()); - }).exceptionally(exception -> { - System.out.println("Error: " + exception); - return null; - }); - future.join(); - } - -} diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java index 1efcf1c736a..3baa9715c34 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java @@ -3,14 +3,11 @@ package ai.vespa.search.llm; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.clients.ConfigurableLanguageModelTest; -import ai.vespa.llm.clients.LlmClientConfig; -import ai.vespa.llm.clients.MockLLMClient; +import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; @@ -20,10 +17,14 @@ import org.junit.jupiter.api.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -36,10 +37,10 @@ public class LLMSearcherTest { @Test public void testLLMSelection() { - var llm1 = createLLMClient("mock1"); - var llm2 = createLLMClient("mock2"); + var client1 = createLLMClient("mock1"); + var client2 = createLLMClient("mock2"); var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); - var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2)); var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); assertEquals(1, result.getHitCount()); assertEquals("My id is mock2", getCompletion(result)); @@ -47,14 +48,16 @@ public class LLMSearcherTest { @Test public void testGeneration() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of("prompt", "why are ducks better than cats"); assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); } @Test public void testPrompting() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); // Prompt with prefix assertEquals("Ducks have adorable waddling walks.", @@ -71,7 +74,8 @@ public class LLMSearcherTest { @Test public void testPromptEvent() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "prompt", "why are ducks better than cats", "traceLevel", "1"); @@ -90,7 +94,8 @@ public class LLMSearcherTest { @Test public void testParameters() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "llm.prompt", "why are ducks better than cats", "llm.temperature", "1.0", @@ -107,16 +112,18 @@ public class LLMSearcherTest { "foo.maxTokens", "5" ); var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(config, client); assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); } @Test public void testApiKeyFromHeader() { var properties = Map.of("prompt", "why are ducks better than cats"); - var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); - assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); - assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + var client = createLLMClient(createApiKeyGenerator("a_valid_key")); + var searcher = createLLMSearcher(client); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key")); } @Test @@ -129,7 +136,8 @@ public class LLMSearcherTest { "llm.stream", "true", // ... but inference parameters says do it anyway "llm.prompt", "why are ducks better than cats?" ); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + var client = createLLMClient(executor); + var searcher = createLLMSearcher(config, client); Result result = runMockSearch(searcher, params); assertEquals(1, result.getHitCount()); @@ -162,6 +170,10 @@ public class LLMSearcherTest { return runMockSearch(searcher, parameters, null, ""); } + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) { + return runMockSearch(searcher, parameters, apiKey, "llm"); + } + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { Chain<Searcher> chain = new Chain<>(searcher); Execution execution = new Execution(chain, Execution.Context.createContextStub()); @@ -191,43 +203,59 @@ public class LLMSearcherTest { } private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { - return ConfigurableLanguageModelTest.createGenerator(); + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; } - static MockLLMClient createLLMClient() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, null); + private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) { + return (prompt, options) -> { + if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) { + throw new IllegalArgumentException("Invalid API key"); + } + return "Ok"; + }; + } + + static MockLLM createLLMClient() { + return new MockLLM(createGenerator(), null); } - static MockLLMClient createLLMClient(String id) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createIdGenerator(id); - return new MockLLMClient(config, secretStore, generator, null); + static MockLLM createLLMClient(String id) { + return new MockLLM(createIdGenerator(id), null); } - static MockLLMClient createLLMClient(ExecutorService executor) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, executor); + static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) { + return new MockLLM(generator, null); } - static MockLLMClient createLLMClientWithoutSecretStore() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = new SecretStoreProvider(); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore.get(), generator, null); + static MockLLM createLLMClient(ExecutorService executor) { + return new MockLLM(createGenerator(), executor); + } + + private static Searcher createLLMSearcher(LanguageModel llm) { + return createLLMSearcher(Map.of("mock", llm)); } private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); - ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); - llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); - models.freeze(); - return new LLMSearcher(config, models); + return createLLMSearcher(config, llms); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) { + return createLLMSearcher(config, Map.of("mock", llm)); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { @@ -237,4 +265,44 @@ public class LLMSearcherTest { return new LLMSearcher(config, models); } + private static class MockLLM implements LanguageModel { + + private final ExecutorService executor; + private final BiFunction<Prompt, InferenceParameters, String> generator; + + public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) { + this.executor = executor; + this.generator = generator; + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters params) { + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters params, + Consumer<Completion> consumer) { + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i = 0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); + Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + return completionFuture; + } + + } + } |