summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarald Musum <musum@yahooinc.com>2024-06-10 18:50:58 +0200
committerGitHub <noreply@github.com>2024-06-10 18:50:58 +0200
commit7d5685f0b5e1988a02d4aa74eed4254e486fc26f (patch)
treeeaa56f2a91e98311aa3ca030e9e92a6bb694c94d
parentccfb0e8294b21dd799adf41ca892f69763c1b222 (diff)
parent3e31b1f29cd1c6cc77a873dafd67dd8294ca2039 (diff)
Merge pull request #31519 from vespa-engine/lesters/render-hits-in-ssev8.355.18
Add rendering of hits (and trace and timing etc) in llm rendering
-rw-r--r--container-search/abi-spec.json2
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java43
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java11
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java2
-rw-r--r--container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java77
-rw-r--r--container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java2
7 files changed, 125 insertions, 14 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 80fe11b7174..8fbf12b16b4 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -9227,7 +9227,7 @@
"methods" : [
"public void <init>(ai.vespa.search.llm.LlmSearcherConfig, com.yahoo.component.provider.ComponentRegistry)",
"public com.yahoo.search.Result search(com.yahoo.search.Query, com.yahoo.search.searchchain.Execution)",
- "protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt)",
+ "protected com.yahoo.search.Result complete(com.yahoo.search.Query, ai.vespa.llm.completion.Prompt, com.yahoo.search.Result, com.yahoo.search.searchchain.Execution)",
"public java.lang.String getPrompt(com.yahoo.search.Query)",
"public java.lang.String getPropertyPrefix()",
"public java.lang.String lookupProperty(java.lang.String, com.yahoo.search.Query)",
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 4c39506ed96..d0d2cd3a442 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -14,11 +14,14 @@ import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
+import com.yahoo.search.rendering.JsonRenderer;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import java.io.ByteArrayOutputStream;
import java.util.List;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
@@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher {
private static final String API_KEY_HEADER = "X-LLM-API-KEY";
private static final String STREAM_PROPERTY = "stream";
private static final String PROMPT_PROPERTY = "prompt";
+ private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt";
+ private static final String INCLUDE_HITS_IN_RESULT = "includeHits";
+
+ private final JsonRenderer jsonRenderer;
private final String propertyPrefix;
private final boolean stream;
@@ -50,11 +57,13 @@ public class LLMSearcher extends Searcher {
this.languageModelId = config.providerId();
this.languageModel = findLanguageModel(languageModelId, languageModels);
this.propertyPrefix = config.propertyPrefix();
+
+ this.jsonRenderer = new JsonRenderer();
}
@Override
public Result search(Query query, Execution execution) {
- return complete(query, StringPrompt.from(getPrompt(query)));
+ return complete(query, StringPrompt.from(getPrompt(query)), null, execution);
}
private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
@@ -81,30 +90,37 @@ public class LLMSearcher extends Searcher {
return languageModel;
}
- protected Result complete(Query query, Prompt prompt) {
+ protected Result complete(Query query, Prompt prompt, Result result, Execution execution) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
try {
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ if (stream) {
+ return completeAsync(query, prompt, options, result, execution);
+ }
+ return completeSync(query, prompt, options, result, execution);
} catch (RejectedExecutionException e) {
return new Result(query, new ErrorMessage(429, e.getMessage()));
}
}
private boolean shouldAddPrompt(Query query) {
- return query.getTrace().getLevel() >= 1;
+ var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false);
+ return query.getTrace().getLevel() >= 1 || includePrompt;
}
private boolean shouldAddTokenStats(Query query) {
return query.getTrace().getLevel() >= 1;
}
- private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
+ private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
final EventStream eventStream = new EventStream();
if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
+ if (shouldAddHits(query) && result != null) {
+ eventStream.add(renderHits(result, execution), "hits");
+ }
final TokenStats tokenStats = new TokenStats();
languageModel.completeAsync(prompt, options, completion -> {
@@ -143,12 +159,15 @@ public class LLMSearcher extends Searcher {
eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
}
- private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
+ private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
EventStream eventStream = new EventStream();
if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
+ if (shouldAddHits(query) && result != null) {
+ eventStream.add(renderHits(result, execution), "hits");
+ }
List<Completion> completions = languageModel.complete(prompt, options);
eventStream.add(completions.get(0).text(), "completion");
@@ -200,6 +219,18 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
+ private boolean shouldAddHits(Query query) {
+ return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false);
+ }
+
+ private String renderHits(Result results, Execution execution) {
+ var bs = new ByteArrayOutputStream();
+ var renderer = jsonRenderer.clone();
+ renderer.init();
+ renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete
+ return Utf8.toString(bs.toByteArray());
+ }
+
private static class TokenStats {
private final long start;
diff --git a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
index cba153d881d..cdf57922bce 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java
@@ -37,7 +37,7 @@ public class RAGSearcher extends LLMSearcher {
public Result search(Query query, Execution execution) {
Result result = execution.search(query);
execution.fill(result);
- return complete(query, buildPrompt(query, result));
+ return complete(query, buildPrompt(query, result), result, execution);
}
protected Prompt buildPrompt(Query query, Result result) {
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 88a1e6c1485..ffbb63514f1 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -79,13 +79,16 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.writeRaw("event: " + event.type() + "\n");
}
generator.writeRaw("data: ");
- generator.writeStartObject();
- generator.writeStringField(event.type(), event.toString());
- generator.writeEndObject();
+ if (event.type().equals("hits")) {
+ generator.writeRaw(event.toString());
+ } else {
+ generator.writeStartObject();
+ generator.writeStringField(event.type(), event.toString());
+ generator.writeEndObject();
+ }
generator.writeRaw("\n\n");
generator.flush();
}
- // Todo: support other types of data such as search results (hits), timing and trace
}
@Override
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
index d6b66b1a8c6..13b5f540a3a 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/RAGSearcherTest.java
@@ -115,7 +115,7 @@ public class RAGSearcherTest {
return execution.search(query);
}
- private static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
+ static Searcher createRAGSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
diff --git a/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java
new file mode 100644
index 00000000000..40aba0d5b6a
--- /dev/null
+++ b/container-search/src/test/java/ai/vespa/search/llm/RAGWithEventRendererTest.java
@@ -0,0 +1,77 @@
+package ai.vespa.search.llm;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.search.Result;
+import com.yahoo.search.rendering.EventRenderer;
+import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import org.junit.jupiter.api.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class RAGWithEventRendererTest {
+
+ @Test
+ public void testPromptAndHitsAreRendered() throws Exception {
+ var params = Map.of(
+ "query", "why are ducks better than cats?",
+ "llm.stream", "false",
+ "llm.includePrompt", "true",
+ "llm.includeHits", "true"
+ );
+ var llm = LLMSearcherTest.createLLMClient();
+ var searcher = RAGSearcherTest.createRAGSearcher(Map.of("mock", llm));
+ var results = RAGSearcherTest.runMockSearch(searcher, params);
+
+ var result = render(results);
+
+ var promptEvent = extractEvent(result, "prompt");
+ assertNotNull(promptEvent);
+ assertTrue(promptEvent.has("prompt"));
+
+ var resultsEvent = extractEvent(result, "hits");
+ assertNotNull(resultsEvent);
+ assertTrue(resultsEvent.has("root"));
+ assertEquals(2, resultsEvent.get("root").get("children").size());
+ }
+
+ private JsonNode extractEvent(String result, String eventName) throws JsonProcessingException {
+ var lines = result.split("\n");
+ for (int i = 0; i < lines.length; i++) {
+ if (lines[i].startsWith("event: " + eventName)) {
+ var data = lines[i + 1].substring("data: ".length()).trim();
+ ObjectMapper objectMapper = new ObjectMapper();
+ return objectMapper.readTree(data);
+ }
+ }
+ return null;
+ }
+
+ private String render(Result r) throws InterruptedException, ExecutionException {
+ var execution = new Execution(Execution.Context.createContextStub());
+ return render(execution, r);
+ }
+
+ private String render(Execution execution, Result r) throws ExecutionException, InterruptedException {
+ var renderer = new EventRenderer();
+ try {
+ renderer.init();
+ ByteArrayOutputStream bs = new ByteArrayOutputStream();
+ CompletableFuture<Boolean> f = renderer.renderResponse(bs, r, execution, null);
+ assertTrue(f.get());
+ return Utf8.toString(bs.toByteArray());
+ } finally {
+ renderer.deconstruct();
+ }
+ }
+
+}
diff --git a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
index 2cfb6552379..f6f6f40bdae 100644
--- a/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/rendering/EventRendererTestCase.java
@@ -232,7 +232,7 @@ public class EventRendererTestCase {
event: end
""";
- assertEquals(expected, result); // Todo: support other types of data such as search results (hits), timing and trace
+ assertEquals(expected, result);
}
static HitGroup newHitGroup(EventStream eventStream, String id) {