From 7518d93961ac7c5c5da1cd41717d42f600dae647 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 15 Apr 2024 14:51:27 +0200 Subject: Revert "Lesters/add local llms" --- application/pom.xml | 10 -- cloud-tenant-base-dependencies-enforcer/pom.xml | 3 +- .../model/container/ContainerModelEvaluation.java | 1 - .../vespa/model/container/PlatformBundles.java | 7 +- container-dependencies-enforcer/pom.xml | 3 +- .../main/java/de/kherud/llama/package-info.java | 8 - container-search-and-docproc/pom.xml | 22 --- container-search/abi-spec.json | 108 ++++++++++-- .../llm/clients/ConfigurableLanguageModel.java | 75 +++++++++ .../src/main/java/ai/vespa/llm/clients/OpenAI.java | 48 ++++++ .../java/ai/vespa/llm/clients/package-info.java | 7 + .../main/java/ai/vespa/search/llm/LLMSearcher.java | 84 ++-------- .../com/yahoo/search/rendering/EventRenderer.java | 25 +-- .../java/com/yahoo/search/result/EventStream.java | 36 +--- .../resources/configdefinitions/llm-client.def | 8 + .../llm/clients/ConfigurableLanguageModelTest.java | 174 ++++++++++++++++++++ .../java/ai/vespa/llm/clients/MockLLMClient.java | 80 +++++++++ .../test/java/ai/vespa/llm/clients/OpenAITest.java | 35 ++++ .../java/ai/vespa/search/llm/LLMSearcherTest.java | 150 +++++------------ container-test/pom.xml | 10 -- fat-model-dependencies/pom.xml | 14 -- model-integration/abi-spec.json | 182 --------------------- model-integration/pom.xml | 12 -- .../llm/clients/ConfigurableLanguageModel.java | 75 --------- .../main/java/ai/vespa/llm/clients/LocalLLM.java | 126 -------------- .../src/main/java/ai/vespa/llm/clients/OpenAI.java | 48 ------ .../java/ai/vespa/llm/clients/package-info.java | 7 - .../resources/configdefinitions/llm-client.def | 8 - .../configdefinitions/llm-local-client.def | 29 ---- .../llm/clients/ConfigurableLanguageModelTest.java | 174 -------------------- .../java/ai/vespa/llm/clients/LocalLLMTest.java | 181 -------------------- .../java/ai/vespa/llm/clients/MockLLMClient.java | 80 --------- .../test/java/ai/vespa/llm/clients/OpenAITest.java | 35 ---- model-integration/src/test/models/llm/tinyllm.gguf | Bin 1185376 -> 0 bytes standalone-container/pom.xml | 1 - 35 files changed, 591 insertions(+), 1275 deletions(-) delete mode 100644 container-llama/src/main/java/de/kherud/llama/package-info.java create mode 100644 container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java create mode 100644 container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java create mode 100644 container-search/src/main/java/ai/vespa/llm/clients/package-info.java create mode 100755 container-search/src/main/resources/configdefinitions/llm-client.def create mode 100644 container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java create mode 100644 container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java create mode 100644 container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java delete mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java delete mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java delete mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java delete mode 100644 model-integration/src/main/java/ai/vespa/llm/clients/package-info.java delete mode 100755 model-integration/src/main/resources/configdefinitions/llm-client.def delete mode 100755 model-integration/src/main/resources/configdefinitions/llm-local-client.def delete mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java delete mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java delete mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java delete mode 100644 model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java delete mode 100644 model-integration/src/test/models/llm/tinyllm.gguf diff --git a/application/pom.xml b/application/pom.xml index 7c34cd966eb..f5704541308 100644 --- a/application/pom.xml +++ b/application/pom.xml @@ -43,16 +43,6 @@ com.yahoo.vespa model-integration ${project.version} - - - com.google.protobuf - protobuf-java - - - org.lz4 - lz4-java - - com.yahoo.vespa diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml index 6d6ac4a81e7..eff4e4125e9 100644 --- a/cloud-tenant-base-dependencies-enforcer/pom.xml +++ b/cloud-tenant-base-dependencies-enforcer/pom.xml @@ -97,7 +97,6 @@ com.yahoo.vespa:messagebus:*:provided com.yahoo.vespa:metrics:*:provided com.yahoo.vespa:model-evaluation:*:provided - com.yahoo.vespa:model-integration:*:provided com.yahoo.vespa:opennlp-linguistics:*:provided com.yahoo.vespa:predicate-search-core:*:provided com.yahoo.vespa:provided-dependencies:*:provided @@ -123,6 +122,7 @@ com.yahoo.vespa:indexinglanguage:*:test com.yahoo.vespa:logd:*:test com.yahoo.vespa:metrics-proxy:*:test + com.yahoo.vespa:model-integration:*:test com.yahoo.vespa:searchsummary:*:test com.yahoo.vespa:standalone-container:*:test com.yahoo.vespa:storage:*:test @@ -141,7 +141,6 @@ com.microsoft.onnxruntime:onnxruntime:jar:${onnxruntime.vespa.version}:test com.thaiopensource:jing:20091111:test commons-codec:commons-codec:${commons-codec.vespa.version}:test - de.kherud:llama:${kherud.llama.vespa.version}:test io.airlift:aircompressor:${aircompressor.vespa.version}:test io.airlift:airline:${airline.vespa.version}:test io.prometheus:simpleclient:${prometheus.client.vespa.version}:test diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index 5be1690f0dc..62979404025 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -31,7 +31,6 @@ public class ContainerModelEvaluation implements public final static String EVALUATION_BUNDLE_NAME = "model-evaluation"; public final static String INTEGRATION_BUNDLE_NAME = "model-integration"; public final static String ONNXRUNTIME_BUNDLE_NAME = "container-onnxruntime.jar"; - public final static String LLAMA_BUNDLE_NAME = "container-llama.jar"; public final static String ONNX_RUNTIME_CLASS = "ai.vespa.modelintegration.evaluator.OnnxRuntime"; private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java index e801884a73a..9f91f6bf5e1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java @@ -14,7 +14,6 @@ import java.util.stream.Stream; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.EVALUATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME; -import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LLAMA_BUNDLE_NAME; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.ONNXRUNTIME_BUNDLE_NAME; /** @@ -64,8 +63,7 @@ public class PlatformBundles { "lucene-linguistics", EVALUATION_BUNDLE_NAME, INTEGRATION_BUNDLE_NAME, - ONNXRUNTIME_BUNDLE_NAME, - LLAMA_BUNDLE_NAME + ONNXRUNTIME_BUNDLE_NAME ); private static Set toBundlePaths(String... bundleNames) { @@ -150,8 +148,7 @@ public class PlatformBundles { com.yahoo.vespa.streamingvisitors.StreamingBackend.class.getName(), ai.vespa.search.llm.LLMSearcher.class.getName(), ai.vespa.search.llm.RAGSearcher.class.getName(), - ai.vespa.llm.clients.OpenAI.class.getName(), - ai.vespa.llm.clients.LocalLLM.class.getName() + ai.vespa.llm.clients.OpenAI.class.getName() ); } diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml index 6451e32941c..a06365abbeb 100644 --- a/container-dependencies-enforcer/pom.xml +++ b/container-dependencies-enforcer/pom.xml @@ -116,7 +116,6 @@ com.yahoo.vespa:messagebus:*:provided com.yahoo.vespa:metrics:*:provided com.yahoo.vespa:model-evaluation:*:provided - com.yahoo.vespa:model-integration:*:provided com.yahoo.vespa:opennlp-linguistics:*:provided com.yahoo.vespa:predicate-search-core:*:provided com.yahoo.vespa:provided-dependencies:*:provided @@ -140,6 +139,7 @@ com.yahoo.vespa:indexinglanguage:*:test com.yahoo.vespa:logd:*:test com.yahoo.vespa:metrics-proxy:*:test + com.yahoo.vespa:model-integration:*:test com.yahoo.vespa:searchsummary:*:test com.yahoo.vespa:standalone-container:*:test com.yahoo.vespa:storage:*:test @@ -152,7 +152,6 @@ com.google.protobuf:protobuf-java:${protobuf.vespa.version}:test com.ibm.icu:icu4j:${icu4j.vespa.version}:test com.microsoft.onnxruntime:onnxruntime:${onnxruntime.vespa.version}:test - de.kherud:llama:${kherud.llama.vespa.version}:test com.thaiopensource:jing:20091111:test commons-codec:commons-codec:${commons-codec.vespa.version}:test io.airlift:aircompressor:${aircompressor.vespa.version}:test diff --git a/container-llama/src/main/java/de/kherud/llama/package-info.java b/container-llama/src/main/java/de/kherud/llama/package-info.java deleted file mode 100644 index 3c9773762b4..00000000000 --- a/container-llama/src/main/java/de/kherud/llama/package-info.java +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * @author lesters - */ -@ExportPackage -package de.kherud.llama; - -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search-and-docproc/pom.xml b/container-search-and-docproc/pom.xml index ee74b86dbec..9554b517586 100644 --- a/container-search-and-docproc/pom.xml +++ b/container-search-and-docproc/pom.xml @@ -96,22 +96,6 @@ - - com.yahoo.vespa - model-integration - ${project.version} - compile - - - com.google.protobuf - protobuf-java - - - org.lz4 - lz4-java - - - @@ -226,12 +210,6 @@ ${project.version} provided - - com.yahoo.vespa - container-llama - ${project.version} - provided - diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 07f0449e61a..e74fe22c588 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -7842,21 +7842,6 @@ "public static final int emptyDocsumsCode" ] }, - "com.yahoo.search.result.EventStream$ErrorEvent" : { - "superClass" : "com.yahoo.search.result.EventStream$Event", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void (int, java.lang.String, com.yahoo.search.result.ErrorMessage)", - "public java.lang.String source()", - "public int code()", - "public java.lang.String message()", - "public com.yahoo.search.result.Hit asHit()" - ], - "fields" : [ ] - }, "com.yahoo.search.result.EventStream$Event" : { "superClass" : "com.yahoo.component.provider.ListenableFreezableClass", "interfaces" : [ @@ -9164,6 +9149,99 @@ ], "fields" : [ ] }, + "ai.vespa.llm.clients.ConfigurableLanguageModel" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public", + "abstract" + ], + "methods" : [ + "public void ()", + "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", + "protected void setApiKey(ai.vespa.llm.InferenceParameters)", + "protected java.lang.String getEndpoint()", + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void ()", + "public void (ai.vespa.llm.clients.LlmClientConfig)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public ai.vespa.llm.clients.LlmClientConfig build()" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Producer" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig" : { + "superClass" : "com.yahoo.config.ConfigInstance", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public void (ai.vespa.llm.clients.LlmClientConfig$Builder)", + "public java.lang.String apiKeySecretName()", + "public java.lang.String endpoint()" + ], + "fields" : [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, + "ai.vespa.llm.clients.OpenAI" : { + "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, "ai.vespa.search.llm.LLMSearcher" : { "superClass" : "com.yahoo.search.Searcher", "interfaces" : [ ], diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..761fdf0af93 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + if (endpoint != null && ! endpoint.isEmpty()) { + params.setEndpoint(endpoint); + } + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..82e19d47c92 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,48 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index f565315b775..860fc69af91 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -20,7 +20,6 @@ import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.List; -import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -84,41 +83,27 @@ public class LLMSearcher extends Searcher { protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config - try { - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); - } catch (RejectedExecutionException e) { - return new Result(query, new ErrorMessage(429, e.getMessage())); - } - } - - private boolean shouldAddPrompt(Query query) { - return query.getTrace().getLevel() >= 1; - } - - private boolean shouldAddTokenStats(Query query) { - return query.getTrace().getLevel() >= 1; + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - final EventStream eventStream = new EventStream(); + EventStream eventStream = new EventStream(); - if (shouldAddPrompt(query)) { + if (query.getTrace().getLevel() >= 1) { eventStream.add(prompt.asString(), "prompt"); } - final TokenStats tokenStats = new TokenStats(); - languageModel.completeAsync(prompt, options, completion -> { - tokenStats.onToken(); - handleCompletion(eventStream, completion); + languageModel.completeAsync(prompt, options, token -> { + eventStream.add(token.text()); }).exceptionally(exception -> { - handleException(eventStream, exception); + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { - tokenStats.onCompletion(); - if (shouldAddTokenStats(query)) { - eventStream.add(tokenStats.report(), "stats"); - } eventStream.markComplete(); }); @@ -127,26 +112,10 @@ public class LLMSearcher extends Searcher { return new Result(query, hitGroup); } - private void handleCompletion(EventStream eventStream, Completion completion) { - if (completion.finishReason() == Completion.FinishReason.error) { - eventStream.add(completion.text(), "error"); - } else { - eventStream.add(completion.text()); - } - } - - private void handleException(EventStream eventStream, Throwable exception) { - int errorCode = 400; - if (exception instanceof LanguageModelException languageModelException) { - errorCode = languageModelException.code(); - } - eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); - } - private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { EventStream eventStream = new EventStream(); - if (shouldAddPrompt(query)) { + if (query.getTrace().getLevel() >= 1) { eventStream.add(prompt.asString(), "prompt"); } @@ -200,35 +169,4 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } - private static class TokenStats { - - private long start; - private long timeToFirstToken; - private long timeToLastToken; - private long tokens = 0; - - TokenStats() { - start = System.currentTimeMillis(); - } - - void onToken() { - if (tokens == 0) { - timeToFirstToken = System.currentTimeMillis() - start; - } - tokens++; - } - - void onCompletion() { - timeToLastToken = System.currentTimeMillis() - start; - } - - String report() { - return "Time to first token: " + timeToFirstToken + " ms, " + - "Generation time: " + timeToLastToken + " ms, " + - "Generated tokens: " + tokens + " " + - String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0)); - } - - } - } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 88a1e6c1485..83ae349f5a0 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -64,17 +64,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer { @Override public void data(Data data) throws IOException { - if (data instanceof EventStream.ErrorEvent error) { - generator.writeRaw("event: error\n"); - generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField("source", error.source()); - generator.writeNumberField("error", error.code()); - generator.writeStringField("message", error.message()); - generator.writeEndObject(); - generator.writeRaw("\n\n"); - generator.flush(); - } else if (data instanceof EventStream.Event event) { + if (data instanceof EventStream.Event event) { if (RENDER_EVENT_HEADER) { generator.writeRaw("event: " + event.type() + "\n"); } @@ -85,6 +75,19 @@ public class EventRenderer extends AsynchronousSectionedRenderer { generator.writeRaw("\n\n"); generator.flush(); } + else if (data instanceof ErrorHit) { + for (ErrorMessage error : ((ErrorHit) data).errors()) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.getSource()); + generator.writeNumberField("error", error.getCode()); + generator.writeStringField("message", error.getMessage()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + } // Todo: support other types of data such as search results (hits), timing and trace } diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index 8e6f7977d55..b393a91e6d0 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList { } public void error(String source, ErrorMessage message) { - incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message)); + incoming().add(new DefaultErrorHit(source, message)); } public void markComplete() { @@ -117,38 +117,4 @@ public class EventStream extends Hit implements DataList { } - public static class ErrorEvent extends Event { - - private final String source; - private final ErrorMessage message; - - public ErrorEvent(int eventNumber, String source, ErrorMessage message) { - super(eventNumber, message.getMessage(), "error"); - this.source = source; - this.message = message; - } - - public String source() { - return source; - } - - public int code() { - return message.getCode(); - } - - public String message() { - return message.getMessage(); - } - - @Override - public Hit asHit() { - Hit hit = super.asHit(); - hit.setField("source", source); - hit.setField("code", message.getCode()); - return hit; - } - - - } - } diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def new file mode 100755 index 00000000000..0866459166a --- /dev/null +++ b/container-search/src/main/resources/configdefinitions/llm-client.def @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The name of the secret containing the api key +apiKeySecretName string default="" + +# Endpoint for LLM client - if not set reverts to default for client +endpoint string default="" diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java new file mode 100644 index 00000000000..35d5cfd3855 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java @@ -0,0 +1,174 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.container.jdisc.secretstore.SecretStore; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigurableLanguageModelTest { + + @Test + public void testSyncGeneration() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); + assertEquals(1, result.size()); + assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var sb = new StringBuilder(); + try { + var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { + sb.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + } finally { + executor.shutdownNow(); + } + + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + @Test + public void testInferenceParameters() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); + var result = createLLM().complete(prompt, params); + assertEquals("Random text about ducks", result.get(0).text()); + } + + @Test + public void testNoApiKey() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key", null); + var secrets = createSecretStore(Map.of()); + assertThrows(IllegalArgumentException.class, () -> { + createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); + }); + } + + @Test + public void testApiKeyFromSecretStore() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key-in-secret-store", null); + var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); + assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); + } + + private static String lookupParameter(String parameter, Map params) { + return params.get(parameter); + } + + private static InferenceParameters inferenceParams() { + return new InferenceParameters(s -> lookupParameter(s, Map.of())); + } + + private static InferenceParameters inferenceParams(Map params) { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); + } + + private static InferenceParameters inferenceParamsWithDefaultKey() { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of())); + } + + private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { + var config = new LlmClientConfig.Builder(); + if (apiKeySecretName != null) { + config.apiKeySecretName(apiKeySecretName); + } + if (endpoint != null) { + config.endpoint(endpoint); + } + return config.build(); + } + + public static SecretStore createSecretStore(Map secrets) { + Provider secretStore = new Provider<>() { + public SecretStore get() { + return new SecretStore() { + public String getSecret(String key) { + return secrets.get(key); + } + public String getSecret(String key, int version) { + return secrets.get(key); + } + }; + } + public void deconstruct() { + } + }; + return secretStore.get(); + } + + public static BiFunction createGenerator() { + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; + } + + private static MockLLMClient createLLM() { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, null); + } + + private static MockLLMClient createLLM(ExecutorService executor) { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { + var generator = createGenerator(); + var secretStore = new SecretStoreProvider(); // throws exception on use + return createLLM(config, generator, secretStore.get(), executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction generator, + SecretStore secretStore) { + return createLLM(config, generator, secretStore, null); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction generator, + SecretStore secretStore, + ExecutorService executor) { + return new MockLLMClient(config, secretStore, generator, executor); + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java new file mode 100644 index 00000000000..4d0073f1cbe --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java @@ -0,0 +1,80 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class MockLLMClient extends ConfigurableLanguageModel { + + public final static String ACCEPTED_API_KEY = "sesame"; + + private final ExecutorService executor; + private final BiFunction generator; + + private Prompt lastPrompt; + + public MockLLMClient(LlmClientConfig config, + SecretStore secretStore, + BiFunction generator, + ExecutorService executor) { + super(config, secretStore); + this.generator = generator; + this.executor = executor; + } + + private void checkApiKey(InferenceParameters options) { + var apiKey = getApiKey(options); + if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { + throw new IllegalArgumentException("Invalid API key"); + } + } + + private void setPrompt(Prompt prompt) { + this.lastPrompt = prompt; + } + + public Prompt getPrompt() { + return this.lastPrompt; + } + + @Override + public List complete(Prompt prompt, InferenceParameters params) { + setApiKey(params); + checkApiKey(params); + setPrompt(prompt); + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters params, + Consumer consumer) { + setPrompt(prompt); + var completionFuture = new CompletableFuture(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i=0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + + return completionFuture; + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java new file mode 100644 index 00000000000..57339f6ad49 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -0,0 +1,35 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.jdisc.SecretStoreProvider; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +public class OpenAITest { + + private static final String apiKey = ""; + + @Test + @Disabled + public void testOpenAIGeneration() { + var config = new LlmClientConfig.Builder().build(); + var openai = new OpenAI(config, new SecretStoreProvider().get()); + var options = Map.of( + "maxTokens", "10" + ); + + var prompt = StringPrompt.from("why are ducks better than cats?"); + var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { + System.out.print(completion.text()); + }).exceptionally(exception -> { + System.out.println("Error: " + exception); + return null; + }); + future.join(); + } + +} diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java index 3baa9715c34..1efcf1c736a 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java @@ -3,11 +3,14 @@ package ai.vespa.search.llm; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.clients.ConfigurableLanguageModelTest; +import ai.vespa.llm.clients.LlmClientConfig; +import ai.vespa.llm.clients.MockLLMClient; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; @@ -17,14 +20,10 @@ import org.junit.jupiter.api.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BiFunction; -import java.util.function.Consumer; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -37,10 +36,10 @@ public class LLMSearcherTest { @Test public void testLLMSelection() { - var client1 = createLLMClient("mock1"); - var client2 = createLLMClient("mock2"); + var llm1 = createLLMClient("mock1"); + var llm2 = createLLMClient("mock2"); var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); - var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2)); + var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); assertEquals(1, result.getHitCount()); assertEquals("My id is mock2", getCompletion(result)); @@ -48,16 +47,14 @@ public class LLMSearcherTest { @Test public void testGeneration() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); var params = Map.of("prompt", "why are ducks better than cats"); assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); } @Test public void testPrompting() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); // Prompt with prefix assertEquals("Ducks have adorable waddling walks.", @@ -74,8 +71,7 @@ public class LLMSearcherTest { @Test public void testPromptEvent() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); var params = Map.of( "prompt", "why are ducks better than cats", "traceLevel", "1"); @@ -94,8 +90,7 @@ public class LLMSearcherTest { @Test public void testParameters() { - var client = createLLMClient(); - var searcher = createLLMSearcher(client); + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); var params = Map.of( "llm.prompt", "why are ducks better than cats", "llm.temperature", "1.0", @@ -112,18 +107,16 @@ public class LLMSearcherTest { "foo.maxTokens", "5" ); var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); - var client = createLLMClient(); - var searcher = createLLMSearcher(config, client); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); } @Test public void testApiKeyFromHeader() { var properties = Map.of("prompt", "why are ducks better than cats"); - var client = createLLMClient(createApiKeyGenerator("a_valid_key")); - var searcher = createLLMSearcher(client); - assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key")); - assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key")); + var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); } @Test @@ -136,8 +129,7 @@ public class LLMSearcherTest { "llm.stream", "true", // ... but inference parameters says do it anyway "llm.prompt", "why are ducks better than cats?" ); - var client = createLLMClient(executor); - var searcher = createLLMSearcher(config, client); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); Result result = runMockSearch(searcher, params); assertEquals(1, result.getHitCount()); @@ -170,10 +162,6 @@ public class LLMSearcherTest { return runMockSearch(searcher, parameters, null, ""); } - static Result runMockSearch(Searcher searcher, Map parameters, String apiKey) { - return runMockSearch(searcher, parameters, apiKey, "llm"); - } - static Result runMockSearch(Searcher searcher, Map parameters, String apiKey, String prefix) { Chain chain = new Chain<>(searcher); Execution execution = new Execution(chain, Execution.Context.createContextStub()); @@ -203,59 +191,43 @@ public class LLMSearcherTest { } private static BiFunction createGenerator() { - return (prompt, options) -> { - String answer = "I have no opinion on the matter"; - if (prompt.asString().contains("ducks")) { - answer = "Ducks have adorable waddling walks."; - var temperature = options.getDouble("temperature"); - if (temperature.isPresent() && temperature.get() > 0.5) { - answer = "Random text about ducks vs cats that makes no sense whatsoever."; - } - } - var maxTokens = options.getInt("maxTokens"); - if (maxTokens.isPresent()) { - return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); - } - return answer; - }; + return ConfigurableLanguageModelTest.createGenerator(); } - private static BiFunction createApiKeyGenerator(String validApiKey) { - return (prompt, options) -> { - if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) { - throw new IllegalArgumentException("Invalid API key"); - } - return "Ok"; - }; - } - - static MockLLM createLLMClient() { - return new MockLLM(createGenerator(), null); + static MockLLMClient createLLMClient() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, null); } - static MockLLM createLLMClient(String id) { - return new MockLLM(createIdGenerator(id), null); + static MockLLMClient createLLMClient(String id) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createIdGenerator(id); + return new MockLLMClient(config, secretStore, generator, null); } - static MockLLM createLLMClient(BiFunction generator) { - return new MockLLM(generator, null); + static MockLLMClient createLLMClient(ExecutorService executor) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, executor); } - static MockLLM createLLMClient(ExecutorService executor) { - return new MockLLM(createGenerator(), executor); - } - - private static Searcher createLLMSearcher(LanguageModel llm) { - return createLLMSearcher(Map.of("mock", llm)); + static MockLLMClient createLLMClientWithoutSecretStore() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = new SecretStoreProvider(); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore.get(), generator, null); } private static Searcher createLLMSearcher(Map llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); - return createLLMSearcher(config, llms); - } - - private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) { - return createLLMSearcher(config, Map.of("mock", llm)); + ComponentRegistry models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcher(config, models); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map llms) { @@ -265,44 +237,4 @@ public class LLMSearcherTest { return new LLMSearcher(config, models); } - private static class MockLLM implements LanguageModel { - - private final ExecutorService executor; - private final BiFunction generator; - - public MockLLM(BiFunction generator, ExecutorService executor) { - this.executor = executor; - this.generator = generator; - } - - @Override - public List complete(Prompt prompt, InferenceParameters params) { - return List.of(Completion.from(this.generator.apply(prompt, params))); - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, - InferenceParameters params, - Consumer consumer) { - var completionFuture = new CompletableFuture(); - var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization - - long sleep = 1; - executor.submit(() -> { - try { - for (int i = 0; i < completions.length; ++i) { - String completion = (i > 0 ? " " : "") + completions[i]; - consumer.accept(Completion.from(completion, Completion.FinishReason.none)); - Thread.sleep(sleep); - } - completionFuture.complete(Completion.FinishReason.stop); - } catch (InterruptedException e) { - // Do nothing - } - }); - return completionFuture; - } - - } - } diff --git a/container-test/pom.xml b/container-test/pom.xml index d6be6946208..8e1b4870665 100644 --- a/container-test/pom.xml +++ b/container-test/pom.xml @@ -60,16 +60,6 @@ com.microsoft.onnxruntime onnxruntime - - de.kherud - llama - - - org.jetbrains - annotations - - - io.airlift airline diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index c4ecd3d2e1c..069b53a0dcd 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -96,14 +96,6 @@ com.theokanning.openai-gpt3-java service - - com.google.protobuf - protobuf-java - - - org.lz4 - lz4-java - @@ -238,12 +230,6 @@ com.yahoo.vespa container-search-and-docproc ${project.version} - - - com.yahoo.vespa - model-integration - - com.yahoo.vespa diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index e7130d9c777..d3c472778e6 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -1,186 +1,4 @@ { - "ai.vespa.llm.clients.ConfigurableLanguageModel" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public", - "abstract" - ], - "methods" : [ - "public void ()", - "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", - "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", - "protected void setApiKey(ai.vespa.llm.InferenceParameters)", - "protected java.lang.String getEndpoint()", - "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Builder" - ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public void ()", - "public void (ai.vespa.llm.clients.LlmClientConfig)", - "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", - "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", - "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", - "public final java.lang.String getDefMd5()", - "public final java.lang.String getDefName()", - "public final java.lang.String getDefNamespace()", - "public final boolean getApplyOnRestart()", - "public final void setApplyOnRestart(boolean)", - "public ai.vespa.llm.clients.LlmClientConfig build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig$Producer" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Producer" - ], - "attributes" : [ - "public", - "interface", - "abstract" - ], - "methods" : [ - "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig" : { - "superClass" : "com.yahoo.config.ConfigInstance", - "interfaces" : [ ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public static java.lang.String getDefMd5()", - "public static java.lang.String getDefName()", - "public static java.lang.String getDefNamespace()", - "public void (ai.vespa.llm.clients.LlmClientConfig$Builder)", - "public java.lang.String apiKeySecretName()", - "public java.lang.String endpoint()" - ], - "fields" : [ - "public static final java.lang.String CONFIG_DEF_MD5", - "public static final java.lang.String CONFIG_DEF_NAME", - "public static final java.lang.String CONFIG_DEF_NAMESPACE", - "public static final java.lang.String[] CONFIG_DEF_SCHEMA" - ] - }, - "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Builder" - ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public void ()", - "public void (ai.vespa.llm.clients.LlmLocalClientConfig)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder model(com.yahoo.config.ModelReference)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)", - "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)", - "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", - "public final java.lang.String getDefMd5()", - "public final java.lang.String getDefName()", - "public final java.lang.String getDefNamespace()", - "public final boolean getApplyOnRestart()", - "public final void setApplyOnRestart(boolean)", - "public ai.vespa.llm.clients.LlmLocalClientConfig build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Producer" - ], - "attributes" : [ - "public", - "interface", - "abstract" - ], - "methods" : [ - "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmLocalClientConfig" : { - "superClass" : "com.yahoo.config.ConfigInstance", - "interfaces" : [ ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public static java.lang.String getDefMd5()", - "public static java.lang.String getDefName()", - "public static java.lang.String getDefNamespace()", - "public void (ai.vespa.llm.clients.LlmLocalClientConfig$Builder)", - "public java.nio.file.Path model()", - "public int parallelRequests()", - "public int maxQueueSize()", - "public boolean useGpu()", - "public int gpuLayers()", - "public int threads()", - "public int contextSize()", - "public int maxTokens()" - ], - "fields" : [ - "public static final java.lang.String CONFIG_DEF_MD5", - "public static final java.lang.String CONFIG_DEF_NAME", - "public static final java.lang.String CONFIG_DEF_NAMESPACE", - "public static final java.lang.String[] CONFIG_DEF_SCHEMA" - ] - }, - "ai.vespa.llm.clients.LocalLLM" : { - "superClass" : "com.yahoo.component.AbstractComponent", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void (ai.vespa.llm.clients.LlmLocalClientConfig)", - "public void deconstruct()", - "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.OpenAI" : { - "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void (ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", - "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" - ], - "fields" : [ ] - }, "ai.vespa.llm.generation.Generator" : { "superClass" : "com.yahoo.component.AbstractComponent", "interfaces" : [ ], diff --git a/model-integration/pom.xml b/model-integration/pom.xml index d92fa319251..0bab30e1453 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -38,12 +38,6 @@ ${project.version} provided - - com.yahoo.vespa - container-disc - ${project.version} - provided - com.yahoo.vespa searchcore @@ -80,12 +74,6 @@ ${project.version} provided - - com.yahoo.vespa - container-llama - ${project.version} - provided - com.yahoo.vespa component diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java deleted file mode 100644 index 761fdf0af93..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.logging.Logger; - - -/** - * Base class for language models that can be configured with config definitions. - * - * @author lesters - */ -@Beta -public abstract class ConfigurableLanguageModel implements LanguageModel { - - private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); - - private final String apiKey; - private final String endpoint; - - public ConfigurableLanguageModel() { - this.apiKey = null; - this.endpoint = null; - } - - @Inject - public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { - this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); - this.endpoint = config.endpoint(); - } - - private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { - String apiKey = ""; - if (property != null && ! property.isEmpty()) { - try { - apiKey = secretStore.getSecret(property); - } catch (UnsupportedOperationException e) { - // Secret store is not available - silently ignore this - } catch (Exception e) { - log.warning("Secret store look up failed: " + e.getMessage() + "\n" + - "Will expect API key in request header"); - } - } - return apiKey; - } - - protected String getApiKey(InferenceParameters params) { - return params.getApiKey().orElse(null); - } - - /** - * Set the API key as retrieved from secret store if it is not already set - */ - protected void setApiKey(InferenceParameters params) { - if (params.getApiKey().isEmpty() && apiKey != null) { - params.setApiKey(apiKey); - } - } - - protected String getEndpoint() { - return endpoint; - } - - protected void setEndpoint(InferenceParameters params) { - if (endpoint != null && ! endpoint.isEmpty()) { - params.setEndpoint(endpoint); - } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java deleted file mode 100644 index fd1b8b700c8..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.component.AbstractComponent; -import com.yahoo.component.annotation.Inject; -import de.kherud.llama.LlamaModel; -import de.kherud.llama.ModelParameters; -import de.kherud.llama.args.LogFormat; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ArrayBlockingQueue; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.SynchronousQueue; -import java.util.concurrent.ThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; -import java.util.logging.Logger; - -/** - * A language model running locally on the container node. - * - * @author lesters - */ -public class LocalLLM extends AbstractComponent implements LanguageModel { - - private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); - private final LlamaModel model; - private final ThreadPoolExecutor executor; - private final int contextSize; - private final int maxTokens; - - @Inject - public LocalLLM(LlmLocalClientConfig config) { - executor = createExecutor(config); - - // Maximum number of tokens to generate - need this since some models can just generate infinitely - maxTokens = config.maxTokens(); - - // Only used if GPU is not used - var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2; - - var modelFile = config.model().toFile().getAbsolutePath(); - var modelParams = new ModelParameters() - .setModelFilePath(modelFile) - .setContinuousBatching(true) - .setNParallel(config.parallelRequests()) - .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads()) - .setNCtx(config.contextSize()) - .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0); - - long startLoad = System.nanoTime(); - model = new LlamaModel(modelParams); - long loadTime = System.nanoTime() - startLoad; - logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000))); - - // Todo: handle prompt context size - such as give a warning when prompt exceeds context size - contextSize = config.contextSize(); - } - - private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) { - return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(), - 0L, TimeUnit.MILLISECONDS, - config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(), - new ThreadPoolExecutor.AbortPolicy()); - } - - @Override - public void deconstruct() { - logger.info("Closing LLM model..."); - model.close(); - executor.shutdownNow(); - } - - @Override - public List complete(Prompt prompt, InferenceParameters options) { - StringBuilder result = new StringBuilder(); - var future = completeAsync(prompt, options, completion -> { - result.append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); - var reason = future.join(); - - List completions = new ArrayList<>(); - completions.add(new Completion(result.toString(), reason)); - return completions; - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, InferenceParameters options, Consumer consumer) { - var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading()); - - // We always set this to some value to avoid infinite token generation - inferParams.setNPredict(maxTokens); - - options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v))); - options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v))); - options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v))); - options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v))); - options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); - // Todo: more options? - - var completionFuture = new CompletableFuture(); - try { - executor.submit(() -> { - for (LlamaModel.Output output : model.generate(inferParams)) { - consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); - } - completionFuture.complete(Completion.FinishReason.stop); - }); - } catch (RejectedExecutionException e) { - // If we have too many requests (active + any waiting in queue), we reject the completion - int activeCount = executor.getActiveCount(); - int queueSize = executor.getQueue().size(); - String error = String.format("Rejected completion due to too many requests, " + - "%d active, %d in queue", activeCount, queueSize); - throw new RejectedExecutionException(error); - } - return completionFuture; - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java deleted file mode 100644 index 82e19d47c92..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.client.openai.OpenAiClient; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; -import com.yahoo.component.annotation.Inject; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -/** - * A configurable OpenAI client. - * - * @author lesters - */ -@Beta -public class OpenAI extends ConfigurableLanguageModel { - - private final OpenAiClient client; - - @Inject - public OpenAI(LlmClientConfig config, SecretStore secretStore) { - super(config, secretStore); - client = new OpenAiClient(); - } - - @Override - public List complete(Prompt prompt, InferenceParameters parameters) { - setApiKey(parameters); - setEndpoint(parameters); - return client.complete(prompt, parameters); - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, - InferenceParameters parameters, - Consumer consumer) { - setApiKey(parameters); - setEndpoint(parameters); - return client.completeAsync(prompt, parameters, consumer); - } -} - diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java deleted file mode 100644 index c360245901c..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.clients; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/model-integration/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def deleted file mode 100755 index 0866459166a..00000000000 --- a/model-integration/src/main/resources/configdefinitions/llm-client.def +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package=ai.vespa.llm.clients - -# The name of the secret containing the api key -apiKeySecretName string default="" - -# Endpoint for LLM client - if not set reverts to default for client -endpoint string default="" diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def deleted file mode 100755 index c06c24b33e5..00000000000 --- a/model-integration/src/main/resources/configdefinitions/llm-local-client.def +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package=ai.vespa.llm.clients - -# The LLM model to use -model model - -# Maximum number of requests to handle in parallel pr container node -parallelRequests int default=10 - -# Additional number of requests to put in queue for processing before starting to reject new requests -maxQueueSize int default=10 - -# Use GPU -useGpu bool default=false - -# Maximum number of model layers to run on GPU -gpuLayers int default=1000000 - -# Number of threads to use for CPU processing - -1 means use all available cores -# Not used for GPU processing -threads int default=-1 - -# Context size for the model -# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context -contextSize int default=512 - -# Maximum number of tokens to process in one request - overriden by inference parameters -maxTokens int default=512 - diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java deleted file mode 100644 index 35d5cfd3855..00000000000 --- a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java +++ /dev/null @@ -1,174 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.container.di.componentgraph.Provider; -import com.yahoo.container.jdisc.SecretStoreProvider; -import com.yahoo.container.jdisc.secretstore.SecretStore; -import org.junit.jupiter.api.Test; - -import java.util.Arrays; -import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.function.BiFunction; -import java.util.stream.Collectors; - -import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -public class ConfigurableLanguageModelTest { - - @Test - public void testSyncGeneration() { - var prompt = StringPrompt.from("Why are ducks better than cats?"); - var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); - assertEquals(1, result.size()); - assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); - } - - @Test - public void testAsyncGeneration() { - var executor = Executors.newFixedThreadPool(1); - var prompt = StringPrompt.from("Why are ducks better than cats?"); - var sb = new StringBuilder(); - try { - var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { - sb.append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); - - var reason = future.join(); - assertTrue(future.isDone()); - assertNotEquals(reason, Completion.FinishReason.error); - } finally { - executor.shutdownNow(); - } - - assertEquals("Ducks have adorable waddling walks.", sb.toString()); - } - - @Test - public void testInferenceParameters() { - var prompt = StringPrompt.from("Why are ducks better than cats?"); - var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); - var result = createLLM().complete(prompt, params); - assertEquals("Random text about ducks", result.get(0).text()); - } - - @Test - public void testNoApiKey() { - var prompt = StringPrompt.from(""); - var config = modelParams("api-key", null); - var secrets = createSecretStore(Map.of()); - assertThrows(IllegalArgumentException.class, () -> { - createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); - }); - } - - @Test - public void testApiKeyFromSecretStore() { - var prompt = StringPrompt.from(""); - var config = modelParams("api-key-in-secret-store", null); - var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); - assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); - } - - private static String lookupParameter(String parameter, Map params) { - return params.get(parameter); - } - - private static InferenceParameters inferenceParams() { - return new InferenceParameters(s -> lookupParameter(s, Map.of())); - } - - private static InferenceParameters inferenceParams(Map params) { - return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); - } - - private static InferenceParameters inferenceParamsWithDefaultKey() { - return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Map.of())); - } - - private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { - var config = new LlmClientConfig.Builder(); - if (apiKeySecretName != null) { - config.apiKeySecretName(apiKeySecretName); - } - if (endpoint != null) { - config.endpoint(endpoint); - } - return config.build(); - } - - public static SecretStore createSecretStore(Map secrets) { - Provider secretStore = new Provider<>() { - public SecretStore get() { - return new SecretStore() { - public String getSecret(String key) { - return secrets.get(key); - } - public String getSecret(String key, int version) { - return secrets.get(key); - } - }; - } - public void deconstruct() { - } - }; - return secretStore.get(); - } - - public static BiFunction createGenerator() { - return (prompt, options) -> { - String answer = "I have no opinion on the matter"; - if (prompt.asString().contains("ducks")) { - answer = "Ducks have adorable waddling walks."; - var temperature = options.getDouble("temperature"); - if (temperature.isPresent() && temperature.get() > 0.5) { - answer = "Random text about ducks vs cats that makes no sense whatsoever."; - } - } - var maxTokens = options.getInt("maxTokens"); - if (maxTokens.isPresent()) { - return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); - } - return answer; - }; - } - - private static MockLLMClient createLLM() { - LlmClientConfig config = new LlmClientConfig.Builder().build(); - return createLLM(config, null); - } - - private static MockLLMClient createLLM(ExecutorService executor) { - LlmClientConfig config = new LlmClientConfig.Builder().build(); - return createLLM(config, executor); - } - - private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { - var generator = createGenerator(); - var secretStore = new SecretStoreProvider(); // throws exception on use - return createLLM(config, generator, secretStore.get(), executor); - } - - private static MockLLMClient createLLM(LlmClientConfig config, - BiFunction generator, - SecretStore secretStore) { - return createLLM(config, generator, secretStore, null); - } - - private static MockLLMClient createLLM(LlmClientConfig config, - BiFunction generator, - SecretStore secretStore, - ExecutorService executor) { - return new MockLLMClient(config, secretStore, generator, executor); - } - -} diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java deleted file mode 100644 index e85e397b7ff..00000000000 --- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java +++ /dev/null @@ -1,181 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.config.ModelReference; -import org.junit.jupiter.api.Test; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.RejectedExecutionException; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; - -/** - * Tests for LocalLLM. - * - * @author lesters - */ -public class LocalLLMTest { - - private static String model = "src/test/models/llm/tinyllm.gguf"; - private static Prompt prompt = StringPrompt.from("A random prompt"); - - @Test - public void testGeneration() { - var config = new LlmLocalClientConfig.Builder() - .parallelRequests(1) - .model(ModelReference.valueOf(model)); - var llm = new LocalLLM(config.build()); - - try { - var result = llm.complete(prompt, defaultOptions()); - assertEquals(Completion.FinishReason.stop, result.get(0).finishReason()); - assertTrue(result.get(0).text().length() > 10); - } finally { - llm.deconstruct(); - } - } - - @Test - public void testAsyncGeneration() { - var sb = new StringBuilder(); - var tokenCount = new AtomicInteger(0); - var config = new LlmLocalClientConfig.Builder() - .parallelRequests(1) - .model(ModelReference.valueOf(model)); - var llm = new LocalLLM(config.build()); - - try { - var future = llm.completeAsync(prompt, defaultOptions(), completion -> { - sb.append(completion.text()); - tokenCount.incrementAndGet(); - }).exceptionally(exception -> Completion.FinishReason.error); - - assertFalse(future.isDone()); - var reason = future.join(); - assertTrue(future.isDone()); - assertNotEquals(reason, Completion.FinishReason.error); - - } finally { - llm.deconstruct(); - } - assertTrue(tokenCount.get() > 0); -// System.out.println(sb); - } - - @Test - public void testParallelGeneration() { - var prompts = testPrompts(); - var promptsToUse = prompts.size(); - var parallelRequests = 10; - - var futures = new ArrayList>(Collections.nCopies(promptsToUse, null)); - var completions = new ArrayList(Collections.nCopies(promptsToUse, null)); - var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0)); - - var config = new LlmLocalClientConfig.Builder() - .parallelRequests(parallelRequests) - .model(ModelReference.valueOf(model)); - var llm = new LocalLLM(config.build()); - - try { - for (int i = 0; i < promptsToUse; i++) { - final var seq = i; - - completions.set(seq, new StringBuilder()); - futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { - completions.get(seq).append(completion.text()); - tokenCounts.set(seq, tokenCounts.get(seq) + 1); - }).exceptionally(exception -> Completion.FinishReason.error)); - } - for (int i = 0; i < promptsToUse; i++) { - var reason = futures.get(i).join(); - assertNotEquals(reason, Completion.FinishReason.error); - } - } finally { - llm.deconstruct(); - } - for (int i = 0; i < promptsToUse; i++) { - assertFalse(completions.get(i).isEmpty()); - assertTrue(tokenCounts.get(i) > 0); - } - } - - @Test - public void testRejection() { - var prompts = testPrompts(); - var promptsToUse = prompts.size(); - var parallelRequests = 2; - var additionalQueue = 1; - // 7 should be rejected - - var futures = new ArrayList>(Collections.nCopies(promptsToUse, null)); - var completions = new ArrayList(Collections.nCopies(promptsToUse, null)); - - var config = new LlmLocalClientConfig.Builder() - .parallelRequests(parallelRequests) - .maxQueueSize(additionalQueue) - .model(ModelReference.valueOf(model)); - var llm = new LocalLLM(config.build()); - - var rejected = new AtomicInteger(0); - try { - for (int i = 0; i < promptsToUse; i++) { - final var seq = i; - - completions.set(seq, new StringBuilder()); - try { - var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { - completions.get(seq).append(completion.text()); - }).exceptionally(exception -> Completion.FinishReason.error); - futures.set(seq, future); - } catch (RejectedExecutionException e) { - rejected.incrementAndGet(); - } - } - for (int i = 0; i < promptsToUse; i++) { - if (futures.get(i) != null) { - assertNotEquals(futures.get(i).join(), Completion.FinishReason.error); - } - } - } finally { - llm.deconstruct(); - } - assertEquals(7, rejected.get()); - } - - private static InferenceParameters defaultOptions() { - final Map options = Map.of( - "temperature", "0.1", - "npredict", "100" - ); - return new InferenceParameters(options::get); - } - - private List testPrompts() { - List prompts = new ArrayList<>(); - prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries."); - prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms."); - prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe."); - prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure."); - prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective."); - prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'"); - prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`."); - prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease."); - prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level."); - prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century."); - return prompts; - } - -} diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java deleted file mode 100644 index 4d0073f1cbe..00000000000 --- a/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.container.jdisc.secretstore.SecretStore; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.function.BiFunction; -import java.util.function.Consumer; - -public class MockLLMClient extends ConfigurableLanguageModel { - - public final static String ACCEPTED_API_KEY = "sesame"; - - private final ExecutorService executor; - private final BiFunction generator; - - private Prompt lastPrompt; - - public MockLLMClient(LlmClientConfig config, - SecretStore secretStore, - BiFunction generator, - ExecutorService executor) { - super(config, secretStore); - this.generator = generator; - this.executor = executor; - } - - private void checkApiKey(InferenceParameters options) { - var apiKey = getApiKey(options); - if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { - throw new IllegalArgumentException("Invalid API key"); - } - } - - private void setPrompt(Prompt prompt) { - this.lastPrompt = prompt; - } - - public Prompt getPrompt() { - return this.lastPrompt; - } - - @Override - public List complete(Prompt prompt, InferenceParameters params) { - setApiKey(params); - checkApiKey(params); - setPrompt(prompt); - return List.of(Completion.from(this.generator.apply(prompt, params))); - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, - InferenceParameters params, - Consumer consumer) { - setPrompt(prompt); - var completionFuture = new CompletableFuture(); - var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization - - long sleep = 1; - executor.submit(() -> { - try { - for (int i=0; i < completions.length; ++i) { - String completion = (i > 0 ? " " : "") + completions[i]; - consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); - } - completionFuture.complete(Completion.FinishReason.stop); - } catch (InterruptedException e) { - // Do nothing - } - }); - - return completionFuture; - } - -} diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java deleted file mode 100644 index 57339f6ad49..00000000000 --- a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.clients; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.StringPrompt; -import com.yahoo.container.jdisc.SecretStoreProvider; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import java.util.Map; - -public class OpenAITest { - - private static final String apiKey = ""; - - @Test - @Disabled - public void testOpenAIGeneration() { - var config = new LlmClientConfig.Builder().build(); - var openai = new OpenAI(config, new SecretStoreProvider().get()); - var options = Map.of( - "maxTokens", "10" - ); - - var prompt = StringPrompt.from("why are ducks better than cats?"); - var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { - System.out.print(completion.text()); - }).exceptionally(exception -> { - System.out.println("Error: " + exception); - return null; - }); - future.join(); - } - -} diff --git a/model-integration/src/test/models/llm/tinyllm.gguf b/model-integration/src/test/models/llm/tinyllm.gguf deleted file mode 100644 index 34367b6b57b..00000000000 Binary files a/model-integration/src/test/models/llm/tinyllm.gguf and /dev/null differ diff --git a/standalone-container/pom.xml b/standalone-container/pom.xml index 92faa1ae670..844b912543c 100644 --- a/standalone-container/pom.xml +++ b/standalone-container/pom.xml @@ -113,7 +113,6 @@ model-evaluation-jar-with-dependencies.jar, model-integration-jar-with-dependencies.jar, container-onnxruntime.jar, - container-llama.jar, -- cgit v1.2.3