summaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java')
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java42
1 files changed, 28 insertions, 14 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
index 9ebe6d048b3..c0a677b2094 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -33,7 +33,6 @@ import java.util.function.Consumer;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class EventRendererTestCase {
@@ -97,12 +96,12 @@ public class EventRendererTestCase {
event: end
""";
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
for (String token : splitter("Ducks have adorable waddling walks")) {
tokenStream.add(token);
}
tokenStream.markComplete();
- var result = render(new Result(new Query(), tokenStream));
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
assertEquals(expected, result);
}
@@ -130,7 +129,7 @@ public class EventRendererTestCase {
var result = "";
var executor = Executors.newFixedThreadPool(1);
try {
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
var future = completeAsync("Ducks have adorable waddling walks", executor, token -> {
tokenStream.add(token);
}).exceptionally(e -> {
@@ -141,7 +140,7 @@ public class EventRendererTestCase {
tokenStream.markComplete();
});
assertFalse(future.isDone());
- result = render(new Result(new Query(), tokenStream));
+ result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
assertTrue(future.isDone()); // Renderer waits for async completion
} finally {
@@ -152,12 +151,12 @@ public class EventRendererTestCase {
@Test
public void testErrorEndsStream() throws ExecutionException, InterruptedException {
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
tokenStream.add("token1");
tokenStream.add("token2");
tokenStream.error("my_llm", new ErrorMessage(400, "Something went wrong"));
tokenStream.markComplete();
- var result = render(new Result(new Query(), tokenStream));
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
var expected = """
event: token
data: {"token":"token1"}
@@ -177,13 +176,13 @@ public class EventRendererTestCase {
public void testPromptRendering() throws ExecutionException, InterruptedException {
String prompt = "Why are ducks better than cats?\n\nBe concise.\n";
- var tokenStream = EventStream.create("token_stream");
+ var tokenStream = new EventStream();
tokenStream.add(prompt, "prompt");
tokenStream.add("Just");
tokenStream.add(" because");
tokenStream.add(".");
tokenStream.markComplete();
- var result = render(new Result(new Query(), tokenStream));
+ var result = render(new Result(new Query(), newHitGroup(tokenStream, "token_stream")));
var expected = """
event: prompt
@@ -205,8 +204,8 @@ public class EventRendererTestCase {
@Test
@Timeout(5)
- public void testResultRenderingFails() {
- var tokenStream = EventStream.create("token_stream");
+ public void testResultRenderingIsSkipped() throws ExecutionException, InterruptedException {
+ var tokenStream = new EventStream();
tokenStream.add("token1");
tokenStream.add("token2");
tokenStream.markComplete();
@@ -221,10 +220,25 @@ public class EventRendererTestCase {
var combined = new HitGroup("all");
combined.add(resultsHitGroup);
- combined.add(tokenStream);
+ combined.add(newHitGroup(tokenStream, "token_stream"));
- var result = new Result(new Query(), combined);
- assertThrows(Exception.class, () -> render(result)); // Todo: support this
+ var result = render(new Result(new Query(), combined));
+ var expected = """
+ event: token
+ data: {"token":"token1"}
+
+ event: token
+ data: {"token":"token2"}
+
+ event: end
+ """;
+ assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
+ }
+
+ static HitGroup newHitGroup(EventStream eventStream, String id) {
+ var hitGroup = new HitGroup(id);
+ hitGroup.add(eventStream);
+ return hitGroup;
}
private String render(Result r) throws InterruptedException, ExecutionException {