summaryrefslogtreecommitdiffstats
path: root/container-search/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/test')
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java2
-rw-r--r--container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java77
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java2
3 files changed, 79 insertions, 2 deletions
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
index d6b66b1a8c6..13b5f540a3a 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
@@ -115,7 +115,7 @@ public class RAGSearcherTest {
return execution.search(query);
}
- private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
+ static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java
new file mode 100644
index 00000000000..40aba0d5b6a
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java
@@ -0,0 +1,77 @@
+package ai.vespa.search.llm;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.search.Result;
+import com.yahoo.search.rendering.EventRenderer;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import org.junit.jupiter.api.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RAGWithEventRendererTest {
+
+ @Test
+ public void testPromptAndHitsAreRendered() throws Exception {
+ var params = Map.of(
+ "query", "why are ducks better than cats?",
+ "llm.stream", "false",
+ "llm.includePrompt", "true",
+ "llm.includeHits", "true"
+ );
+ var llm = LLMSearcherTest.createLLMClient();
+ var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm));
+ var results = RAGSearcherTest.runMockSearch(searcher, params);
+
+ var result = render(results);
+
+ var promptEvent = extractEvent(result, "prompt");
+ assertNotNull(promptEvent);
+ assertTrue(promptEvent.has("prompt"));
+
+ var resultsEvent = extractEvent(result, "hits");
+ assertNotNull(resultsEvent);
+ assertTrue(resultsEvent.has("root"));
+ assertEquals(2, resultsEvent.get("root").get("children").size());
+ }
+
+ private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException {
+ var lines = result.split("\n");
+ for (int i = 0; i < lines.length; i++) {
+ if (lines[i].startsWith("event: " + eventName)) {
+ var data = lines[i + 1].substring("data: ".length()).trim();
+ ObjectMapper objectMapper = new ObjectMapper();
+ return objectMapper.readTree(data);
+ }
+ }
+ return null;
+ }
+
+ private String render(Result r) throws InterruptedException, ExecutionException {
+ var execution = new Execution(Execution.Context.createContextStub());
+ return render(execution, r);
+ }
+
+ private String render(Execution execution, Result r) throws ExecutionException, InterruptedException {
+ var renderer = new EventRenderer();
+ try {
+ renderer.init();
+ ByteArrayOutputStream bs = new ByteArrayOutputStream();
+ CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
+ assertTrue(f.get());
+ return Utf8.toString(bs.toByteArray());
+ } finally {
+ renderer.deconstruct();
+ }
+ }
+
+}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
index 2cfb6552379..f6f6f40bdae 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -232,7 +232,7 @@ public class EventRendererTestCase {
event: end
""";
- assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
+ assertEquals(expected, result);
}
static HitGroup newHitGroup(EventStream eventStream, String id) {