diff options
Diffstat (limited to 'container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java')
-rwxr-xr-x | container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java | 43 |
1 files changed, 37 insertions, 6 deletions
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 4c39506ed96..d0d2cd3a442 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -14,11 +14,14 @@ import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; +import com.yahoo.search.rendering.JsonRenderer; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.EventStream; import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; +import com.yahoo.text.Utf8; +import java.io.ByteArrayOutputStream; import java.util.List; import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; @@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher { private static final String API_KEY_HEADER = "X-LLM-API-KEY"; private static final String STREAM_PROPERTY = "stream"; private static final String PROMPT_PROPERTY = "prompt"; + private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt"; + private static final String INCLUDE_HITS_IN_RESULT = "includeHits"; + + private final JsonRenderer jsonRenderer; private final String propertyPrefix; private final boolean stream; @@ -50,11 +57,13 @@ public class LLMSearcher extends Searcher { this.languageModelId = config.providerId(); this.languageModel = findLanguageModel(languageModelId, languageModels); this.propertyPrefix = config.propertyPrefix(); + + this.jsonRenderer = new JsonRenderer(); } @Override public Result search(Query query, Execution execution) { - return complete(query, StringPrompt.from(getPrompt(query))); + return complete(query, StringPrompt.from(getPrompt(query)), null, execution); } private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels) @@ -81,30 +90,37 @@ public class LLMSearcher extends Searcher { return languageModel; } - protected Result complete(Query query, Prompt prompt) { + protected Result complete(Query query, Prompt prompt, Result result, Execution execution) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config try { - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + if (stream) { + return completeAsync(query, prompt, options, result, execution); + } + return completeSync(query, prompt, options, result, execution); } catch (RejectedExecutionException e) { return new Result(query, new ErrorMessage(429, e.getMessage())); } } private boolean shouldAddPrompt(Query query) { - return query.getTrace().getLevel() >= 1; + var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false); + return query.getTrace().getLevel() >= 1 || includePrompt; } private boolean shouldAddTokenStats(Query query) { return query.getTrace().getLevel() >= 1; } - private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { + private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) { final EventStream eventStream = new EventStream(); if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } + if (shouldAddHits(query) && result != null) { + eventStream.add(renderHits(result, execution), "hits"); + } final TokenStats tokenStats = new TokenStats(); languageModel.completeAsync(prompt, options, completion -> { @@ -143,12 +159,15 @@ public class LLMSearcher extends Searcher { eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); } - private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { + private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) { EventStream eventStream = new EventStream(); if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } + if (shouldAddHits(query) && result != null) { + eventStream.add(renderHits(result, execution), "hits"); + } List<Completion> completions = languageModel.complete(prompt, options); eventStream.add(completions.get(0).text(), "completion"); @@ -200,6 +219,18 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } + private boolean shouldAddHits(Query query) { + return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false); + } + + private String renderHits(Result results, Execution execution) { + var bs = new ByteArrayOutputStream(); + var renderer = jsonRenderer.clone(); + renderer.init(); + renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete + return Utf8.toString(bs.toByteArray()); + } + private static class TokenStats { private final long start; |