diff options
Diffstat (limited to 'container-search/src/main')
3 files changed, 45 insertions, 11 deletions
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java index 4c39506ed96..d0d2cd3a442 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java @@ -14,11 +14,14 @@ import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; +import com.yahoo.search.rendering.JsonRenderer; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.EventStream; import com.yahoo.search.result.HitGroup; import com.yahoo.search.searchchain.Execution; +import com.yahoo.text.Utf8; +import java.io.ByteArrayOutputStream; import java.util.List; import java.util.concurrent.RejectedExecutionException; import java.util.function.Function; @@ -38,6 +41,10 @@ public class LLMSearcher extends Searcher { private static final String API_KEY_HEADER = "X-LLM-API-KEY"; private static final String STREAM_PROPERTY = "stream"; private static final String PROMPT_PROPERTY = "prompt"; + private static final String INCLUDE_PROMPT_IN_RESULT = "includePrompt"; + private static final String INCLUDE_HITS_IN_RESULT = "includeHits"; + + private final JsonRenderer jsonRenderer; private final String propertyPrefix; private final boolean stream; @@ -50,11 +57,13 @@ public class LLMSearcher extends Searcher { this.languageModelId = config.providerId(); this.languageModel = findLanguageModel(languageModelId, languageModels); this.propertyPrefix = config.propertyPrefix(); + + this.jsonRenderer = new JsonRenderer(); } @Override public Result search(Query query, Execution execution) { - return complete(query, StringPrompt.from(getPrompt(query))); + return complete(query, StringPrompt.from(getPrompt(query)), null, execution); } private LanguageModel findLanguageModel(String providerId, ComponentRegistry<LanguageModel> languageModels) @@ -81,30 +90,37 @@ public class LLMSearcher extends Searcher { return languageModel; } - protected Result complete(Query query, Prompt prompt) { + protected Result complete(Query query, Prompt prompt, Result result, Execution execution) { var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query)); var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config try { - return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options); + if (stream) { + return completeAsync(query, prompt, options, result, execution); + } + return completeSync(query, prompt, options, result, execution); } catch (RejectedExecutionException e) { return new Result(query, new ErrorMessage(429, e.getMessage())); } } private boolean shouldAddPrompt(Query query) { - return query.getTrace().getLevel() >= 1; + var includePrompt = lookupPropertyBool(INCLUDE_PROMPT_IN_RESULT, query, false); + return query.getTrace().getLevel() >= 1 || includePrompt; } private boolean shouldAddTokenStats(Query query) { return query.getTrace().getLevel() >= 1; } - private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) { + private Result completeAsync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) { final EventStream eventStream = new EventStream(); if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } + if (shouldAddHits(query) && result != null) { + eventStream.add(renderHits(result, execution), "hits"); + } final TokenStats tokenStats = new TokenStats(); languageModel.completeAsync(prompt, options, completion -> { @@ -143,12 +159,15 @@ public class LLMSearcher extends Searcher { eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage())); } - private Result completeSync(Query query, Prompt prompt, InferenceParameters options) { + private Result completeSync(Query query, Prompt prompt, InferenceParameters options, Result result, Execution execution) { EventStream eventStream = new EventStream(); if (shouldAddPrompt(query)) { eventStream.add(prompt.asString(), "prompt"); } + if (shouldAddHits(query) && result != null) { + eventStream.add(renderHits(result, execution), "hits"); + } List<Completion> completions = languageModel.complete(prompt, options); eventStream.add(completions.get(0).text(), "completion"); @@ -200,6 +219,18 @@ public class LLMSearcher extends Searcher { return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p)); } + private boolean shouldAddHits(Query query) { + return lookupPropertyBool(INCLUDE_HITS_IN_RESULT, query, false); + } + + private String renderHits(Result results, Execution execution) { + var bs = new ByteArrayOutputStream(); + var renderer = jsonRenderer.clone(); + renderer.init(); + renderer.renderResponse(bs, results, execution, null).join(); // wait for renderer to complete + return Utf8.toString(bs.toByteArray()); + } + private static class TokenStats { private final long start; diff --git a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java index cba153d881d..cdf57922bce 100755 --- a/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java +++ b/container-search/src/main/java/ai/vespa/search/llm/RAGSearcher.java @@ -37,7 +37,7 @@ public class RAGSearcher extends LLMSearcher { public Result search(Query query, Execution execution) { Result result = execution.search(query); execution.fill(result); - return complete(query, buildPrompt(query, result)); + return complete(query, buildPrompt(query, result), result, execution); } protected Prompt buildPrompt(Query query, Result result) { diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java index 88a1e6c1485..ffbb63514f1 100644 --- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java +++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java @@ -79,13 +79,16 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> { generator.writeRaw("event: " + event.type() + "\n"); } generator.writeRaw("data: "); - generator.writeStartObject(); - generator.writeStringField(event.type(), event.toString()); - generator.writeEndObject(); + if (event.type().equals("hits")) { + generator.writeRaw(event.toString()); + } else { + generator.writeStartObject(); + generator.writeStringField(event.type(), event.toString()); + generator.writeEndObject(); + } generator.writeRaw("\n\n"); generator.flush(); } - // Todo: support other types of data such as search results (hits), timing and trace } @Override |