From 25efeef096cb30090c6d1ed0bd804f6a0745adc0 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 8 Mar 2024 13:19:33 +0100 Subject: Add server-sent events (SSE) renderer --- .../com/yahoo/search/rendering/EventRenderer.java | 119 +++++++++ .../yahoo/search/rendering/RendererRegistry.java | 8 + .../java/com/yahoo/search/result/EventStream.java | 65 +++++ .../search/rendering/EventRendererTestCase.java | 279 +++++++++++++++++++++ 4 files changed, 471 insertions(+) create mode 100644 container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java create mode 100644 container-search/src/main/java/com/yahoo/search/result/EventStream.java create mode 100644 container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java (limited to 'container-search/src') diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java new file mode 100644 index 00000000000..46f9a53e698 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -0,0 +1,119 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.rendering; + +import com.yahoo.search.result.EventStream; +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonFactoryBuilder; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.StreamReadConstraints; +import com.fasterxml.jackson.core.io.SerializedString; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.processing.rendering.AsynchronousSectionedRenderer; +import com.yahoo.processing.response.Data; +import com.yahoo.processing.response.DataList; +import com.yahoo.search.Result; +import com.yahoo.search.result.ErrorHit; +import com.yahoo.search.result.ErrorMessage; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.concurrent.Executor; + +import static com.fasterxml.jackson.databind.SerializationFeature.FLUSH_AFTER_WRITE_VALUE; + +/** + * A Server-Sent Events (SSE) renderer for asynchronous events such as + * tokens from a language model. + * + * @author lesters + */ +public class EventRenderer extends AsynchronousSectionedRenderer { + + private static final JsonFactory generatorFactory = createGeneratorFactory(); + private volatile JsonGenerator generator; + + private static JsonFactory createGeneratorFactory() { + var factory = new JsonFactoryBuilder() + .streamReadConstraints(StreamReadConstraints.builder().maxStringLength(Integer.MAX_VALUE).build()) + .build(); + factory.setCodec(new ObjectMapper(factory).disable(FLUSH_AFTER_WRITE_VALUE)); + return factory; + } + + private static final boolean RENDER_EVENT_HEADER = true; + private static final boolean RENDER_END_EVENT = true; + + public EventRenderer() { + this(null); + } + + public EventRenderer(Executor executor) { + super(executor); + } + + @Override + public void beginResponse(OutputStream outputStream) throws IOException { + generator = generatorFactory.createGenerator(outputStream, JsonEncoding.UTF8); + generator.setRootValueSeparator(new SerializedString("")); + } + + @Override + public void beginList(DataList dataList) throws IOException { + if ( ! (dataList instanceof EventStream)) { + throw new IllegalArgumentException("EventRenderer currently only supports EventStreams"); + // Todo: support results and timing and trace by delegating to JsonRenderer + } + } + + @Override + public void data(Data data) throws IOException { + if (data instanceof EventStream.Event event) { + if (RENDER_EVENT_HEADER) { + generator.writeRaw("event: " + event.type() + "\n"); + } + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField(event.type(), event.toString()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + else if (data instanceof ErrorHit) { + for (ErrorMessage error : ((ErrorHit) data).errors()) { + generator.writeRaw("event: error\n"); + generator.writeRaw("data: "); + generator.writeStartObject(); + generator.writeStringField("source", error.getSource()); + generator.writeNumberField("error", error.getCode()); + generator.writeStringField("message", error.getMessage()); + generator.writeEndObject(); + generator.writeRaw("\n\n"); + generator.flush(); + } + } + } + + @Override + public void endList(DataList dataList) throws IOException { + } + + @Override + public void endResponse() throws IOException { + if (RENDER_END_EVENT) { + generator.writeRaw("event: end\n"); + } + generator.close(); + } + + @Override + public String getEncoding() { + return "utf-8"; + } + + @Override + public String getMimeType() { + return "text/event-stream"; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java index 3287a61c81f..d62860afcda 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/RendererRegistry.java @@ -24,6 +24,7 @@ public final class RendererRegistry extends ComponentRegistry renderer = getComponent(format); if (renderer == null) diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java new file mode 100644 index 00000000000..84ef9a8ee86 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -0,0 +1,65 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.result; + +import com.yahoo.processing.response.DefaultIncomingData; + +/** + * A stream of events which can be rendered as Server-Sent Events (SSE). + * + * @author lesters + */ +public class EventStream extends HitGroup { + + private int eventCount = 0; + + public static final String DEFAULT_EVENT_TYPE = "token"; + + private EventStream(String id, DefaultIncomingData incomingData) { + super(id, new Relevance(1), incomingData); + this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept + } + + public static EventStream create(String id) { + DefaultIncomingData incomingData = new DefaultIncomingData<>(); + EventStream stream = new EventStream(id, incomingData); + incomingData.assignOwner(stream); + return stream; + } + + public void add(String data) { + add(data, DEFAULT_EVENT_TYPE); + } + + public void add(String data, String type) { + incoming().add(new Event(String.valueOf(eventCount + 1), data, type)); + eventCount++; + } + + public void error(String source, ErrorMessage message) { + incoming().add(new DefaultErrorHit(source, message)); + } + + public void markComplete() { + incoming().markComplete(); + } + + public static class Event extends Hit { + + private final String type; + + public Event(String id, String data, String type) { + super(id); + this.type = type; + setField(type, data); + } + + public String toString() { + return getField(type).toString(); + } + + public String type() { + return type; + } + + } +} diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java new file mode 100644 index 00000000000..9ebe6d048b3 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java @@ -0,0 +1,279 @@ + +package com.yahoo.search.rendering; + +import com.yahoo.search.result.EventStream; +import com.yahoo.concurrent.ThreadFactoryFactory; +import com.yahoo.document.DocumentId; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.Hit; +import com.yahoo.search.result.HitGroup; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.text.Utf8; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.ByteArrayOutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class EventRendererTestCase { + + private static ThreadPoolExecutor executor; + private static EventRenderer blueprint; + private EventRenderer renderer; + + @BeforeAll + public static void createExecutorAndBlueprint() { + ThreadFactory threadFactory = ThreadFactoryFactory.getThreadFactory("test-rendering"); + executor = new ThreadPoolExecutor(4, 4, 1L, TimeUnit.MINUTES, new LinkedBlockingQueue<>(), threadFactory); + executor.prestartAllCoreThreads(); + blueprint = new EventRenderer(executor); + } + + @BeforeEach + public void createClone() { + // Use the shared renderer as a prototype object, as specified in the API contract + renderer = (EventRenderer) blueprint.clone(); + renderer.init(); + } + + @AfterEach + public void deconstructClone() { + if (renderer != null) { + renderer.deconstruct(); + renderer = null; + } + } + + @AfterAll + public static void deconstructBlueprintAndExecutor() throws InterruptedException { + blueprint.deconstruct(); + blueprint = null; + executor.shutdown(); + if (!executor.awaitTermination(1, TimeUnit.MINUTES)) { + throw new RuntimeException("Failed to shutdown executor"); + } + executor = null; + } + + @Test + @Timeout(5) + public void testRendering() throws InterruptedException, ExecutionException { + var expected = """ + event: token + data: {"token":"Ducks"} + + event: token + data: {"token":" have"} + + event: token + data: {"token":" adorable"} + + event: token + data: {"token":" waddling"} + + event: token + data: {"token":" walks"} + + event: end + """; + var tokenStream = EventStream.create("token_stream"); + for (String token : splitter("Ducks have adorable waddling walks")) { + tokenStream.add(token); + } + tokenStream.markComplete(); + var result = render(new Result(new Query(), tokenStream)); + assertEquals(expected, result); + } + + @Test + @Timeout(5) + public void testAsyncRendering() throws InterruptedException, ExecutionException { + var expected = """ + event: token + data: {"token":"Ducks"} + + event: token + data: {"token":" have"} + + event: token + data: {"token":" adorable"} + + event: token + data: {"token":" waddling"} + + event: token + data: {"token":" walks"} + + event: end + """; + var result = ""; + var executor = Executors.newFixedThreadPool(1); + try { + var tokenStream = EventStream.create("token_stream"); + var future = completeAsync("Ducks have adorable waddling walks", executor, token -> { + tokenStream.add(token); + }).exceptionally(e -> { + tokenStream.error("error", new ErrorMessage(400, e.getMessage())); + tokenStream.markComplete(); + return false; + }).thenAccept(finishReason -> { + tokenStream.markComplete(); + }); + assertFalse(future.isDone()); + result = render(new Result(new Query(), tokenStream)); + assertTrue(future.isDone()); // Renderer waits for async completion + + } finally { + executor.shutdownNow(); + } + assertEquals(expected, result); + } + + @Test + public void testErrorEndsStream() throws ExecutionException, InterruptedException { + var tokenStream = EventStream.create("token_stream"); + tokenStream.add("token1"); + tokenStream.add("token2"); + tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong")); + tokenStream.markComplete(); + var result = render(new Result(new Query(), tokenStream)); + var expected = """ + event: token + data: {"token":"token1"} + + event: token + data: {"token":"token2"} + + event: error + data: {"source":"my_llm","error":400,"message":"Something went wrong"} + + event: end + """; + assertEquals(expected, result); + } + + @Test + public void testPromptRendering() throws ExecutionException, InterruptedException { + String prompt = "Why are ducks better than cats?\n\nBe concise.\n"; + + var tokenStream = EventStream.create("token_stream"); + tokenStream.add(prompt, "prompt"); + tokenStream.add("Just"); + tokenStream.add(" because"); + tokenStream.add("."); + tokenStream.markComplete(); + var result = render(new Result(new Query(), tokenStream)); + + var expected = """ + event: prompt + data: {"prompt":"Why are ducks better than cats?\\n\\nBe concise.\\n"} + + event: token + data: {"token":"Just"} + + event: token + data: {"token":" because"} + + event: token + data: {"token":"."} + + event: end + """; + assertEquals(expected, result); + } + + @Test + @Timeout(5) + public void testResultRenderingFails() { + var tokenStream = EventStream.create("token_stream"); + tokenStream.add("token1"); + tokenStream.add("token2"); + tokenStream.markComplete(); + + var resultsHitGroup = new HitGroup("test_results"); + var hit1 = new Hit("result_1"); + var hit2 = new Hit("result_2"); + hit1.setField("documentid", new DocumentId("id:unittest:test::1")); + hit2.setField("documentid", new DocumentId("id:unittest:test::2")); + resultsHitGroup.add(hit1); + resultsHitGroup.add(hit2); + + var combined = new HitGroup("all"); + combined.add(resultsHitGroup); + combined.add(tokenStream); + + var result = new Result(new Query(), combined); + assertThrows(Exception.class, () -> render(result)); // Todo: support this + } + + private String render(Result r) throws InterruptedException, ExecutionException { + var execution = new Execution(Execution.Context.createContextStub()); + return render(execution, r); + } + + private String render(Execution execution, Result r) throws InterruptedException, ExecutionException { + if (renderer == null) createClone(); + try { + ByteArrayOutputStream bs = new ByteArrayOutputStream(); //new DebugOutputStream(); + CompletableFuture f = renderer.renderResponse(bs, r, execution, null); + assertTrue(f.get()); + return Utf8.toString(bs.toByteArray()); + } finally { + deconstructClone(); + } + } + + private static class DebugOutputStream extends ByteArrayOutputStream { + @Override + public synchronized void write(byte[] b, int off, int len) { + super.write(b, off, len); + System.out.print(new String(b, off, len)); + } + } + + private static List splitter(String text) { + var list = new ArrayList(); + for (String token : text.split(" ")) { + list.add(list.isEmpty() ? token : " " + token); + } + return list; + } + + private static CompletableFuture completeAsync(String text, ExecutorService executor, Consumer consumer) { + var completionFuture = new CompletableFuture(); + executor.submit(() -> { + try { + for (String s : splitter(text)) { + consumer.accept(s); + Thread.sleep(10); + } + completionFuture.complete(true); + } catch (Exception e) { + completionFuture.completeExceptionally(e); + } + }); + return completionFuture; + } + +} -- cgit v1.2.3 From e7e6202bb35594f1d95a0a31f2af7f1056498e1b Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 15 Mar 2024 14:40:04 +0100 Subject: Change EventStream to a DataList and be able that with JsonRenderer --- .../src/main/java/com/yahoo/search/Result.java | 2 +- .../com/yahoo/search/rendering/EventRenderer.java | 5 +- .../com/yahoo/search/rendering/JsonRenderer.java | 24 ++++-- .../java/com/yahoo/search/result/EventStream.java | 97 +++++++++++++++++----- .../search/rendering/EventRendererTestCase.java | 42 ++++++---- .../search/rendering/JsonRendererTestCase.java | 40 +++++++++ 6 files changed, 161 insertions(+), 49 deletions(-) (limited to 'container-search/src') diff --git a/container-search/src/main/java/com/yahoo/search/Result.java b/container-search/src/main/java/com/yahoo/search/Result.java index b1a0107c6d8..a989688575d 100644 --- a/container-search/src/main/java/com/yahoo/search/Result.java +++ b/container-search/src/main/java/com/yahoo/search/Result.java @@ -20,7 +20,7 @@ import java.util.Iterator; * a single HitGroup containing hits of the result. The HitGroup may contain Hits, which are the individual * result items, as well as further HitGroups, making up a composite structure. This allows the hits of a result * to be hierarchically organized. A Hit is polymorphic and may contain any kind of information deemed - * an approriate partial answer to the Query. + * an appropriate partial answer to the Query. *

* Do not cache this as it holds references to objects that should be garbage collected. * diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 46f9a53e698..83ae349f5a0 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -60,10 +60,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer { @Override public void beginList(DataList dataList) throws IOException { - if ( ! (dataList instanceof EventStream)) { - throw new IllegalArgumentException("EventRenderer currently only supports EventStreams"); - // Todo: support results and timing and trace by delegating to JsonRenderer - } } @Override @@ -92,6 +88,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer { generator.flush(); } } + // Todo: support other types of data such as search results (hits), timing and trace } @Override diff --git a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java index e876f8e06d0..69410070453 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java @@ -40,6 +40,7 @@ import com.yahoo.search.result.Coverage; import com.yahoo.search.result.DefaultErrorHit; import com.yahoo.search.result.ErrorHit; import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.EventStream; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; import com.yahoo.search.result.HitGroup; @@ -243,17 +244,19 @@ public class JsonRenderer extends AsynchronousSectionedRenderer { @Override public void beginList(DataList list) throws IOException { - Preconditions.checkArgument(list instanceof HitGroup, - "Expected subclass of com.yahoo.search.result.HitGroup, got %s.", - list.getClass()); moreChildren(); - renderHitGroupHead((HitGroup) list); + if (list instanceof HitGroup) { + renderHitGroupHead((HitGroup) list); + } else if (list instanceof EventStream) { + renderHitGroupHead(new HitGroup("event_stream")); // Consider waiting for all events and create a single summary hit + } else { + throw new IllegalArgumentException("Expected subclass of com.yahoo.search.result.HitGroup, got " + list.getClass()); + } } protected void moreChildren() throws IOException { if (!renderedChildren.isEmpty()) childrenArray(); - renderedChildren.push(0); } @@ -443,10 +446,13 @@ public class JsonRenderer extends AsynchronousSectionedRenderer { @Override public void data(Data data) throws IOException { - Preconditions.checkArgument(data instanceof Hit, - "Expected subclass of com.yahoo.search.result.Hit, got %s.", - data.getClass()); - renderHit((Hit) data); + if (data instanceof Hit) { + renderHit((Hit) data); + } else if (data instanceof EventStream.Event) { + renderHit(((EventStream.Event) data).asHit()); + } else { + throw new IllegalArgumentException("Expected subclass of com.yahoo.search.result.Hit, got " + data.getClass()); + } } @Override diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java index 84ef9a8ee86..b393a91e6d0 100644 --- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java +++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java @@ -1,38 +1,43 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.result; +import com.yahoo.collections.ListenableArrayList; +import com.yahoo.component.provider.ListenableFreezableClass; +import com.yahoo.processing.Request; +import com.yahoo.processing.response.Data; +import com.yahoo.processing.response.DataList; import com.yahoo.processing.response.DefaultIncomingData; +import com.yahoo.processing.response.IncomingData; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; /** * A stream of events which can be rendered as Server-Sent Events (SSE). * * @author lesters */ -public class EventStream extends HitGroup { - - private int eventCount = 0; +public class EventStream extends Hit implements DataList { - public static final String DEFAULT_EVENT_TYPE = "token"; + private final ListenableArrayList data = new ListenableArrayList<>(16); + private final IncomingData incomingData; + private final AtomicInteger eventCount = new AtomicInteger(0); - private EventStream(String id, DefaultIncomingData incomingData) { - super(id, new Relevance(1), incomingData); - this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept - } + public final static String EVENT_TYPE_TOKEN = "token"; + public final static String DEFAULT_EVENT_TYPE = EVENT_TYPE_TOKEN; - public static EventStream create(String id) { - DefaultIncomingData incomingData = new DefaultIncomingData<>(); - EventStream stream = new EventStream(id, incomingData); - incomingData.assignOwner(stream); - return stream; + public EventStream() { + super(); + this.incomingData = new DefaultIncomingData<>(this); } public void add(String data) { - add(data, DEFAULT_EVENT_TYPE); + incoming().add(new Event(eventCount.incrementAndGet(), data, DEFAULT_EVENT_TYPE)); } public void add(String data, String type) { - incoming().add(new Event(String.valueOf(eventCount + 1), data, type)); - eventCount++; + incoming().add(new Event(eventCount.incrementAndGet(), data, type)); } public void error(String source, ErrorMessage message) { @@ -43,23 +48,73 @@ public class EventStream extends HitGroup { incoming().markComplete(); } - public static class Event extends Hit { + @Override + public Data add(Data event) { + data.add(event); + return event; + } + + @Override + public Data get(int index) { + return data.get(index); + } + + @Override + public List asList() { + return data; + } + + @Override + public IncomingData incoming() { + return incomingData; + } + + @Override + public CompletableFuture> completeFuture() { + return incomingData.completedFuture(); + } + + @Override + public void addDataListener(Runnable runnable) { + data.addListener(runnable); + } + + @Override + public void close() { + } + + public static class Event extends ListenableFreezableClass implements Data { + private final int eventNumber; + private final String data; private final String type; - public Event(String id, String data, String type) { - super(id); + public Event(int eventNumber, String data, String type) { + this.eventNumber = eventNumber; + this.data = data; this.type = type; - setField(type, data); } public String toString() { - return getField(type).toString(); + return data; } public String type() { return type; } + @Override + public Request request() { + return null; + } + + // For json rendering + public Hit asHit() { + Hit hit = new Hit(String.valueOf(eventNumber)); + hit.setField(type, data); + return hit; + } + } + } diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java index 9ebe6d048b3..c0a677b2094 100644 --- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java @@ -33,7 +33,6 @@ import java.util.function.Consumer; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class EventRendererTestCase { @@ -97,12 +96,12 @@ public class EventRendererTestCase { event: end """; - var tokenStream = EventStream.create("token_stream"); + var tokenStream = new EventStream(); for (String token : splitter("Ducks have adorable waddling walks")) { tokenStream.add(token); } tokenStream.markComplete(); - var result = render(new Result(new Query(), tokenStream)); + var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream"))); assertEquals(expected, result); } @@ -130,7 +129,7 @@ public class EventRendererTestCase { var result = ""; var executor = Executors.newFixedThreadPool(1); try { - var tokenStream = EventStream.create("token_stream"); + var tokenStream = new EventStream(); var future = completeAsync("Ducks have adorable waddling walks", executor, token -> { tokenStream.add(token); }).exceptionally(e -> { @@ -141,7 +140,7 @@ public class EventRendererTestCase { tokenStream.markComplete(); }); assertFalse(future.isDone()); - result = render(new Result(new Query(), tokenStream)); + result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream"))); assertTrue(future.isDone()); // Renderer waits for async completion } finally { @@ -152,12 +151,12 @@ public class EventRendererTestCase { @Test public void testErrorEndsStream() throws ExecutionException, InterruptedException { - var tokenStream = EventStream.create("token_stream"); + var tokenStream = new EventStream(); tokenStream.add("token1"); tokenStream.add("token2"); tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong")); tokenStream.markComplete(); - var result = render(new Result(new Query(), tokenStream)); + var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream"))); var expected = """ event: token data: {"token":"token1"} @@ -177,13 +176,13 @@ public class EventRendererTestCase { public void testPromptRendering() throws ExecutionException, InterruptedException { String prompt = "Why are ducks better than cats?\n\nBe concise.\n"; - var tokenStream = EventStream.create("token_stream"); + var tokenStream = new EventStream(); tokenStream.add(prompt, "prompt"); tokenStream.add("Just"); tokenStream.add(" because"); tokenStream.add("."); tokenStream.markComplete(); - var result = render(new Result(new Query(), tokenStream)); + var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream"))); var expected = """ event: prompt @@ -205,8 +204,8 @@ public class EventRendererTestCase { @Test @Timeout(5) - public void testResultRenderingFails() { - var tokenStream = EventStream.create("token_stream"); + public void testResultRenderingIsSkipped() throws ExecutionException, InterruptedException { + var tokenStream = new EventStream(); tokenStream.add("token1"); tokenStream.add("token2"); tokenStream.markComplete(); @@ -221,10 +220,25 @@ public class EventRendererTestCase { var combined = new HitGroup("all"); combined.add(resultsHitGroup); - combined.add(tokenStream); + combined.add(newHitGroup(tokenStream, "token_stream")); - var result = new Result(new Query(), combined); - assertThrows(Exception.class, () -> render(result)); // Todo: support this + var result = render(new Result(new Query(), combined)); + var expected = """ + event: token + data: {"token":"token1"} + + event: token + data: {"token":"token2"} + + event: end + """; + assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace + } + + static HitGroup newHitGroup(EventStream eventStream, String id) { + var hitGroup = new HitGroup(id); + hitGroup.add(eventStream); + return hitGroup; } private String render(Result r) throws InterruptedException, ExecutionException { diff --git a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java index 3a8584dd0a5..ffa6c82e941 100644 --- a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java @@ -36,6 +36,7 @@ import com.yahoo.search.grouping.result.RootGroup; import com.yahoo.search.grouping.result.StringId; import com.yahoo.search.result.Coverage; import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.EventStream; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; import com.yahoo.search.result.HitGroup; @@ -1566,6 +1567,45 @@ public class JsonRendererTestCase { assertEqualJson(expected, summary); } + @Test + @Timeout(600) + void testEventStreamRendering() throws ExecutionException, InterruptedException { + var tokenStream = new EventStream(); + tokenStream.add("token1"); + tokenStream.add("token2"); + tokenStream.markComplete(); + + var hitGroup = new HitGroup("token_stream"); + hitGroup.add(tokenStream); + var result = render(new Result(new Query(), hitGroup)); + + String expected = "{" + + "\"root\":{" + + "\"id\":\"token_stream\"," + + "\"relevance\":1.0," + + "\"fields\":{" + + "\"totalCount\":0" + + "}," + + "\"children\":[{" + + "\"id\":\"event_stream\"," + + "\"relevance\":1.0," + + "\"children\":[{" + + "\"id\":\"1\"," + + "\"relevance\":1.0," + + "\"fields\":{" + + "\"token\":\"token1\"" + + "}},{" + + "\"id\":\"2\"," + + "\"relevance\":1.0," + + "\"fields\":{" + + "\"token\":\"token2\"" + + "}}" + + "]}]" + + "}" + + "}"; + assertEqualJson(expected, result); + } + private Result newEmptyResult(String[] args) { return new Result(new Query("/?" + String.join("&", args))); } -- cgit v1.2.3