aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-03-08 10:57:27 +0100
committerLester Solbakken <lester.solbakken@gmail.com>2024-03-08 10:57:27 +0100
commit02c1d07d6f6aeb473d7fbc941d4a8e87ed062ffe (patch)
tree76f8cb55371a4b957d7ef27abaa14ebafbb6865f
parent44a8c8eaf8b2f56dd8e89a0ac55917362d751ea8 (diff)
-rw-r--r--cloud-tenant-base-dependencies-enforcer/pom.xml2
-rw-r--r--config-model-fat/pom.xml2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java3
-rw-r--r--container-dependencies-enforcer/pom.xml2
-rw-r--r--container-search/abi-spec.json74
-rw-r--r--container-search/pom.xml8
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java20
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java43
-rw-r--r--container-search/src/main/java/ai/vespa/search/llm/TokenStream.java58
-rw-r--r--container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java91
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java (renamed from container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java)31
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java66
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/local-llm-interface.def6
-rw-r--r--container-search/src/test/java/ai/vespa/search/llm/interfaces/LocalLLMInterfaceTest.java124
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/interfaces/OpenAIInterfaceTest.java3
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java (renamed from container-search/src/test/java/com/yahoo/search/rendering/TokenRendererTestCase.java)50
-rw-r--r--vespa-dependencies-enforcer/allowed-maven-dependencies.txt2
-rw-r--r--vespajlib/abi-spec.json3
-rwxr-xr-xvespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java6
20 files changed, 498 insertions, 110 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml
index 2cc573bc096..a2ce268dfd3 100644
--- a/cloud-tenant-base-dependencies-enforcer/pom.xml
+++ b/cloud-tenant-base-dependencies-enforcer/pom.xml
@@ -109,6 +109,8 @@
<include>com.yahoo.vespa:vespalog:*:provided</include>
<!-- Vespa test dependencies -->
+ <include>de.kherud:llama:*:*</include>
+ <include>org.jetbrains:annotations:jar:*:*</include>
<include>com.yahoo.vespa:application:*:test</include>
<include>com.yahoo.vespa:cloud-tenant-cd:*:test</include>
<include>com.yahoo.vespa:config-application-package:*:test</include>
diff --git a/config-model-fat/pom.xml b/config-model-fat/pom.xml
index db97d5d2e2c..5fdf015c031 100644
--- a/config-model-fat/pom.xml
+++ b/config-model-fat/pom.xml
@@ -184,6 +184,8 @@
<i>com.yahoo.vespa:vsm:*:*</i>
<!-- 3rd party artifacts embedded -->
+ <i>de.kherud:llama:*:*</i>
+ <i>org.jetbrains:annotations:jar:*:*</i>
<i>aopalliance:aopalliance:*:*</i>
<i>com.google.errorprone:error_prone_annotations:*:*</i>
<i>com.google.guava:failureaccess:*:*</i>
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
index 4510a9d68bc..d1e77b90100 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
@@ -156,7 +156,8 @@ public class PlatformBundles {
"com.yahoo.vespa.streamingvisitors.MetricsSearcher",
"com.yahoo.vespa.streamingvisitors.VdsStreamingSearcher",
"ai.vespa.search.llm.RAGSearcher",
- "ai.vespa.search.llm.interfaces.OpenAIInterface"
+ "ai.vespa.search.llm.interfaces.OpenAIInterface",
+ "ai.vespa.search.llm.interfaces.LocalLLMInterface"
);
}
diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml
index 4f624d0a870..9feea68d8cc 100644
--- a/container-dependencies-enforcer/pom.xml
+++ b/container-dependencies-enforcer/pom.xml
@@ -148,6 +148,8 @@
<include>com.yahoo.vespa:vsm:*:test</include>
<!-- 3rd party test dependencies -->
+ <include>de.kherud:llama:*:*</include>
+ <include>org.jetbrains:annotations:jar:*:*</include>
<include>com.google.code.findbugs:jsr305:${findbugs.vespa.version}:test</include>
<include>com.google.protobuf:protobuf-java:${protobuf.vespa.version}:test</include>
<include>com.ibm.icu:icu4j:${icu4j.vespa.version}:test</include>
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 8227a9351ba..b7cc4887a64 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -9230,6 +9230,65 @@
"public static final java.lang.String[] CONFIG_DEF_SCHEMA"
]
},
+ "ai.vespa.search.llm.LocalLlmInterfaceConfig$Builder" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.search.llm.LocalLlmInterfaceConfig)",
+ "public ai.vespa.search.llm.LocalLlmInterfaceConfig$Builder llmfile(java.lang.String)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public ai.vespa.search.llm.LocalLlmInterfaceConfig build()"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.search.llm.LocalLlmInterfaceConfig$Producer" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes" : [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods" : [
+ "public abstract void getConfig(ai.vespa.search.llm.LocalLlmInterfaceConfig$Builder)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.search.llm.LocalLlmInterfaceConfig" : {
+ "superClass" : "com.yahoo.config.ConfigInstance",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public void <init>(ai.vespa.search.llm.LocalLlmInterfaceConfig$Builder)",
+ "public java.lang.String llmfile()"
+ ],
+ "fields" : [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
"ai.vespa.search.llm.RAGSearcher" : {
"superClass" : "ai.vespa.search.llm.LLMSearcher",
"interfaces" : [ ],
@@ -9289,6 +9348,21 @@
],
"fields" : [ ]
},
+ "ai.vespa.search.llm.interfaces.LocalLLMInterface" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.search.llm.LocalLlmInterfaceConfig)",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
"ai.vespa.search.llm.interfaces.OpenAIInterface" : {
"superClass" : "ai.vespa.search.llm.interfaces.LLMInterface",
"interfaces" : [ ],
diff --git a/container-search/pom.xml b/container-search/pom.xml
index 5e7c60d49c3..d6ed94e504f 100644
--- a/container-search/pom.xml
+++ b/container-search/pom.xml
@@ -88,6 +88,14 @@
<scope>provided</scope>
</dependency>
+ <!-- Temporary: this needs to be moved out of here -->
+ <dependency>
+ <groupId>de.kherud</groupId>
+ <artifactId>llama</artifactId>
+ <version>2.3.5</version>
+ <scope>compile</scope>
+ </dependency>
+
<dependency>
<groupId>xerces</groupId>
<artifactId>xercesImpl</artifactId>
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 50ea2646f9f..2b1c553a675 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -7,7 +7,6 @@ import ai.vespa.llm.LanguageModelException;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
-import ai.vespa.search.llm.LlmSearcherConfig;
import com.yahoo.api.annotations.Beta;
import com.yahoo.component.ComponentId;
import com.yahoo.component.annotation.Inject;
@@ -16,6 +15,7 @@ import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
@@ -42,7 +42,7 @@ public abstract class LLMSearcher extends Searcher {
LLMSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
this.stream = config.stream();
this.languageModelId = config.providerId();
- this.languageModel = findLanguageModel(config.providerId(), languageModels);
+ this.languageModel = findLanguageModel(languageModelId, languageModels);
this.propertyPrefix = config.propertyPrefix();
}
@@ -98,23 +98,27 @@ public abstract class LLMSearcher extends Searcher {
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- TokenStream tokenStream = TokenStream.create("token_stream");
+ EventStream eventStream = EventStream.create("token_stream");
+
+ if (query.getTrace().getLevel() >= 1) {
+ eventStream.add(prompt.asString(), "prompt");
+ }
languageModel.completeAsync(prompt, options, token -> {
- tokenStream.add(token.text());
+ eventStream.add(token.text());
}).exceptionally(exception -> {
int errorCode = 400;
if (exception instanceof LanguageModelException languageModelException) {
errorCode = languageModelException.code();
}
- tokenStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
- tokenStream.markComplete();
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
- tokenStream.markComplete();
+ eventStream.markComplete();
});
- return new Result(query, tokenStream);
+ return new Result(query, eventStream);
}
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
diff --git a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
index 92fd904d709..05c6335139b 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
@@ -20,6 +20,8 @@ public class RAGSearcher extends LLMSearcher {
private static Logger log = Logger.getLogger(RAGSearcher.class.getName());
+ private static final String PROMPT = "prompt";
+
@Inject
public RAGSearcher(LlmSearcherConfig config, ComponentRegistry<LanguageModel> languageModels) {
super(config, languageModels);
@@ -30,24 +32,51 @@ public class RAGSearcher extends LLMSearcher {
public Result search(Query query, Execution execution) {
Result result = execution.search(query);
execution.fill(result);
- return complete(query, buildPrompt(query, result));
+
+ // Todo: Add resulting prompt to the result
+ Prompt prompt = buildPrompt(query, result);
+
+ return complete(query, prompt);
}
protected Prompt buildPrompt(Query query, Result result) {
- String propertyWithPrefix = this.getPropertyPrefix() + ".prompt";
- String prompt = query.properties().getString(propertyWithPrefix);
- if (prompt == null) {
- prompt = "Please provide a summary of the above";
+ String prompt = getPrompt(query);
+
+ // Replace @query with the actual query
+ if (prompt.contains("@query")) {
+ prompt = prompt.replace("@query", query.getModel().getQueryString());
}
- if (!prompt.contains("{context}")) {
+
+ if ( !prompt.contains("{context}")) {
prompt = "{context}\n" + prompt;
}
- // Todo: support system and user prompt
prompt = prompt.replace("{context}", buildContext(result));
log.info("Prompt: " + prompt); // remove
return StringPrompt.from(prompt);
}
+ private String getPrompt(Query query) {
+ // First, check if prompt is set with a prefix
+ String propertyWithPrefix = this.getPropertyPrefix() + "." + PROMPT;
+ String prompt = query.properties().getString(propertyWithPrefix);
+ if (prompt != null)
+ return prompt;
+
+ // If not, try without prefix
+ prompt = query.properties().getString(PROMPT);
+ if (prompt != null)
+ return prompt;
+
+ // If not, use query
+ prompt = query.getModel().getQueryString();
+ if (prompt != null)
+ return prompt;
+
+ // If not, throw exception
+ throw new IllegalArgumentException("RAG searcher could not find prompt found for query. Tried looking for " +
+ "'" + propertyWithPrefix + "." + PROMPT + "', '" + PROMPT + "' or '@query'.");
+ }
+
private String buildContext(Result result) {
StringBuilder sb = new StringBuilder();
var hits = result.hits();
diff --git a/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java b/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java
deleted file mode 100644
index bd3cda5a6dc..00000000000
--- a/container-search/src/main/java/ai/vespa/search/llm/TokenStream.java
+++ /dev/null
@@ -1,58 +0,0 @@
-package ai.vespa.search.llm;
-
-import com.yahoo.processing.response.DefaultIncomingData;
-import com.yahoo.search.result.DefaultErrorHit;
-import com.yahoo.search.result.ErrorMessage;
-import com.yahoo.search.result.Hit;
-import com.yahoo.search.result.HitGroup;
-import com.yahoo.search.result.Relevance;
-
-public class TokenStream extends HitGroup {
-
- private int tokenCount = 0;
-
- private TokenStream(String id, DefaultIncomingData<Hit> incomingData) {
- super(id, new Relevance(1), incomingData);
- this.setOrdered(true); // avoid hit group ordering - important for errors
- }
-
- public static TokenStream create(String id) {
- DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>();
- TokenStream stream = new TokenStream(id, incomingData);
- incomingData.assignOwner(stream);
- return stream;
- }
-
- public static HitGroup createAsync(String id) {
- return create(id);
- }
-
- public void add(String token) {
- incoming().add(new Token(String.valueOf(++tokenCount), token));
- }
-
- public void error(String source, ErrorMessage message) {
- incoming().add(new DefaultErrorHit(source, message));
- }
-
- public void markComplete() {
- incoming().markComplete();
- }
-
- public static class Token extends Hit {
-
- public Token(String token) {
- this("", token);
- }
-
- public Token(String id, String token) {
- super(id);
- setField("token", token);
- }
-
- public String toString() {
- return getField("token").toString();
- }
-
- }
-}
diff --git a/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java b/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java
new file mode 100644
index 00000000000..dd9ed01bd73
--- /dev/null
+++ b/container-search/src/main/java/ai/vespa/search/llm/interfaces/LocalLLMInterface.java
@@ -0,0 +1,91 @@
+package ai.vespa.search.llm.interfaces;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.search.llm.LocalLlmInterfaceConfig;
+
+import ai.vespa.util.http.hc4.retry.Sleeper;
+import com.yahoo.component.annotation.Inject;
+import de.kherud.llama.LlamaModel;
+import de.kherud.llama.LogLevel;
+import de.kherud.llama.ModelParameters;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.function.Consumer;
+import java.util.logging.Logger;
+
+public class LocalLLMInterface implements LanguageModel {
+
+ private static Logger logger = Logger.getLogger(LocalLLMInterface.class.getName());
+ private final LlamaModel model;
+ private final ExecutorService executor;
+
+ @Inject
+ public LocalLLMInterface(LocalLlmInterfaceConfig config) {
+ this(config, Executors.newFixedThreadPool(1)); // until we can run llama.cpp in batch
+ }
+
+ LocalLLMInterface(LocalLlmInterfaceConfig config, ExecutorService executor) {
+ this.executor = executor;
+
+ LlamaModel.setLogger(this::log);
+ var modelParams = new ModelParameters()
+ // Todo: retrieve from config
+ ;
+
+ long startLoad = System.nanoTime();
+ model = new LlamaModel(config.llmfile(), modelParams);
+ long loadTime = System.nanoTime() - startLoad;
+ logger.info("Loaded model " + config.llmfile() + " in " + (loadTime*1.0/1000000000) + " sec");
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
+ StringBuilder result = new StringBuilder();
+ var future = completeAsync(prompt, options, completion -> {
+ result.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ var reason = future.join();
+
+ List<Completion> completions = new ArrayList<>();
+ completions.add(new Completion(result.toString(), reason));
+ return completions;
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
+ var inferParams = new de.kherud.llama.InferenceParameters();
+ options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
+ options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
+ options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
+ options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
+ options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
+ // Todo: add more
+
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ executor.submit(() -> {
+ for (LlamaModel.Output output : model.generate(prompt.asString(), inferParams)) {
+ consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ });
+
+ return completionFuture;
+ }
+
+ private void log(LogLevel level, String message) {
+ switch (level) {
+ case WARN -> logger.warning(message);
+ case DEBUG -> logger.fine(message);
+ case ERROR -> logger.severe(message);
+ default -> logger.info(message);
+ }
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 0b0c018e1ac..4cca0fcce9f 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/TokenRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -1,6 +1,6 @@
package com.yahoo.search.rendering;
-import ai.vespa.search.llm.TokenStream;
+import com.yahoo.search.result.EventStream;
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonFactoryBuilder;
@@ -18,19 +18,16 @@ import com.yahoo.search.result.ErrorMessage;
import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.Executor;
-import java.util.logging.Logger;
import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE;
/**
- *
- * A comment about SSE
+ * A Server-Sent Events (SSE) renderer for asynchronous events such as
+ * tokens from a language model.
*
* @author lesters
*/
-public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
-
- private static final Logger log = Logger.getLogger(TokenRenderer.class.getName());
+public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
private static final JsonFactory generatorFactory = createGeneratorFactory();
private volatile JsonGenerator generator;
@@ -43,14 +40,14 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
return factory;
}
- private static final boolean RENDER_TOKEN_EVENT_HEADER = true;
+ private static final boolean RENDER_EVENT_HEADER = true;
private static final boolean RENDER_END_EVENT = true;
- public TokenRenderer() {
+ public EventRenderer() {
this(null);
}
- public TokenRenderer(Executor executor) {
+ public EventRenderer(Executor executor) {
super(executor);
}
@@ -62,21 +59,21 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void beginList(DataList<?> dataList) throws IOException {
- if ( ! (dataList instanceof TokenStream)) {
- throw new IllegalArgumentException("TokenRenderer currently only supports TokenStreams");
+ if ( ! (dataList instanceof EventStream)) {
+ throw new IllegalArgumentException("EventRenderer currently only supports EventStreams");
// Todo: support results and timing and trace by delegating to JsonRenderer
}
}
@Override
public void data(Data data) throws IOException {
- if (data instanceof TokenStream.Token token) {
- if (RENDER_TOKEN_EVENT_HEADER) {
- generator.writeRaw("event: token\n");
+ if (data instanceof EventStream.Event event) {
+ if (RENDER_EVENT_HEADER) {
+ generator.writeRaw("event: " + event.type() + "\n");
}
generator.writeRaw("data: ");
generator.writeStartObject();
- generator.writeStringField("token", token.toString());
+ generator.writeStringField(event.type(), event.toString());
generator.writeEndObject();
generator.writeRaw("\n\n");
generator.flush();
@@ -115,7 +112,7 @@ public class TokenRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public String getMimeType() {
- return "application/json";
+ return "text/event-stream";
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
index ad6875b6c9b..d62860afcda 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
@@ -24,7 +24,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
public static final ComponentId xmlRendererId = ComponentId.fromString("XmlRenderer");
public static final ComponentId pageRendererId = ComponentId.fromString("PageTemplatesXmlRenderer");
public static final ComponentId jsonRendererId = ComponentId.fromString("JsonRenderer");
- public static final ComponentId tokenRendererId = ComponentId.fromString("TokenRenderer");
+ public static final ComponentId eventRendererId = ComponentId.fromString("EventRenderer");
public static final ComponentId defaultRendererId = jsonRendererId;
@@ -57,10 +57,10 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
pageRenderer.initId(pageRendererId);
register(pageRenderer.getId(), pageRenderer);
- // Add token renderer
- Renderer tokenRenderer = new TokenRenderer(executor);
- tokenRenderer.initId(tokenRendererId);
- register(tokenRenderer.getId(), tokenRenderer);
+ // Add event renderer
+ Renderer eventRenderer = new EventRenderer(executor);
+ eventRenderer.initId(eventRendererId);
+ register(eventRenderer.getId(), eventRenderer);
// add application renderers
for (Renderer renderer : renderers)
@@ -75,7 +75,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
getRenderer(jsonRendererId.toSpecification()).deconstruct();
getRenderer(xmlRendererId.toSpecification()).deconstruct();
getRenderer(pageRendererId.toSpecification()).deconstruct();
- getRenderer(tokenRendererId.toSpecification()).deconstruct();
+ getRenderer(eventRendererId.toSpecification()).deconstruct();
}
/**
@@ -99,7 +99,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
if (format.stringValue().equals("json")) return getComponent(jsonRendererId);
if (format.stringValue().equals("xml")) return getComponent(xmlRendererId);
if (format.stringValue().equals("page")) return getComponent(pageRendererId);
- if (format.stringValue().equals("token")) return getComponent(tokenRendererId);
+ if (format.stringValue().equals("sse")) return getComponent(eventRendererId);
com.yahoo.processing.rendering.Renderer<Result> renderer = getComponent(format);
if (renderer == null)
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
new file mode 100644
index 00000000000..957a027608e
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -0,0 +1,66 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.search.result;
+
+import com.yahoo.processing.response.DefaultIncomingData;
+
+/**
+ * A stream of events which can be rendered as Server-Sent Events (SSE).
+ *
+ * @author lesters
+ */
+public class EventStream extends HitGroup {
+
+ private int eventCount = 0;
+
+ public static final String DEFAULT_EVENT_TYPE = "token";
+
+ private EventStream(String id, DefaultIncomingData<Hit> incomingData) {
+ super(id, new Relevance(1), incomingData);
+ this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept
+ }
+
+ public static EventStream create(String id) {
+ DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>();
+ EventStream stream = new EventStream(id, incomingData);
+ incomingData.assignOwner(stream);
+ return stream;
+ }
+
+ public void add(String data) {
+ add(data, DEFAULT_EVENT_TYPE);
+ }
+
+ public void add(String data, String type) {
+ incoming().add(new Event(String.valueOf(eventCount + 1), data, type));
+ eventCount++;
+ }
+
+ public void error(String source, ErrorMessage message) {
+ incoming().add(new DefaultErrorHit(source, message));
+ }
+
+ public void markComplete() {
+ incoming().markComplete();
+ }
+
+ public static class Event extends Hit {
+
+ private final String type;
+
+ public Event(String id, String data, String type) {
+ super(id);
+ this.type = type;
+ setField(type, data);
+ }
+
+ public String toString() {
+ return getField(type).toString();
+ }
+
+ public String type() {
+ return type;
+ }
+
+ }
+}
diff --git a/container-search/src/main/resources/configdefinitions/local-llm-interface.def b/container-search/src/main/resources/configdefinitions/local-llm-interface.def
new file mode 100755
index 00000000000..5d26471d2fe
--- /dev/null
+++ b/container-search/src/main/resources/configdefinitions/local-llm-interface.def
@@ -0,0 +1,6 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.search.llm
+
+# Something that points to either a url or local path?
+llmfile string default=""
+
diff --git a/container-search/src/test/java/ai/vespa/search/llm/interfaces/LocalLLMInterfaceTest.java b/container-search/src/test/java/ai/vespa/search/llm/interfaces/LocalLLMInterfaceTest.java
new file mode 100644
index 00000000000..34ac01aa6ec
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/search/llm/interfaces/LocalLLMInterfaceTest.java
@@ -0,0 +1,124 @@
+package ai.vespa.search.llm.interfaces;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import ai.vespa.search.llm.LocalLlmInterfaceConfig;
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class LocalLLMInterfaceTest {
+
+ private static String model = "/Users/lesters/dev/openai/models/mistral-7b-instruct-v0.1.Q8_0.gguf";
+ private static Prompt prompt = StringPrompt.from("Why are ducks better than cats? Be concise, " +
+ "but use the word 'spoon' somewhere in your answer.");
+
+ @Test
+ public void testGeneration() {
+ var result = createLLM(model).complete(prompt, defaultOptions());
+ assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
+ assertTrue(result.get(0).text().contains("spoon"));
+ }
+
+ @Test
+ public void testAsyncGeneration() {
+ var executor = Executors.newFixedThreadPool(1);
+ var sb = new StringBuilder();
+ Prompt prompt = StringPrompt.from("sddocname: passage\n" +
+ "id: 2\n" +
+ "text: Essay on The Manhattan Project - The Manhattan Project The Manhattan Project was to see if making an atomic bomb possible. The success of this project would forever change the world forever making it known that something this powerful can be manmade.\n" +
+ "documentid: id:msmarco:passage::2\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 0\n" +
+ "text: The presence of communication amid scientific minds was equally important to the success of the Manhattan Project as scientific intellect was. The only cloud hanging over the impressive achievement of the atomic researchers and engineers is what their success truly meant; hundreds of thousands of innocent lives obliterated.\n" +
+ "documentid: id:msmarco:passage::0\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 7\n" +
+ "text: Manhattan Project. The Manhattan Project was a research and development undertaking during World War II that produced the first nuclear weapons. It was led by the United States with the support of the United Kingdom and Canada. From 1942 to 1946, the project was under the direction of Major General Leslie Groves of the U.S. Army Corps of Engineers. Nuclear physicist Robert Oppenheimer was the director of the Los Alamos Laboratory that designed the actual bombs. The Army component of the project was designated the\n" +
+ "documentid: id:msmarco:passage::7\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 3\n" +
+ "text: The Manhattan Project was the name for a project conducted during World War II, to develop the first atomic bomb. It refers specifically to the period of the project from 194 … 2-1946 under the control of the U.S. Army Corps of Engineers, under the administration of General Leslie R. Groves.\n" +
+ "documentid: id:msmarco:passage::3\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 9\n" +
+ "text: One of the main reasons Hanford was selected as a site for the Manhattan Project's B Reactor was its proximity to the Columbia River, the largest river flowing into the Pacific Ocean from the North American coast.\n" +
+ "documentid: id:msmarco:passage::9\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 5\n" +
+ "text: The Manhattan Project. This once classified photograph features the first atomic bomb — a weapon that atomic scientists had nicknamed Gadget.. The nuclear age began on July 16, 1945, when it was detonated in the New Mexico desert.\n" +
+ "documentid: id:msmarco:passage::5\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 8\n" +
+ "text: In June 1942, the United States Army Corps of Engineersbegan the Manhattan Project- The secret name for the 2 atomic bombs.\n" +
+ "documentid: id:msmarco:passage::8\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 1\n" +
+ "text: The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science.\n" +
+ "documentid: id:msmarco:passage::1\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 6\n" +
+ "text: Nor will it attempt to substitute for the extraordinarily rich literature on the atomic bombs and the end of World War II. This collection does not attempt to document the origins and development of the Manhattan Project.\n" +
+ "documentid: id:msmarco:passage::6\n" +
+ "\n" +
+ "sddocname: passage\n" +
+ "id: 4\n" +
+ "text: versions of each volume as well as complementary websites. The first website–The Manhattan Project: An Interactive History–is available on the Office of History and Heritage Resources website, http://www.cfo. doe.gov/me70/history. The Office of History and Heritage Resources and the National Nuclear Security\n" +
+ "documentid: id:msmarco:passage::4\n" +
+ "\n" +
+ "\n" +
+// "Given the documents above, what was the id of the last passage given?");
+ "Rank the documents above according to their relevance to the Manhattan Project. Answer using a json structure.");
+ try {
+ var future = createLLM(model, executor).completeAsync(prompt, defaultOptions(), completion -> {
+ sb.append(completion.text());
+ System.out.println(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ assertFalse(future.isDone());
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+ } finally {
+ executor.shutdownNow();
+ }
+ System.out.println(sb);
+ assertTrue(sb.toString().contains("spoon"));
+ }
+
+ private static InferenceParameters defaultOptions() {
+ final Map<String, String> options = Map.of(
+ "temperature", "0.0",
+ "npredict", "10"
+ );
+ return new InferenceParameters(options::get);
+ }
+
+ private static LocalLLMInterface createLLM(String modelPath) {
+ var config = new LocalLlmInterfaceConfig.Builder().llmfile(modelPath).build();
+ return new LocalLLMInterface(config);
+ }
+
+ private static LocalLLMInterface createLLM(String modelPath, ExecutorService executor) {
+ var config = new LocalLlmInterfaceConfig.Builder().llmfile(modelPath).build();
+ return new LocalLLMInterface(config, executor);
+ }
+}
diff --git a/container-search/src/test/java/ai/vespa/search/llm/interfaces/OpenAIInterfaceTest.java b/container-search/src/test/java/ai/vespa/search/llm/interfaces/OpenAIInterfaceTest.java
index 7e3ee59a85c..1386bea1a59 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/interfaces/OpenAIInterfaceTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/interfaces/OpenAIInterfaceTest.java
@@ -30,7 +30,8 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
// This should be a test for the language model, not openai searchers
public class OpenAIInterfaceTest {
- private static final String apiKey = "foobar";
+ private static final String apiKey = "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";
+// private static final String apiKey = "foobar";
// Change only to interface actually
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/TokenRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
index b5a10a934fd..cf18aa96d18 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/TokenRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -1,6 +1,6 @@
package com.yahoo.search.rendering;
-import ai.vespa.search.llm.TokenStream;
+import com.yahoo.search.result.EventStream;
import com.yahoo.concurrent.ThreadFactoryFactory;
import com.yahoo.document.DocumentId;
import com.yahoo.search.Query;
@@ -35,24 +35,24 @@ import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
-public class TokenRendererTestCase {
+public class EventRendererTestCase {
private static ThreadPoolExecutor executor;
- private static TokenRenderer blueprint;
- private TokenRenderer renderer;
+ private static EventRenderer blueprint;
+ private EventRenderer renderer;
@BeforeAll
public static void createExecutorAndBlueprint() {
ThreadFactory threadFactory = ThreadFactoryFactory.getThreadFactory("test-rendering");
executor = new ThreadPoolExecutor(4, 4, 1L, TimeUnit.MINUTES, new LinkedBlockingQueue<>(), threadFactory);
executor.prestartAllCoreThreads();
- blueprint = new TokenRenderer(executor);
+ blueprint = new EventRenderer(executor);
}
@BeforeEach
public void createClone() {
// Use the shared renderer as a prototype object, as specified in the API contract
- renderer = (TokenRenderer) blueprint.clone();
+ renderer = (EventRenderer) blueprint.clone();
renderer.init();
}
@@ -96,7 +96,7 @@ public class TokenRendererTestCase {
event: end
""";
- var tokenStream = TokenStream.create("token_stream");
+ var tokenStream = EventStream.create("token_stream");
for (String token : splitter("Ducks have adorable waddling walks")) {
tokenStream.add(token);
}
@@ -129,7 +129,7 @@ public class TokenRendererTestCase {
var result = "";
var executor = Executors.newFixedThreadPool(1);
try {
- var tokenStream = TokenStream.create("token_stream");
+ var tokenStream = EventStream.create("token_stream");
var future = completeAsync("Ducks have adorable waddling walks", executor, token -> {
tokenStream.add(token);
}).exceptionally(e -> {
@@ -151,7 +151,7 @@ public class TokenRendererTestCase {
@Test
public void testErrorEndsStream() throws ExecutionException, InterruptedException {
- var tokenStream = TokenStream.create("token_stream");
+ var tokenStream = EventStream.create("token_stream");
tokenStream.add("token1");
tokenStream.add("token2");
tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong"));
@@ -173,9 +173,39 @@ public class TokenRendererTestCase {
}
@Test
+ public void testPromptRendering() throws ExecutionException, InterruptedException {
+ String prompt = "Why are ducks better than cats?\n\nBe concise.\n";
+
+ var tokenStream = EventStream.create("token_stream");
+ tokenStream.add(prompt, "prompt");
+ tokenStream.add("Just");
+ tokenStream.add(" because");
+ tokenStream.add(".");
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), tokenStream));
+
+ var expected = """
+ event: prompt
+ data: {"prompt":"Why are ducks better than cats?\\n\\nBe concise.\\n"}
+
+ event: token
+ data: {"token":"Just"}
+
+ event: token
+ data: {"token":" because"}
+
+ event: token
+ data: {"token":"."}
+
+ event: end
+ """;
+ assertEquals(expected, result);
+ }
+
+ @Test
@Timeout(5)
public void testResultRenderingFails() {
- var tokenStream = TokenStream.create("token_stream");
+ var tokenStream = EventStream.create("token_stream");
tokenStream.add("token1");
tokenStream.add("token2");
tokenStream.markComplete();
diff --git a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
index 9b1bc20fd7d..fbe0713370a 100644
--- a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
+++ b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
@@ -45,6 +45,7 @@ commons-cli:commons-cli:${commons-cli.vespa.version}
commons-codec:commons-codec:${commons-codec.vespa.version}
commons-io:commons-io:${commons-io.vespa.version}
commons-logging:commons-logging:${commons-logging.vespa.version}
+de.kherud:llama:2.3.5
io.airlift:aircompressor:${aircompressor.vespa.version}
io.airlift:airline:${airline.vespa.version}
io.dropwizard.metrics:metrics-core:${dropwizard.metrics.vespa.version}
@@ -169,6 +170,7 @@ org.hamcrest:hamcrest-core:${hamcrest.vespa.version}
org.hamcrest:hamcrest:${hamcrest.vespa.version}
org.hdrhistogram:HdrHistogram:${hdrhistogram.vespa.version}
org.iq80.snappy:snappy:0.4
+org.jetbrains:annotations:24.0.1
org.json:json:${org.json.vespa.version}
org.junit.jupiter:junit-jupiter-api:${junit.vespa.tenant.version}
org.junit.jupiter:junit-jupiter-api:${junit.vespa.version}
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 0c8f8e7d941..cf761a4d4f1 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -4089,7 +4089,8 @@
"public java.util.Optional getApiKey()",
"public java.util.Optional get(java.lang.String)",
"public java.util.Optional getDouble(java.lang.String)",
- "public java.util.Optional getInt(java.lang.String)"
+ "public java.util.Optional getInt(java.lang.String)",
+ "public void ifPresent(java.lang.String, java.util.function.Consumer)"
],
"fields" : [ ]
},
diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
index 2ef1b97a8b2..8c47ae92e38 100755
--- a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
@@ -4,7 +4,9 @@ import ai.vespa.llm.completion.Prompt;
import java.util.Objects;
import java.util.Optional;
+import java.util.function.Consumer;
import java.util.function.Function;
+import java.util.function.Supplier;
public class InferenceParameters {
@@ -44,5 +46,9 @@ public class InferenceParameters {
}
}
+ public void ifPresent(String option, Consumer<String> func) {
+ get(option).ifPresent(func);
+ }
+
}