diff options
author | Lester Solbakken <lester.solbakken@gmail.com> | 2024-03-08 13:19:33 +0100 |
---|---|---|
committer | Lester Solbakken <lester.solbakken@gmail.com> | 2024-03-08 13:19:33 +0100 |
commit | 25efeef096cb30090c6d1ed0bd804f6a0745adc0 (patch) | |
tree | a6d7023f4e5a3df62e6d69d6551be0052ece08af /container-search/src/main | |
parent | 3ede5019a6fe0881917b165166f413c532fe4bc0 (diff) |
Add server-sent events (SSE) renderer
Diffstat (limited to 'container-search/src/main')
3 files changed, 192 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java new file mode 100644 index 00000000000..46f9a53e698 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -0,0 +1,119 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.rendering; + +import com.yahoo.search.result.EventStream; +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonFactoryBuilder; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.StreamReadConstraints; +import com.fasterxml.jackson.core.io.SerializedString; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.processing.rendering.AsynchronousSectionedRenderer; +import com.yahoo.processing.response.Data; +import com.yahoo.processing.response.DataList; +import com.yahoo.search.Result; +import com.yahoo.search.result.ErrorHit; +import com.yahoo.search.result.ErrorMessage; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.concurrent.Executor; + +import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE; + +/** + * A Server-Sent Events (SSE) renderer for asynchronous events such as + * tokens from a language model. + * + * @author lesters + */ +public class EventRenderer extends AsynchronousSectionedRenderer<Result> { + + private static final JsonFactory generatorFactory = createGeneratorFactory(); + private volatile JsonGenerator generator; + + private static JsonFactory createGeneratorFactory() { + var factory = new JsonFactoryBuilder() + .streamReadConstraints(StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()) + .build(); + factory.setCodec(new ObjectMapper(factory).disable(FLUSH_AFTER_WRITE_VALUE)); + return factory; + } + + private static final boolean RENDER_EVENT_HEADER = true; + private static final boolean RENDER_END_EVENT = true; + + public EventRenderer() { + this(null); + } + + public EventRenderer(Executor executor) { + super(executor); + } + + @Override + public void beginResponse(OutputStream outputStream) throws IOException { + generator = generatorFactory.createGenerator(outputStream, JsonEncoding.UTF8); + generator.setRootValueSeparator(new SerializedString("")); + } + + @Override + public void beginList(DataList<?> dataList) throws IOException { + if ( ! (dataList instanceof EventStream)) { + throw new IllegalArgumentException("EventRenderer currently only supports EventStreams"); + // Todo: support results and timing and trace by delegating to JsonRenderer + } + } + + @Override + public void data(Data data) throws IOException { + if (data instanceof EventStream.Event event) { + if (RENDER_EVENT_HEADER) { + generator.writeRaw("event: " + event.type() + "\n"); + } + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField(event.type(), event.toString()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + else if (data instanceof ErrorHit) { + for (ErrorMessage error : ((ErrorHit) data).errors()) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.getSource()); + generator.writeNumberField("error", error.getCode()); + generator.writeStringField("message", error.getMessage()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + } + } + + @Override + public void endList(DataList<?> dataList) throws IOException { + } + + @Override + public void endResponse() throws IOException { + if (RENDER_END_EVENT) { + generator.writeRaw("event: end\n"); + } + generator.close(); + } + + @Override + public String getEncoding() { + return "utf-8"; + } + + @Override + public String getMimeType() { + return "text/event-stream"; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java index 3287a61c81f..d62860afcda 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java @@ -24,6 +24,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi public static final ComponentId xmlRendererId = ComponentId.fromString("XmlRenderer"); public static final ComponentId pageRendererId = ComponentId.fromString("PageTemplatesXmlRenderer"); public static final ComponentId jsonRendererId = ComponentId.fromString("JsonRenderer"); + public static final ComponentId eventRendererId = ComponentId.fromString("EventRenderer"); public static final ComponentId defaultRendererId = jsonRendererId; @@ -56,6 +57,11 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi pageRenderer.initId(pageRendererId); register(pageRenderer.getId(), pageRenderer); + // Add event renderer + Renderer eventRenderer = new EventRenderer(executor); + eventRenderer.initId(eventRendererId); + register(eventRenderer.getId(), eventRenderer); + // add application renderers for (Renderer renderer : renderers) register(renderer.getId(), renderer); @@ -69,6 +75,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi getRenderer(jsonRendererId.toSpecification()).deconstruct(); getRenderer(xmlRendererId.toSpecification()).deconstruct(); getRenderer(pageRendererId.toSpecification()).deconstruct(); + getRenderer(eventRendererId.toSpecification()).deconstruct(); } /** @@ -92,6 +99,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi if (format.stringValue().equals("json")) return getComponent(jsonRendererId); if (format.stringValue().equals("xml")) return getComponent(xmlRendererId); if (format.stringValue().equals("page")) return getComponent(pageRendererId); + if (format.stringValue().equals("sse")) return getComponent(eventRendererId); com.yahoo.processing.rendering.Renderer<Result> renderer = getComponent(format); if (renderer == null) diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java new file mode 100644 index 00000000000..84ef9a8ee86 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -0,0 +1,65 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.result; + +import com.yahoo.processing.response.DefaultIncomingData; + +/** + * A stream of events which can be rendered as Server-Sent Events (SSE). + * + * @author lesters + */ +public class EventStream extends HitGroup { + + private int eventCount = 0; + + public static final String DEFAULT_EVENT_TYPE = "token"; + + private EventStream(String id, DefaultIncomingData<Hit> incomingData) { + super(id, new Relevance(1), incomingData); + this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept + } + + public static EventStream create(String id) { + DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>(); + EventStream stream = new EventStream(id, incomingData); + incomingData.assignOwner(stream); + return stream; + } + + public void add(String data) { + add(data, DEFAULT_EVENT_TYPE); + } + + public void add(String data, String type) { + incoming().add(new Event(String.valueOf(eventCount + 1), data, type)); + eventCount++; + } + + public void error(String source, ErrorMessage message) { + incoming().add(new DefaultErrorHit(source, message)); + } + + public void markComplete() { + incoming().markComplete(); + } + + public static class Event extends Hit { + + private final String type; + + public Event(String id, String data, String type) { + super(id); + this.type = type; + setField(type, data); + } + + public String toString() { + return getField(type).toString(); + } + + public String type() { + return type; + } + + } +} |