aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--container-search/abi-spec.json55
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java74
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java49
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java7
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java166
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java75
-rw-r--r--container-search/src/main/java/ai/vespa/llm/search/package-info.java7
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def8
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-searcher.def11
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java176
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java81
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java36
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java254
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java127
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java2
-rw-r--r--metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java4
-rw-r--r--searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp150
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/flow.h87
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp33
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h4
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp3
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp21
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h3
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp3
-rw-r--r--streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp9
-rw-r--r--streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h1
-rw-r--r--vespajlib/abi-spec.json48
-rwxr-xr-xvespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java76
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java6
-rwxr-xr-xvespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java19
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java89
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java11
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java7
-rw-r--r--vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java34
-rw-r--r--vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java4
36 files changed, 1560 insertions, 181 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index bdb6cd9e7a5..257dd364000 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -9148,5 +9148,60 @@
"public int getTo()"
],
"fields" : [ ]
+ },
+ "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected java.lang.String getEndpoint()",
+ "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.OpenAI" : {
+ "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.search.LLMSearcher" : {
+ "superClass" : "com.yahoo.search.Searcher",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.search.RAGSearcher" : {
+ "superClass" : "ai.vespa.llm.search.LLMSearcher",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)",
+ "public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)",
+ "protected ai.vespa.llm.completion.Prompt buildPrompt(com.yahoo.search.Query, com.yahoo.search.Result)"
+ ],
+ "fields" : [ ]
}
} \ No newline at end of file
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
new file mode 100644
index 00000000000..662d73d4e01
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
@@ -0,0 +1,74 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmClientConfig;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.logging.Logger;
+
+
+/**
+ * Base class for language models that can be configured with config definitions.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class ConfigurableLanguageModel implements LanguageModel {
+
+ private static Logger log = Logger.getLogger(ai.vespa.llm.clients.ConfigurableLanguageModel.class.getName());
+
+ private final String apiKey;
+ private final String endpoint;
+
+ public ConfigurableLanguageModel() {
+ this.apiKey = null;
+ this.endpoint = null;
+ }
+
+ @Inject
+ public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) {
+ this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore);
+ this.endpoint = config.endpoint();
+ }
+
+ private static String findApiKeyInSecretStore(String property, SecretStore secretStore) {
+ String apiKey = "";
+ if (property != null && ! property.isEmpty()) {
+ try {
+ apiKey = secretStore.getSecret(property);
+ } catch (UnsupportedOperationException e) {
+ // Secret store is not available - silently ignore this
+ } catch (Exception e) {
+ log.warning("Secret store look up failed: " + e.getMessage() + "\n" +
+ "Will expect API key in request header");
+ }
+ }
+ return apiKey;
+ }
+
+ protected String getApiKey(InferenceParameters params) {
+ return params.getApiKey().orElse(null);
+ }
+
+ /**
+ * Set the API key as retrieved from secret store if it is not already set
+ */
+ protected void setApiKey(InferenceParameters params) {
+ if (params.getApiKey().isEmpty() && apiKey != null) {
+ params.setApiKey(apiKey);
+ }
+ }
+
+ protected String getEndpoint() {
+ return endpoint;
+ }
+
+ protected void setEndpoint(InferenceParameters params) {
+ params.setEndpoint(endpoint);
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
new file mode 100644
index 00000000000..f6092f51948
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
@@ -0,0 +1,49 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.client.openai.OpenAiClient;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+/**
+ * A configurable OpenAI client.
+ *
+ * @author lesters
+ */
+@Beta
+public class OpenAI extends ConfigurableLanguageModel {
+
+ private final OpenAiClient client;
+
+ @Inject
+ public OpenAI(LlmClientConfig config, SecretStore secretStore) {
+ super(config, secretStore);
+ client = new OpenAiClient();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters parameters) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.complete(prompt, parameters);
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters parameters,
+ Consumer<Completion> consumer) {
+ setApiKey(parameters);
+ setEndpoint(parameters);
+ return client.completeAsync(prompt, parameters, consumer);
+ }
+}
+
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
new file mode 100644
index 00000000000..c360245901c
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.clients;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java
new file mode 100755
index 00000000000..6ff40401a8f
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java
@@ -0,0 +1,166 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LanguageModelException;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.result.HitGroup;
+
+import java.util.List;
+import java.util.function.Function;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+/**
+ * Base class for LLM searchers. Provides utilities for calling LLMs and handling properties.
+ *
+ * @author lesters
+ */
+@Beta
+public abstract class LLMSearcher extends Searcher {
+
+ private static Logger log = Logger.getLogger(LLMSearcher.class.getName());
+
+ private static final String API_KEY_HEADER = "X-LLM-API-KEY";
+ private static final String STREAM_PROPERTY = "stream";
+ private static final String PROMPT_PROPERTY = "prompt";
+
+ private final String propertyPrefix;
+ private final boolean stream;
+ private final LanguageModel languageModel;
+ private final String languageModelId;
+
+ @Inject
+ LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ this.stream = config.stream();
+ this.languageModelId = config.providerId();
+ this.languageModel = findLanguageModel(languageModelId, languageModels);
+ this.propertyPrefix = config.propertyPrefix();
+ }
+
+ private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
+ throws IllegalArgumentException
+ {
+ if (languageModels.allComponents().isEmpty()) {
+ throw new IllegalArgumentException("No language models were found");
+ }
+ if (providerId == null || providerId.isEmpty()) {
+ var entry = languageModels.allComponentsById().entrySet().stream().findFirst();
+ if (entry.isEmpty()) {
+ throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above
+ }
+ log.info("Language model provider was not found in config. " +
+ "Fallback to using first available language model: " + entry.get().getKey());
+ return entry.get().getValue();
+ }
+ final LanguageModel languageModel = languageModels.getComponent(providerId);
+ if (languageModel == null) {
+ throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " +
+ "Available LLM components are: " + languageModels.allComponentsById().keySet().stream()
+ .map(ComponentId::toString).collect(Collectors.joining(",")));
+ }
+ return languageModel;
+ }
+
+ Result complete(Query query, Prompt prompt) {
+ var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
+ var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ }
+
+ private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
+ EventStream eventStream = new EventStream();
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
+
+ languageModel.completeAsync(prompt, options, token -> {
+ eventStream.add(token.text());
+ }).exceptionally(exception -> {
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ eventStream.markComplete();
+ return Completion.FinishReason.error;
+ }).thenAccept(finishReason -> {
+ eventStream.markComplete();
+ });
+
+ HitGroup hitGroup = new HitGroup("token_stream");
+ hitGroup.add(eventStream);
+ return new Result(query, hitGroup);
+ }
+
+ private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
+ EventStream eventStream = new EventStream();
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
+
+ List<Completion> completions = languageModel.complete(prompt, options);
+ eventStream.add(completions.get(0).text(), "completion");
+ eventStream.markComplete();
+
+ HitGroup hitGroup = new HitGroup("completion");
+ hitGroup.add(eventStream);
+ return new Result(query, hitGroup);
+ }
+
+ String getPrompt(Query query) {
+ // Look for prompt with or without prefix
+ String prompt = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString(p));
+ if (prompt != null)
+ return prompt;
+
+ // If not found, use query directly
+ prompt = query.getModel().getQueryString();
+ if (prompt != null)
+ return prompt;
+
+ // If not, throw exception
+ throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for " +
+ "'" + propertyPrefix + "." + PROMPT_PROPERTY + "', '" + PROMPT_PROPERTY + "' or '@query'.");
+ }
+
+ String getPropertyPrefix() {
+ return this.propertyPrefix;
+ }
+
+ String lookupProperty(String property, Query query) {
+ String propertyWithPrefix = this.propertyPrefix + "." + property;
+ return query.properties().getString(propertyWithPrefix, null);
+ }
+
+ Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) {
+ String propertyWithPrefix = this.propertyPrefix + "." + property;
+ return query.properties().getBoolean(propertyWithPrefix, defaultValue);
+ }
+
+ String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) {
+ String value = lookup.apply(getPropertyPrefix() + "." + property);
+ if (value != null)
+ return value;
+ return lookup.apply(property);
+ }
+
+ String getApiKeyHeader(Query query) {
+ return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java
new file mode 100755
index 00000000000..b8e33778ced
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java
@@ -0,0 +1,75 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.searchchain.Execution;
+
+import java.util.logging.Logger;
+
+/**
+ * An LLM searcher that uses the RAG (Retrieval-Augmented Generation) model to generate completions.
+ * Prompts are generated based on the search result context.
+ * By default, the context is a concatenation of the fields of the search result hits.
+ *
+ * @author lesters
+ */
+@Beta
+public class RAGSearcher extends LLMSearcher {
+
+ private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
+
+ private static final String CONTEXT_PROPERTY = "context";
+
+ @Inject
+ public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ super(config, languageModels);
+ log.info("Starting " + RAGSearcher.class.getName() + " with language model " + config.providerId());
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Result result = execution.search(query);
+ execution.fill(result);
+ return complete(query, buildPrompt(query, result));
+ }
+
+ protected Prompt buildPrompt(Query query, Result result) {
+ String prompt = getPrompt(query);
+
+ // Replace @query with the actual query
+ if (prompt.contains("@query")) {
+ prompt = prompt.replace("@query", query.getModel().getQueryString());
+ }
+
+ String context = lookupProperty(CONTEXT_PROPERTY, query);
+ if (context == null || !context.equals("skip")) {
+ if ( !prompt.contains("{context}")) {
+ prompt = "{context}\n" + prompt;
+ }
+ prompt = prompt.replace("{context}", buildContext(result));
+ }
+ return StringPrompt.from(prompt);
+ }
+
+ private String buildContext(Result result) {
+ StringBuilder sb = new StringBuilder();
+ var hits = result.hits();
+ hits.forEach(hit -> {
+ hit.fields().forEach((key, value) -> {
+ sb.append(key).append(": ").append(value).append("\n");
+ });
+ sb.append("\n");
+ });
+ var context = sb.toString();
+ return context;
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/search/package-info.java b/container-search/src/main/java/ai/vespa/llm/search/package-info.java
new file mode 100644
index 00000000000..6a8975fd2fa
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/llm/search/package-info.java
@@ -0,0 +1,7 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.search;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
new file mode 100755
index 00000000000..6bfd95c3cf2
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-client.def
@@ -0,0 +1,8 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm
+
+# The name of the secret containing the api key
+apiKeySecretName string default=""
+
+# Endpoint for LLM client - if not set reverts to default for client
+endpoint string default=""
diff --git a/container-search/src/main/resources/configdefinitions/llm-searcher.def b/container-search/src/main/resources/configdefinitions/llm-searcher.def
new file mode 100755
index 00000000000..918a6e6e8b1
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/llm-searcher.def
@@ -0,0 +1,11 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm
+
+# Query propertry prefix for options
+propertyPrefix string default="llm"
+
+# Should the searcher stream tokens or wait for the entire thing?
+stream bool default=true
+
+# The external LLM provider - the id of a LanguageModel component
+providerId string default=""
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
new file mode 100644
index 00000000000..1f2a12322a1
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
@@ -0,0 +1,176 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class ConfigurableLanguageModelTest {
+
+ @Test
+ public void testSyncGeneration() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey());
+ assertEquals(1, result.size());
+ assertEquals("Ducks have adorable waddling walks.", result.get(0).text());
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var sb = new StringBuilder();
+ try {
+ var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> {
+ sb.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+ } finally {
+ executor.shutdownNow();
+ }
+
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ @Test
+ public void testInferenceParameters() {
+ var prompt = StringPrompt.from("Why are ducks better than cats?");
+ var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4"));
+ var result = createLLM().complete(prompt, params);
+ assertEquals("Random text about ducks", result.get(0).text());
+ }
+
+ @Test
+ public void testNoApiKey() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key", null);
+ var secrets = createSecretStore(Map.of());
+ assertThrows(IllegalArgumentException.class, () -> {
+ createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams());
+ });
+ }
+
+ @Test
+ public void testApiKeyFromSecretStore() {
+ var prompt = StringPrompt.from("");
+ var config = modelParams("api-key-in-secret-store", null);
+ var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY));
+ assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); });
+ }
+
+ private static String lookupParameter(String parameter, Map<String, String> params) {
+ return params.get(parameter);
+ }
+
+ private static InferenceParameters inferenceParams() {
+ return new InferenceParameters(s -> lookupParameter(s, Collections.emptyMap()));
+ }
+
+ private static InferenceParameters inferenceParams(Map<String, String> params) {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params));
+ }
+
+ private static InferenceParameters inferenceParamsWithDefaultKey() {
+ return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Collections.emptyMap()));
+ }
+
+ private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) {
+ var config = new LlmClientConfig.Builder();
+ if (apiKeySecretName != null) {
+ config.apiKeySecretName(apiKeySecretName);
+ }
+ if (endpoint != null) {
+ config.endpoint(endpoint);
+ }
+ return config.build();
+ }
+
+ public static SecretStore createSecretStore(Map<String, String> secrets) {
+ Provider<SecretStore> secretStore = new Provider<>() {
+ public SecretStore get() {
+ return new SecretStore() {
+ public String getSecret(String key) {
+ return secrets.get(key);
+ }
+ public String getSecret(String key, int version) {
+ return secrets.get(key);
+ }
+ };
+ }
+ public void deconstruct() {
+ }
+ };
+ return secretStore.get();
+ }
+
+ public static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
+ }
+
+ private static MockLLMClient createLLM() {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, null);
+ }
+
+ private static MockLLMClient createLLM(ExecutorService executor) {
+ LlmClientConfig config = new LlmClientConfig.Builder().build();
+ return createLLM(config, executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) {
+ var generator = createGenerator();
+ var secretStore = new SecretStoreProvider(); // throws exception on use
+ return createLLM(config, generator, secretStore.get(), executor);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore) {
+ return createLLM(config, generator, secretStore, null);
+ }
+
+ private static MockLLMClient createLLM(LlmClientConfig config,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ SecretStore secretStore,
+ ExecutorService executor) {
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
new file mode 100644
index 00000000000..cfb6a43984f
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
@@ -0,0 +1,81 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.container.jdisc.secretstore.SecretStore;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
+
+public class MockLLMClient extends ConfigurableLanguageModel {
+
+ public final static String ACCEPTED_API_KEY = "sesame";
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ private Prompt lastPrompt;
+
+ public MockLLMClient(LlmClientConfig config,
+ SecretStore secretStore,
+ BiFunction<Prompt, InferenceParameters, String> generator,
+ ExecutorService executor) {
+ super(config, secretStore);
+ this.generator = generator;
+ this.executor = executor;
+ }
+
+ private void checkApiKey(InferenceParameters options) {
+ var apiKey = getApiKey(options);
+ if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) {
+ throw new IllegalArgumentException("Invalid API key");
+ }
+ }
+
+ private void setPrompt(Prompt prompt) {
+ this.lastPrompt = prompt;
+ }
+
+ public Prompt getPrompt() {
+ return this.lastPrompt;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ setApiKey(params);
+ checkApiKey(params);
+ setPrompt(prompt);
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ setPrompt(prompt);
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i=0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+
+ return completionFuture;
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
new file mode 100644
index 00000000000..1111a9824f5
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
@@ -0,0 +1,36 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+public class OpenAITest {
+
+ private static final String apiKey = "<your-api-key>";
+
+ @Test
+ @Disabled
+ public void testOpenAIGeneration() {
+ var config = new LlmClientConfig.Builder().build();
+ var openai = new OpenAI(config, new SecretStoreProvider().get());
+ var options = Map.of(
+ "maxTokens", "10"
+ );
+
+ var prompt = StringPrompt.from("why are ducks better than cats?");
+ var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
+ System.out.print(completion.text());
+ }).exceptionally(exception -> {
+ System.out.println("Error: " + exception);
+ return null;
+ });
+ future.join();
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java
new file mode 100755
index 00000000000..d4f1dbc00a4
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java
@@ -0,0 +1,254 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmClientConfig;
+import ai.vespa.llm.LlmSearcherConfig;
+import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
+import ai.vespa.llm.clients.MockLLMClient;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.chain.Chain;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.container.jdisc.SecretStoreProvider;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.searchchain.Execution;
+import org.junit.jupiter.api.Test;
+
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+
+public class LLMSearcherTest {
+
+ @Test
+ public void testLLMSelection() {
+ var llm1 = createLLMClient("mock1");
+ var llm2 = createLLMClient("mock2");
+ var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
+ var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
+ var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
+ assertEquals(1, result.getHitCount());
+ assertEquals("My id is mock2", getCompletion(result));
+ }
+
+ @Test
+ public void testGeneration() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of("prompt", "why are ducks better than cats");
+ assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testPrompting() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+
+ // Prompt with prefix
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("llm.prompt", "why are ducks better than cats"))));
+
+ // Prompt without prefix
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("prompt", "why are ducks better than cats"))));
+
+ // Fallback to query if not given
+ assertEquals("Ducks have adorable waddling walks.",
+ getCompletion(runMockSearch(searcher, Map.of("query", "why are ducks better than cats"))));
+ }
+
+ @Test
+ public void testPromptEvent() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of(
+ "prompt", "why are ducks better than cats",
+ "traceLevel", "1");
+ var result = runMockSearch(searcher, params);
+ var events = ((EventStream) result.hits().get(0)).incoming().drain();
+ assertEquals(2, events.size());
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("why are ducks better than cats", promptEvent.toString());
+
+ var completionEvent = (EventStream.Event) events.get(1);
+ assertEquals("completion", completionEvent.type());
+ assertEquals("Ducks have adorable waddling walks.", completionEvent.toString());
+ }
+
+ @Test
+ public void testParameters() {
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var params = Map.of(
+ "llm.prompt", "why are ducks better than cats",
+ "llm.temperature", "1.0",
+ "llm.maxTokens", "5"
+ );
+ assertEquals("Random text about ducks vs", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testParameterPrefix() {
+ var prefix = "foo";
+ var params = Map.of(
+ "foo.prompt", "what is your opinion on cats",
+ "foo.maxTokens", "5"
+ );
+ var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
+ assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
+ }
+
+ @Test
+ public void testApiKeyFromHeader() {
+ var properties = Map.of("prompt", "why are ducks better than cats");
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var sb = new StringBuilder();
+ try {
+ var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock").build(); // config says don't stream...
+ var params = Map.of(
+ "llm.stream", "true", // ... but inference parameters says do it anyway
+ "llm.prompt", "why are ducks better than cats?"
+ );
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
+ Result result = runMockSearch(searcher, params);
+
+ assertEquals(1, result.getHitCount());
+ assertTrue(result.hits().get(0) instanceof EventStream);
+ EventStream eventStream = (EventStream) result.hits().get(0);
+
+ var incoming = eventStream.incoming();
+ incoming.addNewDataListener(() -> {
+ incoming.drain().forEach(event -> sb.append(event.toString()));
+ }, executor);
+
+ incoming.completedFuture().join();
+ assertTrue(incoming.isComplete());
+
+ // Ensure incoming has been fully drained to avoid race condition in this test
+ incoming.drain().forEach(event -> sb.append(event.toString()));
+
+ } finally {
+ executor.shutdownNow();
+ }
+ assertEquals("Ducks have adorable waddling walks.", sb.toString());
+ }
+
+ private static String getCompletion(Result result) {
+ assertTrue(result.hits().size() >= 1);
+ return ((EventStream) result.hits().get(0)).incoming().drain().get(0).toString();
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
+ return runMockSearch(searcher, parameters, null, "");
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
+ Chain<Searcher> chain = new Chain<>(searcher);
+ Execution execution = new Execution(chain, Execution.Context.createContextStub());
+ Query query = new Query("?" + toUrlParams(parameters));
+ if (apiKey != null) {
+ String headerKey = "X-LLM-API-KEY";
+ if (prefix != null && ! prefix.isEmpty()) {
+ headerKey = prefix + "." + headerKey;
+ }
+ query.getHttpRequest().getJDiscRequest().headers().add(headerKey, apiKey);
+ }
+ return execution.search(query);
+ }
+
+ public static String toUrlParams(Map<String, String> parameters) {
+ return parameters.entrySet().stream().map(
+ e -> e.getKey() + "=" + URLEncoder.encode(e.getValue(), StandardCharsets.UTF_8)
+ ).collect(Collectors.joining("&"));
+ }
+
+ private static BiFunction<Prompt, InferenceParameters, String> createIdGenerator(String id) {
+ return (prompt, options) -> {
+ if (id == null || id.isEmpty())
+ return "I have no ID";
+ return "My id is " + id;
+ };
+ }
+
+ private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
+ return ConfigurableLanguageModelTest.createGenerator();
+ }
+
+ static MockLLMClient createLLMClient() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, null);
+ }
+
+ static MockLLMClient createLLMClient(String id) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createIdGenerator(id);
+ return new MockLLMClient(config, secretStore, generator, null);
+ }
+
+ static MockLLMClient createLLMClient(ExecutorService executor) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, executor);
+ }
+
+ static MockLLMClient createLLMClientWithoutSecretStore() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = new SecretStoreProvider();
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore.get(), generator, null);
+ }
+
+ private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
+ var config = new LlmSearcherConfig.Builder().stream(false).build();
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcherImpl(config, models);
+ }
+
+ private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcherImpl(config, models);
+ }
+
+ public static class LLMSearcherImpl extends LLMSearcher {
+
+ public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
+ super(config, languageModels);
+ }
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ return complete(query, StringPrompt.from(getPrompt(query)));
+ }
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java
new file mode 100755
index 00000000000..ccf9a4a6401
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java
@@ -0,0 +1,127 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.search;
+
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.LlmSearcherConfig;
+import com.yahoo.component.ComponentId;
+import com.yahoo.component.chain.Chain;
+import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.Searcher;
+import com.yahoo.search.result.EventStream;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.searchchain.Execution;
+import org.junit.jupiter.api.Test;
+
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+
+public class RAGSearcherTest {
+
+ private static final String DOC1_TITLE = "Exploring the Delightful Qualities of Ducks";
+ private static final String DOC1_CONTENT = "Ducks, with their gentle quacks and adorable waddling walks, possess a unique " +
+ "charm that sets them apart as extraordinary pets.";
+ private static final String DOC2_TITLE = "Why Cats Reign Supreme";
+ private static final String DOC2_CONTENT = "Cats bring an enchanting allure to households with their independent " +
+ "companionship, playful nature, natural hunting abilities, low-maintenance grooming, and the " +
+ "emotional support they offer.";
+
+ @Test
+ public void testRAGGeneration() {
+ var eventStream = runRAGQuery(Map.of(
+ "prompt", "why are ducks better than cats?",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+ assertEquals(2, events.size());
+
+ // Generated prompt
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("title: " + DOC1_TITLE + "\n" +
+ "content: " + DOC1_CONTENT + "\n\n" +
+ "title: " + DOC2_TITLE + "\n" +
+ "content: " + DOC2_CONTENT + "\n\n\n" +
+ "why are ducks better than cats?", promptEvent.toString());
+
+ // Generated completion
+ var completionEvent = (EventStream.Event) events.get(1);
+ assertEquals("completion", completionEvent.type());
+ assertEquals("Ducks have adorable waddling walks.", completionEvent.toString());
+ }
+
+ @Test
+ public void testPromptGeneration() {
+ var eventStream = runRAGQuery(Map.of(
+ "query", "why are ducks better than cats?",
+ "prompt", "{context}\nGiven these documents, answer this query as concisely as possible: @query",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("title: " + DOC1_TITLE + "\n" +
+ "content: " + DOC1_CONTENT + "\n\n" +
+ "title: " + DOC2_TITLE + "\n" +
+ "content: " + DOC2_CONTENT + "\n\n\n" +
+ "Given these documents, answer this query as concisely as possible: " +
+ "why are ducks better than cats?", promptEvent.toString());
+ }
+
+ @Test
+ public void testSkipContextInPrompt() {
+ var eventStream = runRAGQuery(Map.of(
+ "query", "why are ducks better than cats?",
+ "llm.context", "skip",
+ "traceLevel", "1"));
+ var events = eventStream.incoming().drain();
+
+ var promptEvent = (EventStream.Event) events.get(0);
+ assertEquals("prompt", promptEvent.type());
+ assertEquals("why are ducks better than cats?", promptEvent.toString());
+ }
+
+ public static class MockSearchResults extends Searcher {
+
+ @Override
+ public Result search(Query query, Execution execution) {
+ Hit hit1 = new Hit("1");
+ hit1.setField("title", DOC1_TITLE);
+ hit1.setField("content", DOC1_CONTENT);
+
+ Hit hit2 = new Hit("2");
+ hit2.setField("title", DOC2_TITLE);
+ hit2.setField("content", DOC2_CONTENT);
+
+ Result r = new Result(query);
+ r.hits().add(hit1);
+ r.hits().add(hit2);
+ return r;
+ }
+ }
+
+ private EventStream runRAGQuery(Map<String, String> params) {
+ var llm = LLMSearcherTest.createLLMClient();
+ var searcher = createRAGSearcher(Map.of("mock", llm));
+ var result = runMockSearch(searcher, params);
+ return (EventStream) result.hits().get(0);
+ }
+
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters) {
+ Chain<Searcher> chain = new Chain<>(searcher, new MockSearchResults());
+ Execution execution = new Execution(chain, Execution.Context.createContextStub());
+ Query query = new Query("?" + LLMSearcherTest.toUrlParams(parameters));
+ return execution.search(query);
+ }
+
+ private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
+ var config = new LlmSearcherConfig.Builder().stream(false).build();
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new RAGSearcher(config, models);
+ }
+
+}
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index 37fe29c2eb6..0b30901bb89 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -406,7 +406,7 @@ public class Flags {
"Takes effect immediately");
public static UnboundBooleanFlag CALYPSO_ENABLED = defineFeatureFlag(
- "calypso-enabled", false,
+ "calypso-enabled", true,
List.of("mortent"), "2024-02-19", "2024-05-01",
"Whether to enable calypso for host",
"Takes effect immediately", HOSTNAME);
diff --git a/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java b/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java
index 38a1b252df9..9479c814e89 100644
--- a/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java
+++ b/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java
@@ -81,6 +81,7 @@ public class InfrastructureMetricSet {
addMetric(metrics, ConfigServerMetrics.HAS_WIRE_GUARD_KEY.max());
addMetric(metrics, ConfigServerMetrics.WANT_TO_DEPROVISION.max());
addMetric(metrics, ConfigServerMetrics.SUSPENDED.max());
+ addMetric(metrics, ConfigServerMetrics.SUSPENDED_SECONDS.count());
addMetric(metrics, ConfigServerMetrics.SOME_SERVICES_DOWN.max());
addMetric(metrics, ConfigServerMetrics.NODE_FAILER_BAD_NODE.max());
addMetric(metrics, ConfigServerMetrics.LOCK_ATTEMPT_LOCKED_LOAD, EnumSet.of(max,average));
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java
index 69ae4fddb63..e3d72d1189e 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java
@@ -169,8 +169,6 @@ public class MetricsReporter extends NodeRepositoryMaintainer {
* NB: Keep this metric set in sync with internal configserver metric pre-aggregation
*/
private void updateNodeMetrics(Node node, ServiceModel serviceModel) {
- if (node.state() != State.active)
- return;
Metric.Context context;
Optional<Allocation> allocation = node.allocation();
if (allocation.isPresent()) {
@@ -235,7 +233,7 @@ public class MetricsReporter extends NodeRepositoryMaintainer {
long suspendedSeconds = info.suspendedSince()
.map(suspendedSince -> Duration.between(suspendedSince, clock().instant()).getSeconds())
.orElse(0L);
- metric.set(ConfigServerMetrics.SUSPENDED_SECONDS.baseName(), suspendedSeconds, context);
+ metric.add(ConfigServerMetrics.SUSPENDED_SECONDS.baseName(), suspendedSeconds, context);
});
long numberOfServices;
diff --git a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
index 4562f8cb50e..a4b05d6540c 100644
--- a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
+++ b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
@@ -10,17 +10,17 @@ constexpr size_t loop_cnt = 64;
using namespace search::queryeval;
template <typename FLOW>
-double ordered_cost_of(const std::vector<FlowStats> &data, bool strict) {
- return flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(strict));
+double ordered_cost_of(const std::vector<FlowStats> &data, InFlow in_flow, bool allow_force_strict) {
+ return flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(in_flow), allow_force_strict);
}
template <typename FLOW>
-double dual_ordered_cost_of(const std::vector<FlowStats> &data, bool strict) {
- double result = flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(strict));
- AnyFlow any_flow = AnyFlow::create<FLOW>(strict);
+double dual_ordered_cost_of(const std::vector<FlowStats> &data, InFlow in_flow, bool allow_force_strict) {
+ double result = flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(in_flow), allow_force_strict);
+ AnyFlow any_flow = AnyFlow::create<FLOW>(in_flow);
double total_cost = 0.0;
for (const auto &item: data) {
- double child_cost = any_flow.strict() ? item.strict_cost : any_flow.flow() * item.cost;
+ double child_cost = flow::min_child_cost(InFlow(any_flow.strict(), any_flow.flow()), item, allow_force_strict);
any_flow.update_cost(total_cost, child_cost);
any_flow.add(item.estimate);
}
@@ -38,8 +38,12 @@ std::vector<FlowStats> gen_data(size_t size) {
for (size_t i = 0; i < size; ++i) {
result.emplace_back(estimate(gen), cost(gen), strict_cost(gen));
}
+ if (size == 0) {
+ gen.seed(gen.default_seed);
+ }
return result;
}
+void re_seed() { gen_data(0); }
template <typename T, typename F>
void each_perm(std::vector<T> &data, size_t k, F fun) {
@@ -275,16 +279,16 @@ TEST(FlowTest, in_flow_strict_vs_rate_interaction) {
TEST(FlowTest, flow_cost) {
std::vector<FlowStats> data = {{0.4, 1.1, 0.6}, {0.7, 1.2, 0.5}, {0.2, 1.3, 0.4}};
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.7*1.3);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, true), 0.6 + 0.5 + 0.4);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.3*1.3);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, false), 1.1);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, true), 0.6);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, false), 1.3);
- EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, true), 0.6);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, false, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, true, false), 0.6 + 0.4*1.2 + 0.4*0.7*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, false, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, true, false), 0.6 + 0.5 + 0.4);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, false, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, true, false), 0.6 + 0.4*1.2 + 0.4*0.3*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, false, false), 1.1);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, true, false), 0.6);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, false, false), 1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, true, false), 0.6);
}
TEST(FlowTest, rank_flow_cost_accumulation_is_first) {
@@ -322,9 +326,9 @@ TEST(FlowTest, optimal_and_flow) {
double min_cost = AndFlow::cost_of(data, strict);
double max_cost = 0.0;
AndFlow::sort(data, strict);
- EXPECT_EQ(ordered_cost_of<AndFlow>(data, strict), min_cost);
+ EXPECT_EQ(ordered_cost_of<AndFlow>(data, strict, false), min_cost);
auto check = [&](const std::vector<FlowStats> &my_data) noexcept {
- double my_cost = ordered_cost_of<AndFlow>(my_data, strict);
+ double my_cost = ordered_cost_of<AndFlow>(my_data, strict, false);
EXPECT_LE(min_cost, my_cost);
max_cost = std::max(max_cost, my_cost);
};
@@ -345,9 +349,9 @@ TEST(FlowTest, optimal_or_flow) {
double min_cost = OrFlow::cost_of(data, strict);
double max_cost = 0.0;
OrFlow::sort(data, strict);
- EXPECT_EQ(ordered_cost_of<OrFlow>(data, strict), min_cost);
+ EXPECT_EQ(ordered_cost_of<OrFlow>(data, strict, false), min_cost);
auto check = [&](const std::vector<FlowStats> &my_data) noexcept {
- double my_cost = ordered_cost_of<OrFlow>(my_data, strict);
+ double my_cost = ordered_cost_of<OrFlow>(my_data, strict, false);
EXPECT_LE(min_cost, my_cost + 1e-9);
max_cost = std::max(max_cost, my_cost);
};
@@ -369,10 +373,10 @@ TEST(FlowTest, optimal_and_not_flow) {
double max_cost = 0.0;
AndNotFlow::sort(data, strict);
EXPECT_EQ(data[0], first);
- EXPECT_DOUBLE_EQ(ordered_cost_of<AndNotFlow>(data, strict), min_cost);
+ EXPECT_DOUBLE_EQ(ordered_cost_of<AndNotFlow>(data, strict, false), min_cost);
auto check = [&](const std::vector<FlowStats> &my_data) noexcept {
if (my_data[0] == first) {
- double my_cost = ordered_cost_of<AndNotFlow>(my_data, strict);
+ double my_cost = ordered_cost_of<AndNotFlow>(my_data, strict, false);
EXPECT_LE(min_cost, my_cost + 1e-9);
max_cost = std::max(max_cost, my_cost);
}
@@ -386,4 +390,106 @@ TEST(FlowTest, optimal_and_not_flow) {
}
}
+void test_strict_AND_sort_strategy(auto my_sort) {
+ re_seed();
+ size_t cnt = 64;
+ double max_rel_err = 0.0;
+ double sum_rel_err = 0.0;
+ for (size_t i = 0; i < cnt; ++i) {
+ auto data = gen_data(7);
+ double ref_est = AndFlow::estimate_of(data);
+ double min_cost = 1'000'000.0;
+ double max_cost = 0.0;
+ my_sort(data);
+ double est_cost = ordered_cost_of<AndFlow>(data, true, true);
+ auto check = [&](const std::vector<FlowStats> &my_data) noexcept {
+ double my_cost = ordered_cost_of<AndFlow>(my_data, true, true);
+ min_cost = std::min(min_cost, my_cost);
+ max_cost = std::max(max_cost, my_cost);
+ };
+ each_perm(data, check);
+ double rel_err = 0.0;
+ double cost_range = (max_cost - min_cost);
+ if (cost_range > 1e-9) {
+ rel_err = (est_cost - min_cost) / cost_range;
+ }
+ max_rel_err = std::max(max_rel_err, rel_err);
+ sum_rel_err += rel_err;
+ EXPECT_NEAR(ref_est, AndFlow::estimate_of(data), 1e-9);
+ }
+ fprintf(stderr, " strict AND allow_force_strict: avg rel_err: %g, max rel_err: %g\n",
+ sum_rel_err / cnt, max_rel_err);
+}
+
+TEST(FlowTest, strict_and_with_allow_force_strict_basic_order) {
+ auto my_sort = [](auto &data){ AndFlow::sort(data, true); };
+ test_strict_AND_sort_strategy(my_sort);
+}
+
+void greedy_sort(std::vector<FlowStats> &data, auto flow, auto score_of) {
+ for (size_t next = 0; (next + 1) < data.size(); ++next) {
+ InFlow in_flow = InFlow(flow.strict(), flow.flow());
+ size_t best_idx = next;
+ double best_score = score_of(in_flow, data[next]);
+ for (size_t i = next + 1; i < data.size(); ++i) {
+ double score = score_of(in_flow, data[i]);
+ if (score > best_score) {
+ best_score = score;
+ best_idx = i;
+ }
+ }
+ std::swap(data[next], data[best_idx]);
+ flow.add(data[next].estimate);
+ }
+}
+
+TEST(FlowTest, strict_and_with_allow_force_strict_greedy_reduction_efficiency) {
+ auto score_of = [](InFlow in_flow, const FlowStats &stats) {
+ double child_cost = flow::min_child_cost(in_flow, stats, true);
+ double reduction = (1.0 - stats.estimate);
+ return reduction / child_cost;
+ };
+ auto my_sort = [&](auto &data){ greedy_sort(data, AndFlow(true), score_of); };
+ test_strict_AND_sort_strategy(my_sort);
+}
+
+TEST(FlowTest, strict_and_with_allow_force_strict_incremental_strict_selection) {
+ auto my_sort = [](auto &data) {
+ AndFlow::sort(data, true);
+ for (size_t next = 1; (next + 1) < data.size(); ++next) {
+ auto [idx, score] = flow::select_forced_strict_and_child(flow::DirectAdapter(), data, next);
+ if (score >= 0.0) {
+ break;
+ }
+ auto the_one = data[idx];
+ for (; idx > next; --idx) {
+ data[idx] = data[idx-1];
+ }
+ data[next] = the_one;
+ }
+ };
+ test_strict_AND_sort_strategy(my_sort);
+}
+
+TEST(FlowTest, strict_and_with_allow_force_strict_incremental_strict_selection_with_strict_re_sorting) {
+ auto my_sort = [](auto &data) {
+ AndFlow::sort(data, true);
+ size_t strict_cnt = 1;
+ while (strict_cnt < data.size()) {
+ auto [idx, score] = flow::select_forced_strict_and_child(flow::DirectAdapter(), data, strict_cnt);
+ if (score >= 0.0) {
+ break;
+ }
+ auto the_one = data[idx];
+ for (; idx > strict_cnt; --idx) {
+ data[idx] = data[idx-1];
+ }
+ data[strict_cnt++] = the_one;
+ }
+ std::sort(data.begin(), data.begin() + strict_cnt,
+ [](const auto &a, const auto &b){ return (a.estimate < b.estimate); });
+ };
+ test_strict_AND_sort_strategy(my_sort);
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h
index ba0235991a8..f38d447d3b0 100644
--- a/searchlib/src/vespa/searchlib/queryeval/flow.h
+++ b/searchlib/src/vespa/searchlib/queryeval/flow.h
@@ -5,6 +5,7 @@
#include <cstddef>
#include <algorithm>
#include <functional>
+#include <limits>
// Model how boolean result decisions flow through intermediate nodes
// of different types based on relative estimates for sub-expressions
@@ -34,7 +35,10 @@ struct FlowStats {
double strict_cost;
constexpr FlowStats(double estimate_in, double cost_in, double strict_cost_in) noexcept
: estimate(estimate_in), cost(cost_in), strict_cost(strict_cost_in) {}
- auto operator <=>(const FlowStats &rhs) const noexcept = default;
+ constexpr auto operator <=>(const FlowStats &rhs) const noexcept = default;
+ static constexpr FlowStats from(auto adapter, const auto &child) noexcept {
+ return {adapter.estimate(child), adapter.cost(child), adapter.strict_cost(child)};
+ }
};
namespace flow {
@@ -117,6 +121,23 @@ struct MinOrCost {
}
};
+// estimate the cost of evaluating a strict child in a non-strict context
+inline double forced_strict_cost(double estimate, double strict_cost, double rate) {
+ return 0.2 * (rate - estimate) + strict_cost;
+}
+
+// estimate the absolute cost of evaluating a child with a specific in flow
+inline double min_child_cost(InFlow in_flow, const FlowStats &stats, bool allow_force_strict) {
+ if (in_flow.strict()) {
+ return stats.strict_cost;
+ }
+ if (!allow_force_strict) {
+ return stats.cost * in_flow.rate();
+ }
+ return std::min(forced_strict_cost(stats.estimate, stats.strict_cost, in_flow.rate()),
+ stats.cost * in_flow.rate());
+}
+
template <typename ADAPTER, typename T>
double estimate_of_and(ADAPTER adapter, const T &children) {
double flow = children.empty() ? 0.0 : adapter.estimate(children[0]);
@@ -157,43 +178,59 @@ void sort_partial(ADAPTER adapter, T &children, size_t offset) {
}
template <typename ADAPTER, typename T, typename F>
-double ordered_cost_of(ADAPTER adapter, const T &children, F flow) {
+double ordered_cost_of(ADAPTER adapter, const T &children, F flow, bool allow_force_strict) {
double total_cost = 0.0;
for (const auto &child: children) {
- double child_cost = flow.strict() ? adapter.strict_cost(child) : (flow.flow() * adapter.cost(child));
+ FlowStats stats(adapter.estimate(child), adapter.cost(child), adapter.strict_cost(child));
+ double child_cost = min_child_cost(InFlow(flow.strict(), flow.flow()), stats, allow_force_strict);
flow.update_cost(total_cost, child_cost);
- flow.add(adapter.estimate(child));
+ flow.add(stats.estimate);
}
return total_cost;
}
-template <typename ADAPTER, typename T>
-size_t select_strict_and_child(ADAPTER adapter, const T &children) {
- size_t idx = 0;
+size_t select_strict_and_child(auto adapter, const auto &children) {
+ double est = 1.0;
double cost = 0.0;
size_t best_idx = 0;
- double best_diff = 0.0;
- double est = 1.0;
- for (const auto &child: children) {
- double child_cost = est * adapter.cost(child);
- double child_strict_cost = adapter.strict_cost(child);
- double child_est = adapter.estimate(child);
- if (idx == 0) {
- best_diff = child_strict_cost - child_cost;
- } else {
- double my_diff = (child_strict_cost + child_est * cost) - (cost + child_cost);
- if (my_diff < best_diff) {
- best_diff = my_diff;
- best_idx = idx;
- }
+ double best_diff = std::numeric_limits<double>::max();
+ for (size_t idx = 0; idx < children.size(); ++idx) {
+ auto child = FlowStats::from(adapter, children[idx]);
+ double child_abs_cost = est * child.cost;
+ double my_diff = (child.strict_cost + child.estimate * cost) - (cost + child_abs_cost);
+ if (my_diff < best_diff) {
+ best_diff = my_diff;
+ best_idx = idx;
}
- cost += child_cost;
- est *= child_est;
- ++idx;
+ cost += child_abs_cost;
+ est *= child.estimate;
}
return best_idx;
}
+auto select_forced_strict_and_child(auto adapter, const auto &children, size_t first) {
+ double est = 1.0;
+ double cost = 0.0;
+ size_t best_idx = 0;
+ double best_diff = std::numeric_limits<double>::max();
+ for (size_t idx = 0; idx < first && idx < children.size(); ++idx) {
+ est *= adapter.estimate(children[idx]);
+ }
+ for (size_t idx = first; idx < children.size(); ++idx) {
+ auto child = FlowStats::from(adapter, children[idx]);
+ double child_abs_cost = est * child.cost;
+ double forced_cost = forced_strict_cost(child.estimate, child.strict_cost, est);
+ double my_diff = (forced_cost + child.estimate * cost) - (cost + child_abs_cost);
+ if (my_diff < best_diff) {
+ best_diff = my_diff;
+ best_idx = idx;
+ }
+ cost += child_abs_cost;
+ est *= child.estimate;
+ }
+ return std::make_pair(best_idx, best_diff);
+}
+
} // flow
template <typename FLOW>
@@ -202,7 +239,7 @@ struct FlowMixin {
auto my_adapter = flow::IndirectAdapter(adapter, children);
auto order = flow::make_index(children.size());
FLOW::sort(my_adapter, order, strict);
- return flow::ordered_cost_of(my_adapter, order, FLOW(strict));
+ return flow::ordered_cost_of(my_adapter, order, FLOW(strict), false);
}
static double cost_of(const auto &children, bool strict) {
return cost_of(flow::make_adapter(children), children, strict);
diff --git a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp
index d8a6091fe11..5988bdd912f 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp
@@ -4,6 +4,19 @@
namespace vsm {
+namespace {
+
+template <bool exact_match> inline bool is_word_char(ucs4_t c);
+
+template <>
+inline bool is_word_char<false>(ucs4_t c) { return Fast_UnicodeUtil::IsWordChar(c); }
+
+// All characters are treated as word characters for exact match
+template <>
+inline constexpr bool is_word_char<true>(ucs4_t) { return true; }
+
+}
+
void
TokenizeReader::fold(ucs4_t c) {
const char *repl = Fast_NormalizeWordFolder::ReplacementString(c);
@@ -18,4 +31,24 @@ TokenizeReader::fold(ucs4_t c) {
}
}
+template <bool exact_match>
+size_t
+TokenizeReader::tokenize_helper(Normalizing norm_mode)
+{
+ ucs4_t c(0);
+ while (hasNext()) {
+ if (is_word_char<exact_match>(c = next())) {
+ normalize(c, norm_mode);
+ while (hasNext() && is_word_char<exact_match>(c = next())) {
+ normalize(c, norm_mode);
+ }
+ break;
+ }
+ }
+ return complete();
+}
+
+template size_t TokenizeReader::tokenize_helper<false>(Normalizing);
+template size_t TokenizeReader::tokenize_helper<true>(Normalizing);
+
}
diff --git a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h
index 2bb5e62e0aa..f680d9b6c47 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h
+++ b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h
@@ -43,6 +43,10 @@ public:
_q = _q_start;
return token_len;
}
+ template <bool exact_match>
+ size_t tokenize_helper(Normalizing norm_mode);
+ size_t tokenize(Normalizing norm_mode) { return tokenize_helper<false>(norm_mode); }
+ size_t tokenize_exact_match(Normalizing norm_mode) { return tokenize_helper<true>(norm_mode); }
private:
void fold(ucs4_t c);
const byte *_p;
diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp
index 37dc4ffb99c..c860178d583 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp
@@ -26,8 +26,7 @@ UTF8StrChrFieldSearcher::matchTerms(const FieldRef & f, size_t mintsz)
TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), fn);
while ( reader.hasNext() ) {
- tokenize(reader);
- size_t fl = reader.complete();
+ size_t fl = reader.tokenize(normalize_mode());
for (auto qt : _qtl) {
const cmptype_t * term;
termsize_t tsz = qt->term(term);
diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp
index 5036e9bedb1..f016d08ece8 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp
@@ -10,21 +10,6 @@ using search::byte;
namespace vsm {
-template<typename Reader>
-void
-UTF8StringFieldSearcherBase::tokenize(Reader & reader) {
- ucs4_t c(0);
- Normalizing norm_mode = normalize_mode();
- while (reader.hasNext() && ! Fast_UnicodeUtil::IsWordChar(c = reader.next()));
-
- if (Fast_UnicodeUtil::IsWordChar(c)) {
- reader.normalize(c, norm_mode);
- while (reader.hasNext() && Fast_UnicodeUtil::IsWordChar(c = reader.next())) {
- reader.normalize(c, norm_mode);
- }
- }
-}
-
size_t
UTF8StringFieldSearcherBase::matchTermRegular(const FieldRef & f, QueryTerm & qt)
{
@@ -38,8 +23,7 @@ UTF8StringFieldSearcherBase::matchTermRegular(const FieldRef & f, QueryTerm & qt
TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), fn);
while ( reader.hasNext() ) {
- tokenize(reader);
- size_t fl = reader.complete();
+ size_t fl = reader.tokenize(normalize_mode());
if ((tsz <= fl) && (prefix() || qt.isPrefix() || (tsz == fl))) {
const cmptype_t *tt=term, *et=term+tsz;
for (const cmptype_t *fnt=fn; (tt < et) && (*tt == *fnt); tt++, fnt++);
@@ -127,8 +111,7 @@ UTF8StringFieldSearcherBase::matchTermSuffix(const FieldRef & f, QueryTerm & qt)
TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), dstbuf);
while ( reader.hasNext() ) {
- tokenize(reader);
- size_t tokenlen = reader.complete();
+ size_t tokenlen = reader.tokenize(normalize_mode());
if (matchTermSuffix(term, tsz, dstbuf, tokenlen)) {
addHit(qt, words);
}
diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h
index b196f2795a4..c217a7b8866 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h
+++ b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h
@@ -60,9 +60,6 @@ public:
protected:
SharedSearcherBuf _buf;
- template<typename Reader>
- void tokenize(Reader & reader);
-
/**
* Matches the given query term against the words in the given field reference
* using exact or prefix match strategy.
diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp
index 8bbacf168cf..d5bf4e4238a 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp
@@ -26,8 +26,7 @@ UTF8SuffixStringFieldSearcher::matchTerms(const FieldRef & f, size_t mintsz)
TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), dstbuf);
while ( reader.hasNext() ) {
- tokenize(reader);
- size_t tokenlen = reader.complete();
+ size_t tokenlen = reader.tokenize(normalize_mode());
for (auto qt : _qtl) {
const cmptype_t * term;
termsize_t tsz = qt->term(term);
diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp
index 1ab1b16cb86..1986db79972 100644
--- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp
+++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp
@@ -273,8 +273,11 @@ buildFieldSet(const VsmfieldsConfig::Documenttype::Index & ci, const FieldSearch
return ifm;
}
+}
+
search::Normalizing
-normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode) {
+FieldSearchSpecMap::convert_normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode)
+{
switch (normalize_mode) {
case VsmfieldsConfig::Fieldspec::Normalize::NONE: return search::Normalizing::NONE;
case VsmfieldsConfig::Fieldspec::Normalize::LOWERCASE: return search::Normalizing::LOWERCASE;
@@ -283,8 +286,6 @@ normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode) {
return search::Normalizing::LOWERCASE_AND_FOLD;
}
-}
-
void
FieldSearchSpecMap::buildFromConfig(const VsmfieldsHandle & conf, const search::fef::IIndexEnvironment& index_env)
{
@@ -292,7 +293,7 @@ FieldSearchSpecMap::buildFromConfig(const VsmfieldsHandle & conf, const search::
for(const VsmfieldsConfig::Fieldspec & cfs : conf->fieldspec) {
LOG(spam, "Parsing %s", cfs.name.c_str());
FieldIdT fieldId = specMap().size();
- FieldSearchSpec fss(fieldId, cfs.name, cfs.searchmethod, normalize_mode(cfs.normalize), cfs.arg1, cfs.maxlength);
+ FieldSearchSpec fss(fieldId, cfs.name, cfs.searchmethod, convert_normalize_mode(cfs.normalize), cfs.arg1, cfs.maxlength);
_specMap[fieldId] = std::move(fss);
_nameIdMap.add(cfs.name, fieldId);
LOG(spam, "M in %d = %s", fieldId, cfs.name.c_str());
diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h
index 5b5a6b9a783..8bab0cad3b6 100644
--- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h
+++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h
@@ -101,6 +101,7 @@ public:
static vespalib::string stripNonFields(vespalib::stringref rawIndex);
search::attribute::DistanceMetric get_distance_metric(const vespalib::string& name) const;
+ static search::Normalizing convert_normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode);
private:
FieldSearchSpecMapT _specMap; // mapping from field id to field search spec
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 754352a45a4..45e88ac2e94 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -4077,6 +4077,26 @@
],
"fields" : [ ]
},
+ "ai.vespa.llm.InferenceParameters" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(java.util.function.Function)",
+ "public void <init>(java.lang.String, java.util.function.Function)",
+ "public void setApiKey(java.lang.String)",
+ "public java.util.Optional getApiKey()",
+ "public void setEndpoint(java.lang.String)",
+ "public java.util.Optional getEndpoint()",
+ "public java.util.Optional get(java.lang.String)",
+ "public java.util.Optional getDouble(java.lang.String)",
+ "public java.util.Optional getInt(java.lang.String)",
+ "public void ifPresent(java.lang.String, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
"ai.vespa.llm.LanguageModel" : {
"superClass" : "java.lang.Object",
"interfaces" : [ ],
@@ -4086,23 +4106,20 @@
"abstract"
],
"methods" : [
- "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt)",
- "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)"
+ "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
],
"fields" : [ ]
},
- "ai.vespa.llm.client.openai.OpenAiClient$Builder" : {
- "superClass" : "java.lang.Object",
+ "ai.vespa.llm.LanguageModelException" : {
+ "superClass" : "java.lang.RuntimeException",
"interfaces" : [ ],
"attributes" : [
"public"
],
"methods" : [
- "public void <init>(java.lang.String)",
- "public ai.vespa.llm.client.openai.OpenAiClient$Builder model(java.lang.String)",
- "public ai.vespa.llm.client.openai.OpenAiClient$Builder temperature(double)",
- "public ai.vespa.llm.client.openai.OpenAiClient$Builder maxTokens(long)",
- "public ai.vespa.llm.client.openai.OpenAiClient build()"
+ "public void <init>(int, java.lang.String)",
+ "public int code()"
],
"fields" : [ ]
},
@@ -4115,8 +4132,9 @@
"public"
],
"methods" : [
- "public java.util.List complete(ai.vespa.llm.completion.Prompt)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)"
+ "public void <init>()",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
],
"fields" : [ ]
},
@@ -4135,7 +4153,8 @@
"fields" : [
"public static final enum ai.vespa.llm.completion.Completion$FinishReason length",
"public static final enum ai.vespa.llm.completion.Completion$FinishReason stop",
- "public static final enum ai.vespa.llm.completion.Completion$FinishReason none"
+ "public static final enum ai.vespa.llm.completion.Completion$FinishReason none",
+ "public static final enum ai.vespa.llm.completion.Completion$FinishReason error"
]
},
"ai.vespa.llm.completion.Completion" : {
@@ -4151,6 +4170,7 @@
"public java.lang.String text()",
"public ai.vespa.llm.completion.Completion$FinishReason finishReason()",
"public static ai.vespa.llm.completion.Completion from(java.lang.String)",
+ "public static ai.vespa.llm.completion.Completion from(java.lang.String, ai.vespa.llm.completion.Completion$FinishReason)",
"public final java.lang.String toString()",
"public final int hashCode()",
"public final boolean equals(java.lang.Object)"
@@ -4212,8 +4232,8 @@
],
"methods" : [
"public void <init>(ai.vespa.llm.test.MockLanguageModel$Builder)",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)"
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
],
"fields" : [ ]
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
new file mode 100755
index 00000000000..a942e5090e5
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
@@ -0,0 +1,76 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+/**
+ * Parameters for inference to language models. Parameters are typically
+ * supplied from searchers or processors and comes from query strings,
+ * headers, or other sources. Which parameters are available depends on
+ * the language model used.
+ *
+ * author lesters
+ */
+@Beta
+public class InferenceParameters {
+
+ private String apiKey;
+ private String endpoint;
+ private final Function<String, String> options;
+
+ public InferenceParameters(Function<String, String> options) {
+ this(null, options);
+ }
+
+ public InferenceParameters(String apiKey, Function<String, String> options) {
+ this.apiKey = apiKey;
+ this.options = Objects.requireNonNull(options);
+ }
+
+ public void setApiKey(String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public Optional<String> getApiKey() {
+ return Optional.ofNullable(apiKey);
+ }
+
+ public void setEndpoint(String endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ public Optional<String> getEndpoint() {
+ return Optional.ofNullable(endpoint);
+ }
+
+ public Optional<String> get(String option) {
+ return Optional.ofNullable(options.apply(option));
+ }
+
+ public Optional<Double> getDouble(String option) {
+ try {
+ return Optional.of(Double.parseDouble(options.apply(option)));
+ } catch (Exception e) {
+ return Optional.empty();
+ }
+ }
+
+ public Optional<Integer> getInt(String option) {
+ try {
+ return Optional.of(Integer.parseInt(options.apply(option)));
+ } catch (Exception e) {
+ return Optional.empty();
+ }
+ }
+
+ public void ifPresent(String option, Consumer<String> func) {
+ get(option).ifPresent(func);
+ }
+
+}
+
diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java
index f4b8938934b..059f25fadb4 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java
@@ -17,8 +17,10 @@ import java.util.function.Consumer;
@Beta
public interface LanguageModel {
- List<Completion> complete(Prompt prompt);
+ List<Completion> complete(Prompt prompt, InferenceParameters options);
- CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action);
+ CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> consumer);
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java
new file mode 100755
index 00000000000..b5dbf615c08
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java
@@ -0,0 +1,19 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+@Beta
+public class LanguageModelException extends RuntimeException {
+
+ private final int code;
+
+ public LanguageModelException(int code, String message) {
+ super(message);
+ this.code = code;
+ }
+
+ public int code() {
+ return code;
+ }
+}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
index d7334b40963..75308a84faa 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.llm.client.openai;
+import ai.vespa.llm.LanguageModelException;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Prompt;
@@ -28,31 +30,28 @@ import java.util.stream.Stream;
* Currently, only completions are implemented.
*
* @author bratseth
+ * @author lesters
*/
@Beta
public class OpenAiClient implements LanguageModel {
+ private static final String DEFAULT_MODEL = "gpt-3.5-turbo";
private static final String DATA_FIELD = "data: ";
- private final String token;
- private final String model;
- private final double temperature;
- private final long maxTokens;
+ private static final String OPTION_MODEL = "model";
+ private static final String OPTION_TEMPERATURE = "temperature";
+ private static final String OPTION_MAX_TOKENS = "maxTokens";
private final HttpClient httpClient;
- private OpenAiClient(Builder builder) {
- this.token = builder.token;
- this.model = builder.model;
- this.temperature = builder.temperature;
- this.maxTokens = builder.maxTokens;
+ public OpenAiClient() {
this.httpClient = HttpClient.newBuilder().build();
}
@Override
- public List<Completion> complete(Prompt prompt) {
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
try {
- HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray());
+ HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray());
var response = SlimeUtils.jsonToSlime(httpResponse.body()).get();
if ( httpResponse.statusCode() != 200)
throw new IllegalArgumentException(SlimeUtils.toJson(response));
@@ -64,9 +63,11 @@ public class OpenAiClient implements LanguageModel {
}
@Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) {
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> consumer) {
try {
- var request = toRequest(prompt, true);
+ var request = toRequest(prompt, options, true);
var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines());
var completionFuture = new CompletableFuture<Completion.FinishReason>();
@@ -74,8 +75,7 @@ public class OpenAiClient implements LanguageModel {
try {
int responseCode = response.statusCode();
if (responseCode != 200) {
- throw new IllegalArgumentException("Received code " + responseCode + ": " +
- response.body().collect(Collectors.joining()));
+ throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining()));
}
Stream<String> lines = response.body();
@@ -100,28 +100,28 @@ public class OpenAiClient implements LanguageModel {
}
}
- private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException {
- return toRequest(prompt, false);
- }
-
- private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException {
+ private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException {
var slime = new Slime();
var root = slime.setObject();
- root.setString("model", model);
- root.setDouble("temperature", temperature);
+ root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL));
root.setBool("stream", stream);
root.setLong("n", 1);
- if (maxTokens > 0) {
- root.setLong("max_tokens", maxTokens);
- }
+
+ if (options.getDouble(OPTION_TEMPERATURE).isPresent())
+ root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get());
+ if (options.getInt(OPTION_MAX_TOKENS).isPresent())
+ root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get());
+ // Others?
+
var messagesArray = root.setArray("messages");
var messagesObject = messagesArray.addObject();
messagesObject.setString("role", "user");
messagesObject.setString("content", prompt.asString());
- return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions"))
+ var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions");
+ return HttpRequest.newBuilder(new URI(endpoint))
.header("Content-Type", "application/json")
- .header("Authorization", "Bearer " + token)
+ .header("Authorization", "Bearer " + options.getApiKey().orElse(""))
.POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime)))
.build();
}
@@ -152,39 +152,4 @@ public class OpenAiClient implements LanguageModel {
};
}
- public static class Builder {
-
- private final String token;
- private String model = "gpt-3.5-turbo";
- private double temperature = 0.0;
- private long maxTokens = 0;
-
- public Builder(String token) {
- this.token = token;
- }
-
- /** One of the language models listed at https://platform.openai.com/docs/models */
- public Builder model(String model) {
- this.model = model;
- return this;
- }
-
- /** A value between 0 and 2 - higher gives more random/creative output. */
- public Builder temperature(double temperature) {
- this.temperature = temperature;
- return this;
- }
-
- /** Maximum number of tokens to generate */
- public Builder maxTokens(long maxTokens) {
- this.maxTokens = maxTokens;
- return this;
- }
-
- public OpenAiClient build() {
- return new OpenAiClient(this);
- }
-
- }
-
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java
index ea784013812..91d0ad9bd02 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java
@@ -22,7 +22,10 @@ public record Completion(String text, FinishReason finishReason) {
stop,
/** The completion is not finished yet, more tokens are incoming. */
- none
+ none,
+
+ /** An error occurred while generating the completion */
+ error
}
public Completion(String text, FinishReason finishReason) {
@@ -37,7 +40,11 @@ public record Completion(String text, FinishReason finishReason) {
public FinishReason finishReason() { return finishReason; }
public static Completion from(String text) {
- return new Completion(text, FinishReason.stop);
+ return from(text, FinishReason.stop);
+ }
+
+ public static Completion from(String text, FinishReason reason) {
+ return new Completion(text, reason);
}
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
index db1b42fbbac..0e757a1f1e7 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
@@ -2,6 +2,7 @@
package ai.vespa.llm.test;
import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.api.annotations.Beta;
@@ -24,12 +25,14 @@ public class MockLanguageModel implements LanguageModel {
}
@Override
- public List<Completion> complete(Prompt prompt) {
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
return completer.apply(prompt);
}
@Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action) {
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> action) {
throw new RuntimeException("Not implemented");
}
diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
index 45ef7e270aa..1baab26f496 100644
--- a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
+++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
@@ -1,46 +1,46 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.llm.client.openai;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.StringPrompt;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
+import java.util.Map;
+
/**
* @author bratseth
*/
public class OpenAiClientCompletionTest {
- private static final String apiKey = "your-api-key-here";
+ private static final String apiKey = "<your-api-key-here>";
@Test
@Disabled
public void testClient() {
- var client = new OpenAiClient.Builder(apiKey).maxTokens(10).build();
- String input = "You are an unhelpful assistant who never answers questions straightforwardly. " +
- "Be as long-winded as possible. Are humans smarter than cats?\n\n";
- StringPrompt prompt = StringPrompt.from(input);
+ var client = new OpenAiClient();
+ var options = Map.of("maxTokens", "10");
+ var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " +
+ "Be as long-winded as possible. Are humans smarter than cats?");
+
System.out.print(prompt);
- for (int i = 0; i < 10; i++) {
- var completion = client.complete(prompt).get(0);
- System.out.print(completion.text());
- if (completion.finishReason() == Completion.FinishReason.stop) break;
- prompt = prompt.append(completion.text());
- }
+ var completion = client.complete(prompt, new InferenceParameters(apiKey, options::get)).get(0);
+ System.out.print(completion.text());
}
@Test
@Disabled
public void testAsyncClient() {
- var client = new OpenAiClient.Builder(apiKey).build();
- String input = "You are an unhelpful assistant who never answers questions straightforwardly. " +
- "Be as long-winded as possible. Are humans smarter than cats?\n\n";
- StringPrompt prompt = StringPrompt.from(input);
+ var client = new OpenAiClient();
+ var options = Map.of("maxTokens", "10");
+ var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " +
+ "Be as long-winded as possible. Are humans smarter than cats?");
System.out.print(prompt);
- var future = client.completeAsync(prompt, completion -> {
+ var future = client.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> {
System.out.print(completion.text());
});
- System.out.println("Waiting for completion...");
+ System.out.println("\nWaiting for completion...\n\n");
System.out.println("\nFinished streaming because of " + future.join());
}
diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
index 7407eb526e7..24c496a3d2c 100644
--- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
+++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.llm.completion;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.test.MockLanguageModel;
import org.junit.jupiter.api.Test;
@@ -27,8 +28,9 @@ public class CompletionTest {
String input = "Complete this: ";
StringPrompt prompt = StringPrompt.from(input);
+ InferenceParameters options = new InferenceParameters(s -> "");
for (int i = 0; i < 10; i++) {
- var completion = llm.complete(prompt).get(0);
+ var completion = llm.complete(prompt, options).get(0);
prompt = prompt.append(completion);
if (completion.finishReason() == Completion.FinishReason.stop) break;
}