aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-03-15 14:40:04 +0100
committerLester Solbakken <lester.solbakken@gmail.com>2024-03-15 14:40:04 +0100
commite7e6202bb35594f1d95a0a31f2af7f1056498e1b (patch)
tree4935a9c51533907349bdb7f907b5bdc7554353cc /container-search/src
parent25efeef096cb30090c6d1ed0bd804f6a0745adc0 (diff)
Change EventStream to a DataList and be able that with JsonRenderer
Diffstat (limited to 'container-search/src')
-rw-r--r--container-search/src/main/java/com/yahoo/search/Result.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java5
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java24
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java97
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java42
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java40
6 files changed, 161 insertions, 49 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/Result.java b/container-search/src/main/java/com/yahoo/search/Result.java
index b1a0107c6d8..a989688575d 100644
--- a/container-search/src/main/java/com/yahoo/search/Result.java
+++ b/container-search/src/main/java/com/yahoo/search/Result.java
@@ -20,7 +20,7 @@ import java.util.Iterator;
* a single HitGroup containing hits of the result. The HitGroup may contain Hits, which are the individual
* result items, as well as further HitGroups, making up a <i>composite</i> structure. This allows the hits of a result
* to be hierarchically organized. A Hit is polymorphic and may contain any kind of information deemed
- * an approriate partial answer to the Query.
+ * an appropriate partial answer to the Query.
* <p>
* Do not cache this as it holds references to objects that should be garbage collected.
*
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 46f9a53e698..83ae349f5a0 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -60,10 +60,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void beginList(DataList<?> dataList) throws IOException {
- if ( ! (dataList instanceof EventStream)) {
- throw new IllegalArgumentException("EventRenderer currently only supports EventStreams");
- // Todo: support results and timing and trace by delegating to JsonRenderer
- }
}
@Override
@@ -92,6 +88,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.flush();
}
}
+ // Todo: support other types of data such as search results (hits), timing and trace
}
@Override
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
index e876f8e06d0..69410070453 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/JsonRenderer.java
@@ -40,6 +40,7 @@ import com.yahoo.search.result.Coverage;
import com.yahoo.search.result.DefaultErrorHit;
import com.yahoo.search.result.ErrorHit;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
@@ -243,17 +244,19 @@ public class JsonRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void beginList(DataList<?> list) throws IOException {
- Preconditions.checkArgument(list instanceof HitGroup,
- "Expected subclass of com.yahoo.search.result.HitGroup, got %s.",
- list.getClass());
moreChildren();
- renderHitGroupHead((HitGroup) list);
+ if (list instanceof HitGroup) {
+ renderHitGroupHead((HitGroup) list);
+ } else if (list instanceof EventStream) {
+ renderHitGroupHead(new HitGroup("event_stream")); // Consider waiting for all events and create a single summary hit
+ } else {
+ throw new IllegalArgumentException("Expected subclass of com.yahoo.search.result.HitGroup, got " + list.getClass());
+ }
}
protected void moreChildren() throws IOException {
if (!renderedChildren.isEmpty())
childrenArray();
-
renderedChildren.push(0);
}
@@ -443,10 +446,13 @@ public class JsonRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- Preconditions.checkArgument(data instanceof Hit,
- "Expected subclass of com.yahoo.search.result.Hit, got %s.",
- data.getClass());
- renderHit((Hit) data);
+ if (data instanceof Hit) {
+ renderHit((Hit) data);
+ } else if (data instanceof EventStream.Event) {
+ renderHit(((EventStream.Event) data).asHit());
+ } else {
+ throw new IllegalArgumentException("Expected subclass of com.yahoo.search.result.Hit, got " + data.getClass());
+ }
}
@Override
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
index 84ef9a8ee86..b393a91e6d0 100644
--- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -1,38 +1,43 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.result;
+import com.yahoo.collections.ListenableArrayList;
+import com.yahoo.component.provider.ListenableFreezableClass;
+import com.yahoo.processing.Request;
+import com.yahoo.processing.response.Data;
+import com.yahoo.processing.response.DataList;
import com.yahoo.processing.response.DefaultIncomingData;
+import com.yahoo.processing.response.IncomingData;
+
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicInteger;
/**
* A stream of events which can be rendered as Server-Sent Events (SSE).
*
* @author lesters
*/
-public class EventStream extends HitGroup {
-
- private int eventCount = 0;
+public class EventStream extends Hit implements DataList<Data> {
- public static final String DEFAULT_EVENT_TYPE = "token";
+ private final ListenableArrayList<Data> data = new ListenableArrayList<>(16);
+ private final IncomingData<Data> incomingData;
+ private final AtomicInteger eventCount = new AtomicInteger(0);
- private EventStream(String id, DefaultIncomingData<Hit> incomingData) {
- super(id, new Relevance(1), incomingData);
- this.setOrdered(true); // avoid hit group ordering - important as sequence as inserted should be kept
- }
+ public final static String EVENT_TYPE_TOKEN = "token";
+ public final static String DEFAULT_EVENT_TYPE = EVENT_TYPE_TOKEN;
- public static EventStream create(String id) {
- DefaultIncomingData<Hit> incomingData = new DefaultIncomingData<>();
- EventStream stream = new EventStream(id, incomingData);
- incomingData.assignOwner(stream);
- return stream;
+ public EventStream() {
+ super();
+ this.incomingData = new DefaultIncomingData<>(this);
}
public void add(String data) {
- add(data, DEFAULT_EVENT_TYPE);
+ incoming().add(new Event(eventCount.incrementAndGet(), data, DEFAULT_EVENT_TYPE));
}
public void add(String data, String type) {
- incoming().add(new Event(String.valueOf(eventCount + 1), data, type));
- eventCount++;
+ incoming().add(new Event(eventCount.incrementAndGet(), data, type));
}
public void error(String source, ErrorMessage message) {
@@ -43,23 +48,73 @@ public class EventStream extends HitGroup {
incoming().markComplete();
}
- public static class Event extends Hit {
+ @Override
+ public Data add(Data event) {
+ data.add(event);
+ return event;
+ }
+
+ @Override
+ public Data get(int index) {
+ return data.get(index);
+ }
+
+ @Override
+ public List<Data> asList() {
+ return data;
+ }
+
+ @Override
+ public IncomingData<Data> incoming() {
+ return incomingData;
+ }
+
+ @Override
+ public CompletableFuture<DataList<Data>> completeFuture() {
+ return incomingData.completedFuture();
+ }
+
+ @Override
+ public void addDataListener(Runnable runnable) {
+ data.addListener(runnable);
+ }
+
+ @Override
+ public void close() {
+ }
+
+ public static class Event extends ListenableFreezableClass implements Data {
+ private final int eventNumber;
+ private final String data;
private final String type;
- public Event(String id, String data, String type) {
- super(id);
+ public Event(int eventNumber, String data, String type) {
+ this.eventNumber = eventNumber;
+ this.data = data;
this.type = type;
- setField(type, data);
}
public String toString() {
- return getField(type).toString();
+ return data;
}
public String type() {
return type;
}
+ @Override
+ public Request request() {
+ return null;
+ }
+
+ // For json rendering
+ public Hit asHit() {
+ Hit hit = new Hit(String.valueOf(eventNumber));
+ hit.setField(type, data);
+ return hit;
+ }
+
}
+
}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
index 9ebe6d048b3..c0a677b2094 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -33,7 +33,6 @@ import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class EventRendererTestCase {
@@ -97,12 +96,12 @@ public class EventRendererTestCase {
event: end
""";
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
for (String token : splitter("Ducks have adorable waddling walks")) {
tokenStream.add(token);
}
tokenStream.markComplete();
- var result = render(new Result(new Query(), tokenStream));
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
assertEquals(expected, result);
}
@@ -130,7 +129,7 @@ public class EventRendererTestCase {
var result = "";
var executor = Executors.newFixedThreadPool(1);
try {
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
var future = completeAsync("Ducks have adorable waddling walks", executor, token -> {
tokenStream.add(token);
}).exceptionally(e -> {
@@ -141,7 +140,7 @@ public class EventRendererTestCase {
tokenStream.markComplete();
});
assertFalse(future.isDone());
- result = render(new Result(new Query(), tokenStream));
+ result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
assertTrue(future.isDone()); // Renderer waits for async completion
} finally {
@@ -152,12 +151,12 @@ public class EventRendererTestCase {
@Test
public void testErrorEndsStream() throws ExecutionException, InterruptedException {
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
tokenStream.add("token1");
tokenStream.add("token2");
tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong"));
tokenStream.markComplete();
- var result = render(new Result(new Query(), tokenStream));
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
var expected = """
event: token
data: {"token":"token1"}
@@ -177,13 +176,13 @@ public class EventRendererTestCase {
public void testPromptRendering() throws ExecutionException, InterruptedException {
String prompt = "Why are ducks better than cats?\n\nBe concise.\n";
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
tokenStream.add(prompt, "prompt");
tokenStream.add("Just");
tokenStream.add(" because");
tokenStream.add(".");
tokenStream.markComplete();
- var result = render(new Result(new Query(), tokenStream));
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
var expected = """
event: prompt
@@ -205,8 +204,8 @@ public class EventRendererTestCase {
@Test
@Timeout(5)
- public void testResultRenderingFails() {
- var tokenStream = EventStream.create("token_stream");
+ public void testResultRenderingIsSkipped() throws ExecutionException, InterruptedException {
+ var tokenStream = new EventStream();
tokenStream.add("token1");
tokenStream.add("token2");
tokenStream.markComplete();
@@ -221,10 +220,25 @@ public class EventRendererTestCase {
var combined = new HitGroup("all");
combined.add(resultsHitGroup);
- combined.add(tokenStream);
+ combined.add(newHitGroup(tokenStream, "token_stream"));
- var result = new Result(new Query(), combined);
- assertThrows(Exception.class, () -> render(result)); // Todo: support this
+ var result = render(new Result(new Query(), combined));
+ var expected = """
+ event: token
+ data: {"token":"token1"}
+
+ event: token
+ data: {"token":"token2"}
+
+ event: end
+ """;
+ assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
+ }
+
+ static HitGroup newHitGroup(EventStream eventStream, String id) {
+ var hitGroup = new HitGroup(id);
+ hitGroup.add(eventStream);
+ return hitGroup;
}
private String render(Result r) throws InterruptedException, ExecutionException {
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
index 3a8584dd0a5..ffa6c82e941 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/JsonRendererTestCase.java
@@ -36,6 +36,7 @@ import com.yahoo.search.grouping.result.RootGroup;
import com.yahoo.search.grouping.result.StringId;
import com.yahoo.search.result.Coverage;
import com.yahoo.search.result.ErrorMessage;
+import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
@@ -1566,6 +1567,45 @@ public class JsonRendererTestCase {
assertEqualJson(expected, summary);
}
+ @Test
+ @Timeout(600)
+ void testEventStreamRendering() throws ExecutionException, InterruptedException {
+ var tokenStream = new EventStream();
+ tokenStream.add("token1");
+ tokenStream.add("token2");
+ tokenStream.markComplete();
+
+ var hitGroup = new HitGroup("token_stream");
+ hitGroup.add(tokenStream);
+ var result = render(new Result(new Query(), hitGroup));
+
+ String expected = "{" +
+ "\"root\":{" +
+ "\"id\":\"token_stream\"," +
+ "\"relevance\":1.0," +
+ "\"fields\":{" +
+ "\"totalCount\":0" +
+ "}," +
+ "\"children\":[{" +
+ "\"id\":\"event_stream\"," +
+ "\"relevance\":1.0," +
+ "\"children\":[{" +
+ "\"id\":\"1\"," +
+ "\"relevance\":1.0," +
+ "\"fields\":{" +
+ "\"token\":\"token1\"" +
+ "}},{" +
+ "\"id\":\"2\"," +
+ "\"relevance\":1.0," +
+ "\"fields\":{" +
+ "\"token\":\"token2\"" +
+ "}}" +
+ "]}]" +
+ "}" +
+ "}";
+ assertEqualJson(expected, result);
+ }
+
private Result newEmptyResult(String[] args) {
return new Result(new Query("/?" + String.join("&", args)));
}