aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2024-03-21 14:22:03 +0100
committerGitHub <noreply@github.com>2024-03-21 14:22:03 +0100
commit138da140cfde273599fb37c8f65e28f1a5c6957a (patch)
tree635ced173447b6fec679c88b947958e577db95cf /container-search
parent723d6cacbdce4c45e01c92cb3e2eeb71f7b513f2 (diff)
parentd0333079f0cc7c13185b2bf4f015a304c72af2f1 (diff)
Merge pull request #30526 from vespa-engine/lesters/server-sent-events
Add server-sent events (SSE) renderer
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json64
-rw-r--r--container-search/src/main/java/com/yahoo/search/Result.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java116
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java24
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java8
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java120
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java293
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java40
8 files changed, 657 insertions, 10 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 18d1345cb06..bdb6cd9e7a5 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -7460,6 +7460,25 @@
],
"fields" : [ ]
},
+ "com.yahoo.search.rendering.EventRenderer" : {
+ "superClass" : "com.yahoo.processing.rendering.AsynchronousSectionedRenderer",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(java.util.concurrent.Executor)",
+ "public void beginResponse(java.io.OutputStream)",
+ "public void beginList(com.yahoo.processing.response.DataList)",
+ "public void data(com.yahoo.processing.response.Data)",
+ "public void endList(com.yahoo.processing.response.DataList)",
+ "public void endResponse()",
+ "public java.lang.String getEncoding()",
+ "public java.lang.String getMimeType()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.search.rendering.JsonRenderer$FieldConsumer" : {
"superClass" : "java.lang.Object",
"interfaces" : [
@@ -7554,6 +7573,7 @@
"public static final com.yahoo.component.ComponentId xmlRendererId",
"public static final com.yahoo.component.ComponentId pageRendererId",
"public static final com.yahoo.component.ComponentId jsonRendererId",
+ "public static final com.yahoo.component.ComponentId eventRendererId",
"public static final com.yahoo.component.ComponentId defaultRendererId"
]
},
@@ -7822,6 +7842,50 @@
"public static final int emptyDocsumsCode"
]
},
+ "com.yahoo.search.result.EventStream$Event" : {
+ "superClass" : "com.yahoo.component.provider.ListenableFreezableClass",
+ "interfaces" : [
+ "com.yahoo.processing.response.Data"
+ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(int, java.lang.String, java.lang.String)",
+ "public java.lang.String toString()",
+ "public java.lang.String type()",
+ "public com.yahoo.processing.Request request()",
+ "public com.yahoo.search.result.Hit asHit()"
+ ],
+ "fields" : [ ]
+ },
+ "com.yahoo.search.result.EventStream" : {
+ "superClass" : "com.yahoo.search.result.Hit",
+ "interfaces" : [
+ "com.yahoo.processing.response.DataList"
+ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void add(java.lang.String)",
+ "public void add(java.lang.String, java.lang.String)",
+ "public void error(java.lang.String, com.yahoo.search.result.ErrorMessage)",
+ "public void markComplete()",
+ "public com.yahoo.processing.response.Data add(com.yahoo.processing.response.Data)",
+ "public com.yahoo.processing.response.Data get(int)",
+ "public java.util.List asList()",
+ "public com.yahoo.processing.response.IncomingData incoming()",
+ "public java.util.concurrent.CompletableFuture completeFuture()",
+ "public void addDataListener(java.lang.Runnable)",
+ "public void close()"
+ ],
+ "fields" : [
+ "public static final java.lang.String EVENT_TYPE_TOKEN",
+ "public static final java.lang.String DEFAULT_EVENT_TYPE"
+ ]
+ },
"com.yahoo.search.result.FeatureData" : {
"superClass" : "java.lang.Object",
"interfaces" : [
diff --git a/container-search/src/main/java/com/yahoo/search/Result.java b/container-search/src/main/java/com/yahoo/search/Result.java
index b1a0107c6d8..a989688575d 100644
--- a/container-search/src/main/java/com/yahoo/search/Result.java
+++ b/container-search/src/main/java/com/yahoo/search/Result.java
@@ -20,7 +20,7 @@ import java.util.Iterator;
* a single HitGroup containing hits of the result. The HitGroup may contain Hits, which are the individual
* result items, as well as further HitGroups, making up a <i>composite</i> structure. This allows the hits of a result
* to be hierarchically organized. A Hit is polymorphic and may contain any kind of information deemed
- * an approriate partial answer to the Query.
+ * an appropriate partial answer to the Query.
* <p>
* Do not cache this as it holds references to objects that should be garbage collected.
*
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
new file mode 100644
index 00000000000..83ae349f5a0
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -0,0 +1,116 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.rendering;
+
+import com.yahoo.search.result.EventStream;
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonFactoryBuilder;
+import com.fasterxml.jackson.core.JsonGenerator;
+import com.fasterxml.jackson.core.StreamReadConstraints;
+import com.fasterxml.jackson.core.io.SerializedString;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.processing.rendering.AsynchronousSectionedRenderer;
+import com.yahoo.processing.response.Data;
+import com.yahoo.processing.response.DataList;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.ErrorHit;
+import com.yahoo.search.result.ErrorMessage;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.concurrent.Executor;
+
+import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE;
+
+/**
+ * A Server-Sent Events (SSE) renderer for asynchronous events such as
+ * tokens from a language model.
+ *
+ * @author lesters
+ */
+public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
+
+ private static final JsonFactory generatorFactory = createGeneratorFactory();
+ private volatile JsonGenerator generator;
+
+ private static JsonFactory createGeneratorFactory() {
+ var factory = new JsonFactoryBuilder()
+ .streamReadConstraints(StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build())
+ .build();
+ factory.setCodec(new ObjectMapper(factory).disable(FLUSH_AFTER_WRITE_VALUE));
+ return factory;
+ }
+
+ private static final boolean RENDER_EVENT_HEADER = true;
+ private static final boolean RENDER_END_EVENT = true;
+
+ public EventRenderer() {
+ this(null);
+ }
+
+ public EventRenderer(Executor executor) {
+ super(executor);
+ }
+
+ @Override
+ public void beginResponse(OutputStream outputStream) throws IOException {
+ generator = generatorFactory.createGenerator(outputStream, JsonEncoding.UTF8);
+ generator.setRootValueSeparator(new SerializedString(""));
+ }
+
+ @Override
+ public void beginList(DataList<?> dataList) throws IOException {
+ }
+
+ @Override
+ public void data(Data data) throws IOException {
+ if (data instanceof EventStream.Event event) {
+ if (RENDER_EVENT_HEADER) {
+ generator.writeRaw("event: " + event.type() + "\n");
+ }
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField(event.type(), event.toString());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ }
+ else if (data instanceof ErrorHit) {
+ for (ErrorMessage error : ((ErrorHit) data).errors()) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.getSource());
+ generator.writeNumberField("error", error.getCode());
+ generator.writeStringField("message", error.getMessage());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ }
+ }
+ // Todo: support other types of data such as search results (hits), timing and trace
+ }
+
+ @Override
+ public void endList(DataList<?> dataList) throws IOException {
+ }
+
+ @Override
+ public void endResponse() throws IOException {
+ if (RENDER_END_EVENT) {
+ generator.writeRaw("event: end\n");
+ }
+ generator.close();
+ }
+
+ @Override
+ public String getEncoding() {
+ return "utf-8";
+ }
+
+ @Override
+ public String getMimeType() {
+ return "text/event-stream";
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
index e876f8e06d0..69410070453 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
@@ -40,6 +40,7 @@ import com.yahoo.search.result.Coverage;
import com.yahoo.search.result.DefaultErrorHit;
import com.yahoo.search.result.ErrorHit;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
@@ -243,17 +244,19 @@ public class JsonRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void beginList(DataList<?> list) throws IOException {
- Preconditions.checkArgument(list instanceof HitGroup,
- "Expected subclass of com.yahoo.search.result.HitGroup, got %s.",
- list.getClass());
moreChildren();
- renderHitGroupHead((HitGroup) list);
+ if (list instanceof HitGroup) {
+ renderHitGroupHead((HitGroup) list);
+ } else if (list instanceof EventStream) {
+ renderHitGroupHead(new HitGroup("event_stream")); // Consider waiting for all events and create a single summary hit
+ } else {
+ throw new IllegalArgumentException("Expected subclass of com.yahoo.search.result.HitGroup, got " + list.getClass());
+ }
}
protected void moreChildren() throws IOException {
if (!renderedChildren.isEmpty())
childrenArray();
-
renderedChildren.push(0);
}
@@ -443,10 +446,13 @@ public class JsonRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- Preconditions.checkArgument(data instanceof Hit,
- "Expected subclass of com.yahoo.search.result.Hit, got %s.",
- data.getClass());
- renderHit((Hit) data);
+ if (data instanceof Hit) {
+ renderHit((Hit) data);
+ } else if (data instanceof EventStream.Event) {
+ renderHit(((EventStream.Event) data).asHit());
+ } else {
+ throw new IllegalArgumentException("Expected subclass of com.yahoo.search.result.Hit, got " + data.getClass());
+ }
}
@Override
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
index 3287a61c81f..d62860afcda 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java
@@ -24,6 +24,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
public static final ComponentId xmlRendererId = ComponentId.fromString("XmlRenderer");
public static final ComponentId pageRendererId = ComponentId.fromString("PageTemplatesXmlRenderer");
public static final ComponentId jsonRendererId = ComponentId.fromString("JsonRenderer");
+ public static final ComponentId eventRendererId = ComponentId.fromString("EventRenderer");
public static final ComponentId defaultRendererId = jsonRendererId;
@@ -56,6 +57,11 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
pageRenderer.initId(pageRendererId);
register(pageRenderer.getId(), pageRenderer);
+ // Add event renderer
+ Renderer eventRenderer = new EventRenderer(executor);
+ eventRenderer.initId(eventRendererId);
+ register(eventRenderer.getId(), eventRenderer);
+
// add application renderers
for (Renderer renderer : renderers)
register(renderer.getId(), renderer);
@@ -69,6 +75,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
getRenderer(jsonRendererId.toSpecification()).deconstruct();
getRenderer(xmlRendererId.toSpecification()).deconstruct();
getRenderer(pageRendererId.toSpecification()).deconstruct();
+ getRenderer(eventRendererId.toSpecification()).deconstruct();
}
/**
@@ -92,6 +99,7 @@ public final class RendererRegistry extends ComponentRegistry<com.yahoo.processi
if (format.stringValue().equals("json")) return getComponent(jsonRendererId);
if (format.stringValue().equals("xml")) return getComponent(xmlRendererId);
if (format.stringValue().equals("page")) return getComponent(pageRendererId);
+ if (format.stringValue().equals("sse")) return getComponent(eventRendererId);
com.yahoo.processing.rendering.Renderer<Result> renderer = getComponent(format);
if (renderer == null)
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
new file mode 100644
index 00000000000..b393a91e6d0
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -0,0 +1,120 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.result;
+
+import com.yahoo.collections.ListenableArrayList;
+import com.yahoo.component.provider.ListenableFreezableClass;
+import com.yahoo.processing.Request;
+import com.yahoo.processing.response.Data;
+import com.yahoo.processing.response.DataList;
+import com.yahoo.processing.response.DefaultIncomingData;
+import com.yahoo.processing.response.IncomingData;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * A stream of events which can be rendered as Server-Sent Events (SSE).
+ *
+ * @author lesters
+ */
+public class EventStream extends Hit implements DataList<Data> {
+
+ private final ListenableArrayList<Data> data = new ListenableArrayList<>(16);
+ private final IncomingData<Data> incomingData;
+ private final AtomicInteger eventCount = new AtomicInteger(0);
+
+ public final static String EVENT_TYPE_TOKEN = "token";
+ public final static String DEFAULT_EVENT_TYPE = EVENT_TYPE_TOKEN;
+
+ public EventStream() {
+ super();
+ this.incomingData = new DefaultIncomingData<>(this);
+ }
+
+ public void add(String data) {
+ incoming().add(new Event(eventCount.incrementAndGet(), data, DEFAULT_EVENT_TYPE));
+ }
+
+ public void add(String data, String type) {
+ incoming().add(new Event(eventCount.incrementAndGet(), data, type));
+ }
+
+ public void error(String source, ErrorMessage message) {
+ incoming().add(new DefaultErrorHit(source, message));
+ }
+
+ public void markComplete() {
+ incoming().markComplete();
+ }
+
+ @Override
+ public Data add(Data event) {
+ data.add(event);
+ return event;
+ }
+
+ @Override
+ public Data get(int index) {
+ return data.get(index);
+ }
+
+ @Override
+ public List<Data> asList() {
+ return data;
+ }
+
+ @Override
+ public IncomingData<Data> incoming() {
+ return incomingData;
+ }
+
+ @Override
+ public CompletableFuture<DataList<Data>> completeFuture() {
+ return incomingData.completedFuture();
+ }
+
+ @Override
+ public void addDataListener(Runnable runnable) {
+ data.addListener(runnable);
+ }
+
+ @Override
+ public void close() {
+ }
+
+ public static class Event extends ListenableFreezableClass implements Data {
+
+ private final int eventNumber;
+ private final String data;
+ private final String type;
+
+ public Event(int eventNumber, String data, String type) {
+ this.eventNumber = eventNumber;
+ this.data = data;
+ this.type = type;
+ }
+
+ public String toString() {
+ return data;
+ }
+
+ public String type() {
+ return type;
+ }
+
+ @Override
+ public Request request() {
+ return null;
+ }
+
+ // For json rendering
+ public Hit asHit() {
+ Hit hit = new Hit(String.valueOf(eventNumber));
+ hit.setField(type, data);
+ return hit;
+ }
+
+ }
+
+}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
new file mode 100644
index 00000000000..c0a677b2094
--- /dev/null
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -0,0 +1,293 @@
+
+package com.yahoo.search.rendering;
+
+import com.yahoo.search.result.EventStream;
+import com.yahoo.concurrent.ThreadFactoryFactory;
+import com.yahoo.document.DocumentId;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.result.HitGroup;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.Timeout;
+
+import java.io.ByteArrayOutputStream;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class EventRendererTestCase {
+
+ private static ThreadPoolExecutor executor;
+ private static EventRenderer blueprint;
+ private EventRenderer renderer;
+
+ @BeforeAll
+ public static void createExecutorAndBlueprint() {
+ ThreadFactory threadFactory = ThreadFactoryFactory.getThreadFactory("test-rendering");
+ executor = new ThreadPoolExecutor(4, 4, 1L, TimeUnit.MINUTES, new LinkedBlockingQueue<>(), threadFactory);
+ executor.prestartAllCoreThreads();
+ blueprint = new EventRenderer(executor);
+ }
+
+ @BeforeEach
+ public void createClone() {
+ // Use the shared renderer as a prototype object, as specified in the API contract
+ renderer = (EventRenderer) blueprint.clone();
+ renderer.init();
+ }
+
+ @AfterEach
+ public void deconstructClone() {
+ if (renderer != null) {
+ renderer.deconstruct();
+ renderer = null;
+ }
+ }
+
+ @AfterAll
+ public static void deconstructBlueprintAndExecutor() throws InterruptedException {
+ blueprint.deconstruct();
+ blueprint = null;
+ executor.shutdown();
+ if (!executor.awaitTermination(1, TimeUnit.MINUTES)) {
+ throw new RuntimeException("Failed to shutdown executor");
+ }
+ executor = null;
+ }
+
+ @Test
+ @Timeout(5)
+ public void testRendering() throws InterruptedException, ExecutionException {
+ var expected = """
+ event: token
+ data: {"token":"Ducks"}
+
+ event: token
+ data: {"token":" have"}
+
+ event: token
+ data: {"token":" adorable"}
+
+ event: token
+ data: {"token":" waddling"}
+
+ event: token
+ data: {"token":" walks"}
+
+ event: end
+ """;
+ var tokenStream = new EventStream();
+ for (String token : splitter("Ducks have adorable waddling walks")) {
+ tokenStream.add(token);
+ }
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
+ assertEquals(expected, result);
+ }
+
+ @Test
+ @Timeout(5)
+ public void testAsyncRendering() throws InterruptedException, ExecutionException {
+ var expected = """
+ event: token
+ data: {"token":"Ducks"}
+
+ event: token
+ data: {"token":" have"}
+
+ event: token
+ data: {"token":" adorable"}
+
+ event: token
+ data: {"token":" waddling"}
+
+ event: token
+ data: {"token":" walks"}
+
+ event: end
+ """;
+ var result = "";
+ var executor = Executors.newFixedThreadPool(1);
+ try {
+ var tokenStream = new EventStream();
+ var future = completeAsync("Ducks have adorable waddling walks", executor, token -> {
+ tokenStream.add(token);
+ }).exceptionally(e -> {
+ tokenStream.error("error", new ErrorMessage(400, e.getMessage()));
+ tokenStream.markComplete();
+ return false;
+ }).thenAccept(finishReason -> {
+ tokenStream.markComplete();
+ });
+ assertFalse(future.isDone());
+ result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
+ assertTrue(future.isDone()); // Renderer waits for async completion
+
+ } finally {
+ executor.shutdownNow();
+ }
+ assertEquals(expected, result);
+ }
+
+ @Test
+ public void testErrorEndsStream() throws ExecutionException, InterruptedException {
+ var tokenStream = new EventStream();
+ tokenStream.add("token1");
+ tokenStream.add("token2");
+ tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong"));
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
+ var expected = """
+ event: token
+ data: {"token":"token1"}
+
+ event: token
+ data: {"token":"token2"}
+
+ event: error
+ data: {"source":"my_llm","error":400,"message":"Something went wrong"}
+
+ event: end
+ """;
+ assertEquals(expected, result);
+ }
+
+ @Test
+ public void testPromptRendering() throws ExecutionException, InterruptedException {
+ String prompt = "Why are ducks better than cats?\n\nBe concise.\n";
+
+ var tokenStream = new EventStream();
+ tokenStream.add(prompt, "prompt");
+ tokenStream.add("Just");
+ tokenStream.add(" because");
+ tokenStream.add(".");
+ tokenStream.markComplete();
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
+
+ var expected = """
+ event: prompt
+ data: {"prompt":"Why are ducks better than cats?\\n\\nBe concise.\\n"}
+
+ event: token
+ data: {"token":"Just"}
+
+ event: token
+ data: {"token":" because"}
+
+ event: token
+ data: {"token":"."}
+
+ event: end
+ """;
+ assertEquals(expected, result);
+ }
+
+ @Test
+ @Timeout(5)
+ public void testResultRenderingIsSkipped() throws ExecutionException, InterruptedException {
+ var tokenStream = new EventStream();
+ tokenStream.add("token1");
+ tokenStream.add("token2");
+ tokenStream.markComplete();
+
+ var resultsHitGroup = new HitGroup("test_results");
+ var hit1 = new Hit("result_1");
+ var hit2 = new Hit("result_2");
+ hit1.setField("documentid", new DocumentId("id:unittest:test::1"));
+ hit2.setField("documentid", new DocumentId("id:unittest:test::2"));
+ resultsHitGroup.add(hit1);
+ resultsHitGroup.add(hit2);
+
+ var combined = new HitGroup("all");
+ combined.add(resultsHitGroup);
+ combined.add(newHitGroup(tokenStream, "token_stream"));
+
+ var result = render(new Result(new Query(), combined));
+ var expected = """
+ event: token
+ data: {"token":"token1"}
+
+ event: token
+ data: {"token":"token2"}
+
+ event: end
+ """;
+ assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
+ }
+
+ static HitGroup newHitGroup(EventStream eventStream, String id) {
+ var hitGroup = new HitGroup(id);
+ hitGroup.add(eventStream);
+ return hitGroup;
+ }
+
+ private String render(Result r) throws InterruptedException, ExecutionException {
+ var execution = new Execution(Execution.Context.createContextStub());
+ return render(execution, r);
+ }
+
+ private String render(Execution execution, Result r) throws InterruptedException, ExecutionException {
+ if (renderer == null) createClone();
+ try {
+ ByteArrayOutputStream bs = new ByteArrayOutputStream(); //new DebugOutputStream();
+ CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
+ assertTrue(f.get());
+ return Utf8.toString(bs.toByteArray());
+ } finally {
+ deconstructClone();
+ }
+ }
+
+ private static class DebugOutputStream extends ByteArrayOutputStream {
+ @Override
+ public synchronized void write(byte[] b, int off, int len) {
+ super.write(b, off, len);
+ System.out.print(new String(b, off, len));
+ }
+ }
+
+ private static List<String> splitter(String text) {
+ var list = new ArrayList<String>();
+ for (String token : text.split(" ")) {
+ list.add(list.isEmpty() ? token : " " + token);
+ }
+ return list;
+ }
+
+ private static CompletableFuture<Boolean> completeAsync(String text, ExecutorService executor, Consumer<String> consumer) {
+ var completionFuture = new CompletableFuture<Boolean>();
+ executor.submit(() -> {
+ try {
+ for (String s : splitter(text)) {
+ consumer.accept(s);
+ Thread.sleep(10);
+ }
+ completionFuture.complete(true);
+ } catch (Exception e) {
+ completionFuture.completeExceptionally(e);
+ }
+ });
+ return completionFuture;
+ }
+
+}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
index 3a8584dd0a5..ffa6c82e941 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
@@ -36,6 +36,7 @@ import com.yahoo.search.grouping.result.RootGroup;
import com.yahoo.search.grouping.result.StringId;
import com.yahoo.search.result.Coverage;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
@@ -1566,6 +1567,45 @@ public class JsonRendererTestCase {
assertEqualJson(expected, summary);
}
+ @Test
+ @Timeout(600)
+ void testEventStreamRendering() throws ExecutionException, InterruptedException {
+ var tokenStream = new EventStream();
+ tokenStream.add("token1");
+ tokenStream.add("token2");
+ tokenStream.markComplete();
+
+ var hitGroup = new HitGroup("token_stream");
+ hitGroup.add(tokenStream);
+ var result = render(new Result(new Query(), hitGroup));
+
+ String expected = "{" +
+ "\"root\":{" +
+ "\"id\":\"token_stream\"," +
+ "\"relevance\":1.0," +
+ "\"fields\":{" +
+ "\"totalCount\":0" +
+ "}," +
+ "\"children\":[{" +
+ "\"id\":\"event_stream\"," +
+ "\"relevance\":1.0," +
+ "\"children\":[{" +
+ "\"id\":\"1\"," +
+ "\"relevance\":1.0," +
+ "\"fields\":{" +
+ "\"token\":\"token1\"" +
+ "}},{" +
+ "\"id\":\"2\"," +
+ "\"relevance\":1.0," +
+ "\"fields\":{" +
+ "\"token\":\"token2\"" +
+ "}}" +
+ "]}]" +
+ "}" +
+ "}";
+ assertEqualJson(expected, result);
+ }
+
private Result newEmptyResult(String[] args) {
return new Result(new Query("/?" + String.join("&", args)));
}