summaryrefslogtreecommitdiffstats
path: root/container-search/src
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-03-08 13:19:33 +0100
committerLester Solbakken <lester.solbakken@gmail.com>2024-03-08 13:19:33 +0100
commit25efeef096cb30090c6d1ed0bd804f6a0745adc0 (patch)
treea6d7023f4e5a3df62e6d69d6551be0052ece08af /container-search/src
parent3ede5019a6fe0881917b165166f413c532fe4bc0 (diff)
Add server-sent events (SSE) renderer
Diffstat (limited to 'container-search/src')
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java119
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java8
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java65
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java279
4 files changed, 471 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
new file mode 100644
index 00000000000..46f9a53e698
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -0,0 +1,119 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.rendering;
+
+import com.yahoo.search.result.EventStream;
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonFactoryBuilder;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.core.StreamReadConstraints;
+import com.fasterxml.jackson.core.io.SerializedString;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.processing.rendering.AsynchronousSectionedRenderer;
+import com.yahoo.processing.response.Data;
+import com.yahoo.processing.response.DataList;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.ErrorHit;
+import com.yahoo.search.result.ErrorMessage;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.concurrent.Executor;
+
+import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE;
+
+/**
+ * A Server-Sent Events (SSE) renderer for asynchronous events such as
+ * tokens from a language model.
+ *
+ * @author lesters
+ */
+public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
+
+ private static final JsonFactory generatorFactory = createGeneratorFactory();
+ private volatile JsonGenerator generator;
+
+ private static JsonFactory createGeneratorFactory() {
+ var factory = new JsonFactoryBuilder()
+ .streamReadConstraints(StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build())
+ .build();
+ factory.setCodec(new ObjectMapper(factory).disable(FLUSH_AFTER_WRITE_VALUE));
+ return factory;
+ }
+
+ private static final boolean RENDER_EVENT_HEADER = true;
+ private static final boolean RENDER_END_EVENT = true;
+
+ public EventRenderer() {
+ this(null);
+ }
+
+ public EventRenderer(Executor executor) {
+ super(executor);
+ }
+
+ @Override
+ public void beginResponse(OutputStream outputStream) throws IOException {
+ generator = generatorFactory.createGenerator(outputStream, JsonEncoding.UTF8);
+ generator.setRootValueSeparator(new SerializedString(""));
+ }
+
+ @Override
+ public void beginList(DataList<?> dataList) throws IOException {
+ if ( ! (dataList instanceof EventStream)) {
+ throw new IllegalArgumentException("EventRenderer currently only supports EventStreams");
+ // Todo: support results and timing and trace by delegating to JsonRenderer
+ }
+ }
+
+ @Override
+ public void data(Data data) throws IOException {
+ if (data instanceof EventStream.Event event) {
+ if (RENDER_EVENT_HEADER) {
+ generator.writeRaw("event: " + event.type() + "\n");
+ }
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField(event.type(), event.toString());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ }
+ else if (data instanceof ErrorHit) {
+ for (ErrorMessage error : ((ErrorHit) data).errors()) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.getSource());
+ generator.writeNumberField("error", error.getCode());
+ generator.writeStringField("message", error.getMessage());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ }
+ }
+ }
+
+ @Override
+ public void endList(DataList<?> dataList) throws IOException {
+ }
+
+ @Override
+ public void endResponse() throws IOException {
+ if (RENDER_END_EVENT) {
+ generator.writeRaw("event: end\n");
+ }
+ generator.close();
+ }
+
+ @Override
+ public String getEncoding() {
+ return "utf-8";
+ }
+
+ @Override
+ public String getMimeType() {
+ return "text/event-stream";
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
index 3287a61c81f..d62860afcda 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
@@ -24,6 +24,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
public static final ComponentId xmlRendererId = ComponentId.fromString("XmlRenderer");
public static final ComponentId pageRendererId = ComponentId.fromString("PageTemplatesXmlRenderer");
public static final ComponentId jsonRendererId = ComponentId.fromString("JsonRenderer");
+ public static final ComponentId eventRendererId = ComponentId.fromString("EventRenderer");
public static final ComponentId defaultRendererId = jsonRendererId;
@@ -56,6 +57,11 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
pageRenderer.initId(pageRendererId);
register(pageRenderer.getId(), pageRenderer);
+ // Add event renderer
+ Renderer eventRenderer = new EventRenderer(executor);
+ eventRenderer.initId(eventRendererId);
+ register(eventRenderer.getId(), eventRenderer);
+
// add application renderers
for (Renderer renderer : renderers)
register(renderer.getId(), renderer);
@@ -69,6 +75,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
getRenderer(jsonRendererId.toSpecification()).deconstruct();
getRenderer(xmlRendererId.toSpecification()).deconstruct();
getRenderer(pageRendererId.toSpecification()).deconstruct();
+ getRenderer(eventRendererId.toSpecification()).deconstruct();
}
/**
@@ -92,6 +99,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
if (format.stringValue().equals("json")) return getComponent(jsonRendererId);
if (format.stringValue().equals("xml")) return getComponent(xmlRendererId);
if (format.stringValue().equals("page")) return getComponent(pageRendererId);
+ if (format.stringValue().equals("sse")) return getComponent(eventRendererId);
com.yahoo.processing.rendering.Renderer<Result> renderer = getComponent(format);
if (renderer == null)
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
new file mode 100644
index 00000000000..84ef9a8ee86
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -0,0 +1,65 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.result;
+
+import com.yahoo.processing.response.DefaultIncomingData;
+
+/**
+ * A stream of events which can be rendered as Server-Sent Events (SSE).
+ *
+ * @author lesters
+ */
+public class EventStream extends HitGroup {
+
+ private int eventCount = 0;
+
+ public static final String DEFAULT_EVENT_TYPE = "token";
+
+ private EventStream(String id, DefaultIncomingData<Hit> incomingData) {
+ super(id, new Relevance(1), incomingData);
+ this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept
+ }
+
+ public static EventStream create(String id) {
+ DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>();
+ EventStream stream = new EventStream(id, incomingData);
+ incomingData.assignOwner(stream);
+ return stream;
+ }
+
+ public void add(String data) {
+ add(data, DEFAULT_EVENT_TYPE);
+ }
+
+ public void add(String data, String type) {
+ incoming().add(new Event(String.valueOf(eventCount + 1), data, type));
+ eventCount++;
+ }
+
+ public void error(String source, ErrorMessage message) {
+ incoming().add(new DefaultErrorHit(source, message));
+ }
+
+ public void markComplete() {
+ incoming().markComplete();
+ }
+
+ public static class Event extends Hit {
+
+ private final String type;
+
+ public Event(String id, String data, String type) {
+ super(id);
+ this.type = type;
+ setField(type, data);
+ }
+
+ public String toString() {
+ return getField(type).toString();
+ }
+
+ public String type() {
+ return type;
+ }
+
+ }
+}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
new file mode 100644
index 00000000000..9ebe6d048b3
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -0,0 +1,279 @@
+
+package com.yahoo.search.rendering;
+
+import com.yahoo.search.result.EventStream;
+import com.yahoo.concurrent.ThreadFactoryFactory;
+import com.yahoo.document.DocumentId;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.result.HitGroup;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import java.io.ByteArrayOutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class EventRendererTestCase {
+
+ private static ThreadPoolExecutor executor;
+ private static EventRenderer blueprint;
+ private EventRenderer renderer;
+
+ @BeforeAll
+ public static void createExecutorAndBlueprint() {
+ ThreadFactory threadFactory = ThreadFactoryFactory.getThreadFactory("test-rendering");
+ executor = new ThreadPoolExecutor(4, 4, 1L, TimeUnit.MINUTES, new LinkedBlockingQueue<>(), threadFactory);
+ executor.prestartAllCoreThreads();
+ blueprint = new EventRenderer(executor);
+ }
+
+ @BeforeEach
+ public void createClone() {
+ // Use the shared renderer as a prototype object, as specified in the API contract
+ renderer = (EventRenderer) blueprint.clone();
+ renderer.init();
+ }
+
+ @AfterEach
+ public void deconstructClone() {
+ if (renderer != null) {
+ renderer.deconstruct();
+ renderer = null;
+ }
+ }
+
+ @AfterAll
+ public static void deconstructBlueprintAndExecutor() throws InterruptedException {
+ blueprint.deconstruct();
+ blueprint = null;
+ executor.shutdown();
+ if (!executor.awaitTermination(1, TimeUnit.MINUTES)) {
+ throw new RuntimeException("Failed to shutdown executor");
+ }
+ executor = null;
+ }
+
+ @Test
+ @Timeout(5)
+ public void testRendering() throws InterruptedException, ExecutionException {
+ var expected = """
+ event: token
+ data: {"token":"Ducks"}
+
+ event: token
+ data: {"token":" have"}
+
+ event: token
+ data: {"token":" adorable"}
+
+ event: token
+ data: {"token":" waddling"}
+
+ event: token
+ data: {"token":" walks"}
+
+ event: end
+ """;
+ var tokenStream = EventStream.create("token_stream");
+ for (String token : splitter("Ducks have adorable waddling walks")) {
+ tokenStream.add(token);
+ }
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), tokenStream));
+ assertEquals(expected, result);
+ }
+
+ @Test
+ @Timeout(5)
+ public void testAsyncRendering() throws InterruptedException, ExecutionException {
+ var expected = """
+ event: token
+ data: {"token":"Ducks"}
+
+ event: token
+ data: {"token":" have"}
+
+ event: token
+ data: {"token":" adorable"}
+
+ event: token
+ data: {"token":" waddling"}
+
+ event: token
+ data: {"token":" walks"}
+
+ event: end
+ """;
+ var result = "";
+ var executor = Executors.newFixedThreadPool(1);
+ try {
+ var tokenStream = EventStream.create("token_stream");
+ var future = completeAsync("Ducks have adorable waddling walks", executor, token -> {
+ tokenStream.add(token);
+ }).exceptionally(e -> {
+ tokenStream.error("error", new ErrorMessage(400, e.getMessage()));
+ tokenStream.markComplete();
+ return false;
+ }).thenAccept(finishReason -> {
+ tokenStream.markComplete();
+ });
+ assertFalse(future.isDone());
+ result = render(new Result(new Query(), tokenStream));
+ assertTrue(future.isDone()); // Renderer waits for async completion
+
+ } finally {
+ executor.shutdownNow();
+ }
+ assertEquals(expected, result);
+ }
+
+ @Test
+ public void testErrorEndsStream() throws ExecutionException, InterruptedException {
+ var tokenStream = EventStream.create("token_stream");
+ tokenStream.add("token1");
+ tokenStream.add("token2");
+ tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong"));
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), tokenStream));
+ var expected = """
+ event: token
+ data: {"token":"token1"}
+
+ event: token
+ data: {"token":"token2"}
+
+ event: error
+ data: {"source":"my_llm","error":400,"message":"Something went wrong"}
+
+ event: end
+ """;
+ assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPromptRendering() throws ExecutionException, InterruptedException {
+ String prompt = "Why are ducks better than cats?\n\nBe concise.\n";
+
+ var tokenStream = EventStream.create("token_stream");
+ tokenStream.add(prompt, "prompt");
+ tokenStream.add("Just");
+ tokenStream.add(" because");
+ tokenStream.add(".");
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), tokenStream));
+
+ var expected = """
+ event: prompt
+ data: {"prompt":"Why are ducks better than cats?\\n\\nBe concise.\\n"}
+
+ event: token
+ data: {"token":"Just"}
+
+ event: token
+ data: {"token":" because"}
+
+ event: token
+ data: {"token":"."}
+
+ event: end
+ """;
+ assertEquals(expected, result);
+ }
+
+ @Test
+ @Timeout(5)
+ public void testResultRenderingFails() {
+ var tokenStream = EventStream.create("token_stream");
+ tokenStream.add("token1");
+ tokenStream.add("token2");
+ tokenStream.markComplete();
+
+ var resultsHitGroup = new HitGroup("test_results");
+ var hit1 = new Hit("result_1");
+ var hit2 = new Hit("result_2");
+ hit1.setField("documentid", new DocumentId("id:unittest:test::1"));
+ hit2.setField("documentid", new DocumentId("id:unittest:test::2"));
+ resultsHitGroup.add(hit1);
+ resultsHitGroup.add(hit2);
+
+ var combined = new HitGroup("all");
+ combined.add(resultsHitGroup);
+ combined.add(tokenStream);
+
+ var result = new Result(new Query(), combined);
+ assertThrows(Exception.class, () -> render(result)); // Todo: support this
+ }
+
+ private String render(Result r) throws InterruptedException, ExecutionException {
+ var execution = new Execution(Execution.Context.createContextStub());
+ return render(execution, r);
+ }
+
+ private String render(Execution execution, Result r) throws InterruptedException, ExecutionException {
+ if (renderer == null) createClone();
+ try {
+ ByteArrayOutputStream bs = new ByteArrayOutputStream(); //new DebugOutputStream();
+ CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
+ assertTrue(f.get());
+ return Utf8.toString(bs.toByteArray());
+ } finally {
+ deconstructClone();
+ }
+ }
+
+ private static class DebugOutputStream extends ByteArrayOutputStream {
+ @Override
+ public synchronized void write(byte[] b, int off, int len) {
+ super.write(b, off, len);
+ System.out.print(new String(b, off, len));
+ }
+ }
+
+ private static List<String> splitter(String text) {
+ var list = new ArrayList<String>();
+ for (String token : text.split(" ")) {
+ list.add(list.isEmpty() ? token : " " + token);
+ }
+ return list;
+ }
+
+ private static CompletableFuture<Boolean> completeAsync(String text, ExecutorService executor, Consumer<String> consumer) {
+ var completionFuture = new CompletableFuture<Boolean>();
+ executor.submit(() -> {
+ try {
+ for (String s : splitter(text)) {
+ consumer.accept(s);
+ Thread.sleep(10);
+ }
+ completionFuture.complete(true);
+ } catch (Exception e) {
+ completionFuture.completeExceptionally(e);
+ }
+ });
+ return completionFuture;
+ }
+
+}