diff options
Diffstat (limited to 'container-search/src/main')
-rwxr-xr-x | container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java | 20 | ||||
-rwxr-xr-x | container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java | 43 | ||||
-rw-r--r-- | container-search/src/main/java/ai/vespa/search/llm/TokenStream.java | 58 | ||||
-rw-r--r-- | container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java | 91 | ||||
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java (renamed from container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java) | 31 | ||||
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java | 14 | ||||
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/result/EventStream.java | 66 | ||||
-rwxr-xr-x | container-search/src/main/resources/configdefinitions/local-llm-interface.def | 6 |
8 files changed, 232 insertions, 97 deletions
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 50ea2646f9f..2b1c553a675 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -7,7 +7,6 @@ import ai.vespa.llm.LanguageModelException; import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; -import ai.vespa.search.llm.LlmSearcherConfig; import com.yahoo.api.annotations.Beta; import com.yahoo.component.ComponentId; import com.yahoo.component.annotation.Inject; @@ -16,6 +15,7 @@ import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.EventStream; import com.yahoo.search.result.Hit; import com.yahoo.search.result.HitGroup; @@ -42,7 +42,7 @@ public abstract class LLMSearcher extends Searcher { LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { this.stream = config.stream(); this.languageModelId = config.providerId(); - this.languageModel = findLanguageModel(config.providerId(), languageModels); + this.languageModel = findLanguageModel(languageModelId, languageModels); this.propertyPrefix = config.propertyPrefix(); } @@ -98,23 +98,27 @@ public abstract class LLMSearcher extends Searcher { } private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { - TokenStream tokenStream = TokenStream.create("token_stream"); + EventStream eventStream = EventStream.create("token_stream"); + + if (query.getTrace().getLevel() >= 1) { + eventStream.add(prompt.asString(), "prompt"); + } languageModel.completeAsync(prompt, options, token -> { - tokenStream.add(token.text()); + eventStream.add(token.text()); }).exceptionally(exception -> { int errorCode = 400; if (exception instanceof LanguageModelException languageModelException) { errorCode = languageModelException.code(); } - tokenStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); - tokenStream.markComplete(); + eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); + eventStream.markComplete(); return Completion.FinishReason.error; }).thenAccept(finishReason -> { - tokenStream.markComplete(); + eventStream.markComplete(); }); - return new Result(query, tokenStream); + return new Result(query, eventStream); } private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { diff --git a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java index 92fd904d709..05c6335139b 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java @@ -20,6 +20,8 @@ public class RAGSearcher extends LLMSearcher { private static Logger log = Logger.getLogger(RAGSearcher.class.getName()); + private static final String PROMPT = "prompt"; + @Inject public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) { super(config, languageModels); @@ -30,24 +32,51 @@ public class RAGSearcher extends LLMSearcher { public Result search(Query query, Execution execution) { Result result = execution.search(query); execution.fill(result); - return complete(query, buildPrompt(query, result)); + + // Todo: Add resulting prompt to the result + Prompt prompt = buildPrompt(query, result); + + return complete(query, prompt); } protected Prompt buildPrompt(Query query, Result result) { - String propertyWithPrefix = this.getPropertyPrefix() + ".prompt"; - String prompt = query.properties().getString(propertyWithPrefix); - if (prompt == null) { - prompt = "Please provide a summary of the above"; + String prompt = getPrompt(query); + + // Replace @query with the actual query + if (prompt.contains("@query")) { + prompt = prompt.replace("@query", query.getModel().getQueryString()); } - if (!prompt.contains("{context}")) { + + if ( !prompt.contains("{context}")) { prompt = "{context}\n" + prompt; } - // Todo: support system and user prompt prompt = prompt.replace("{context}", buildContext(result)); log.info("Prompt: " + prompt); // remove return StringPrompt.from(prompt); } + private String getPrompt(Query query) { + // First, check if prompt is set with a prefix + String propertyWithPrefix = this.getPropertyPrefix() + "." + PROMPT; + String prompt = query.properties().getString(propertyWithPrefix); + if (prompt != null) + return prompt; + + // If not, try without prefix + prompt = query.properties().getString(PROMPT); + if (prompt != null) + return prompt; + + // If not, use query + prompt = query.getModel().getQueryString(); + if (prompt != null) + return prompt; + + // If not, throw exception + throw new IllegalArgumentException("RAG searcher could not find prompt found for query. Tried looking for " + + "'" + propertyWithPrefix + "." + PROMPT + "', '" + PROMPT + "' or '@query'."); + } + private String buildContext(Result result) { StringBuilder sb = new StringBuilder(); var hits = result.hits(); diff --git a/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java b/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java deleted file mode 100644 index bd3cda5a6dc..00000000000 --- a/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java +++ /dev/null @@ -1,58 +0,0 @@ -package ai.vespa.search.llm; - -import com.yahoo.processing.response.DefaultIncomingData; -import com.yahoo.search.result.DefaultErrorHit; -import com.yahoo.search.result.ErrorMessage; -import com.yahoo.search.result.Hit; -import com.yahoo.search.result.HitGroup; -import com.yahoo.search.result.Relevance; - -public class TokenStream extends HitGroup { - - private int tokenCount = 0; - - private TokenStream(String id, DefaultIncomingData<Hit> incomingData) { - super(id, new Relevance(1), incomingData); - this.setOrdered(true); // avoid hit group ordering - important for errors - } - - public static TokenStream create(String id) { - DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>(); - TokenStream stream = new TokenStream(id, incomingData); - incomingData.assignOwner(stream); - return stream; - } - - public static HitGroup createAsync(String id) { - return create(id); - } - - public void add(String token) { - incoming().add(new Token(String.valueOf(++tokenCount), token)); - } - - public void error(String source, ErrorMessage message) { - incoming().add(new DefaultErrorHit(source, message)); - } - - public void markComplete() { - incoming().markComplete(); - } - - public static class Token extends Hit { - - public Token(String token) { - this("", token); - } - - public Token(String id, String token) { - super(id); - setField("token", token); - } - - public String toString() { - return getField("token").toString(); - } - - } -} diff --git a/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java b/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java new file mode 100644 index 00000000000..dd9ed01bd73 --- /dev/null +++ b/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java @@ -0,0 +1,91 @@ +package ai.vespa.search.llm.interfaces; + +import ai.vespa.llm.InferenceParameters; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.Prompt; +import ai.vespa.search.llm.LocalLlmInterfaceConfig; + +import ai.vespa.util.http.hc4.retry.Sleeper; +import com.yahoo.component.annotation.Inject; +import de.kherud.llama.LlamaModel; +import de.kherud.llama.LogLevel; +import de.kherud.llama.ModelParameters; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.function.Consumer; +import java.util.logging.Logger; + +public class LocalLLMInterface implements LanguageModel { + + private static Logger logger = Logger.getLogger(LocalLLMInterface.class.getName()); + private final LlamaModel model; + private final ExecutorService executor; + + @Inject + public LocalLLMInterface(LocalLlmInterfaceConfig config) { + this(config, Executors.newFixedThreadPool(1)); // until we can run llama.cpp in batch + } + + LocalLLMInterface(LocalLlmInterfaceConfig config, ExecutorService executor) { + this.executor = executor; + + LlamaModel.setLogger(this::log); + var modelParams = new ModelParameters() + // Todo: retrieve from config + ; + + long startLoad = System.nanoTime(); + model = new LlamaModel(config.llmfile(), modelParams); + long loadTime = System.nanoTime() - startLoad; + logger.info("Loaded model " + config.llmfile() + " in " + (loadTime*1.0/1000000000) + " sec"); + } + + @Override + public List<Completion> complete(Prompt prompt, InferenceParameters options) { + StringBuilder result = new StringBuilder(); + var future = completeAsync(prompt, options, completion -> { + result.append(completion.text()); + }).exceptionally(exception -> Completion.FinishReason.error); + var reason = future.join(); + + List<Completion> completions = new ArrayList<>(); + completions.add(new Completion(result.toString(), reason)); + return completions; + } + + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) { + var inferParams = new de.kherud.llama.InferenceParameters(); + options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v))); + options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v))); + options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v))); + options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v))); + options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v))); + // Todo: add more + + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + executor.submit(() -> { + for (LlamaModel.Output output : model.generate(prompt.asString(), inferParams)) { + consumer.accept(Completion.from(output.text, Completion.FinishReason.none)); + } + completionFuture.complete(Completion.FinishReason.stop); + }); + + return completionFuture; + } + + private void log(LogLevel level, String message) { + switch (level) { + case WARN -> logger.warning(message); + case DEBUG -> logger.fine(message); + case ERROR -> logger.severe(message); + default -> logger.info(message); + } + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 0b0c018e1ac..4cca0fcce9f 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -1,6 +1,6 @@ package com.yahoo.search.rendering; -import ai.vespa.search.llm.TokenStream; +import com.yahoo.search.result.EventStream; import com.fasterxml.jackson.core.JsonEncoding; import com.fasterxml.jackson.core.JsonFactory; import com.fasterxml.jackson.core.JsonFactoryBuilder; @@ -18,19 +18,16 @@ import com.yahoo.search.result.ErrorMessage; import java.io.IOException; import java.io.OutputStream; import java.util.concurrent.Executor; -import java.util.logging.Logger; import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE; /** - * - * A comment about SSE + * A Server-Sent Events (SSE) renderer for asynchronous events such as + * tokens from a language model. * * @author lesters */ -public class TokenRenderer extends AsynchronousSectionedRenderer<Result> { - - private static final Logger log = Logger.getLogger(TokenRenderer.class.getName()); +public class EventRenderer extends AsynchronousSectionedRenderer<Result> { private static final JsonFactory generatorFactory = createGeneratorFactory(); private volatile JsonGenerator generator; @@ -43,14 +40,14 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> { return factory; } - private static final boolean RENDER_TOKEN_EVENT_HEADER = true; + private static final boolean RENDER_EVENT_HEADER = true; private static final boolean RENDER_END_EVENT = true; - public TokenRenderer() { + public EventRenderer() { this(null); } - public TokenRenderer(Executor executor) { + public EventRenderer(Executor executor) { super(executor); } @@ -62,21 +59,21 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> { @Override public void beginList(DataList<?> dataList) throws IOException { - if ( ! (dataList instanceof TokenStream)) { - throw new IllegalArgumentException("TokenRenderer currently only supports TokenStreams"); + if ( ! (dataList instanceof EventStream)) { + throw new IllegalArgumentException("EventRenderer currently only supports EventStreams"); // Todo: support results and timing and trace by delegating to JsonRenderer } } @Override public void data(Data data) throws IOException { - if (data instanceof TokenStream.Token token) { - if (RENDER_TOKEN_EVENT_HEADER) { - generator.writeRaw("event: token\n"); + if (data instanceof EventStream.Event event) { + if (RENDER_EVENT_HEADER) { + generator.writeRaw("event: " + event.type() + "\n"); } generator.writeRaw("data: "); generator.writeStartObject(); - generator.writeStringField("token", token.toString()); + generator.writeStringField(event.type(), event.toString()); generator.writeEndObject(); generator.writeRaw("\n\n"); generator.flush(); @@ -115,7 +112,7 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> { @Override public String getMimeType() { - return "application/json"; + return "text/event-stream"; } } diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java index ad6875b6c9b..d62860afcda 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java @@ -24,7 +24,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi public static final ComponentId xmlRendererId = ComponentId.fromString("XmlRenderer"); public static final ComponentId pageRendererId = ComponentId.fromString("PageTemplatesXmlRenderer"); public static final ComponentId jsonRendererId = ComponentId.fromString("JsonRenderer"); - public static final ComponentId tokenRendererId = ComponentId.fromString("TokenRenderer"); + public static final ComponentId eventRendererId = ComponentId.fromString("EventRenderer"); public static final ComponentId defaultRendererId = jsonRendererId; @@ -57,10 +57,10 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi pageRenderer.initId(pageRendererId); register(pageRenderer.getId(), pageRenderer); - // Add token renderer - Renderer tokenRenderer = new TokenRenderer(executor); - tokenRenderer.initId(tokenRendererId); - register(tokenRenderer.getId(), tokenRenderer); + // Add event renderer + Renderer eventRenderer = new EventRenderer(executor); + eventRenderer.initId(eventRendererId); + register(eventRenderer.getId(), eventRenderer); // add application renderers for (Renderer renderer : renderers) @@ -75,7 +75,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi getRenderer(jsonRendererId.toSpecification()).deconstruct(); getRenderer(xmlRendererId.toSpecification()).deconstruct(); getRenderer(pageRendererId.toSpecification()).deconstruct(); - getRenderer(tokenRendererId.toSpecification()).deconstruct(); + getRenderer(eventRendererId.toSpecification()).deconstruct(); } /** @@ -99,7 +99,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi if (format.stringValue().equals("json")) return getComponent(jsonRendererId); if (format.stringValue().equals("xml")) return getComponent(xmlRendererId); if (format.stringValue().equals("page")) return getComponent(pageRendererId); - if (format.stringValue().equals("token")) return getComponent(tokenRendererId); + if (format.stringValue().equals("sse")) return getComponent(eventRendererId); com.yahoo.processing.rendering.Renderer<Result> renderer = getComponent(format); if (renderer == null) diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java new file mode 100644 index 00000000000..957a027608e --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -0,0 +1,66 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.result; + +import com.yahoo.processing.response.DefaultIncomingData; + +/** + * A stream of events which can be rendered as Server-Sent Events (SSE). + * + * @author lesters + */ +public class EventStream extends HitGroup { + + private int eventCount = 0; + + public static final String DEFAULT_EVENT_TYPE = "token"; + + private EventStream(String id, DefaultIncomingData<Hit> incomingData) { + super(id, new Relevance(1), incomingData); + this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept + } + + public static EventStream create(String id) { + DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>(); + EventStream stream = new EventStream(id, incomingData); + incomingData.assignOwner(stream); + return stream; + } + + public void add(String data) { + add(data, DEFAULT_EVENT_TYPE); + } + + public void add(String data, String type) { + incoming().add(new Event(String.valueOf(eventCount + 1), data, type)); + eventCount++; + } + + public void error(String source, ErrorMessage message) { + incoming().add(new DefaultErrorHit(source, message)); + } + + public void markComplete() { + incoming().markComplete(); + } + + public static class Event extends Hit { + + private final String type; + + public Event(String id, String data, String type) { + super(id); + this.type = type; + setField(type, data); + } + + public String toString() { + return getField(type).toString(); + } + + public String type() { + return type; + } + + } +} diff --git a/container-search/src/main/resources/configdefinitions/local-llm-interface.def b/container-search/src/main/resources/configdefinitions/local-llm-interface.def new file mode 100755 index 00000000000..5d26471d2fe --- /dev/null +++ b/container-search/src/main/resources/configdefinitions/local-llm-interface.def @@ -0,0 +1,6 @@ +# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package=ai.vespa.search.llm + +# Something that points to either a url or local path? +llmfile string default="" + |