diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2023-10-17 13:29:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-17 13:29:43 +0200 |
commit | d82eefec840e49c196ca843cceeae6fb91c71b9f (patch) | |
tree | b73237e7355850c46a2617d43a3261b20062919c | |
parent | 9c78c8d33c29d38aaa17df0497f1d48cb4d8c80e (diff) | |
parent | 42e0c3ad9d3efebde578313cfd059e3d895f0300 (diff) |
Merge pull request #28953 from vespa-engine/lesters/openai-async-client
Lesters/openai async client
6 files changed, 123 insertions, 16 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 1c19c2ba5d6..3e588e24d47 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -4045,7 +4045,8 @@ "abstract" ], "methods" : [ - "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt)" + "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt)", + "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)" ], "fields" : [ ] }, @@ -4059,6 +4060,7 @@ "public void <init>(java.lang.String)", "public ai.vespa.llm.client.openai.OpenAiClient$Builder model(java.lang.String)", "public ai.vespa.llm.client.openai.OpenAiClient$Builder temperature(double)", + "public ai.vespa.llm.client.openai.OpenAiClient$Builder maxTokens(long)", "public ai.vespa.llm.client.openai.OpenAiClient build()" ], "fields" : [ ] @@ -4072,7 +4074,8 @@ "public" ], "methods" : [ - "public java.util.List complete(ai.vespa.llm.completion.Prompt)" + "public java.util.List complete(ai.vespa.llm.completion.Prompt)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)" ], "fields" : [ ] }, @@ -4090,7 +4093,8 @@ ], "fields" : [ "public static final enum ai.vespa.llm.completion.Completion$FinishReason length", - "public static final enum ai.vespa.llm.completion.Completion$FinishReason stop" + "public static final enum ai.vespa.llm.completion.Completion$FinishReason stop", + "public static final enum ai.vespa.llm.completion.Completion$FinishReason none" ] }, "ai.vespa.llm.completion.Completion" : { @@ -4167,7 +4171,8 @@ ], "methods" : [ "public void <init>(ai.vespa.llm.test.MockLanguageModel$Builder)", - "public java.util.List complete(ai.vespa.llm.completion.Prompt)" + "public java.util.List complete(ai.vespa.llm.completion.Prompt)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, java.util.function.Consumer)" ], "fields" : [ ] } diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java index bd9004a659b..f4b8938934b 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java @@ -6,6 +6,8 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.api.annotations.Beta; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; /** * Interface to language models. @@ -17,4 +19,6 @@ public interface LanguageModel { List<Completion> complete(Prompt prompt); + CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action); + } diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index efa8927988c..d7334b40963 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -18,25 +18,34 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. - * Currently only completions are implemented. + * Currently, only completions are implemented. * * @author bratseth */ @Beta public class OpenAiClient implements LanguageModel { + private static final String DATA_FIELD = "data: "; + private final String token; private final String model; private final double temperature; + private final long maxTokens; + private final HttpClient httpClient; private OpenAiClient(Builder builder) { this.token = builder.token; this.model = builder.model; this.temperature = builder.temperature; + this.maxTokens = builder.maxTokens; this.httpClient = HttpClient.newBuilder().build(); } @@ -54,13 +63,63 @@ public class OpenAiClient implements LanguageModel { } } + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) { + try { + var request = toRequest(prompt, true); + var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + + futureResponse.thenAcceptAsync(response -> { + try { + int responseCode = response.statusCode(); + if (responseCode != 200) { + throw new IllegalArgumentException("Received code " + responseCode + ": " + + response.body().collect(Collectors.joining())); + } + + Stream<String> lines = response.body(); + lines.forEach(line -> { + if (line.startsWith(DATA_FIELD)) { + var root = SlimeUtils.jsonToSlime(line.substring(DATA_FIELD.length())).get(); + var completion = toCompletions(root, "delta").get(0); + consumer.accept(completion); + if (!completion.finishReason().equals(Completion.FinishReason.none)) { + completionFuture.complete(completion.finishReason()); + } + } + }); + } catch (Exception e) { + completionFuture.completeExceptionally(e); + } + }); + return completionFuture; + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { + return toRequest(prompt, false); + } + + private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException { var slime = new Slime(); var root = slime.setObject(); root.setString("model", model); root.setDouble("temperature", temperature); - root.setString("prompt", prompt.asString()); - return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions")) + root.setBool("stream", stream); + root.setLong("n", 1); + if (maxTokens > 0) { + root.setLong("max_tokens", maxTokens); + } + var messagesArray = root.setArray("messages"); + var messagesObject = messagesArray.addObject(); + messagesObject.setString("role", "user"); + messagesObject.setString("content", prompt.asString()); + + return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + token) .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) @@ -68,21 +127,27 @@ public class OpenAiClient implements LanguageModel { } private List<Completion> toCompletions(Inspector response) { + return toCompletions(response, "message"); + } + + private List<Completion> toCompletions(Inspector response, String field) { List<Completion> completions = new ArrayList<>(); response.field("choices") - .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice))); + .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice, field))); return completions; } - private Completion toCompletion(Inspector choice) { - return new Completion(choice.field("text").asString(), - toFinishReason(choice.field("finish_reason").asString())); + private Completion toCompletion(Inspector choice, String field) { + var content = choice.field(field).field("content").asString(); + var finishReason = toFinishReason(choice.field("finish_reason").asString()); + return new Completion(content, finishReason); } private Completion.FinishReason toFinishReason(String finishReasonString) { return switch(finishReasonString) { case "length" -> Completion.FinishReason.length; case "stop" -> Completion.FinishReason.stop; + case "", "null" -> Completion.FinishReason.none; default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); }; } @@ -90,8 +155,9 @@ public class OpenAiClient implements LanguageModel { public static class Builder { private final String token; - private String model = "text-davinci-003"; - private double temperature = 0; + private String model = "gpt-3.5-turbo"; + private double temperature = 0.0; + private long maxTokens = 0; public Builder(String token) { this.token = token; @@ -109,6 +175,12 @@ public class OpenAiClient implements LanguageModel { return this; } + /** Maximum number of tokens to generate */ + public Builder maxTokens(long maxTokens) { + this.maxTokens = maxTokens; + return this; + } + public OpenAiClient build() { return new OpenAiClient(this); } diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java index f5731852d93..ea784013812 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java @@ -19,8 +19,10 @@ public record Completion(String text, FinishReason finishReason) { length, /** The completion is the predicted ending of the prompt. */ - stop + stop, + /** The completion is not finished yet, more tokens are incoming. */ + none } public Completion(String text, FinishReason finishReason) { diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java index 16e9c4e1848..db1b42fbbac 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java @@ -7,6 +7,8 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.api.annotations.Beta; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.function.Function; /** @@ -26,6 +28,11 @@ public class MockLanguageModel implements LanguageModel { return completer.apply(prompt); } + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action) { + throw new RuntimeException("Not implemented"); + } + public static class Builder { private Function<Prompt, List<Completion>> completer = prompt -> List.of(Completion.from("")); diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java index 444f082b1c0..45ef7e270aa 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java @@ -11,12 +11,14 @@ import org.junit.jupiter.api.Test; */ public class OpenAiClientCompletionTest { + private static final String apiKey = "your-api-key-here"; + @Test @Disabled public void testClient() { - var client = new OpenAiClient.Builder("your token here").build(); + var client = new OpenAiClient.Builder(apiKey).maxTokens(10).build(); String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?"; + "Be as long-winded as possible. Are humans smarter than cats?\n\n"; StringPrompt prompt = StringPrompt.from(input); System.out.print(prompt); for (int i = 0; i < 10; i++) { @@ -27,4 +29,19 @@ public class OpenAiClientCompletionTest { } } + @Test + @Disabled + public void testAsyncClient() { + var client = new OpenAiClient.Builder(apiKey).build(); + String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?\n\n"; + StringPrompt prompt = StringPrompt.from(input); + System.out.print(prompt); + var future = client.completeAsync(prompt, completion -> { + System.out.print(completion.text()); + }); + System.out.println("Waiting for completion..."); + System.out.println("\nFinished streaming because of " + future.join()); + } + } |