aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main')
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java20
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java43
-rw-r--r--container-search/src/main/java/ai/vespa/search/llm/TokenStream.java58
-rw-r--r--container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java91
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java (renamed from container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java)31
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java66
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/local-llm-interface.def6
8 files changed, 232 insertions, 97 deletions
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 50ea2646f9f..2b1c553a675 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -7,7 +7,6 @@ import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
-import ai.vespa.search.llm.LlmSearcherConfig;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.ComponentId;
import com.yahoo.component.annotation.Inject;
@@ -16,6 +15,7 @@ import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
@@ -42,7 +42,7 @@ public abstract class LLMSearcher extends Searcher {
LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
this.stream = config.stream();
this.languageModelId = config.providerId();
- this.languageModel = findLanguageModel(config.providerId(), languageModels);
+ this.languageModel = findLanguageModel(languageModelId, languageModels);
this.propertyPrefix = config.propertyPrefix();
}
@@ -98,23 +98,27 @@ public abstract class LLMSearcher extends Searcher {
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- TokenStream tokenStream = TokenStream.create("token_stream");
+ EventStream eventStream = EventStream.create("token_stream");
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
languageModel.completeAsync(prompt, options, token -> {
- tokenStream.add(token.text());
+ eventStream.add(token.text());
}).exceptionally(exception -> {
int errorCode = 400;
if (exception instanceof LanguageModelException languageModelException) {
errorCode = languageModelException.code();
}
- tokenStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
- tokenStream.markComplete();
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
- tokenStream.markComplete();
+ eventStream.markComplete();
});
- return new Result(query, tokenStream);
+ return new Result(query, eventStream);
}
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
diff --git a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
index 92fd904d709..05c6335139b 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
@@ -20,6 +20,8 @@ public class RAGSearcher extends LLMSearcher {
private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
+ private static final String PROMPT = "prompt";
+
@Inject
public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
super(config, languageModels);
@@ -30,24 +32,51 @@ public class RAGSearcher extends LLMSearcher {
public Result search(Query query, Execution execution) {
Result result = execution.search(query);
execution.fill(result);
- return complete(query, buildPrompt(query, result));
+
+ // Todo: Add resulting prompt to the result
+ Prompt prompt = buildPrompt(query, result);
+
+ return complete(query, prompt);
}
protected Prompt buildPrompt(Query query, Result result) {
- String propertyWithPrefix = this.getPropertyPrefix() + ".prompt";
- String prompt = query.properties().getString(propertyWithPrefix);
- if (prompt == null) {
- prompt = "Please provide a summary of the above";
+ String prompt = getPrompt(query);
+
+ // Replace @query with the actual query
+ if (prompt.contains("@query")) {
+ prompt = prompt.replace("@query", query.getModel().getQueryString());
}
- if (!prompt.contains("{context}")) {
+
+ if ( !prompt.contains("{context}")) {
prompt = "{context}\n" + prompt;
}
- // Todo: support system and user prompt
prompt = prompt.replace("{context}", buildContext(result));
log.info("Prompt: " + prompt); // remove
return StringPrompt.from(prompt);
}
+ private String getPrompt(Query query) {
+ // First, check if prompt is set with a prefix
+ String propertyWithPrefix = this.getPropertyPrefix() + "." + PROMPT;
+ String prompt = query.properties().getString(propertyWithPrefix);
+ if (prompt != null)
+ return prompt;
+
+ // If not, try without prefix
+ prompt = query.properties().getString(PROMPT);
+ if (prompt != null)
+ return prompt;
+
+ // If not, use query
+ prompt = query.getModel().getQueryString();
+ if (prompt != null)
+ return prompt;
+
+ // If not, throw exception
+ throw new IllegalArgumentException("RAG searcher could not find prompt found for query. Tried looking for " +
+ "'" + propertyWithPrefix + "." + PROMPT + "', '" + PROMPT + "' or '@query'.");
+ }
+
private String buildContext(Result result) {
StringBuilder sb = new StringBuilder();
var hits = result.hits();
diff --git a/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java b/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java
deleted file mode 100644
index bd3cda5a6dc..00000000000
--- a/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java
+++ /dev/null
@@ -1,58 +0,0 @@
-package ai.vespa.search.llm;
-
-import com.yahoo.processing.response.DefaultIncomingData;
-import com.yahoo.search.result.DefaultErrorHit;
-import com.yahoo.search.result.ErrorMessage;
-import com.yahoo.search.result.Hit;
-import com.yahoo.search.result.HitGroup;
-import com.yahoo.search.result.Relevance;
-
-public class TokenStream extends HitGroup {
-
- private int tokenCount = 0;
-
- private TokenStream(String id, DefaultIncomingData<Hit> incomingData) {
- super(id, new Relevance(1), incomingData);
- this.setOrdered(true); // avoid hit group ordering - important for errors
- }
-
- public static TokenStream create(String id) {
- DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>();
- TokenStream stream = new TokenStream(id, incomingData);
- incomingData.assignOwner(stream);
- return stream;
- }
-
- public static HitGroup createAsync(String id) {
- return create(id);
- }
-
- public void add(String token) {
- incoming().add(new Token(String.valueOf(++tokenCount), token));
- }
-
- public void error(String source, ErrorMessage message) {
- incoming().add(new DefaultErrorHit(source, message));
- }
-
- public void markComplete() {
- incoming().markComplete();
- }
-
- public static class Token extends Hit {
-
- public Token(String token) {
- this("", token);
- }
-
- public Token(String id, String token) {
- super(id);
- setField("token", token);
- }
-
- public String toString() {
- return getField("token").toString();
- }
-
- }
-}
diff --git a/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java b/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java
new file mode 100644
index 00000000000..dd9ed01bd73
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java
@@ -0,0 +1,91 @@
+package ai.vespa.search.llm.interfaces;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.search.llm.LocalLlmInterfaceConfig;
+
+import ai.vespa.util.http.hc4.retry.Sleeper;
+import com.yahoo.component.annotation.Inject;
+import de.kherud.llama.LlamaModel;
+import de.kherud.llama.LogLevel;
+import de.kherud.llama.ModelParameters;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.Consumer;
+import java.util.logging.Logger;
+
+public class LocalLLMInterface implements LanguageModel {
+
+ private static Logger logger = Logger.getLogger(LocalLLMInterface.class.getName());
+ private final LlamaModel model;
+ private final ExecutorService executor;
+
+ @Inject
+ public LocalLLMInterface(LocalLlmInterfaceConfig config) {
+ this(config, Executors.newFixedThreadPool(1)); // until we can run llama.cpp in batch
+ }
+
+ LocalLLMInterface(LocalLlmInterfaceConfig config, ExecutorService executor) {
+ this.executor = executor;
+
+ LlamaModel.setLogger(this::log);
+ var modelParams = new ModelParameters()
+ // Todo: retrieve from config
+ ;
+
+ long startLoad = System.nanoTime();
+ model = new LlamaModel(config.llmfile(), modelParams);
+ long loadTime = System.nanoTime() - startLoad;
+ logger.info("Loaded model " + config.llmfile() + " in " + (loadTime*1.0/1000000000) + " sec");
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
+ StringBuilder result = new StringBuilder();
+ var future = completeAsync(prompt, options, completion -> {
+ result.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ var reason = future.join();
+
+ List<Completion> completions = new ArrayList<>();
+ completions.add(new Completion(result.toString(), reason));
+ return completions;
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
+ var inferParams = new de.kherud.llama.InferenceParameters();
+ options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
+ options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
+ options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
+ options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
+ options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
+ // Todo: add more
+
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ executor.submit(() -> {
+ for (LlamaModel.Output output : model.generate(prompt.asString(), inferParams)) {
+ consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ });
+
+ return completionFuture;
+ }
+
+ private void log(LogLevel level, String message) {
+ switch (level) {
+ case WARN -> logger.warning(message);
+ case DEBUG -> logger.fine(message);
+ case ERROR -> logger.severe(message);
+ default -> logger.info(message);
+ }
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 0b0c018e1ac..4cca0fcce9f 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -1,6 +1,6 @@
package com.yahoo.search.rendering;
-import ai.vespa.search.llm.TokenStream;
+import com.yahoo.search.result.EventStream;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonFactoryBuilder;
@@ -18,19 +18,16 @@ import com.yahoo.search.result.ErrorMessage;
import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.Executor;
-import java.util.logging.Logger;
import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE;
/**
- *
- * A comment about SSE
+ * A Server-Sent Events (SSE) renderer for asynchronous events such as
+ * tokens from a language model.
*
* @author lesters
*/
-public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
-
- private static final Logger log = Logger.getLogger(TokenRenderer.class.getName());
+public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
private static final JsonFactory generatorFactory = createGeneratorFactory();
private volatile JsonGenerator generator;
@@ -43,14 +40,14 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
return factory;
}
- private static final boolean RENDER_TOKEN_EVENT_HEADER = true;
+ private static final boolean RENDER_EVENT_HEADER = true;
private static final boolean RENDER_END_EVENT = true;
- public TokenRenderer() {
+ public EventRenderer() {
this(null);
}
- public TokenRenderer(Executor executor) {
+ public EventRenderer(Executor executor) {
super(executor);
}
@@ -62,21 +59,21 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void beginList(DataList<?> dataList) throws IOException {
- if ( ! (dataList instanceof TokenStream)) {
- throw new IllegalArgumentException("TokenRenderer currently only supports TokenStreams");
+ if ( ! (dataList instanceof EventStream)) {
+ throw new IllegalArgumentException("EventRenderer currently only supports EventStreams");
// Todo: support results and timing and trace by delegating to JsonRenderer
}
}
@Override
public void data(Data data) throws IOException {
- if (data instanceof TokenStream.Token token) {
- if (RENDER_TOKEN_EVENT_HEADER) {
- generator.writeRaw("event: token\n");
+ if (data instanceof EventStream.Event event) {
+ if (RENDER_EVENT_HEADER) {
+ generator.writeRaw("event: " + event.type() + "\n");
}
generator.writeRaw("data: ");
generator.writeStartObject();
- generator.writeStringField("token", token.toString());
+ generator.writeStringField(event.type(), event.toString());
generator.writeEndObject();
generator.writeRaw("\n\n");
generator.flush();
@@ -115,7 +112,7 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public String getMimeType() {
- return "application/json";
+ return "text/event-stream";
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
index ad6875b6c9b..d62860afcda 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
@@ -24,7 +24,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
public static final ComponentId xmlRendererId = ComponentId.fromString("XmlRenderer");
public static final ComponentId pageRendererId = ComponentId.fromString("PageTemplatesXmlRenderer");
public static final ComponentId jsonRendererId = ComponentId.fromString("JsonRenderer");
- public static final ComponentId tokenRendererId = ComponentId.fromString("TokenRenderer");
+ public static final ComponentId eventRendererId = ComponentId.fromString("EventRenderer");
public static final ComponentId defaultRendererId = jsonRendererId;
@@ -57,10 +57,10 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
pageRenderer.initId(pageRendererId);
register(pageRenderer.getId(), pageRenderer);
- // Add token renderer
- Renderer tokenRenderer = new TokenRenderer(executor);
- tokenRenderer.initId(tokenRendererId);
- register(tokenRenderer.getId(), tokenRenderer);
+ // Add event renderer
+ Renderer eventRenderer = new EventRenderer(executor);
+ eventRenderer.initId(eventRendererId);
+ register(eventRenderer.getId(), eventRenderer);
// add application renderers
for (Renderer renderer : renderers)
@@ -75,7 +75,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
getRenderer(jsonRendererId.toSpecification()).deconstruct();
getRenderer(xmlRendererId.toSpecification()).deconstruct();
getRenderer(pageRendererId.toSpecification()).deconstruct();
- getRenderer(tokenRendererId.toSpecification()).deconstruct();
+ getRenderer(eventRendererId.toSpecification()).deconstruct();
}
/**
@@ -99,7 +99,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
if (format.stringValue().equals("json")) return getComponent(jsonRendererId);
if (format.stringValue().equals("xml")) return getComponent(xmlRendererId);
if (format.stringValue().equals("page")) return getComponent(pageRendererId);
- if (format.stringValue().equals("token")) return getComponent(tokenRendererId);
+ if (format.stringValue().equals("sse")) return getComponent(eventRendererId);
com.yahoo.processing.rendering.Renderer<Result> renderer = getComponent(format);
if (renderer == null)
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
new file mode 100644
index 00000000000..957a027608e
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -0,0 +1,66 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.search.result;
+
+import com.yahoo.processing.response.DefaultIncomingData;
+
+/**
+ * A stream of events which can be rendered as Server-Sent Events (SSE).
+ *
+ * @author lesters
+ */
+public class EventStream extends HitGroup {
+
+ private int eventCount = 0;
+
+ public static final String DEFAULT_EVENT_TYPE = "token";
+
+ private EventStream(String id, DefaultIncomingData<Hit> incomingData) {
+ super(id, new Relevance(1), incomingData);
+ this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept
+ }
+
+ public static EventStream create(String id) {
+ DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>();
+ EventStream stream = new EventStream(id, incomingData);
+ incomingData.assignOwner(stream);
+ return stream;
+ }
+
+ public void add(String data) {
+ add(data, DEFAULT_EVENT_TYPE);
+ }
+
+ public void add(String data, String type) {
+ incoming().add(new Event(String.valueOf(eventCount + 1), data, type));
+ eventCount++;
+ }
+
+ public void error(String source, ErrorMessage message) {
+ incoming().add(new DefaultErrorHit(source, message));
+ }
+
+ public void markComplete() {
+ incoming().markComplete();
+ }
+
+ public static class Event extends Hit {
+
+ private final String type;
+
+ public Event(String id, String data, String type) {
+ super(id);
+ this.type = type;
+ setField(type, data);
+ }
+
+ public String toString() {
+ return getField(type).toString();
+ }
+
+ public String type() {
+ return type;
+ }
+
+ }
+}
diff --git a/container-search/src/main/resources/configdefinitions/local-llm-interface.def b/container-search/src/main/resources/configdefinitions/local-llm-interface.def
new file mode 100755
index 00000000000..5d26471d2fe
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/local-llm-interface.def
@@ -0,0 +1,6 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.search.llm
+
+# Something that points to either a url or local path?
+llmfile string default=""
+