diff options
Diffstat (limited to 'container-search/src/test/java')
3 files changed, 79 insertions, 2 deletions
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java index d6b66b1a8c6..13b5f540a3a 100755 --- a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java +++ b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java @@ -115,7 +115,7 @@ public class RAGSearcherTest { return execution.search(query); } - private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) { + static Searcher createRAGSearcher(Map<String, LanguageModel> llms) { var config = new LlmSearcherConfig.Builder().stream(false).build(); ComponentRegistry<LanguageModel> models = new ComponentRegistry<>(); llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value)); diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java new file mode 100644 index 00000000000..40aba0d5b6a --- /dev/null +++ b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java @@ -0,0 +1,77 @@ +package ai.vespa.search.llm; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.search.Result; +import com.yahoo.search.rendering.EventRenderer; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.text.Utf8; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayOutputStream; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RAGWithEventRendererTest { + + @Test + public void testPromptAndHitsAreRendered() throws Exception { + var params = Map.of( + "query", "why are ducks better than cats?", + "llm.stream", "false", + "llm.includePrompt", "true", + "llm.includeHits", "true" + ); + var llm = LLMSearcherTest.createLLMClient(); + var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm)); + var results = RAGSearcherTest.runMockSearch(searcher, params); + + var result = render(results); + + var promptEvent = extractEvent(result, "prompt"); + assertNotNull(promptEvent); + assertTrue(promptEvent.has("prompt")); + + var resultsEvent = extractEvent(result, "hits"); + assertNotNull(resultsEvent); + assertTrue(resultsEvent.has("root")); + assertEquals(2, resultsEvent.get("root").get("children").size()); + } + + private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException { + var lines = result.split("\n"); + for (int i = 0; i < lines.length; i++) { + if (lines[i].startsWith("event: " + eventName)) { + var data = lines[i + 1].substring("data: ".length()).trim(); + ObjectMapper objectMapper = new ObjectMapper(); + return objectMapper.readTree(data); + } + } + return null; + } + + private String render(Result r) throws InterruptedException, ExecutionException { + var execution = new Execution(Execution.Context.createContextStub()); + return render(execution, r); + } + + private String render(Execution execution, Result r) throws ExecutionException, InterruptedException { + var renderer = new EventRenderer(); + try { + renderer.init(); + ByteArrayOutputStream bs = new ByteArrayOutputStream(); + CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null); + assertTrue(f.get()); + return Utf8.toString(bs.toByteArray()); + } finally { + renderer.deconstruct(); + } + } + +} diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java index 2cfb6552379..f6f6f40bdae 100644 --- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java @@ -232,7 +232,7 @@ public class EventRendererTestCase { event: end """; - assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace + assertEquals(expected, result); } static HitGroup newHitGroup(EventStream eventStream, String id) { |