aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java')
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java43
1 files changed, 37 insertions, 6 deletions
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 4c39506ed96..d0d2cd3a442 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -14,11 +14,14 @@ import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
+import com.yahoo.search.rendering.JsonRenderer;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.EventStream;
import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
+import com.yahoo.text.Utf8;
+import java.io.ByteArrayOutputStream;
import java.util.List;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
@@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher {
private static final String API_KEY_HEADER = "X-LLM-API-KEY";
private static final String STREAM_PROPERTY = "stream";
private static final String PROMPT_PROPERTY = "prompt";
+ private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt";
+ private static final String INCLUDE_HITS_IN_RESULT = "includeHits";
+
+ private final JsonRenderer jsonRenderer;
private final String propertyPrefix;
private final boolean stream;
@@ -50,11 +57,13 @@ public class LLMSearcher extends Searcher {
this.languageModelId = config.providerId();
this.languageModel = findLanguageModel(languageModelId, languageModels);
this.propertyPrefix = config.propertyPrefix();
+
+ this.jsonRenderer = new JsonRenderer();
}
@Override
public Result search(Query query, Execution execution) {
- return complete(query, StringPrompt.from(getPrompt(query)));
+ return complete(query, StringPrompt.from(getPrompt(query)), null, execution);
}
private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels)
@@ -81,30 +90,37 @@ public class LLMSearcher extends Searcher {
return languageModel;
}
- protected Result complete(Query query, Prompt prompt) {
+ protected Result complete(Query query, Prompt prompt, Result result, Execution execution) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
try {
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ if (stream) {
+ return completeAsync(query, prompt, options, result, execution);
+ }
+ return completeSync(query, prompt, options, result, execution);
} catch (RejectedExecutionException e) {
return new Result(query, new ErrorMessage(429, e.getMessage()));
}
}
private boolean shouldAddPrompt(Query query) {
- return query.getTrace().getLevel() >= 1;
+ var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false);
+ return query.getTrace().getLevel() >= 1 || includePrompt;
}
private boolean shouldAddTokenStats(Query query) {
return query.getTrace().getLevel() >= 1;
}
- private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
+ private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
final EventStream eventStream = new EventStream();
if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
+ if (shouldAddHits(query) && result != null) {
+ eventStream.add(renderHits(result, execution), "hits");
+ }
final TokenStats tokenStats = new TokenStats();
languageModel.completeAsync(prompt, options, completion -> {
@@ -143,12 +159,15 @@ public class LLMSearcher extends Searcher {
eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
}
- private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
+ private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) {
EventStream eventStream = new EventStream();
if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
+ if (shouldAddHits(query) && result != null) {
+ eventStream.add(renderHits(result, execution), "hits");
+ }
List<Completion> completions = languageModel.complete(prompt, options);
eventStream.add(completions.get(0).text(), "completion");
@@ -200,6 +219,18 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
+ private boolean shouldAddHits(Query query) {
+ return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false);
+ }
+
+ private String renderHits(Result results, Execution execution) {
+ var bs = new ByteArrayOutputStream();
+ var renderer = jsonRenderer.clone();
+ renderer.init();
+ renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete
+ return Utf8.toString(bs.toByteArray());
+ }
+
private static class TokenStats {
private final long start;