diff options
author | Lester Solbakken <lester.solbakken@gmail.com> | 2024-04-02 09:23:54 +0200 |
---|---|---|
committer | Lester Solbakken <lester.solbakken@gmail.com> | 2024-04-02 09:23:54 +0200 |
commit | 98ccab8d442bc3b13de47746bbd265a08b319add (patch) | |
tree | 6d0120815163943b1d6428585c1fc52efd9dc21d /container-search/src | |
parent | bc79caf0b639bc451c4630d49f6d9ac2a53dcc39 (diff) |
Move LLM classes in vespajlib from ai.vespa.llm to ai.vespa.languagemodels
Diffstat (limited to 'container-search/src')
9 files changed, 49 insertions, 53 deletions
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java index 662d73d4e01..f3bb29552d6 100644 --- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -1,8 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; +import ai.vespa.languagemodels.InferenceParameters; +import ai.vespa.languagemodels.LanguageModel; import ai.vespa.llm.LlmClientConfig; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; @@ -19,7 +19,7 @@ import java.util.logging.Logger; @Beta public abstract class ConfigurableLanguageModel implements LanguageModel { - private static Logger log = Logger.getLogger(ai.vespa.llm.clients.ConfigurableLanguageModel.class.getName()); + private static Logger log = Logger.getLogger(ConfigurableLanguageModel.class.getName()); private final String apiKey; private final String endpoint; @@ -68,7 +68,9 @@ public abstract class ConfigurableLanguageModel implements LanguageModel { } protected void setEndpoint(InferenceParameters params) { - params.setEndpoint(endpoint); + if (endpoint != null && ! endpoint.isEmpty()) { + params.setEndpoint(endpoint); + } } } diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java index f6092f51948..bc99aa51097 100644 --- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -1,11 +1,11 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; -import ai.vespa.llm.InferenceParameters; +import ai.vespa.languagemodels.InferenceParameters; import ai.vespa.llm.LlmClientConfig; -import ai.vespa.llm.client.openai.OpenAiClient; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; +import ai.vespa.languagemodels.client.openai.OpenAiClient; +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.completion.Prompt; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.container.jdisc.secretstore.SecretStore; diff --git a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java index 6ff40401a8f..040393083b8 100755 --- a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java @@ -1,12 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.search; -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.LanguageModelException; +import ai.vespa.languagemodels.InferenceParameters; +import ai.vespa.languagemodels.LanguageModel; +import ai.vespa.languagemodels.LanguageModelException; import ai.vespa.llm.LlmSearcherConfig; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.completion.Prompt; +import ai.vespa.languagemodels.completion.StringPrompt; import com.yahoo.api.annotations.Beta; import com.yahoo.component.ComponentId; import com.yahoo.component.annotation.Inject; @@ -17,6 +18,7 @@ import com.yahoo.search.Searcher; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.EventStream; import com.yahoo.search.result.HitGroup; +import com.yahoo.search.searchchain.Execution; import java.util.List; import java.util.function.Function; @@ -29,7 +31,7 @@ import java.util.stream.Collectors; * @author lesters */ @Beta -public abstract class LLMSearcher extends Searcher { +public class LLMSearcher extends Searcher { private static Logger log = Logger.getLogger(LLMSearcher.class.getName()); @@ -43,13 +45,18 @@ public abstract class LLMSearcher extends Searcher { private final String languageModelId; @Inject - LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + public LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { this.stream = config.stream(); this.languageModelId = config.providerId(); this.languageModel = findLanguageModel(languageModelId, languageModels); this.propertyPrefix = config.propertyPrefix(); } + @Override + public Result search(Query query, Execution execution) { + return complete(query, StringPrompt.from(getPrompt(query))); + } + private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels) throws IllegalArgumentException { @@ -74,7 +81,7 @@ public abstract class LLMSearcher extends Searcher { return languageModel; } - Result complete(Query query, Prompt prompt) { + protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); @@ -122,7 +129,7 @@ public abstract class LLMSearcher extends Searcher { return new Result(query, hitGroup); } - String getPrompt(Query query) { + public String getPrompt(Query query) { // Look for prompt with or without prefix String prompt = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString(p)); if (prompt != null) @@ -138,28 +145,28 @@ public abstract class LLMSearcher extends Searcher { "'" + propertyPrefix + "." + PROMPT_PROPERTY + "', '" + PROMPT_PROPERTY + "' or '@query'."); } - String getPropertyPrefix() { + public String getPropertyPrefix() { return this.propertyPrefix; } - String lookupProperty(String property, Query query) { + public String lookupProperty(String property, Query query) { String propertyWithPrefix = this.propertyPrefix + "." + property; return query.properties().getString(propertyWithPrefix, null); } - Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) { + public Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) { String propertyWithPrefix = this.propertyPrefix + "." + property; return query.properties().getBoolean(propertyWithPrefix, defaultValue); } - String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) { + public String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) { String value = lookup.apply(getPropertyPrefix() + "." + property); if (value != null) return value; return lookup.apply(property); } - String getApiKeyHeader(Query query) { + public String getApiKeyHeader(Query query) { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } diff --git a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java index b8e33778ced..ac3cdb04749 100755 --- a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java +++ b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java @@ -1,10 +1,10 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.search; -import ai.vespa.llm.LanguageModel; +import ai.vespa.languagemodels.LanguageModel; import ai.vespa.llm.LlmSearcherConfig; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; +import ai.vespa.languagemodels.completion.Prompt; +import ai.vespa.languagemodels.completion.StringPrompt; import com.yahoo.api.annotations.Beta; import com.yahoo.component.annotation.Inject; import com.yahoo.component.provider.ComponentRegistry; diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java index 1f2a12322a1..ed786a0f372 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java +++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java @@ -1,11 +1,11 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; -import ai.vespa.llm.InferenceParameters; +import ai.vespa.languagemodels.InferenceParameters; import ai.vespa.llm.LlmClientConfig; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.completion.Prompt; +import ai.vespa.languagemodels.completion.StringPrompt; import com.yahoo.container.di.componentgraph.Provider; import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.container.jdisc.secretstore.SecretStore; diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java index cfb6a43984f..45a36bd1b6f 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java +++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java @@ -1,10 +1,10 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; -import ai.vespa.llm.InferenceParameters; +import ai.vespa.languagemodels.InferenceParameters; import ai.vespa.llm.LlmClientConfig; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.completion.Prompt; import com.yahoo.container.jdisc.secretstore.SecretStore; import java.util.List; diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java index 1111a9824f5..363833cd0c1 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java +++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -1,9 +1,9 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.clients; -import ai.vespa.llm.InferenceParameters; +import ai.vespa.languagemodels.InferenceParameters; import ai.vespa.llm.LlmClientConfig; -import ai.vespa.llm.completion.StringPrompt; +import ai.vespa.languagemodels.completion.StringPrompt; import com.yahoo.container.jdisc.SecretStoreProvider; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java index d4f1dbc00a4..0b4d334e4be 100755 --- a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java @@ -1,14 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.search; -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.LanguageModel; +import ai.vespa.languagemodels.InferenceParameters; +import ai.vespa.languagemodels.LanguageModel; import ai.vespa.llm.LlmClientConfig; import ai.vespa.llm.LlmSearcherConfig; import ai.vespa.llm.clients.ConfigurableLanguageModelTest; import ai.vespa.llm.clients.MockLLMClient; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; +import ai.vespa.languagemodels.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; @@ -229,26 +228,14 @@ public class LLMSearcherTest { ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); models.freeze(); - return new LLMSearcherImpl(config, models); + return new LLMSearcher(config, models); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); models.freeze(); - return new LLMSearcherImpl(config, models); - } - - public static class LLMSearcherImpl extends LLMSearcher { - - public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { - super(config, languageModels); - } - - @Override - public Result search(Query query, Execution execution) { - return complete(query, StringPrompt.from(getPrompt(query))); - } + return new LLMSearcher(config, models); } } diff --git a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java index ccf9a4a6401..41d999794cb 100755 --- a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java @@ -1,7 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.search; -import ai.vespa.llm.LanguageModel; +import ai.vespa.languagemodels.LanguageModel; import ai.vespa.llm.LlmSearcherConfig; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; |