diff options
36 files changed, 1560 insertions, 181 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index bdb6cd9e7a5..257dd364000 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -9148,5 +9148,60 @@ "public int getTo()" ], "fields" : [ ] + }, + "ai.vespa.llm.clients.ConfigurableLanguageModel" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public", + "abstract" + ], + "methods" : [ + "public void <init>()", + "public void <init>(ai.vespa.llm.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", + "protected void setApiKey(ai.vespa.llm.InferenceParameters)", + "protected java.lang.String getEndpoint()", + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.OpenAI" : { + "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(ai.vespa.llm.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.search.LLMSearcher" : { + "superClass" : "com.yahoo.search.Searcher", + "interfaces" : [ ], + "attributes" : [ + "public", + "abstract" + ], + "methods" : [ ], + "fields" : [ ] + }, + "ai.vespa.llm.search.RAGSearcher" : { + "superClass" : "ai.vespa.llm.search.LLMSearcher", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(ai.vespa.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)", + "public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)", + "protected ai.vespa.llm.completion.Prompt buildPrompt(com.yahoo.search.Query, com.yahoo.search.Result)" + ], + "fields" : [ ] } }
\ No newline at end of file diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java new file mode 100644 index 00000000000..662d73d4e01 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java @@ -0,0 +1,74 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmClientConfig; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.logging.Logger; + + +/** + * Base class for language models that can be configured with config definitions. + * + * @author lesters + */ +@Beta +public abstract class ConfigurableLanguageModel implements LanguageModel { + + private static Logger log = Logger.getLogger(ai.vespa.llm.clients.ConfigurableLanguageModel.class.getName()); + + private final String apiKey; + private final String endpoint; + + public ConfigurableLanguageModel() { + this.apiKey = null; + this.endpoint = null; + } + + @Inject + public ConfigurableLanguageModel(LlmClientConfig config, SecretStore secretStore) { + this.apiKey = findApiKeyInSecretStore(config.apiKeySecretName(), secretStore); + this.endpoint = config.endpoint(); + } + + private static String findApiKeyInSecretStore(String property, SecretStore secretStore) { + String apiKey = ""; + if (property != null && ! property.isEmpty()) { + try { + apiKey = secretStore.getSecret(property); + } catch (UnsupportedOperationException e) { + // Secret store is not available - silently ignore this + } catch (Exception e) { + log.warning("Secret store look up failed: " + e.getMessage() + "\n" + + "Will expect API key in request header"); + } + } + return apiKey; + } + + protected String getApiKey(InferenceParameters params) { + return params.getApiKey().orElse(null); + } + + /** + * Set the API key as retrieved from secret store if it is not already set + */ + protected void setApiKey(InferenceParameters params) { + if (params.getApiKey().isEmpty() && apiKey != null) { + params.setApiKey(apiKey); + } + } + + protected String getEndpoint() { + return endpoint; + } + + protected void setEndpoint(InferenceParameters params) { + params.setEndpoint(endpoint); + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java new file mode 100644 index 00000000000..f6092f51948 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java @@ -0,0 +1,49 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.client.openai.OpenAiClient; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * A configurable OpenAI client. + * + * @author lesters + */ +@Beta +public class OpenAI extends ConfigurableLanguageModel { + + private final OpenAiClient client; + + @Inject + public OpenAI(LlmClientConfig config, SecretStore secretStore) { + super(config, secretStore); + client = new OpenAiClient(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters parameters) { + setApiKey(parameters); + setEndpoint(parameters); + return client.complete(prompt, parameters); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters parameters, + Consumer<Completion> consumer) { + setApiKey(parameters); + setEndpoint(parameters); + return client.completeAsync(prompt, parameters, consumer); + } +} + diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java new file mode 100644 index 00000000000..c360245901c --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.clients; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java new file mode 100755 index 00000000000..6ff40401a8f --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/search/LLMSearcher.java @@ -0,0 +1,166 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LanguageModelException; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.ComponentId; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.result.HitGroup; + +import java.util.List; +import java.util.function.Function; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * Base class for LLM searchers. Provides utilities for calling LLMs and handling properties. + * + * @author lesters + */ +@Beta +public abstract class LLMSearcher extends Searcher { + + private static Logger log = Logger.getLogger(LLMSearcher.class.getName()); + + private static final String API_KEY_HEADER = "X-LLM-API-KEY"; + private static final String STREAM_PROPERTY = "stream"; + private static final String PROMPT_PROPERTY = "prompt"; + + private final String propertyPrefix; + private final boolean stream; + private final LanguageModel languageModel; + private final String languageModelId; + + @Inject + LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + this.stream = config.stream(); + this.languageModelId = config.providerId(); + this.languageModel = findLanguageModel(languageModelId, languageModels); + this.propertyPrefix = config.propertyPrefix(); + } + + private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels) + throws IllegalArgumentException + { + if (languageModels.allComponents().isEmpty()) { + throw new IllegalArgumentException("No language models were found"); + } + if (providerId == null || providerId.isEmpty()) { + var entry = languageModels.allComponentsById().entrySet().stream().findFirst(); + if (entry.isEmpty()) { + throw new IllegalArgumentException("No language models were found"); // shouldn't happen given check above + } + log.info("Language model provider was not found in config. " + + "Fallback to using first available language model: " + entry.get().getKey()); + return entry.get().getValue(); + } + final LanguageModel languageModel = languageModels.getComponent(providerId); + if (languageModel == null) { + throw new IllegalArgumentException("No component with id '" + providerId + "' was found. " + + "Available LLM components are: " + languageModels.allComponentsById().keySet().stream() + .map(ComponentId::toString).collect(Collectors.joining(","))); + } + return languageModel; + } + + Result complete(Query query, Prompt prompt) { + var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); + var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + } + + private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { + EventStream eventStream = new EventStream(); + + if (query.getTrace().getLevel() >= 1) { + eventStream.add(prompt.asString(), "prompt"); + } + + languageModel.completeAsync(prompt, options, token -> { + eventStream.add(token.text()); + }).exceptionally(exception -> { + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + eventStream.markComplete(); + return Completion.FinishReason.error; + }).thenAccept(finishReason -> { + eventStream.markComplete(); + }); + + HitGroup hitGroup = new HitGroup("token_stream"); + hitGroup.add(eventStream); + return new Result(query, hitGroup); + } + + private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { + EventStream eventStream = new EventStream(); + + if (query.getTrace().getLevel() >= 1) { + eventStream.add(prompt.asString(), "prompt"); + } + + List<Completion> completions = languageModel.complete(prompt, options); + eventStream.add(completions.get(0).text(), "completion"); + eventStream.markComplete(); + + HitGroup hitGroup = new HitGroup("completion"); + hitGroup.add(eventStream); + return new Result(query, hitGroup); + } + + String getPrompt(Query query) { + // Look for prompt with or without prefix + String prompt = lookupPropertyWithOrWithoutPrefix(PROMPT_PROPERTY, p -> query.properties().getString(p)); + if (prompt != null) + return prompt; + + // If not found, use query directly + prompt = query.getModel().getQueryString(); + if (prompt != null) + return prompt; + + // If not, throw exception + throw new IllegalArgumentException("Could not find prompt found for query. Tried looking for " + + "'" + propertyPrefix + "." + PROMPT_PROPERTY + "', '" + PROMPT_PROPERTY + "' or '@query'."); + } + + String getPropertyPrefix() { + return this.propertyPrefix; + } + + String lookupProperty(String property, Query query) { + String propertyWithPrefix = this.propertyPrefix + "." + property; + return query.properties().getString(propertyWithPrefix, null); + } + + Boolean lookupPropertyBool(String property, Query query, boolean defaultValue) { + String propertyWithPrefix = this.propertyPrefix + "." + property; + return query.properties().getBoolean(propertyWithPrefix, defaultValue); + } + + String lookupPropertyWithOrWithoutPrefix(String property, Function<String, String> lookup) { + String value = lookup.apply(getPropertyPrefix() + "." + property); + if (value != null) + return value; + return lookup.apply(property); + } + + String getApiKeyHeader(Query query) { + return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java new file mode 100755 index 00000000000..b8e33778ced --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/search/RAGSearcher.java @@ -0,0 +1,75 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.searchchain.Execution; + +import java.util.logging.Logger; + +/** + * An LLM searcher that uses the RAG (Retrieval-Augmented Generation) model to generate completions. + * Prompts are generated based on the search result context. + * By default, the context is a concatenation of the fields of the search result hits. + * + * @author lesters + */ +@Beta +public class RAGSearcher extends LLMSearcher { + + private static Logger log = Logger.getLogger(RAGSearcher.class.getName()); + + private static final String CONTEXT_PROPERTY = "context"; + + @Inject + public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + super(config, languageModels); + log.info("Starting " + RAGSearcher.class.getName() + " with language model " + config.providerId()); + } + + @Override + public Result search(Query query, Execution execution) { + Result result = execution.search(query); + execution.fill(result); + return complete(query, buildPrompt(query, result)); + } + + protected Prompt buildPrompt(Query query, Result result) { + String prompt = getPrompt(query); + + // Replace @query with the actual query + if (prompt.contains("@query")) { + prompt = prompt.replace("@query", query.getModel().getQueryString()); + } + + String context = lookupProperty(CONTEXT_PROPERTY, query); + if (context == null || !context.equals("skip")) { + if ( !prompt.contains("{context}")) { + prompt = "{context}\n" + prompt; + } + prompt = prompt.replace("{context}", buildContext(result)); + } + return StringPrompt.from(prompt); + } + + private String buildContext(Result result) { + StringBuilder sb = new StringBuilder(); + var hits = result.hits(); + hits.forEach(hit -> { + hit.fields().forEach((key, value) -> { + sb.append(key).append(": ").append(value).append("\n"); + }); + sb.append("\n"); + }); + var context = sb.toString(); + return context; + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/search/package-info.java b/container-search/src/main/java/ai/vespa/llm/search/package-info.java new file mode 100644 index 00000000000..6a8975fd2fa --- /dev/null +++ b/container-search/src/main/java/ai/vespa/llm/search/package-info.java @@ -0,0 +1,7 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.search; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def new file mode 100755 index 00000000000..6bfd95c3cf2 --- /dev/null +++ b/container-search/src/main/resources/configdefinitions/llm-client.def @@ -0,0 +1,8 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm + +# The name of the secret containing the api key +apiKeySecretName string default="" + +# Endpoint for LLM client - if not set reverts to default for client +endpoint string default="" diff --git a/container-search/src/main/resources/configdefinitions/llm-searcher.def b/container-search/src/main/resources/configdefinitions/llm-searcher.def new file mode 100755 index 00000000000..918a6e6e8b1 --- /dev/null +++ b/container-search/src/main/resources/configdefinitions/llm-searcher.def @@ -0,0 +1,11 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm + +# Query propertry prefix for options +propertyPrefix string default="llm" + +# Should the searcher stream tokens or wait for the entire thing? +stream bool default=true + +# The external LLM provider - the id of a LanguageModel component +providerId string default="" diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java new file mode 100644 index 00000000000..1f2a12322a1 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java @@ -0,0 +1,176 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.container.jdisc.secretstore.SecretStore; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ConfigurableLanguageModelTest { + + @Test + public void testSyncGeneration() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var result = createLLM().complete(prompt, inferenceParamsWithDefaultKey()); + assertEquals(1, result.size()); + assertEquals("Ducks have adorable waddling walks.", result.get(0).text()); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var sb = new StringBuilder(); + try { + var future = createLLM(executor).completeAsync(prompt, inferenceParamsWithDefaultKey(), completion -> { + sb.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + } finally { + executor.shutdownNow(); + } + + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + @Test + public void testInferenceParameters() { + var prompt = StringPrompt.from("Why are ducks better than cats?"); + var params = inferenceParams(Map.of("temperature", "1.0", "maxTokens", "4")); + var result = createLLM().complete(prompt, params); + assertEquals("Random text about ducks", result.get(0).text()); + } + + @Test + public void testNoApiKey() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key", null); + var secrets = createSecretStore(Map.of()); + assertThrows(IllegalArgumentException.class, () -> { + createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); + }); + } + + @Test + public void testApiKeyFromSecretStore() { + var prompt = StringPrompt.from(""); + var config = modelParams("api-key-in-secret-store", null); + var secrets = createSecretStore(Map.of("api-key-in-secret-store", MockLLMClient.ACCEPTED_API_KEY)); + assertDoesNotThrow(() -> { createLLM(config, createGenerator(), secrets).complete(prompt, inferenceParams()); }); + } + + private static String lookupParameter(String parameter, Map<String, String> params) { + return params.get(parameter); + } + + private static InferenceParameters inferenceParams() { + return new InferenceParameters(s -> lookupParameter(s, Collections.emptyMap())); + } + + private static InferenceParameters inferenceParams(Map<String, String> params) { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, params)); + } + + private static InferenceParameters inferenceParamsWithDefaultKey() { + return new InferenceParameters(MockLLMClient.ACCEPTED_API_KEY, s -> lookupParameter(s, Collections.emptyMap())); + } + + private LlmClientConfig modelParams(String apiKeySecretName, String endpoint) { + var config = new LlmClientConfig.Builder(); + if (apiKeySecretName != null) { + config.apiKeySecretName(apiKeySecretName); + } + if (endpoint != null) { + config.endpoint(endpoint); + } + return config.build(); + } + + public static SecretStore createSecretStore(Map<String, String> secrets) { + Provider<SecretStore> secretStore = new Provider<>() { + public SecretStore get() { + return new SecretStore() { + public String getSecret(String key) { + return secrets.get(key); + } + public String getSecret(String key, int version) { + return secrets.get(key); + } + }; + } + public void deconstruct() { + } + }; + return secretStore.get(); + } + + public static BiFunction<Prompt, InferenceParameters, String> createGenerator() { + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; + } + + private static MockLLMClient createLLM() { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, null); + } + + private static MockLLMClient createLLM(ExecutorService executor) { + LlmClientConfig config = new LlmClientConfig.Builder().build(); + return createLLM(config, executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, ExecutorService executor) { + var generator = createGenerator(); + var secretStore = new SecretStoreProvider(); // throws exception on use + return createLLM(config, generator, secretStore.get(), executor); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction<Prompt, InferenceParameters, String> generator, + SecretStore secretStore) { + return createLLM(config, generator, secretStore, null); + } + + private static MockLLMClient createLLM(LlmClientConfig config, + BiFunction<Prompt, InferenceParameters, String> generator, + SecretStore secretStore, + ExecutorService executor) { + return new MockLLMClient(config, secretStore, generator, executor); + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java new file mode 100644 index 00000000000..cfb6a43984f --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java @@ -0,0 +1,81 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.container.jdisc.secretstore.SecretStore; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +public class MockLLMClient extends ConfigurableLanguageModel { + + public final static String ACCEPTED_API_KEY = "sesame"; + + private final ExecutorService executor; + private final BiFunction<Prompt, InferenceParameters, String> generator; + + private Prompt lastPrompt; + + public MockLLMClient(LlmClientConfig config, + SecretStore secretStore, + BiFunction<Prompt, InferenceParameters, String> generator, + ExecutorService executor) { + super(config, secretStore); + this.generator = generator; + this.executor = executor; + } + + private void checkApiKey(InferenceParameters options) { + var apiKey = getApiKey(options); + if (apiKey == null || ! apiKey.equals(ACCEPTED_API_KEY)) { + throw new IllegalArgumentException("Invalid API key"); + } + } + + private void setPrompt(Prompt prompt) { + this.lastPrompt = prompt; + } + + public Prompt getPrompt() { + return this.lastPrompt; + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters params) { + setApiKey(params); + checkApiKey(params); + setPrompt(prompt); + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters params, + Consumer<Completion> consumer) { + setPrompt(prompt); + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i=0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + + return completionFuture; + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java new file mode 100644 index 00000000000..1111a9824f5 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java @@ -0,0 +1,36 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.container.jdisc.SecretStoreProvider; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +public class OpenAITest { + + private static final String apiKey = "<your-api-key>"; + + @Test + @Disabled + public void testOpenAIGeneration() { + var config = new LlmClientConfig.Builder().build(); + var openai = new OpenAI(config, new SecretStoreProvider().get()); + var options = Map.of( + "maxTokens", "10" + ); + + var prompt = StringPrompt.from("why are ducks better than cats?"); + var future = openai.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { + System.out.print(completion.text()); + }).exceptionally(exception -> { + System.out.println("Error: " + exception); + return null; + }); + future.join(); + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java new file mode 100755 index 00000000000..d4f1dbc00a4 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/search/LLMSearcherTest.java @@ -0,0 +1,254 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmClientConfig; +import ai.vespa.llm.LlmSearcherConfig; +import ai.vespa.llm.clients.ConfigurableLanguageModelTest; +import ai.vespa.llm.clients.MockLLMClient; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.component.ComponentId; +import com.yahoo.component.chain.Chain; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.searchchain.Execution; +import org.junit.jupiter.api.Test; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class LLMSearcherTest { + + @Test + public void testLLMSelection() { + var llm1 = createLLMClient("mock1"); + var llm2 = createLLMClient("mock2"); + var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); + var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); + assertEquals(1, result.getHitCount()); + assertEquals("My id is mock2", getCompletion(result)); + } + + @Test + public void testGeneration() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of("prompt", "why are ducks better than cats"); + assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testPrompting() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + + // Prompt with prefix + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("llm.prompt", "why are ducks better than cats")))); + + // Prompt without prefix + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("prompt", "why are ducks better than cats")))); + + // Fallback to query if not given + assertEquals("Ducks have adorable waddling walks.", + getCompletion(runMockSearch(searcher, Map.of("query", "why are ducks better than cats")))); + } + + @Test + public void testPromptEvent() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of( + "prompt", "why are ducks better than cats", + "traceLevel", "1"); + var result = runMockSearch(searcher, params); + var events = ((EventStream) result.hits().get(0)).incoming().drain(); + assertEquals(2, events.size()); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("why are ducks better than cats", promptEvent.toString()); + + var completionEvent = (EventStream.Event) events.get(1); + assertEquals("completion", completionEvent.type()); + assertEquals("Ducks have adorable waddling walks.", completionEvent.toString()); + } + + @Test + public void testParameters() { + var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var params = Map.of( + "llm.prompt", "why are ducks better than cats", + "llm.temperature", "1.0", + "llm.maxTokens", "5" + ); + assertEquals("Random text about ducks vs", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testParameterPrefix() { + var prefix = "foo"; + var params = Map.of( + "foo.prompt", "what is your opinion on cats", + "foo.maxTokens", "5" + ); + var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); + } + + @Test + public void testApiKeyFromHeader() { + var properties = Map.of("prompt", "why are ducks better than cats"); + var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + } + + @Test + public void testAsyncGeneration() { + var executor = Executors.newFixedThreadPool(1); + var sb = new StringBuilder(); + try { + var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock").build(); // config says don't stream... + var params = Map.of( + "llm.stream", "true", // ... but inference parameters says do it anyway + "llm.prompt", "why are ducks better than cats?" + ); + var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + Result result = runMockSearch(searcher, params); + + assertEquals(1, result.getHitCount()); + assertTrue(result.hits().get(0) instanceof EventStream); + EventStream eventStream = (EventStream) result.hits().get(0); + + var incoming = eventStream.incoming(); + incoming.addNewDataListener(() -> { + incoming.drain().forEach(event -> sb.append(event.toString())); + }, executor); + + incoming.completedFuture().join(); + assertTrue(incoming.isComplete()); + + // Ensure incoming has been fully drained to avoid race condition in this test + incoming.drain().forEach(event -> sb.append(event.toString())); + + } finally { + executor.shutdownNow(); + } + assertEquals("Ducks have adorable waddling walks.", sb.toString()); + } + + private static String getCompletion(Result result) { + assertTrue(result.hits().size() >= 1); + return ((EventStream) result.hits().get(0)).incoming().drain().get(0).toString(); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters) { + return runMockSearch(searcher, parameters, null, ""); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { + Chain<Searcher> chain = new Chain<>(searcher); + Execution execution = new Execution(chain, Execution.Context.createContextStub()); + Query query = new Query("?" + toUrlParams(parameters)); + if (apiKey != null) { + String headerKey = "X-LLM-API-KEY"; + if (prefix != null && ! prefix.isEmpty()) { + headerKey = prefix + "." + headerKey; + } + query.getHttpRequest().getJDiscRequest().headers().add(headerKey, apiKey); + } + return execution.search(query); + } + + public static String toUrlParams(Map<String, String> parameters) { + return parameters.entrySet().stream().map( + e -> e.getKey() + "=" + URLEncoder.encode(e.getValue(), StandardCharsets.UTF_8) + ).collect(Collectors.joining("&")); + } + + private static BiFunction<Prompt, InferenceParameters, String> createIdGenerator(String id) { + return (prompt, options) -> { + if (id == null || id.isEmpty()) + return "I have no ID"; + return "My id is " + id; + }; + } + + private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { + return ConfigurableLanguageModelTest.createGenerator(); + } + + static MockLLMClient createLLMClient() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, null); + } + + static MockLLMClient createLLMClient(String id) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createIdGenerator(id); + return new MockLLMClient(config, secretStore, generator, null); + } + + static MockLLMClient createLLMClient(ExecutorService executor) { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore, generator, executor); + } + + static MockLLMClient createLLMClientWithoutSecretStore() { + var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); + var secretStore = new SecretStoreProvider(); + var generator = createGenerator(); + return new MockLLMClient(config, secretStore.get(), generator, null); + } + + private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { + var config = new LlmSearcherConfig.Builder().stream(false).build(); + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcherImpl(config, models); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new LLMSearcherImpl(config, models); + } + + public static class LLMSearcherImpl extends LLMSearcher { + + public LLMSearcherImpl(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { + super(config, languageModels); + } + + @Override + public Result search(Query query, Execution execution) { + return complete(query, StringPrompt.from(getPrompt(query))); + } + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java new file mode 100755 index 00000000000..ccf9a4a6401 --- /dev/null +++ b/container-search/src/test/java/ai/vespa/llm/search/RAGSearcherTest.java @@ -0,0 +1,127 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.search; + +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.LlmSearcherConfig; +import com.yahoo.component.ComponentId; +import com.yahoo.component.chain.Chain; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.result.EventStream; +import com.yahoo.search.result.Hit; +import com.yahoo.search.searchchain.Execution; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; + + +public class RAGSearcherTest { + + private static final String DOC1_TITLE = "Exploring the Delightful Qualities of Ducks"; + private static final String DOC1_CONTENT = "Ducks, with their gentle quacks and adorable waddling walks, possess a unique " + + "charm that sets them apart as extraordinary pets."; + private static final String DOC2_TITLE = "Why Cats Reign Supreme"; + private static final String DOC2_CONTENT = "Cats bring an enchanting allure to households with their independent " + + "companionship, playful nature, natural hunting abilities, low-maintenance grooming, and the " + + "emotional support they offer."; + + @Test + public void testRAGGeneration() { + var eventStream = runRAGQuery(Map.of( + "prompt", "why are ducks better than cats?", + "traceLevel", "1")); + var events = eventStream.incoming().drain(); + assertEquals(2, events.size()); + + // Generated prompt + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("title: " + DOC1_TITLE + "\n" + + "content: " + DOC1_CONTENT + "\n\n" + + "title: " + DOC2_TITLE + "\n" + + "content: " + DOC2_CONTENT + "\n\n\n" + + "why are ducks better than cats?", promptEvent.toString()); + + // Generated completion + var completionEvent = (EventStream.Event) events.get(1); + assertEquals("completion", completionEvent.type()); + assertEquals("Ducks have adorable waddling walks.", completionEvent.toString()); + } + + @Test + public void testPromptGeneration() { + var eventStream = runRAGQuery(Map.of( + "query", "why are ducks better than cats?", + "prompt", "{context}\nGiven these documents, answer this query as concisely as possible: @query", + "traceLevel", "1")); + var events = eventStream.incoming().drain(); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("title: " + DOC1_TITLE + "\n" + + "content: " + DOC1_CONTENT + "\n\n" + + "title: " + DOC2_TITLE + "\n" + + "content: " + DOC2_CONTENT + "\n\n\n" + + "Given these documents, answer this query as concisely as possible: " + + "why are ducks better than cats?", promptEvent.toString()); + } + + @Test + public void testSkipContextInPrompt() { + var eventStream = runRAGQuery(Map.of( + "query", "why are ducks better than cats?", + "llm.context", "skip", + "traceLevel", "1")); + var events = eventStream.incoming().drain(); + + var promptEvent = (EventStream.Event) events.get(0); + assertEquals("prompt", promptEvent.type()); + assertEquals("why are ducks better than cats?", promptEvent.toString()); + } + + public static class MockSearchResults extends Searcher { + + @Override + public Result search(Query query, Execution execution) { + Hit hit1 = new Hit("1"); + hit1.setField("title", DOC1_TITLE); + hit1.setField("content", DOC1_CONTENT); + + Hit hit2 = new Hit("2"); + hit2.setField("title", DOC2_TITLE); + hit2.setField("content", DOC2_CONTENT); + + Result r = new Result(query); + r.hits().add(hit1); + r.hits().add(hit2); + return r; + } + } + + private EventStream runRAGQuery(Map<String, String> params) { + var llm = LLMSearcherTest.createLLMClient(); + var searcher = createRAGSearcher(Map.of("mock", llm)); + var result = runMockSearch(searcher, params); + return (EventStream) result.hits().get(0); + } + + static Result runMockSearch(Searcher searcher, Map<String, String> parameters) { + Chain<Searcher> chain = new Chain<>(searcher, new MockSearchResults()); + Execution execution = new Execution(chain, Execution.Context.createContextStub()); + Query query = new Query("?" + LLMSearcherTest.toUrlParams(parameters)); + return execution.search(query); + } + + private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) { + var config = new LlmSearcherConfig.Builder().stream(false).build(); + ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); + llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); + models.freeze(); + return new RAGSearcher(config, models); + } + +} diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 37fe29c2eb6..0b30901bb89 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -406,7 +406,7 @@ public class Flags { "Takes effect immediately"); public static UnboundBooleanFlag CALYPSO_ENABLED = defineFeatureFlag( - "calypso-enabled", false, + "calypso-enabled", true, List.of("mortent"), "2024-02-19", "2024-05-01", "Whether to enable calypso for host", "Takes effect immediately", HOSTNAME); diff --git a/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java b/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java index 38a1b252df9..9479c814e89 100644 --- a/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java +++ b/metrics/src/main/java/ai/vespa/metrics/set/InfrastructureMetricSet.java @@ -81,6 +81,7 @@ public class InfrastructureMetricSet { addMetric(metrics, ConfigServerMetrics.HAS_WIRE_GUARD_KEY.max()); addMetric(metrics, ConfigServerMetrics.WANT_TO_DEPROVISION.max()); addMetric(metrics, ConfigServerMetrics.SUSPENDED.max()); + addMetric(metrics, ConfigServerMetrics.SUSPENDED_SECONDS.count()); addMetric(metrics, ConfigServerMetrics.SOME_SERVICES_DOWN.max()); addMetric(metrics, ConfigServerMetrics.NODE_FAILER_BAD_NODE.max()); addMetric(metrics, ConfigServerMetrics.LOCK_ATTEMPT_LOCKED_LOAD, EnumSet.of(max,average)); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java index 69ae4fddb63..e3d72d1189e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java @@ -169,8 +169,6 @@ public class MetricsReporter extends NodeRepositoryMaintainer { * NB: Keep this metric set in sync with internal configserver metric pre-aggregation */ private void updateNodeMetrics(Node node, ServiceModel serviceModel) { - if (node.state() != State.active) - return; Metric.Context context; Optional<Allocation> allocation = node.allocation(); if (allocation.isPresent()) { @@ -235,7 +233,7 @@ public class MetricsReporter extends NodeRepositoryMaintainer { long suspendedSeconds = info.suspendedSince() .map(suspendedSince -> Duration.between(suspendedSince, clock().instant()).getSeconds()) .orElse(0L); - metric.set(ConfigServerMetrics.SUSPENDED_SECONDS.baseName(), suspendedSeconds, context); + metric.add(ConfigServerMetrics.SUSPENDED_SECONDS.baseName(), suspendedSeconds, context); }); long numberOfServices; diff --git a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp index 4562f8cb50e..a4b05d6540c 100644 --- a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp +++ b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp @@ -10,17 +10,17 @@ constexpr size_t loop_cnt = 64; using namespace search::queryeval; template <typename FLOW> -double ordered_cost_of(const std::vector<FlowStats> &data, bool strict) { - return flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(strict)); +double ordered_cost_of(const std::vector<FlowStats> &data, InFlow in_flow, bool allow_force_strict) { + return flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(in_flow), allow_force_strict); } template <typename FLOW> -double dual_ordered_cost_of(const std::vector<FlowStats> &data, bool strict) { - double result = flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(strict)); - AnyFlow any_flow = AnyFlow::create<FLOW>(strict); +double dual_ordered_cost_of(const std::vector<FlowStats> &data, InFlow in_flow, bool allow_force_strict) { + double result = flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(in_flow), allow_force_strict); + AnyFlow any_flow = AnyFlow::create<FLOW>(in_flow); double total_cost = 0.0; for (const auto &item: data) { - double child_cost = any_flow.strict() ? item.strict_cost : any_flow.flow() * item.cost; + double child_cost = flow::min_child_cost(InFlow(any_flow.strict(), any_flow.flow()), item, allow_force_strict); any_flow.update_cost(total_cost, child_cost); any_flow.add(item.estimate); } @@ -38,8 +38,12 @@ std::vector<FlowStats> gen_data(size_t size) { for (size_t i = 0; i < size; ++i) { result.emplace_back(estimate(gen), cost(gen), strict_cost(gen)); } + if (size == 0) { + gen.seed(gen.default_seed); + } return result; } +void re_seed() { gen_data(0); } template <typename T, typename F> void each_perm(std::vector<T> &data, size_t k, F fun) { @@ -275,16 +279,16 @@ TEST(FlowTest, in_flow_strict_vs_rate_interaction) { TEST(FlowTest, flow_cost) { std::vector<FlowStats> data = {{0.4, 1.1, 0.6}, {0.7, 1.2, 0.5}, {0.2, 1.3, 0.4}}; - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.7*1.3); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, true), 0.6 + 0.5 + 0.4); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.3*1.3); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, false), 1.1); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, true), 0.6); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, false), 1.3); - EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, true), 0.6); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, false, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, true, false), 0.6 + 0.4*1.2 + 0.4*0.7*1.3); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, false, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, true, false), 0.6 + 0.5 + 0.4); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, false, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, true, false), 0.6 + 0.4*1.2 + 0.4*0.3*1.3); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, false, false), 1.1); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<RankFlow>(data, true, false), 0.6); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, false, false), 1.3); + EXPECT_DOUBLE_EQ(dual_ordered_cost_of<BlenderFlow>(data, true, false), 0.6); } TEST(FlowTest, rank_flow_cost_accumulation_is_first) { @@ -322,9 +326,9 @@ TEST(FlowTest, optimal_and_flow) { double min_cost = AndFlow::cost_of(data, strict); double max_cost = 0.0; AndFlow::sort(data, strict); - EXPECT_EQ(ordered_cost_of<AndFlow>(data, strict), min_cost); + EXPECT_EQ(ordered_cost_of<AndFlow>(data, strict, false), min_cost); auto check = [&](const std::vector<FlowStats> &my_data) noexcept { - double my_cost = ordered_cost_of<AndFlow>(my_data, strict); + double my_cost = ordered_cost_of<AndFlow>(my_data, strict, false); EXPECT_LE(min_cost, my_cost); max_cost = std::max(max_cost, my_cost); }; @@ -345,9 +349,9 @@ TEST(FlowTest, optimal_or_flow) { double min_cost = OrFlow::cost_of(data, strict); double max_cost = 0.0; OrFlow::sort(data, strict); - EXPECT_EQ(ordered_cost_of<OrFlow>(data, strict), min_cost); + EXPECT_EQ(ordered_cost_of<OrFlow>(data, strict, false), min_cost); auto check = [&](const std::vector<FlowStats> &my_data) noexcept { - double my_cost = ordered_cost_of<OrFlow>(my_data, strict); + double my_cost = ordered_cost_of<OrFlow>(my_data, strict, false); EXPECT_LE(min_cost, my_cost + 1e-9); max_cost = std::max(max_cost, my_cost); }; @@ -369,10 +373,10 @@ TEST(FlowTest, optimal_and_not_flow) { double max_cost = 0.0; AndNotFlow::sort(data, strict); EXPECT_EQ(data[0], first); - EXPECT_DOUBLE_EQ(ordered_cost_of<AndNotFlow>(data, strict), min_cost); + EXPECT_DOUBLE_EQ(ordered_cost_of<AndNotFlow>(data, strict, false), min_cost); auto check = [&](const std::vector<FlowStats> &my_data) noexcept { if (my_data[0] == first) { - double my_cost = ordered_cost_of<AndNotFlow>(my_data, strict); + double my_cost = ordered_cost_of<AndNotFlow>(my_data, strict, false); EXPECT_LE(min_cost, my_cost + 1e-9); max_cost = std::max(max_cost, my_cost); } @@ -386,4 +390,106 @@ TEST(FlowTest, optimal_and_not_flow) { } } +void test_strict_AND_sort_strategy(auto my_sort) { + re_seed(); + size_t cnt = 64; + double max_rel_err = 0.0; + double sum_rel_err = 0.0; + for (size_t i = 0; i < cnt; ++i) { + auto data = gen_data(7); + double ref_est = AndFlow::estimate_of(data); + double min_cost = 1'000'000.0; + double max_cost = 0.0; + my_sort(data); + double est_cost = ordered_cost_of<AndFlow>(data, true, true); + auto check = [&](const std::vector<FlowStats> &my_data) noexcept { + double my_cost = ordered_cost_of<AndFlow>(my_data, true, true); + min_cost = std::min(min_cost, my_cost); + max_cost = std::max(max_cost, my_cost); + }; + each_perm(data, check); + double rel_err = 0.0; + double cost_range = (max_cost - min_cost); + if (cost_range > 1e-9) { + rel_err = (est_cost - min_cost) / cost_range; + } + max_rel_err = std::max(max_rel_err, rel_err); + sum_rel_err += rel_err; + EXPECT_NEAR(ref_est, AndFlow::estimate_of(data), 1e-9); + } + fprintf(stderr, " strict AND allow_force_strict: avg rel_err: %g, max rel_err: %g\n", + sum_rel_err / cnt, max_rel_err); +} + +TEST(FlowTest, strict_and_with_allow_force_strict_basic_order) { + auto my_sort = [](auto &data){ AndFlow::sort(data, true); }; + test_strict_AND_sort_strategy(my_sort); +} + +void greedy_sort(std::vector<FlowStats> &data, auto flow, auto score_of) { + for (size_t next = 0; (next + 1) < data.size(); ++next) { + InFlow in_flow = InFlow(flow.strict(), flow.flow()); + size_t best_idx = next; + double best_score = score_of(in_flow, data[next]); + for (size_t i = next + 1; i < data.size(); ++i) { + double score = score_of(in_flow, data[i]); + if (score > best_score) { + best_score = score; + best_idx = i; + } + } + std::swap(data[next], data[best_idx]); + flow.add(data[next].estimate); + } +} + +TEST(FlowTest, strict_and_with_allow_force_strict_greedy_reduction_efficiency) { + auto score_of = [](InFlow in_flow, const FlowStats &stats) { + double child_cost = flow::min_child_cost(in_flow, stats, true); + double reduction = (1.0 - stats.estimate); + return reduction / child_cost; + }; + auto my_sort = [&](auto &data){ greedy_sort(data, AndFlow(true), score_of); }; + test_strict_AND_sort_strategy(my_sort); +} + +TEST(FlowTest, strict_and_with_allow_force_strict_incremental_strict_selection) { + auto my_sort = [](auto &data) { + AndFlow::sort(data, true); + for (size_t next = 1; (next + 1) < data.size(); ++next) { + auto [idx, score] = flow::select_forced_strict_and_child(flow::DirectAdapter(), data, next); + if (score >= 0.0) { + break; + } + auto the_one = data[idx]; + for (; idx > next; --idx) { + data[idx] = data[idx-1]; + } + data[next] = the_one; + } + }; + test_strict_AND_sort_strategy(my_sort); +} + +TEST(FlowTest, strict_and_with_allow_force_strict_incremental_strict_selection_with_strict_re_sorting) { + auto my_sort = [](auto &data) { + AndFlow::sort(data, true); + size_t strict_cnt = 1; + while (strict_cnt < data.size()) { + auto [idx, score] = flow::select_forced_strict_and_child(flow::DirectAdapter(), data, strict_cnt); + if (score >= 0.0) { + break; + } + auto the_one = data[idx]; + for (; idx > strict_cnt; --idx) { + data[idx] = data[idx-1]; + } + data[strict_cnt++] = the_one; + } + std::sort(data.begin(), data.begin() + strict_cnt, + [](const auto &a, const auto &b){ return (a.estimate < b.estimate); }); + }; + test_strict_AND_sort_strategy(my_sort); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h index ba0235991a8..f38d447d3b0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/flow.h +++ b/searchlib/src/vespa/searchlib/queryeval/flow.h @@ -5,6 +5,7 @@ #include <cstddef> #include <algorithm> #include <functional> +#include <limits> // Model how boolean result decisions flow through intermediate nodes // of different types based on relative estimates for sub-expressions @@ -34,7 +35,10 @@ struct FlowStats { double strict_cost; constexpr FlowStats(double estimate_in, double cost_in, double strict_cost_in) noexcept : estimate(estimate_in), cost(cost_in), strict_cost(strict_cost_in) {} - auto operator <=>(const FlowStats &rhs) const noexcept = default; + constexpr auto operator <=>(const FlowStats &rhs) const noexcept = default; + static constexpr FlowStats from(auto adapter, const auto &child) noexcept { + return {adapter.estimate(child), adapter.cost(child), adapter.strict_cost(child)}; + } }; namespace flow { @@ -117,6 +121,23 @@ struct MinOrCost { } }; +// estimate the cost of evaluating a strict child in a non-strict context +inline double forced_strict_cost(double estimate, double strict_cost, double rate) { + return 0.2 * (rate - estimate) + strict_cost; +} + +// estimate the absolute cost of evaluating a child with a specific in flow +inline double min_child_cost(InFlow in_flow, const FlowStats &stats, bool allow_force_strict) { + if (in_flow.strict()) { + return stats.strict_cost; + } + if (!allow_force_strict) { + return stats.cost * in_flow.rate(); + } + return std::min(forced_strict_cost(stats.estimate, stats.strict_cost, in_flow.rate()), + stats.cost * in_flow.rate()); +} + template <typename ADAPTER, typename T> double estimate_of_and(ADAPTER adapter, const T &children) { double flow = children.empty() ? 0.0 : adapter.estimate(children[0]); @@ -157,43 +178,59 @@ void sort_partial(ADAPTER adapter, T &children, size_t offset) { } template <typename ADAPTER, typename T, typename F> -double ordered_cost_of(ADAPTER adapter, const T &children, F flow) { +double ordered_cost_of(ADAPTER adapter, const T &children, F flow, bool allow_force_strict) { double total_cost = 0.0; for (const auto &child: children) { - double child_cost = flow.strict() ? adapter.strict_cost(child) : (flow.flow() * adapter.cost(child)); + FlowStats stats(adapter.estimate(child), adapter.cost(child), adapter.strict_cost(child)); + double child_cost = min_child_cost(InFlow(flow.strict(), flow.flow()), stats, allow_force_strict); flow.update_cost(total_cost, child_cost); - flow.add(adapter.estimate(child)); + flow.add(stats.estimate); } return total_cost; } -template <typename ADAPTER, typename T> -size_t select_strict_and_child(ADAPTER adapter, const T &children) { - size_t idx = 0; +size_t select_strict_and_child(auto adapter, const auto &children) { + double est = 1.0; double cost = 0.0; size_t best_idx = 0; - double best_diff = 0.0; - double est = 1.0; - for (const auto &child: children) { - double child_cost = est * adapter.cost(child); - double child_strict_cost = adapter.strict_cost(child); - double child_est = adapter.estimate(child); - if (idx == 0) { - best_diff = child_strict_cost - child_cost; - } else { - double my_diff = (child_strict_cost + child_est * cost) - (cost + child_cost); - if (my_diff < best_diff) { - best_diff = my_diff; - best_idx = idx; - } + double best_diff = std::numeric_limits<double>::max(); + for (size_t idx = 0; idx < children.size(); ++idx) { + auto child = FlowStats::from(adapter, children[idx]); + double child_abs_cost = est * child.cost; + double my_diff = (child.strict_cost + child.estimate * cost) - (cost + child_abs_cost); + if (my_diff < best_diff) { + best_diff = my_diff; + best_idx = idx; } - cost += child_cost; - est *= child_est; - ++idx; + cost += child_abs_cost; + est *= child.estimate; } return best_idx; } +auto select_forced_strict_and_child(auto adapter, const auto &children, size_t first) { + double est = 1.0; + double cost = 0.0; + size_t best_idx = 0; + double best_diff = std::numeric_limits<double>::max(); + for (size_t idx = 0; idx < first && idx < children.size(); ++idx) { + est *= adapter.estimate(children[idx]); + } + for (size_t idx = first; idx < children.size(); ++idx) { + auto child = FlowStats::from(adapter, children[idx]); + double child_abs_cost = est * child.cost; + double forced_cost = forced_strict_cost(child.estimate, child.strict_cost, est); + double my_diff = (forced_cost + child.estimate * cost) - (cost + child_abs_cost); + if (my_diff < best_diff) { + best_diff = my_diff; + best_idx = idx; + } + cost += child_abs_cost; + est *= child.estimate; + } + return std::make_pair(best_idx, best_diff); +} + } // flow template <typename FLOW> @@ -202,7 +239,7 @@ struct FlowMixin { auto my_adapter = flow::IndirectAdapter(adapter, children); auto order = flow::make_index(children.size()); FLOW::sort(my_adapter, order, strict); - return flow::ordered_cost_of(my_adapter, order, FLOW(strict)); + return flow::ordered_cost_of(my_adapter, order, FLOW(strict), false); } static double cost_of(const auto &children, bool strict) { return cost_of(flow::make_adapter(children), children, strict); diff --git a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp index d8a6091fe11..5988bdd912f 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.cpp @@ -4,6 +4,19 @@ namespace vsm { +namespace { + +template <bool exact_match> inline bool is_word_char(ucs4_t c); + +template <> +inline bool is_word_char<false>(ucs4_t c) { return Fast_UnicodeUtil::IsWordChar(c); } + +// All characters are treated as word characters for exact match +template <> +inline constexpr bool is_word_char<true>(ucs4_t) { return true; } + +} + void TokenizeReader::fold(ucs4_t c) { const char *repl = Fast_NormalizeWordFolder::ReplacementString(c); @@ -18,4 +31,24 @@ TokenizeReader::fold(ucs4_t c) { } } +template <bool exact_match> +size_t +TokenizeReader::tokenize_helper(Normalizing norm_mode) +{ + ucs4_t c(0); + while (hasNext()) { + if (is_word_char<exact_match>(c = next())) { + normalize(c, norm_mode); + while (hasNext() && is_word_char<exact_match>(c = next())) { + normalize(c, norm_mode); + } + break; + } + } + return complete(); +} + +template size_t TokenizeReader::tokenize_helper<false>(Normalizing); +template size_t TokenizeReader::tokenize_helper<true>(Normalizing); + } diff --git a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h index 2bb5e62e0aa..f680d9b6c47 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h +++ b/streamingvisitors/src/vespa/vsm/searcher/tokenizereader.h @@ -43,6 +43,10 @@ public: _q = _q_start; return token_len; } + template <bool exact_match> + size_t tokenize_helper(Normalizing norm_mode); + size_t tokenize(Normalizing norm_mode) { return tokenize_helper<false>(norm_mode); } + size_t tokenize_exact_match(Normalizing norm_mode) { return tokenize_helper<true>(norm_mode); } private: void fold(ucs4_t c); const byte *_p; diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp index 37dc4ffb99c..c860178d583 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8strchrfieldsearcher.cpp @@ -26,8 +26,7 @@ UTF8StrChrFieldSearcher::matchTerms(const FieldRef & f, size_t mintsz) TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), fn); while ( reader.hasNext() ) { - tokenize(reader); - size_t fl = reader.complete(); + size_t fl = reader.tokenize(normalize_mode()); for (auto qt : _qtl) { const cmptype_t * term; termsize_t tsz = qt->term(term); diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp index 5036e9bedb1..f016d08ece8 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.cpp @@ -10,21 +10,6 @@ using search::byte; namespace vsm { -template<typename Reader> -void -UTF8StringFieldSearcherBase::tokenize(Reader & reader) { - ucs4_t c(0); - Normalizing norm_mode = normalize_mode(); - while (reader.hasNext() && ! Fast_UnicodeUtil::IsWordChar(c = reader.next())); - - if (Fast_UnicodeUtil::IsWordChar(c)) { - reader.normalize(c, norm_mode); - while (reader.hasNext() && Fast_UnicodeUtil::IsWordChar(c = reader.next())) { - reader.normalize(c, norm_mode); - } - } -} - size_t UTF8StringFieldSearcherBase::matchTermRegular(const FieldRef & f, QueryTerm & qt) { @@ -38,8 +23,7 @@ UTF8StringFieldSearcherBase::matchTermRegular(const FieldRef & f, QueryTerm & qt TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), fn); while ( reader.hasNext() ) { - tokenize(reader); - size_t fl = reader.complete(); + size_t fl = reader.tokenize(normalize_mode()); if ((tsz <= fl) && (prefix() || qt.isPrefix() || (tsz == fl))) { const cmptype_t *tt=term, *et=term+tsz; for (const cmptype_t *fnt=fn; (tt < et) && (*tt == *fnt); tt++, fnt++); @@ -127,8 +111,7 @@ UTF8StringFieldSearcherBase::matchTermSuffix(const FieldRef & f, QueryTerm & qt) TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), dstbuf); while ( reader.hasNext() ) { - tokenize(reader); - size_t tokenlen = reader.complete(); + size_t tokenlen = reader.tokenize(normalize_mode()); if (matchTermSuffix(term, tsz, dstbuf, tokenlen)) { addHit(qt, words); } diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h index b196f2795a4..c217a7b8866 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8stringfieldsearcherbase.h @@ -60,9 +60,6 @@ public: protected: SharedSearcherBuf _buf; - template<typename Reader> - void tokenize(Reader & reader); - /** * Matches the given query term against the words in the given field reference * using exact or prefix match strategy. diff --git a/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp index 8bbacf168cf..d5bf4e4238a 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/utf8suffixstringfieldsearcher.cpp @@ -26,8 +26,7 @@ UTF8SuffixStringFieldSearcher::matchTerms(const FieldRef & f, size_t mintsz) TokenizeReader reader(reinterpret_cast<const byte *> (f.data()), f.size(), dstbuf); while ( reader.hasNext() ) { - tokenize(reader); - size_t tokenlen = reader.complete(); + size_t tokenlen = reader.tokenize(normalize_mode()); for (auto qt : _qtl) { const cmptype_t * term; termsize_t tsz = qt->term(term); diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp index 1ab1b16cb86..1986db79972 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp @@ -273,8 +273,11 @@ buildFieldSet(const VsmfieldsConfig::Documenttype::Index & ci, const FieldSearch return ifm; } +} + search::Normalizing -normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode) { +FieldSearchSpecMap::convert_normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode) +{ switch (normalize_mode) { case VsmfieldsConfig::Fieldspec::Normalize::NONE: return search::Normalizing::NONE; case VsmfieldsConfig::Fieldspec::Normalize::LOWERCASE: return search::Normalizing::LOWERCASE; @@ -283,8 +286,6 @@ normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode) { return search::Normalizing::LOWERCASE_AND_FOLD; } -} - void FieldSearchSpecMap::buildFromConfig(const VsmfieldsHandle & conf, const search::fef::IIndexEnvironment& index_env) { @@ -292,7 +293,7 @@ FieldSearchSpecMap::buildFromConfig(const VsmfieldsHandle & conf, const search:: for(const VsmfieldsConfig::Fieldspec & cfs : conf->fieldspec) { LOG(spam, "Parsing %s", cfs.name.c_str()); FieldIdT fieldId = specMap().size(); - FieldSearchSpec fss(fieldId, cfs.name, cfs.searchmethod, normalize_mode(cfs.normalize), cfs.arg1, cfs.maxlength); + FieldSearchSpec fss(fieldId, cfs.name, cfs.searchmethod, convert_normalize_mode(cfs.normalize), cfs.arg1, cfs.maxlength); _specMap[fieldId] = std::move(fss); _nameIdMap.add(cfs.name, fieldId); LOG(spam, "M in %d = %s", fieldId, cfs.name.c_str()); diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h index 5b5a6b9a783..8bab0cad3b6 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h @@ -101,6 +101,7 @@ public: static vespalib::string stripNonFields(vespalib::stringref rawIndex); search::attribute::DistanceMetric get_distance_metric(const vespalib::string& name) const; + static search::Normalizing convert_normalize_mode(VsmfieldsConfig::Fieldspec::Normalize normalize_mode); private: FieldSearchSpecMapT _specMap; // mapping from field id to field search spec diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 754352a45a4..45e88ac2e94 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -4077,6 +4077,26 @@ ], "fields" : [ ] }, + "ai.vespa.llm.InferenceParameters" : { + "superClass" : "java.lang.Object", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(java.util.function.Function)", + "public void <init>(java.lang.String, java.util.function.Function)", + "public void setApiKey(java.lang.String)", + "public java.util.Optional getApiKey()", + "public void setEndpoint(java.lang.String)", + "public java.util.Optional getEndpoint()", + "public java.util.Optional get(java.lang.String)", + "public java.util.Optional getDouble(java.lang.String)", + "public java.util.Optional getInt(java.lang.String)", + "public void ifPresent(java.lang.String, java.util.function.Consumer)" + ], + "fields" : [ ] + }, "ai.vespa.llm.LanguageModel" : { "superClass" : "java.lang.Object", "interfaces" : [ ], @@ -4086,23 +4106,20 @@ "abstract" ], "methods" : [ - "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt)", - "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)" + "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] }, - "ai.vespa.llm.client.openai.OpenAiClient$Builder" : { - "superClass" : "java.lang.Object", + "ai.vespa.llm.LanguageModelException" : { + "superClass" : "java.lang.RuntimeException", "interfaces" : [ ], "attributes" : [ "public" ], "methods" : [ - "public void <init>(java.lang.String)", - "public ai.vespa.llm.client.openai.OpenAiClient$Builder model(java.lang.String)", - "public ai.vespa.llm.client.openai.OpenAiClient$Builder temperature(double)", - "public ai.vespa.llm.client.openai.OpenAiClient$Builder maxTokens(long)", - "public ai.vespa.llm.client.openai.OpenAiClient build()" + "public void <init>(int, java.lang.String)", + "public int code()" ], "fields" : [ ] }, @@ -4115,8 +4132,9 @@ "public" ], "methods" : [ - "public java.util.List complete(ai.vespa.llm.completion.Prompt)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)" + "public void <init>()", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] }, @@ -4135,7 +4153,8 @@ "fields" : [ "public static final enum ai.vespa.llm.completion.Completion$FinishReason length", "public static final enum ai.vespa.llm.completion.Completion$FinishReason stop", - "public static final enum ai.vespa.llm.completion.Completion$FinishReason none" + "public static final enum ai.vespa.llm.completion.Completion$FinishReason none", + "public static final enum ai.vespa.llm.completion.Completion$FinishReason error" ] }, "ai.vespa.llm.completion.Completion" : { @@ -4151,6 +4170,7 @@ "public java.lang.String text()", "public ai.vespa.llm.completion.Completion$FinishReason finishReason()", "public static ai.vespa.llm.completion.Completion from(java.lang.String)", + "public static ai.vespa.llm.completion.Completion from(java.lang.String, ai.vespa.llm.completion.Completion$FinishReason)", "public final java.lang.String toString()", "public final int hashCode()", "public final boolean equals(java.lang.Object)" @@ -4212,8 +4232,8 @@ ], "methods" : [ "public void <init>(ai.vespa.llm.test.MockLanguageModel$Builder)", - "public java.util.List complete(ai.vespa.llm.completion.Prompt)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)" + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] } diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java new file mode 100755 index 00000000000..a942e5090e5 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java @@ -0,0 +1,76 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Parameters for inference to language models. Parameters are typically + * supplied from searchers or processors and comes from query strings, + * headers, or other sources. Which parameters are available depends on + * the language model used. + * + * author lesters + */ +@Beta +public class InferenceParameters { + + private String apiKey; + private String endpoint; + private final Function<String, String> options; + + public InferenceParameters(Function<String, String> options) { + this(null, options); + } + + public InferenceParameters(String apiKey, Function<String, String> options) { + this.apiKey = apiKey; + this.options = Objects.requireNonNull(options); + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public Optional<String> getApiKey() { + return Optional.ofNullable(apiKey); + } + + public void setEndpoint(String endpoint) { + this.endpoint = endpoint; + } + + public Optional<String> getEndpoint() { + return Optional.ofNullable(endpoint); + } + + public Optional<String> get(String option) { + return Optional.ofNullable(options.apply(option)); + } + + public Optional<Double> getDouble(String option) { + try { + return Optional.of(Double.parseDouble(options.apply(option))); + } catch (Exception e) { + return Optional.empty(); + } + } + + public Optional<Integer> getInt(String option) { + try { + return Optional.of(Integer.parseInt(options.apply(option))); + } catch (Exception e) { + return Optional.empty(); + } + } + + public void ifPresent(String option, Consumer<String> func) { + get(option).ifPresent(func); + } + +} + diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java index f4b8938934b..059f25fadb4 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java @@ -17,8 +17,10 @@ import java.util.function.Consumer; @Beta public interface LanguageModel { - List<Completion> complete(Prompt prompt); + List<Completion> complete(Prompt prompt, InferenceParameters options); - CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action); + CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> consumer); } diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java new file mode 100755 index 00000000000..b5dbf615c08 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java @@ -0,0 +1,19 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +@Beta +public class LanguageModelException extends RuntimeException { + + private final int code; + + public LanguageModelException(int code, String message) { + super(message); + this.code = code; + } + + public int code() { + return code; + } +} diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index d7334b40963..75308a84faa 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.client.openai; +import ai.vespa.llm.LanguageModelException; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.LanguageModel; import ai.vespa.llm.completion.Prompt; @@ -28,31 +30,28 @@ import java.util.stream.Stream; * Currently, only completions are implemented. * * @author bratseth + * @author lesters */ @Beta public class OpenAiClient implements LanguageModel { + private static final String DEFAULT_MODEL = "gpt-3.5-turbo"; private static final String DATA_FIELD = "data: "; - private final String token; - private final String model; - private final double temperature; - private final long maxTokens; + private static final String OPTION_MODEL = "model"; + private static final String OPTION_TEMPERATURE = "temperature"; + private static final String OPTION_MAX_TOKENS = "maxTokens"; private final HttpClient httpClient; - private OpenAiClient(Builder builder) { - this.token = builder.token; - this.model = builder.model; - this.temperature = builder.temperature; - this.maxTokens = builder.maxTokens; + public OpenAiClient() { this.httpClient = HttpClient.newBuilder().build(); } @Override - public List<Completion> complete(Prompt prompt) { + public List<Completion> complete(Prompt prompt, InferenceParameters options) { try { - HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray()); + HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray()); var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); if ( httpResponse.statusCode() != 200) throw new IllegalArgumentException(SlimeUtils.toJson(response)); @@ -64,9 +63,11 @@ public class OpenAiClient implements LanguageModel { } @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) { + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> consumer) { try { - var request = toRequest(prompt, true); + var request = toRequest(prompt, options, true); var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); var completionFuture = new CompletableFuture<Completion.FinishReason>(); @@ -74,8 +75,7 @@ public class OpenAiClient implements LanguageModel { try { int responseCode = response.statusCode(); if (responseCode != 200) { - throw new IllegalArgumentException("Received code " + responseCode + ": " + - response.body().collect(Collectors.joining())); + throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining())); } Stream<String> lines = response.body(); @@ -100,28 +100,28 @@ public class OpenAiClient implements LanguageModel { } } - private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { - return toRequest(prompt, false); - } - - private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException { + private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException { var slime = new Slime(); var root = slime.setObject(); - root.setString("model", model); - root.setDouble("temperature", temperature); + root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL)); root.setBool("stream", stream); root.setLong("n", 1); - if (maxTokens > 0) { - root.setLong("max_tokens", maxTokens); - } + + if (options.getDouble(OPTION_TEMPERATURE).isPresent()) + root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get()); + if (options.getInt(OPTION_MAX_TOKENS).isPresent()) + root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get()); + // Others? + var messagesArray = root.setArray("messages"); var messagesObject = messagesArray.addObject(); messagesObject.setString("role", "user"); messagesObject.setString("content", prompt.asString()); - return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions")) + var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions"); + return HttpRequest.newBuilder(new URI(endpoint)) .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + token) + .header("Authorization", "Bearer " + options.getApiKey().orElse("")) .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) .build(); } @@ -152,39 +152,4 @@ public class OpenAiClient implements LanguageModel { }; } - public static class Builder { - - private final String token; - private String model = "gpt-3.5-turbo"; - private double temperature = 0.0; - private long maxTokens = 0; - - public Builder(String token) { - this.token = token; - } - - /** One of the language models listed at https://platform.openai.com/docs/models */ - public Builder model(String model) { - this.model = model; - return this; - } - - /** A value between 0 and 2 - higher gives more random/creative output. */ - public Builder temperature(double temperature) { - this.temperature = temperature; - return this; - } - - /** Maximum number of tokens to generate */ - public Builder maxTokens(long maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public OpenAiClient build() { - return new OpenAiClient(this); - } - - } - } diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java index ea784013812..91d0ad9bd02 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java @@ -22,7 +22,10 @@ public record Completion(String text, FinishReason finishReason) { stop, /** The completion is not finished yet, more tokens are incoming. */ - none + none, + + /** An error occurred while generating the completion */ + error } public Completion(String text, FinishReason finishReason) { @@ -37,7 +40,11 @@ public record Completion(String text, FinishReason finishReason) { public FinishReason finishReason() { return finishReason; } public static Completion from(String text) { - return new Completion(text, FinishReason.stop); + return from(text, FinishReason.stop); + } + + public static Completion from(String text, FinishReason reason) { + return new Completion(text, reason); } } diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java index db1b42fbbac..0e757a1f1e7 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java @@ -2,6 +2,7 @@ package ai.vespa.llm.test; import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.api.annotations.Beta; @@ -24,12 +25,14 @@ public class MockLanguageModel implements LanguageModel { } @Override - public List<Completion> complete(Prompt prompt) { + public List<Completion> complete(Prompt prompt, InferenceParameters options) { return completer.apply(prompt); } @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action) { + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> action) { throw new RuntimeException("Not implemented"); } diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java index 45ef7e270aa..1baab26f496 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java @@ -1,46 +1,46 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.client.openai; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.StringPrompt; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import java.util.Map; + /** * @author bratseth */ public class OpenAiClientCompletionTest { - private static final String apiKey = "your-api-key-here"; + private static final String apiKey = "<your-api-key-here>"; @Test @Disabled public void testClient() { - var client = new OpenAiClient.Builder(apiKey).maxTokens(10).build(); - String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?\n\n"; - StringPrompt prompt = StringPrompt.from(input); + var client = new OpenAiClient(); + var options = Map.of("maxTokens", "10"); + var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"); + System.out.print(prompt); - for (int i = 0; i < 10; i++) { - var completion = client.complete(prompt).get(0); - System.out.print(completion.text()); - if (completion.finishReason() == Completion.FinishReason.stop) break; - prompt = prompt.append(completion.text()); - } + var completion = client.complete(prompt, new InferenceParameters(apiKey, options::get)).get(0); + System.out.print(completion.text()); } @Test @Disabled public void testAsyncClient() { - var client = new OpenAiClient.Builder(apiKey).build(); - String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?\n\n"; - StringPrompt prompt = StringPrompt.from(input); + var client = new OpenAiClient(); + var options = Map.of("maxTokens", "10"); + var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"); System.out.print(prompt); - var future = client.completeAsync(prompt, completion -> { + var future = client.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { System.out.print(completion.text()); }); - System.out.println("Waiting for completion..."); + System.out.println("\nWaiting for completion...\n\n"); System.out.println("\nFinished streaming because of " + future.join()); } diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java index 7407eb526e7..24c496a3d2c 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.completion; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.test.MockLanguageModel; import org.junit.jupiter.api.Test; @@ -27,8 +28,9 @@ public class CompletionTest { String input = "Complete this: "; StringPrompt prompt = StringPrompt.from(input); + InferenceParameters options = new InferenceParameters(s -> ""); for (int i = 0; i < 10; i++) { - var completion = llm.complete(prompt).get(0); + var completion = llm.complete(prompt, options).get(0); prompt = prompt.append(completion); if (completion.finishReason() == Completion.FinishReason.stop) break; } |