aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lester.solbakken@gmail.com>2024-06-10 15:54:10 +0200
committerLester Solbakken <lester.solbakken@gmail.com>2024-06-10 15:54:10 +0200
commit3e31b1f29cd1c6cc77a873dafd67dd8294ca2039 (patch)
tree509f2610b53bd4f64be3e56a2feff75fc25b807e /container-search/src/test/java/ai
parent367b751a72f52f8baa890ff0f2fe4a78653976fa (diff)
Add rendering of hits (and trace and timing etc) in llm rendering
Diffstat (limited to 'container-search/src/test/java/ai')
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java2
-rw-r--r--container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java77
2 files changed, 78 insertions, 1 deletions
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
index d6b66b1a8c6..13b5f540a3a 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
@@ -115,7 +115,7 @@ public class RAGSearcherTest {
return execution.search(query);
}
- private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
+ static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java
new file mode 100644
index 00000000000..40aba0d5b6a
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java
@@ -0,0 +1,77 @@
+package ai.vespa.search.llm;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.search.Result;
+import com.yahoo.search.rendering.EventRenderer;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import org.junit.jupiter.api.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RAGWithEventRendererTest {
+
+ @Test
+ public void testPromptAndHitsAreRendered() throws Exception {
+ var params = Map.of(
+ "query", "why are ducks better than cats?",
+ "llm.stream", "false",
+ "llm.includePrompt", "true",
+ "llm.includeHits", "true"
+ );
+ var llm = LLMSearcherTest.createLLMClient();
+ var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm));
+ var results = RAGSearcherTest.runMockSearch(searcher, params);
+
+ var result = render(results);
+
+ var promptEvent = extractEvent(result, "prompt");
+ assertNotNull(promptEvent);
+ assertTrue(promptEvent.has("prompt"));
+
+ var resultsEvent = extractEvent(result, "hits");
+ assertNotNull(resultsEvent);
+ assertTrue(resultsEvent.has("root"));
+ assertEquals(2, resultsEvent.get("root").get("children").size());
+ }
+
+ private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException {
+ var lines = result.split("\n");
+ for (int i = 0; i < lines.length; i++) {
+ if (lines[i].startsWith("event: " + eventName)) {
+ var data = lines[i + 1].substring("data: ".length()).trim();
+ ObjectMapper objectMapper = new ObjectMapper();
+ return objectMapper.readTree(data);
+ }
+ }
+ return null;
+ }
+
+ private String render(Result r) throws InterruptedException, ExecutionException {
+ var execution = new Execution(Execution.Context.createContextStub());
+ return render(execution, r);
+ }
+
+ private String render(Execution execution, Result r) throws ExecutionException, InterruptedException {
+ var renderer = new EventRenderer();
+ try {
+ renderer.init();
+ ByteArrayOutputStream bs = new ByteArrayOutputStream();
+ CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
+ assertTrue(f.get());
+ return Utf8.toString(bs.toByteArray());
+ } finally {
+ renderer.deconstruct();
+ }
+ }
+
+}