diff options
50 files changed, 976 insertions, 236 deletions
diff --git a/client/go/internal/admin/vespa-wrapper/standalone/start.go b/client/go/internal/admin/vespa-wrapper/standalone/start.go index a3703ce930c..16e76562b99 100644 --- a/client/go/internal/admin/vespa-wrapper/standalone/start.go +++ b/client/go/internal/admin/vespa-wrapper/standalone/start.go @@ -41,6 +41,7 @@ func StartStandaloneContainer(extraArgs []string) int { c := jvm.NewStandaloneContainer(serviceName) jvmOpts := c.JvmOptions() jvmOpts.AddOption("-DOnnxBundleActivator.skip=true") + jvmOpts.AddOption("-DLlamaBundleActivator.skip=true") for _, extra := range extraArgs { jvmOpts.AddOption(extra) } diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml index eff4e4125e9..98bef7df402 100644 --- a/cloud-tenant-base-dependencies-enforcer/pom.xml +++ b/cloud-tenant-base-dependencies-enforcer/pom.xml @@ -141,6 +141,7 @@ <include>com.microsoft.onnxruntime:onnxruntime:jar:${onnxruntime.vespa.version}:test</include> <include>com.thaiopensource:jing:20091111:test</include> <include>commons-codec:commons-codec:${commons-codec.vespa.version}:test</include> + <include>de.kherud:llama:${kherud.llama.vespa.version}:test</include> <include>io.airlift:aircompressor:${aircompressor.vespa.version}:test</include> <include>io.airlift:airline:${airline.vespa.version}:test</include> <include>io.prometheus:simpleclient:${prometheus.client.vespa.version}:test</include> diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json index c416a5e3a0b..42e7e23dfcc 100644 --- a/config-model-api/abi-spec.json +++ b/config-model-api/abi-spec.json @@ -1862,4 +1862,4 @@ "public final java.lang.String serviceName" ] } -} +}
\ No newline at end of file diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java index 62979404025..5be1690f0dc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java @@ -31,6 +31,7 @@ public class ContainerModelEvaluation implements public final static String EVALUATION_BUNDLE_NAME = "model-evaluation"; public final static String INTEGRATION_BUNDLE_NAME = "model-integration"; public final static String ONNXRUNTIME_BUNDLE_NAME = "container-onnxruntime.jar"; + public final static String LLAMA_BUNDLE_NAME = "container-llama.jar"; public final static String ONNX_RUNTIME_CLASS = "ai.vespa.modelintegration.evaluator.OnnxRuntime"; private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java index 9f91f6bf5e1..468cf8dd961 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java @@ -14,6 +14,7 @@ import java.util.stream.Stream; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.EVALUATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LLAMA_BUNDLE_NAME; import static com.yahoo.vespa.model.container.ContainerModelEvaluation.ONNXRUNTIME_BUNDLE_NAME; /** @@ -37,6 +38,7 @@ public class PlatformBundles { public static final Path LIBRARY_PATH = Paths.get(Defaults.getDefaults().underVespaHome("lib/jars")); public static final String SEARCH_AND_DOCPROC_BUNDLE = BundleInstantiationSpecification.CONTAINER_SEARCH_AND_DOCPROC; + public static final String MODEL_INTEGRATION_BUNDLE = BundleInstantiationSpecification.MODEL_INTEGRATION; // Bundles that must be loaded for all container types. public static final Set<Path> COMMON_VESPA_BUNDLES = toBundlePaths( @@ -63,7 +65,8 @@ public class PlatformBundles { "lucene-linguistics", EVALUATION_BUNDLE_NAME, INTEGRATION_BUNDLE_NAME, - ONNXRUNTIME_BUNDLE_NAME + ONNXRUNTIME_BUNDLE_NAME, + LLAMA_BUNDLE_NAME ); private static Set<Path> toBundlePaths(String... bundleNames) { @@ -86,6 +89,10 @@ public class PlatformBundles { return searchAndDocprocComponents.contains(className); } + public static boolean isModelIntegrationClass(String className) { + return modelIntegrationComponents.contains(className); + } + // This is a hack to allow users to declare components from the search-and-docproc bundle without naming the bundle. private static final Set<String> searchAndDocprocComponents = Set.of( com.yahoo.docproc.AbstractConcreteDocumentFactory.class.getName(), @@ -147,8 +154,13 @@ public class PlatformBundles { com.yahoo.vespa.streamingvisitors.MetricsSearcher.class.getName(), com.yahoo.vespa.streamingvisitors.StreamingBackend.class.getName(), ai.vespa.search.llm.LLMSearcher.class.getName(), - ai.vespa.search.llm.RAGSearcher.class.getName(), - ai.vespa.llm.clients.OpenAI.class.getName() + ai.vespa.search.llm.RAGSearcher.class.getName() + ); + + // This is a hack to allow users to declare components from the model-integration bundle without naming the bundle. + private static final Set<String> modelIntegrationComponents = Set.of( + ai.vespa.llm.clients.OpenAI.class.getName(), + ai.vespa.llm.clients.LocalLLM.class.getName() ); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java index 7e14eafc2ee..1323506eaeb 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java @@ -26,17 +26,18 @@ public class BundleInstantiationSpecificationBuilder { BundleInstantiationSpecification instSpec = new BundleInstantiationSpecification(id, classId, bundle); validate(instSpec); - return bundle == null ? setBundleForSearchAndDocprocComponents(instSpec) : instSpec; + return bundle == null ? setBundleForComponent(instSpec) : instSpec; } - private static BundleInstantiationSpecification setBundleForSearchAndDocprocComponents(BundleInstantiationSpecification spec) { + private static BundleInstantiationSpecification setBundleForComponent(BundleInstantiationSpecification spec) { if (PlatformBundles.isSearchAndDocprocClass(spec.getClassName())) return spec.inBundle(PlatformBundles.SEARCH_AND_DOCPROC_BUNDLE); + else if (PlatformBundles.isModelIntegrationClass(spec.getClassName())) + return spec.inBundle(PlatformBundles.MODEL_INTEGRATION_BUNDLE); else return spec; } - private static void validate(BundleInstantiationSpecification instSpec) { List<String> forbiddenClasses = List.of(SearchHandler.HANDLER_CLASSNAME, PROCESSING_HANDLER_CLASS); @@ -47,7 +48,7 @@ public class BundleInstantiationSpecificationBuilder { } } - //null if missing + // null if missing private static ComponentSpecification getComponentSpecification(Element spec, String attributeName) { return (spec.hasAttribute(attributeName)) ? new ComponentSpecification(spec.getAttribute(attributeName)) : diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index f9993b770e5..867ac86f8d5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -26,6 +26,7 @@ public class ModelIdResolver { public static final String ONNX_MODEL = "onnx-model"; public static final String BERT_VOCAB = "bert-vocabulary"; public static final String SIGNIFICANCE_MODEL = "significance-model"; + public static final String GGUF_MODEL = "gguf-model"; private static Map<String, ProvidedModel> setupProvidedModels() { var m = new HashMap<String, ProvidedModel>(); @@ -60,6 +61,9 @@ public class ModelIdResolver { register(m, "e5-large-v2", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/model.onnx", Set.of(ONNX_MODEL)); register(m, "e5-large-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/tokenizer.json", Set.of(HF_TOKENIZER)); + + register(m, "mistral-7b", "https://data.vespa.oath.cloud/gguf_models/mistral-7b-instruct-v0.1.Q6_K.gguf", Set.of(GGUF_MODEL)); + register(m, "mistral-7b-q8", "https://data.vespa.oath.cloud/gguf_models/mistral-7b-instruct-v0.1.Q8_0.gguf", Set.of(GGUF_MODEL)); return Map.copyOf(m); } @@ -124,7 +128,7 @@ public class ModelIdResolver { throw new IllegalArgumentException("Unknown model id '" + modelId + "' on '" + valueName + "'. Available models are [" + providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]"); var providedModel = providedModels.get(modelId); - if (!providedModel.tags().containsAll(requiredTags)) { + if ( ! providedModel.tags().containsAll(requiredTags)) { throw new IllegalArgumentException( "Model '%s' on '%s' has tags %s but are missing required tags %s" .formatted(modelId, valueName, providedModel.tags(), requiredTags)); diff --git a/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java b/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java index bd35d257813..b49f519906f 100644 --- a/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java +++ b/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java @@ -15,6 +15,7 @@ import com.yahoo.component.ComponentSpecification; public final class BundleInstantiationSpecification { public static final String CONTAINER_SEARCH_AND_DOCPROC = "container-search-and-docproc"; + public static final String MODEL_INTEGRATION = "model-integration"; public final ComponentId id; public final ComponentSpecification classId; diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml index a06365abbeb..f67f33a3b05 100644 --- a/container-dependencies-enforcer/pom.xml +++ b/container-dependencies-enforcer/pom.xml @@ -154,6 +154,7 @@ <include>com.microsoft.onnxruntime:onnxruntime:${onnxruntime.vespa.version}:test</include> <include>com.thaiopensource:jing:20091111:test</include> <include>commons-codec:commons-codec:${commons-codec.vespa.version}:test</include> + <include>de.kherud:llama:${kherud.llama.vespa.version}:test</include> <include>io.airlift:aircompressor:${aircompressor.vespa.version}:test</include> <include>io.airlift:airline:${airline.vespa.version}:test</include> <include>io.prometheus:simpleclient:${prometheus.client.vespa.version}:test</include> diff --git a/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java b/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java index 11ba05e363d..846a2008858 100644 --- a/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java +++ b/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java @@ -13,12 +13,19 @@ import java.util.logging.Logger; **/ public class LlamaBundleActivator implements BundleActivator { + private static final String SKIP_SUFFIX = ".skip"; + private static final String SKIP_VALUE = "true"; private static final String PATH_PROPNAME = "de.kherud.llama.lib.path"; private static final Logger log = Logger.getLogger(LlamaBundleActivator.class.getName()); @Override public void start(BundleContext ctx) { log.fine("start bundle"); + String skipAll = LlamaBundleActivator.class.getSimpleName() + SKIP_SUFFIX; + if (SKIP_VALUE.equals(System.getProperty(skipAll))) { + log.info("skip loading of native libraries"); + return; + } if (checkFilenames( "/dev/nvidia0", "/opt/vespa-deps/lib64/cuda/libllama.so", diff --git a/container-llama/src/main/java/de/kherud/llama/package-info.java b/container-llama/src/main/java/de/kherud/llama/package-info.java new file mode 100644 index 00000000000..3c9773762b4 --- /dev/null +++ b/container-llama/src/main/java/de/kherud/llama/package-info.java @@ -0,0 +1,8 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * @author lesters + */ +@ExportPackage +package de.kherud.llama; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search-and-docproc/pom.xml b/container-search-and-docproc/pom.xml index 9554b517586..e2afa0e91f4 100644 --- a/container-search-and-docproc/pom.xml +++ b/container-search-and-docproc/pom.xml @@ -210,6 +210,18 @@ <version>${project.version}</version> <scope>provided</scope> </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>model-integration</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>container-llama</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> <!-- TEST scope --> <dependency> diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index e74fe22c588..07f0449e61a 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -7842,6 +7842,21 @@ "public static final int emptyDocsumsCode" ] }, + "com.yahoo.search.result.EventStream$ErrorEvent" : { + "superClass" : "com.yahoo.search.result.EventStream$Event", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)", + "public java.lang.String source()", + "public int code()", + "public java.lang.String message()", + "public com.yahoo.search.result.Hit asHit()" + ], + "fields" : [ ] + }, "com.yahoo.search.result.EventStream$Event" : { "superClass" : "com.yahoo.component.provider.ListenableFreezableClass", "interfaces" : [ @@ -9149,99 +9164,6 @@ ], "fields" : [ ] }, - "ai.vespa.llm.clients.ConfigurableLanguageModel" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public", - "abstract" - ], - "methods" : [ - "public void <init>()", - "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", - "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", - "protected void setApiKey(ai.vespa.llm.InferenceParameters)", - "protected java.lang.String getEndpoint()", - "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Builder" - ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public void <init>()", - "public void <init>(ai.vespa.llm.clients.LlmClientConfig)", - "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", - "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", - "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", - "public final java.lang.String getDefMd5()", - "public final java.lang.String getDefName()", - "public final java.lang.String getDefNamespace()", - "public final boolean getApplyOnRestart()", - "public final void setApplyOnRestart(boolean)", - "public ai.vespa.llm.clients.LlmClientConfig build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig$Producer" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigInstance$Producer" - ], - "attributes" : [ - "public", - "interface", - "abstract" - ], - "methods" : [ - "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.clients.LlmClientConfig" : { - "superClass" : "com.yahoo.config.ConfigInstance", - "interfaces" : [ ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public static java.lang.String getDefMd5()", - "public static java.lang.String getDefName()", - "public static java.lang.String getDefNamespace()", - "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)", - "public java.lang.String apiKeySecretName()", - "public java.lang.String endpoint()" - ], - "fields" : [ - "public static final java.lang.String CONFIG_DEF_MD5", - "public static final java.lang.String CONFIG_DEF_NAME", - "public static final java.lang.String CONFIG_DEF_NAMESPACE", - "public static final java.lang.String[] CONFIG_DEF_SCHEMA" - ] - }, - "ai.vespa.llm.clients.OpenAI" : { - "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", - "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" - ], - "fields" : [ ] - }, "ai.vespa.search.llm.LLMSearcher" : { "superClass" : "com.yahoo.search.Searcher", "interfaces" : [ ], diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 860fc69af91..f565315b775 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -20,6 +20,7 @@ import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; import java.util.List; +import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -83,27 +84,41 @@ public class LLMSearcher extends Searcher { protected Result complete(Query query, Prompt prompt) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + try { + return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + } catch (RejectedExecutionException e) { + return new Result(query, new ErrorMessage(429, e.getMessage())); + } + } + + private boolean shouldAddPrompt(Query query) { + return query.getTrace().getLevel() >= 1; + } + + private boolean shouldAddTokenStats(Query query) { + return query.getTrace().getLevel() >= 1; } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - EventStream eventStream = new EventStream(); + final EventStream eventStream = new EventStream(); - if (query.getTrace().getLevel() >= 1) { + if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } - languageModel.completeAsync(prompt, options, token -> { - eventStream.add(token.text()); + final TokenStats tokenStats = new TokenStats(); + languageModel.completeAsync(prompt, options, completion -> { + tokenStats.onToken(); + handleCompletion(eventStream, completion); }).exceptionally(exception -> { - int errorCode = 400; - if (exception instanceof LanguageModelException languageModelException) { - errorCode = languageModelException.code(); - } - eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + handleException(eventStream, exception); eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { + tokenStats.onCompletion(); + if (shouldAddTokenStats(query)) { + eventStream.add(tokenStats.report(), "stats"); + } eventStream.markComplete(); }); @@ -112,10 +127,26 @@ public class LLMSearcher extends Searcher { return new Result(query, hitGroup); } + private void handleCompletion(EventStream eventStream, Completion completion) { + if (completion.finishReason() == Completion.FinishReason.error) { + eventStream.add(completion.text(), "error"); + } else { + eventStream.add(completion.text()); + } + } + + private void handleException(EventStream eventStream, Throwable exception) { + int errorCode = 400; + if (exception instanceof LanguageModelException languageModelException) { + errorCode = languageModelException.code(); + } + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + } + private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { EventStream eventStream = new EventStream(); - if (query.getTrace().getLevel() >= 1) { + if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } @@ -169,4 +200,35 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } + private static class TokenStats { + + private long start; + private long timeToFirstToken; + private long timeToLastToken; + private long tokens = 0; + + TokenStats() { + start = System.currentTimeMillis(); + } + + void onToken() { + if (tokens == 0) { + timeToFirstToken = System.currentTimeMillis() - start; + } + tokens++; + } + + void onCompletion() { + timeToLastToken = System.currentTimeMillis() - start; + } + + String report() { + return "Time to first token: " + timeToFirstToken + " ms, " + + "Generation time: " + timeToLastToken + " ms, " + + "Generated tokens: " + tokens + " " + + String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0)); + } + + } + } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 83ae349f5a0..88a1e6c1485 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -64,7 +64,17 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { @Override public void data(Data data) throws IOException { - if (data instanceof EventStream.Event event) { + if (data instanceof EventStream.ErrorEvent error) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.source()); + generator.writeNumberField("error", error.code()); + generator.writeStringField("message", error.message()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } else if (data instanceof EventStream.Event event) { if (RENDER_EVENT_HEADER) { generator.writeRaw("event: " + event.type() + "\n"); } @@ -75,19 +85,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { generator.writeRaw("\n\n"); generator.flush(); } - else if (data instanceof ErrorHit) { - for (ErrorMessage error : ((ErrorHit) data).errors()) { - generator.writeRaw("event: error\n"); - generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField("source", error.getSource()); - generator.writeNumberField("error", error.getCode()); - generator.writeStringField("message", error.getMessage()); - generator.writeEndObject(); - generator.writeRaw("\n\n"); - generator.flush(); - } - } // Todo: support other types of data such as search results (hits), timing and trace } diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index b393a91e6d0..8e6f7977d55 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> { } public void error(String source, ErrorMessage message) { - incoming().add(new DefaultErrorHit(source, message)); + incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message)); } public void markComplete() { @@ -117,4 +117,38 @@ public class EventStream extends Hit implements DataList<Data> { } + public static class ErrorEvent extends Event { + + private final String source; + private final ErrorMessage message; + + public ErrorEvent(int eventNumber, String source, ErrorMessage message) { + super(eventNumber, message.getMessage(), "error"); + this.source = source; + this.message = message; + } + + public String source() { + return source; + } + + public int code() { + return message.getCode(); + } + + public String message() { + return message.getMessage(); + } + + @Override + public Hit asHit() { + Hit hit = super.asHit(); + hit.setField("source", source); + hit.setField("code", message.getCode()); + return hit; + } + + + } + } diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java index 1efcf1c736a..3baa9715c34 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java @@ -3,14 +3,11 @@ package ai.vespa.search.llm; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.clients.ConfigurableLanguageModelTest; -import ai.vespa.llm.clients.LlmClientConfig; -import ai.vespa.llm.clients.MockLLMClient; +import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.component.ComponentId; import com.yahoo.component.chain.Chain; import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.container.jdisc.SecretStoreProvider; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; @@ -20,10 +17,14 @@ import org.junit.jupiter.api.Test; import java.net.URLEncoder; import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BiFunction; +import java.util.function.Consumer; import java.util.stream.Collectors; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; @@ -36,10 +37,10 @@ public class LLMSearcherTest { @Test public void testLLMSelection() { - var llm1 = createLLMClient("mock1"); - var llm2 = createLLMClient("mock2"); + var client1 = createLLMClient("mock1"); + var client2 = createLLMClient("mock2"); var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build(); - var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2)); + var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2)); var result = runMockSearch(searcher, Map.of("prompt", "what is your id?")); assertEquals(1, result.getHitCount()); assertEquals("My id is mock2", getCompletion(result)); @@ -47,14 +48,16 @@ public class LLMSearcherTest { @Test public void testGeneration() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of("prompt", "why are ducks better than cats"); assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params))); } @Test public void testPrompting() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); // Prompt with prefix assertEquals("Ducks have adorable waddling walks.", @@ -71,7 +74,8 @@ public class LLMSearcherTest { @Test public void testPromptEvent() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "prompt", "why are ducks better than cats", "traceLevel", "1"); @@ -90,7 +94,8 @@ public class LLMSearcherTest { @Test public void testParameters() { - var searcher = createLLMSearcher(Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(client); var params = Map.of( "llm.prompt", "why are ducks better than cats", "llm.temperature", "1.0", @@ -107,16 +112,18 @@ public class LLMSearcherTest { "foo.maxTokens", "5" ); var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build(); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient())); + var client = createLLMClient(); + var searcher = createLLMSearcher(config, client); assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params))); } @Test public void testApiKeyFromHeader() { var properties = Map.of("prompt", "why are ducks better than cats"); - var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore())); - assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm")); - assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm")); + var client = createLLMClient(createApiKeyGenerator("a_valid_key")); + var searcher = createLLMSearcher(client); + assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key")); + assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key")); } @Test @@ -129,7 +136,8 @@ public class LLMSearcherTest { "llm.stream", "true", // ... but inference parameters says do it anyway "llm.prompt", "why are ducks better than cats?" ); - var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor))); + var client = createLLMClient(executor); + var searcher = createLLMSearcher(config, client); Result result = runMockSearch(searcher, params); assertEquals(1, result.getHitCount()); @@ -162,6 +170,10 @@ public class LLMSearcherTest { return runMockSearch(searcher, parameters, null, ""); } + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) { + return runMockSearch(searcher, parameters, apiKey, "llm"); + } + static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) { Chain<Searcher> chain = new Chain<>(searcher); Execution execution = new Execution(chain, Execution.Context.createContextStub()); @@ -191,43 +203,59 @@ public class LLMSearcherTest { } private static BiFunction<Prompt, InferenceParameters, String> createGenerator() { - return ConfigurableLanguageModelTest.createGenerator(); + return (prompt, options) -> { + String answer = "I have no opinion on the matter"; + if (prompt.asString().contains("ducks")) { + answer = "Ducks have adorable waddling walks."; + var temperature = options.getDouble("temperature"); + if (temperature.isPresent() && temperature.get() > 0.5) { + answer = "Random text about ducks vs cats that makes no sense whatsoever."; + } + } + var maxTokens = options.getInt("maxTokens"); + if (maxTokens.isPresent()) { + return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" ")); + } + return answer; + }; } - static MockLLMClient createLLMClient() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, null); + private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) { + return (prompt, options) -> { + if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) { + throw new IllegalArgumentException("Invalid API key"); + } + return "Ok"; + }; + } + + static MockLLM createLLMClient() { + return new MockLLM(createGenerator(), null); } - static MockLLMClient createLLMClient(String id) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createIdGenerator(id); - return new MockLLMClient(config, secretStore, generator, null); + static MockLLM createLLMClient(String id) { + return new MockLLM(createIdGenerator(id), null); } - static MockLLMClient createLLMClient(ExecutorService executor) { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY)); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore, generator, executor); + static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) { + return new MockLLM(generator, null); } - static MockLLMClient createLLMClientWithoutSecretStore() { - var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build(); - var secretStore = new SecretStoreProvider(); - var generator = createGenerator(); - return new MockLLMClient(config, secretStore.get(), generator, null); + static MockLLM createLLMClient(ExecutorService executor) { + return new MockLLM(createGenerator(), executor); + } + + private static Searcher createLLMSearcher(LanguageModel llm) { + return createLLMSearcher(Map.of("mock", llm)); } private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); - ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); - llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); - models.freeze(); - return new LLMSearcher(config, models); + return createLLMSearcher(config, llms); + } + + private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) { + return createLLMSearcher(config, Map.of("mock", llm)); } private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) { @@ -237,4 +265,44 @@ public class LLMSearcherTest { return new LLMSearcher(config, models); } + private static class MockLLM implements LanguageModel { + + private final ExecutorService executor; + private final BiFunction<Prompt, InferenceParameters, String> generator; + + public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) { + this.executor = executor; + this.generator = generator; + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters params) { + return List.of(Completion.from(this.generator.apply(prompt, params))); + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters params, + Consumer<Completion> consumer) { + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization + + long sleep = 1; + executor.submit(() -> { + try { + for (int i = 0; i < completions.length; ++i) { + String completion = (i > 0 ? " " : "") + completions[i]; + consumer.accept(Completion.from(completion, Completion.FinishReason.none)); + Thread.sleep(sleep); + } + completionFuture.complete(Completion.FinishReason.stop); + } catch (InterruptedException e) { + // Do nothing + } + }); + return completionFuture; + } + + } + } diff --git a/container-test/pom.xml b/container-test/pom.xml index 8e1b4870665..d6be6946208 100644 --- a/container-test/pom.xml +++ b/container-test/pom.xml @@ -61,6 +61,16 @@ <artifactId>onnxruntime</artifactId> </dependency> <dependency> + <groupId>de.kherud</groupId> + <artifactId>llama</artifactId> + <exclusions> + <exclusion> + <groupId>org.jetbrains</groupId> + <artifactId>annotations</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> <groupId>io.airlift</groupId> <artifactId>airline</artifactId> <exclusions> diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java index a6b09ddefd8..43cc8fda3c9 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java @@ -2,6 +2,7 @@ package ai.vespa.metricsproxy.metric.model; import com.yahoo.concurrent.CopyOnWriteHashMap; +import io.prometheus.client.Collector; import java.util.Map; import java.util.Objects; @@ -13,12 +14,18 @@ public final class DimensionId { private static final Map<String, DimensionId> dictionary = new CopyOnWriteHashMap<>(); public final String id; - private DimensionId(String id) { this.id = id; } + private final String idForPrometheus; + private DimensionId(String id) { + this.id = id; + idForPrometheus = Collector.sanitizeMetricName(id); + } public static DimensionId toDimensionId(String id) { return dictionary.computeIfAbsent(id, key -> new DimensionId(key)); } + public String getIdForPrometheus() { return idForPrometheus; } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java index 9014e818eab..829eb06101f 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java @@ -2,6 +2,7 @@ package ai.vespa.metricsproxy.metric.model; import com.yahoo.concurrent.CopyOnWriteHashMap; +import io.prometheus.client.Collector; import java.util.Map; import java.util.Objects; @@ -14,11 +15,16 @@ public class MetricId { private static final Map<String, MetricId> dictionary = new CopyOnWriteHashMap<>(); public static final MetricId empty = toMetricId(""); public final String id; - private MetricId(String id) { this.id = id; } + private final String idForPrometheus; + private MetricId(String id) { + this.id = id; + idForPrometheus = Collector.sanitizeMetricName(id); + } public static MetricId toMetricId(String id) { - return dictionary.computeIfAbsent(id, key -> new MetricId(key)); + return dictionary.computeIfAbsent(id, MetricId::new); } + public String getIdForPrometheus() { return idForPrometheus; } @Override public boolean equals(Object o) { @@ -29,13 +35,9 @@ public class MetricId { } @Override - public int hashCode() { - return Objects.hash(id); - } + public int hashCode() { return Objects.hash(id); } @Override - public String toString() { - return id; - } + public String toString() { return id; } } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java index 96ee2fa00e2..28c64b012c1 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.metricsproxy.metric.model; +import io.prometheus.client.Collector; + import java.util.Objects; /** @@ -9,10 +11,16 @@ import java.util.Objects; public class ServiceId { public final String id; - private ServiceId(String id) { this.id = id; } + private final String idForPrometheus; + private ServiceId(String id) { + this.id = id; + idForPrometheus = Collector.sanitizeMetricName(id); + } public static ServiceId toServiceId(String id) { return new ServiceId(id); } + public String getIdForPrometheus() { return idForPrometheus; } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java index d7436ccf404..2b0db5381bc 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java @@ -28,13 +28,14 @@ public class PrometheusUtil { Map<String, List<Sample>> samples = new HashMap<>(); packetsByService.forEach(((serviceId, packets) -> { - var serviceName = Collector.sanitizeMetricName(serviceId.id); + var serviceName = serviceId.getIdForPrometheus(); for (var packet : packets) { + Long timeStamp = packet.timestamp * 1000; var dimensions = packet.dimensions(); List<String> labels = new ArrayList<>(dimensions.size()); List<String> labelValues = new ArrayList<>(dimensions.size()); for (var entry : dimensions.entrySet()) { - var labelName = Collector.sanitizeMetricName(entry.getKey().id); + var labelName = entry.getKey().getIdForPrometheus(); labels.add(labelName); labelValues.add(entry.getValue()); } @@ -42,7 +43,7 @@ public class PrometheusUtil { labelValues.add(serviceName); for (var metric : packet.metrics().entrySet()) { - var metricName = Collector.sanitizeMetricName(metric.getKey().id); + var metricName = metric.getKey().getIdForPrometheus(); List<Sample> sampleList; if (samples.containsKey(metricName)) { sampleList = samples.get(metricName); @@ -51,7 +52,7 @@ public class PrometheusUtil { samples.put(metricName, sampleList); metricFamilySamples.add(new MetricFamilySamples(metricName, Collector.Type.UNKNOWN, "", sampleList)); } - sampleList.add(new Sample(metricName, labels, labelValues, metric.getValue().doubleValue(), packet.timestamp * 1000)); + sampleList.add(new Sample(metricName, labels, labelValues, metric.getValue().doubleValue(), timeStamp)); } } if (!packets.isEmpty()) { diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java index 6c3b759e97b..0e33d7dbf2f 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java @@ -28,6 +28,7 @@ import static ai.vespa.metricsproxy.metric.model.DimensionId.toDimensionId; * @author Jo Kristian Bergum */ public class MetricsParser { + private static final Double ZERO_DOUBLE = 0d; public interface Collector { void accept(Metric metric); } @@ -186,7 +187,8 @@ public class MetricsParser { if (token == JsonToken.VALUE_NUMBER_INT) { metrics.add(Map.entry(metricName, parser.getLongValue())); } else if (token == JsonToken.VALUE_NUMBER_FLOAT) { - metrics.add(Map.entry(metricName, parser.getValueAsDouble())); + double value = parser.getValueAsDouble(); + metrics.add(Map.entry(metricName, value == ZERO_DOUBLE ? ZERO_DOUBLE : value)); } else { throw new IllegalArgumentException("Value for aggregator '" + fieldName + "' is not a number"); } diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index d3c472778e6..e7130d9c777 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -1,4 +1,186 @@ { + "ai.vespa.llm.clients.ConfigurableLanguageModel" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public", + "abstract" + ], + "methods" : [ + "public void <init>()", + "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)", + "protected void setApiKey(ai.vespa.llm.InferenceParameters)", + "protected java.lang.String getEndpoint()", + "protected void setEndpoint(ai.vespa.llm.InferenceParameters)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void <init>()", + "public void <init>(ai.vespa.llm.clients.LlmClientConfig)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)", + "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public ai.vespa.llm.clients.LlmClientConfig build()" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig$Producer" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmClientConfig" : { + "superClass" : "com.yahoo.config.ConfigInstance", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)", + "public java.lang.String apiKeySecretName()", + "public java.lang.String endpoint()" + ], + "fields" : [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, + "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Builder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void <init>()", + "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder model(com.yahoo.config.ModelReference)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)", + "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)", + "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", + "public final java.lang.String getDefMd5()", + "public final java.lang.String getDefName()", + "public final java.lang.String getDefNamespace()", + "public final boolean getApplyOnRestart()", + "public final void setApplyOnRestart(boolean)", + "public ai.vespa.llm.clients.LlmLocalClientConfig build()" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigInstance$Producer" + ], + "attributes" : [ + "public", + "interface", + "abstract" + ], + "methods" : [ + "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.LlmLocalClientConfig" : { + "superClass" : "com.yahoo.config.ConfigInstance", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public static java.lang.String getDefMd5()", + "public static java.lang.String getDefName()", + "public static java.lang.String getDefNamespace()", + "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)", + "public java.nio.file.Path model()", + "public int parallelRequests()", + "public int maxQueueSize()", + "public boolean useGpu()", + "public int gpuLayers()", + "public int threads()", + "public int contextSize()", + "public int maxTokens()" + ], + "fields" : [ + "public static final java.lang.String CONFIG_DEF_MD5", + "public static final java.lang.String CONFIG_DEF_NAME", + "public static final java.lang.String CONFIG_DEF_NAMESPACE", + "public static final java.lang.String[] CONFIG_DEF_SCHEMA" + ] + }, + "ai.vespa.llm.clients.LocalLLM" : { + "superClass" : "com.yahoo.component.AbstractComponent", + "interfaces" : [ + "ai.vespa.llm.LanguageModel" + ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)", + "public void deconstruct()", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, + "ai.vespa.llm.clients.OpenAI" : { + "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)", + "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + ], + "fields" : [ ] + }, "ai.vespa.llm.generation.Generator" : { "superClass" : "com.yahoo.component.AbstractComponent", "interfaces" : [ ], diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 0bab30e1453..d92fa319251 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -40,6 +40,12 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>container-disc</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>searchcore</artifactId> <version>${project.version}</version> <scope>provided</scope> @@ -76,6 +82,12 @@ </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> + <artifactId>container-llama</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> <artifactId>component</artifactId> <version>${project.version}</version> <scope>provided</scope> diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java index 761fdf0af93..761fdf0af93 100644 --- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java new file mode 100644 index 00000000000..fd1b8b700c8 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java @@ -0,0 +1,126 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import de.kherud.llama.LlamaModel; +import de.kherud.llama.ModelParameters; +import de.kherud.llama.args.LogFormat; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; +import java.util.logging.Logger; + +/** + * A language model running locally on the container node. + * + * @author lesters + */ +public class LocalLLM extends AbstractComponent implements LanguageModel { + + private final static Logger logger = Logger.getLogger(LocalLLM.class.getName()); + private final LlamaModel model; + private final ThreadPoolExecutor executor; + private final int contextSize; + private final int maxTokens; + + @Inject + public LocalLLM(LlmLocalClientConfig config) { + executor = createExecutor(config); + + // Maximum number of tokens to generate - need this since some models can just generate infinitely + maxTokens = config.maxTokens(); + + // Only used if GPU is not used + var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2; + + var modelFile = config.model().toFile().getAbsolutePath(); + var modelParams = new ModelParameters() + .setModelFilePath(modelFile) + .setContinuousBatching(true) + .setNParallel(config.parallelRequests()) + .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads()) + .setNCtx(config.contextSize()) + .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0); + + long startLoad = System.nanoTime(); + model = new LlamaModel(modelParams); + long loadTime = System.nanoTime() - startLoad; + logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000))); + + // Todo: handle prompt context size - such as give a warning when prompt exceeds context size + contextSize = config.contextSize(); + } + + private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) { + return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(), + 0L, TimeUnit.MILLISECONDS, + config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(), + new ThreadPoolExecutor.AbortPolicy()); + } + + @Override + public void deconstruct() { + logger.info("Closing LLM model..."); + model.close(); + executor.shutdownNow(); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters options) { + StringBuilder result = new StringBuilder(); + var future = completeAsync(prompt, options, completion -> { + result.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + var reason = future.join(); + + List<Completion> completions = new ArrayList<>(); + completions.add(new Completion(result.toString(), reason)); + return completions; + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) { + var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading()); + + // We always set this to some value to avoid infinite token generation + inferParams.setNPredict(maxTokens); + + options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v))); + options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v))); + options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v))); + options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v))); + options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); + // Todo: more options? + + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + try { + executor.submit(() -> { + for (LlamaModel.Output output : model.generate(inferParams)) { + consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); + } + completionFuture.complete(Completion.FinishReason.stop); + }); + } catch (RejectedExecutionException e) { + // If we have too many requests (active + any waiting in queue), we reject the completion + int activeCount = executor.getActiveCount(); + int queueSize = executor.getQueue().size(); + String error = String.format("Rejected completion due to too many requests, " + + "%d active, %d in queue", activeCount, queueSize); + throw new RejectedExecutionException(error); + } + return completionFuture; + } + +} diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java index 82e19d47c92..82e19d47c92 100644 --- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java index c360245901c..c360245901c 100644 --- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java +++ b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def index 0866459166a..0866459166a 100755 --- a/container-search/src/main/resources/configdefinitions/llm-client.def +++ b/model-integration/src/main/resources/configdefinitions/llm-client.def diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def new file mode 100755 index 00000000000..c06c24b33e5 --- /dev/null +++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def @@ -0,0 +1,29 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.llm.clients + +# The LLM model to use +model model + +# Maximum number of requests to handle in parallel pr container node +parallelRequests int default=10 + +# Additional number of requests to put in queue for processing before starting to reject new requests +maxQueueSize int default=10 + +# Use GPU +useGpu bool default=false + +# Maximum number of model layers to run on GPU +gpuLayers int default=1000000 + +# Number of threads to use for CPU processing - -1 means use all available cores +# Not used for GPU processing +threads int default=-1 + +# Context size for the model +# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context +contextSize int default=512 + +# Maximum number of tokens to process in one request - overriden by inference parameters +maxTokens int default=512 + diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java index 35d5cfd3855..35d5cfd3855 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java new file mode 100644 index 00000000000..a3b260f3fb5 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java @@ -0,0 +1,186 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.clients; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.llm.completion.StringPrompt; +import com.yahoo.config.ModelReference; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for LocalLLM. + * + * @author lesters + */ +public class LocalLLMTest { + + private static String model = "src/test/models/llm/tinyllm.gguf"; + private static Prompt prompt = StringPrompt.from("A random prompt"); + + @Test + @Disabled + public void testGeneration() { + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(1) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + try { + var result = llm.complete(prompt, defaultOptions()); + assertEquals(Completion.FinishReason.stop, result.get(0).finishReason()); + assertTrue(result.get(0).text().length() > 10); + } finally { + llm.deconstruct(); + } + } + + @Test + @Disabled + public void testAsyncGeneration() { + var sb = new StringBuilder(); + var tokenCount = new AtomicInteger(0); + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(1) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + try { + var future = llm.completeAsync(prompt, defaultOptions(), completion -> { + sb.append(completion.text()); + tokenCount.incrementAndGet(); + }).exceptionally(exception -> Completion.FinishReason.error); + + assertFalse(future.isDone()); + var reason = future.join(); + assertTrue(future.isDone()); + assertNotEquals(reason, Completion.FinishReason.error); + + } finally { + llm.deconstruct(); + } + assertTrue(tokenCount.get() > 0); + System.out.println(sb); + } + + @Test + @Disabled + public void testParallelGeneration() { + var prompts = testPrompts(); + var promptsToUse = prompts.size(); + var parallelRequests = 10; + + var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null)); + var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null)); + var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0)); + + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(parallelRequests) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + try { + for (int i = 0; i < promptsToUse; i++) { + final var seq = i; + + completions.set(seq, new StringBuilder()); + futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { + completions.get(seq).append(completion.text()); + tokenCounts.set(seq, tokenCounts.get(seq) + 1); + }).exceptionally(exception -> Completion.FinishReason.error)); + } + for (int i = 0; i < promptsToUse; i++) { + var reason = futures.get(i).join(); + assertNotEquals(reason, Completion.FinishReason.error); + } + } finally { + llm.deconstruct(); + } + for (int i = 0; i < promptsToUse; i++) { + assertFalse(completions.get(i).isEmpty()); + assertTrue(tokenCounts.get(i) > 0); + } + } + + @Test + @Disabled + public void testRejection() { + var prompts = testPrompts(); + var promptsToUse = prompts.size(); + var parallelRequests = 2; + var additionalQueue = 1; + // 7 should be rejected + + var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null)); + var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null)); + + var config = new LlmLocalClientConfig.Builder() + .parallelRequests(parallelRequests) + .maxQueueSize(additionalQueue) + .model(ModelReference.valueOf(model)); + var llm = new LocalLLM(config.build()); + + var rejected = new AtomicInteger(0); + try { + for (int i = 0; i < promptsToUse; i++) { + final var seq = i; + + completions.set(seq, new StringBuilder()); + try { + var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> { + completions.get(seq).append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + futures.set(seq, future); + } catch (RejectedExecutionException e) { + rejected.incrementAndGet(); + } + } + for (int i = 0; i < promptsToUse; i++) { + if (futures.get(i) != null) { + assertNotEquals(futures.get(i).join(), Completion.FinishReason.error); + } + } + } finally { + llm.deconstruct(); + } + assertEquals(7, rejected.get()); + } + + private static InferenceParameters defaultOptions() { + final Map<String, String> options = Map.of( + "temperature", "0.1", + "npredict", "100" + ); + return new InferenceParameters(options::get); + } + + private List<String> testPrompts() { + List<String> prompts = new ArrayList<>(); + prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries."); + prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms."); + prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe."); + prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure."); + prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective."); + prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'"); + prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`."); + prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease."); + prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level."); + prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century."); + return prompts; + } + +} diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java index 4d0073f1cbe..4d0073f1cbe 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java index 57339f6ad49..57339f6ad49 100644 --- a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java +++ b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java diff --git a/model-integration/src/test/models/llm/tinyllm.gguf b/model-integration/src/test/models/llm/tinyllm.gguf Binary files differnew file mode 100644 index 00000000000..34367b6b57b --- /dev/null +++ b/model-integration/src/test/models/llm/tinyllm.gguf diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp index 87004d7e5f2..0c986422be6 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp @@ -2,10 +2,11 @@ #include "lid_allocator.h" #include <vespa/searchlib/common/bitvectoriterator.h> -#include <vespa/searchlib/fef/termfieldmatchdataarray.h> #include <vespa/searchlib/fef/matchdata.h> -#include <vespa/searchlib/queryeval/full_search.h> +#include <vespa/searchlib/fef/termfieldmatchdataarray.h> #include <vespa/searchlib/queryeval/blueprint.h> +#include <vespa/searchlib/queryeval/flow_tuning.h> +#include <vespa/searchlib/queryeval/full_search.h> #include <mutex> #include <vespa/log/log.h> @@ -19,6 +20,8 @@ using search::queryeval::SearchIterator; using search::queryeval::SimpleLeafBlueprint; using vespalib::GenerationHolder; +using namespace search::queryeval::flow; + namespace proton::documentmetastore { LidAllocator::LidAllocator(uint32_t size, @@ -206,7 +209,8 @@ private: return search::BitVectorIterator::create(&_activeLids, get_docid_limit(), *tfmd, strict); } FlowStats calculate_flow_stats(uint32_t docid_limit) const override { - return default_flow_stats(docid_limit, _activeLids.size(), 0); + double rel_est = abs_to_rel_est(_activeLids.size(), docid_limit); + return {rel_est, bitvector_cost(), bitvector_strict_cost(rel_est)}; } SearchIterator::UP createLeafSearch(const TermFieldMatchDataArray &tfmda) const override diff --git a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp index 6ec1ffd460e..485410e0eba 100644 --- a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp @@ -39,7 +39,7 @@ public: return mixChildrenFields(); } - void sort(Children &children, bool, bool) const override { + void sort(Children &children, InFlow) const override { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp index d2b5ec2cb8b..1db9cd58d46 100644 --- a/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp +++ b/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp @@ -20,7 +20,11 @@ to_string(const Config& attr_config) oss << col_type.asString() << "<" << basic_type.asString() << ">"; } if (attr_config.fastSearch()) { - oss << "(fs)"; + oss << "(fs"; + if (attr_config.getIsFilter()) { + oss << ",rf"; + } + oss << ")"; } return oss.str(); } diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp index d162ef05b06..f7a358efb26 100644 --- a/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp +++ b/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp @@ -792,10 +792,11 @@ gen_ratios(double middle, double range_multiplier, size_t num_samples) } FieldConfig -make_attr_config(BasicType basic_type, CollectionType col_type, bool fast_search) +make_attr_config(BasicType basic_type, CollectionType col_type, bool fast_search, bool rank_filter = false) { Config cfg(basic_type, col_type); cfg.setFastSearch(fast_search); + cfg.setIsFilter(rank_filter); return FieldConfig(cfg); } @@ -812,6 +813,7 @@ const std::vector<double> base_hit_ratios = {0.0001, 0.001, 0.01, 0.1, 0.5, 1.0} const std::vector<double> filter_hit_ratios = {0.00001, 0.00005, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.5, 1.0}; const auto int32 = make_attr_config(BasicType::INT32, CollectionType::SINGLE, false); const auto int32_fs = make_attr_config(BasicType::INT32, CollectionType::SINGLE, true); +const auto int32_fs_rf = make_attr_config(BasicType::INT32, CollectionType::SINGLE, true, true); const auto int32_array = make_attr_config(BasicType::INT32, CollectionType::ARRAY, false); const auto int32_array_fs = make_attr_config(BasicType::INT32, CollectionType::ARRAY, true); const auto int32_wset = make_attr_config(BasicType::INT32, CollectionType::WSET, false); @@ -940,6 +942,15 @@ TEST(IteratorBenchmark, analyze_and_with_filter_vs_in) } } +TEST(IteratorBenchmark, analyze_and_with_bitvector_vs_in) +{ + for (uint32_t children: {10, 100, 1000, 10000}) { + run_and_benchmark({int32_fs, QueryOperator::In, {0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60}, children, true}, + {int32_fs_rf, QueryOperator::Term, {1.0}, 1, true}, // this setup returns a bitvector matching all documents. + num_docs); + } +} + TEST(IteratorBenchmark, analyze_and_with_filter_vs_in_array) { for (uint32_t children: {10, 100, 1000}) { @@ -958,6 +969,12 @@ TEST(IteratorBenchmark, analyze_and_with_filter_vs_or) } } +TEST(IteratorBenchmark, analyze_btree_vs_bitvector_iterators_strict) +{ + BenchmarkSetup setup(num_docs, {int32_fs, int32_fs_rf}, {QueryOperator::Term}, {true}, {0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0}, {1}); + run_benchmarks(setup); +} + int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); int res = RUN_ALL_TESTS(); diff --git a/searchlib/src/vespa/searchlib/bitcompression/compression.cpp b/searchlib/src/vespa/searchlib/bitcompression/compression.cpp index 0f089c60e4b..f3fc31ac8b1 100644 --- a/searchlib/src/vespa/searchlib/bitcompression/compression.cpp +++ b/searchlib/src/vespa/searchlib/bitcompression/compression.cpp @@ -359,6 +359,24 @@ getParams(PostingListParams ¶ms) const params.clear(); } +template <bool bigEndian> +void +FeatureEncodeContext<bigEndian>::pad_for_memory_map_and_flush() +{ + // Write some pad bits to avoid decompression readahead going past + // memory mapped file during search and into SIGSEGV territory. + + // First pad to 64 bits alignment. + this->smallAlign(64); + writeComprBufferIfNeeded(); + + // Then write 128 more bits. This allows for 64-bit decoding + // with a readbits that always leaves a nonzero preRead + padBits(128); + this->alignDirectIO(); + this->flush(); + writeComprBuffer(); // Also flushes slack +} template <bool bigEndian> void diff --git a/searchlib/src/vespa/searchlib/bitcompression/compression.h b/searchlib/src/vespa/searchlib/bitcompression/compression.h index 9d4ca38eed3..4124f1f659f 100644 --- a/searchlib/src/vespa/searchlib/bitcompression/compression.h +++ b/searchlib/src/vespa/searchlib/bitcompression/compression.h @@ -1595,6 +1595,8 @@ public: writeComprBufferIfNeeded(); } + void pad_for_memory_map_and_flush(); + virtual void readHeader(const vespalib::GenericHeader &header, const vespalib::string &prefix); virtual void writeHeader(vespalib::GenericHeader &header, const vespalib::string &prefix) const; virtual const vespalib::string &getIdentifier() const; diff --git a/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp b/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp index 387d95bce66..bceeb1e7bc1 100644 --- a/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp +++ b/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp @@ -269,11 +269,9 @@ PageDict4FileSeqWrite::DictFileContext::DictFileContext(bool extended, vespalib: } bool -PageDict4FileSeqWrite::DictFileContext::DictFileContext::close() { - //uint64_t usedPBits = _ec.getWriteOffset(); - _ec.flush(); - _writeContext.writeComprBuffer(true); - +PageDict4FileSeqWrite::DictFileContext::DictFileContext::close() +{ + _ec.pad_for_memory_map_and_flush(); _writeContext.dropComprBuf(); bool success = _file.Sync(); success &= _file.Close(); diff --git a/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp b/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp index c7480633e21..f2b7911ba55 100644 --- a/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp +++ b/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp @@ -247,19 +247,7 @@ template <bool bigEndian> void Zc4PostingWriter<bigEndian>::on_close() { - // Write some pad bits to avoid decompression readahead going past - // memory mapped file during search and into SIGSEGV territory. - - // First pad to 64 bits alignment. - _encode_context.smallAlign(64); - _encode_context.writeComprBufferIfNeeded(); - - // Then write 128 more bits. This allows for 64-bit decoding - // with a readbits that always leaves a nonzero preRead - _encode_context.padBits(128); - _encode_context.alignDirectIO(); - _encode_context.flush(); - _encode_context.writeComprBuffer(); // Also flushes slack + _encode_context.pad_for_memory_map_and_flush(); } template class Zc4PostingWriter<false>; diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp index 43339b68999..d11ee25a7e5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp @@ -619,7 +619,7 @@ IntermediateBlueprint::sort(InFlow in_flow) { resolve_strict(in_flow); if (!opt_keep_order()) [[likely]] { - sort(_children, in_flow.strict(), opt_sort_by_cost()); + sort(_children, in_flow); } auto flow = my_flow(in_flow); for (size_t i = 0; i < _children.size(); ++i) { diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h index 1501289c590..a443f34f856 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h @@ -499,7 +499,7 @@ public: virtual HitEstimate combine(const std::vector<HitEstimate> &data) const = 0; virtual FieldSpecBaseList exposeFields() const = 0; - virtual void sort(Children &children, bool strict, bool sort_by_cost) const = 0; + virtual void sort(Children &children, InFlow in_flow) const = 0; virtual SearchIteratorUP createIntermediateSearch(MultiSearch::Children subSearches, fef::MatchData &md) const = 0; diff --git a/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h b/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h index dae0bd82cd0..356ecd4c992 100644 --- a/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h +++ b/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h @@ -61,6 +61,17 @@ inline double btree_strict_cost(double my_est) { return my_est; } +// Non-strict cost of matching in a bitvector. +inline double bitvector_cost() { + return 1.0; +} + +// Strict cost of matching in a bitvector. +// Test used: IteratorBenchmark::analyze_btree_vs_bitvector_iterators_strict +inline double bitvector_strict_cost(double my_est) { + return 1.5 * my_est; +} + // Non-strict cost of matching in a disk index posting list. inline double disk_index_cost() { return 1.5; diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp index 449a6a044b9..2fd632f9b97 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp @@ -162,10 +162,10 @@ AndNotBlueprint::get_replacement() } void -AndNotBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const +AndNotBlueprint::sort(Children &children, InFlow in_flow) const { - if (sort_by_cost) { - AndNotFlow::sort(children, strict); + if (opt_sort_by_cost()) { + AndNotFlow::sort(children, in_flow.strict()); } else { if (children.size() > 2) { std::sort(children.begin() + 1, children.end(), TieredGreaterEstimate()); @@ -257,12 +257,12 @@ AndBlueprint::get_replacement() } void -AndBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const +AndBlueprint::sort(Children &children, InFlow in_flow) const { - if (sort_by_cost) { - AndFlow::sort(children, strict); - if (strict && opt_allow_force_strict()) { - AndFlow::reorder_for_extra_strictness(children, true, 3); + if (opt_sort_by_cost()) { + AndFlow::sort(children, in_flow.strict()); + if (opt_allow_force_strict()) { + AndFlow::reorder_for_extra_strictness(children, in_flow, 3); } } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); @@ -360,10 +360,10 @@ OrBlueprint::get_replacement() } void -OrBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const +OrBlueprint::sort(Children &children, InFlow in_flow) const { - if (sort_by_cost) { - OrFlow::sort(children, strict); + if (opt_sort_by_cost()) { + OrFlow::sort(children, in_flow.strict()); } else { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } @@ -449,7 +449,7 @@ WeakAndBlueprint::exposeFields() const } void -WeakAndBlueprint::sort(Children &, bool, bool) const +WeakAndBlueprint::sort(Children &, InFlow) const { // order needs to stay the same as _weights } @@ -511,10 +511,10 @@ NearBlueprint::exposeFields() const } void -NearBlueprint::sort(Children &children, bool strict, bool sort_by_cost) const +NearBlueprint::sort(Children &children, InFlow in_flow) const { - if (sort_by_cost) { - AndFlow::sort(children, strict); + if (opt_sort_by_cost()) { + AndFlow::sort(children, in_flow.strict()); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -576,7 +576,7 @@ ONearBlueprint::exposeFields() const } void -ONearBlueprint::sort(Children &, bool, bool) const +ONearBlueprint::sort(Children &, InFlow) const { // ordered near cannot sort children here } @@ -662,7 +662,7 @@ RankBlueprint::get_replacement() } void -RankBlueprint::sort(Children &, bool, bool) const +RankBlueprint::sort(Children &, InFlow) const { } @@ -743,7 +743,7 @@ SourceBlenderBlueprint::exposeFields() const } void -SourceBlenderBlueprint::sort(Children &, bool, bool) const +SourceBlenderBlueprint::sort(Children &, InFlow) const { } diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h index 5d6d098510f..5b7b5b701b5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h @@ -21,7 +21,7 @@ public: void optimize_self(OptimizePass pass) override; AndNotBlueprint * asAndNot() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, fef::MatchData &md) const override; @@ -48,7 +48,7 @@ public: void optimize_self(OptimizePass pass) override; AndBlueprint * asAnd() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, fef::MatchData &md) const override; @@ -72,7 +72,7 @@ public: void optimize_self(OptimizePass pass) override; OrBlueprint * asOr() noexcept final { return this; } Blueprint::UP get_replacement() override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, fef::MatchData &md) const override; @@ -96,7 +96,7 @@ public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool strict, bool sort_on_cost) const override; + void sort(Children &children, InFlow in_flow) const override; bool always_needs_unpack() const override; WeakAndBlueprint * asWeakAnd() noexcept final { return this; } SearchIterator::UP @@ -126,7 +126,7 @@ public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; SearchIteratorUP createSearch(fef::MatchData &md) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -148,7 +148,7 @@ public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; SearchIteratorUP createSearch(fef::MatchData &md) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -168,7 +168,7 @@ public: FieldSpecBaseList exposeFields() const override; void optimize_self(OptimizePass pass) override; Blueprint::UP get_replacement() override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; bool isRank() const noexcept final { return true; } SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, @@ -196,7 +196,7 @@ public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; FieldSpecBaseList exposeFields() const override; - void sort(Children &children, bool strict, bool sort_by_cost) const override; + void sort(Children &children, InFlow in_flow) const override; SearchIterator::UP createIntermediateSearch(MultiSearch::Children subSearches, fef::MatchData &md) const override; diff --git a/standalone-container/pom.xml b/standalone-container/pom.xml index 844b912543c..92faa1ae670 100644 --- a/standalone-container/pom.xml +++ b/standalone-container/pom.xml @@ -113,6 +113,7 @@ model-evaluation-jar-with-dependencies.jar, model-integration-jar-with-dependencies.jar, container-onnxruntime.jar, + container-llama.jar, <!-- END config-model dependencies --> </discPreInstallBundle> </configuration> |