From 0bab8eefe3443cc6f7befc13607b3c23602998b4 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 16 Oct 2023 16:56:40 +0200 Subject: Add OpenAI async client --- .../ai/vespa/llm/client/openai/OpenAiClient.java | 90 +++++++++++++++++++--- 1 file changed, 81 insertions(+), 9 deletions(-) (limited to 'vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java') diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index efa8927988c..c50731d1ae1 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -18,25 +18,34 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. - * Currently only completions are implemented. + * Currently, only completions are implemented. * * @author bratseth */ @Beta public class OpenAiClient implements LanguageModel { + private static final String DATA_FIELD = "data: "; + private final String token; private final String model; private final double temperature; + private final long maxTokens; + private final HttpClient httpClient; private OpenAiClient(Builder builder) { this.token = builder.token; this.model = builder.model; this.temperature = builder.temperature; + this.maxTokens = builder.maxTokens; this.httpClient = HttpClient.newBuilder().build(); } @@ -54,13 +63,63 @@ public class OpenAiClient implements LanguageModel { } } + @Override + public CompletableFuture completeAsync(Prompt prompt, Consumer action) { + try { + var request = toRequest(prompt, true); + var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); + var completionFuture = new CompletableFuture(); + + futureResponse.thenAcceptAsync(response -> { + try { + int responseCode = response.statusCode(); + if (responseCode != 200) { + throw new IllegalArgumentException("Received code " + responseCode + ": " + + response.body().collect(Collectors.joining())); + } + + Stream lines = response.body(); + lines.forEach(line -> { + if (line.startsWith(DATA_FIELD)) { + var root = SlimeUtils.jsonToSlime(line.substring(DATA_FIELD.length())).get(); + var completion = toCompletions(root, "delta").get(0); + action.accept(completion); + if (!completion.finishReason().equals(Completion.FinishReason.none)) { + completionFuture.complete(completion.finishReason()); + } + } + }); + } catch (Exception e) { + completionFuture.completeExceptionally(e); + } + }); + return completionFuture; + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { + return toRequest(prompt, false); + } + + private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException { var slime = new Slime(); var root = slime.setObject(); root.setString("model", model); root.setDouble("temperature", temperature); - root.setString("prompt", prompt.asString()); - return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions")) + root.setBool("stream", stream); + root.setLong("n", 1); + if (maxTokens > 0) { + root.setLong("max_tokens", maxTokens); + } + var messagesArray = root.setArray("messages"); + var messagesObject = messagesArray.addObject(); + messagesObject.setString("role", "user"); + messagesObject.setString("content", prompt.asString()); + + return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + token) .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) @@ -68,21 +127,27 @@ public class OpenAiClient implements LanguageModel { } private List toCompletions(Inspector response) { + return toCompletions(response, "message"); + } + + private List toCompletions(Inspector response, String field) { List completions = new ArrayList<>(); response.field("choices") - .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice))); + .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice, field))); return completions; } - private Completion toCompletion(Inspector choice) { - return new Completion(choice.field("text").asString(), - toFinishReason(choice.field("finish_reason").asString())); + private Completion toCompletion(Inspector choice, String field) { + var content = choice.field(field).field("content").asString(); + var finishReason = toFinishReason(choice.field("finish_reason").asString()); + return new Completion(content, finishReason); } private Completion.FinishReason toFinishReason(String finishReasonString) { return switch(finishReasonString) { case "length" -> Completion.FinishReason.length; case "stop" -> Completion.FinishReason.stop; + case "", "null" -> Completion.FinishReason.none; default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); }; } @@ -90,8 +155,9 @@ public class OpenAiClient implements LanguageModel { public static class Builder { private final String token; - private String model = "text-davinci-003"; - private double temperature = 0; + private String model = "gpt-3.5-turbo"; + private double temperature = 0.0; + private long maxTokens = 0; public Builder(String token) { this.token = token; @@ -109,6 +175,12 @@ public class OpenAiClient implements LanguageModel { return this; } + /** A value between 0 and 2 - higher gives more random/creative output. */ + public Builder maxTokens(long maxTokens) { + this.maxTokens = maxTokens; + return this; + } + public OpenAiClient build() { return new OpenAiClient(this); } -- cgit v1.2.3