diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2023-06-13 23:40:39 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2023-06-13 23:40:39 +0200 |
commit | 3567995f6b857b677a6e7dbf82f952a3dfc388cd (patch) | |
tree | 72d68e16bf6ba449537a8604ffcbc3be5783341e /vespajlib | |
parent | 50d7555bfe7bdaec86f8b31c4d316c9ba66bb976 (diff) |
Get rid of third party openai client
Diffstat (limited to 'vespajlib')
4 files changed, 158 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java new file mode 100644 index 00000000000..9145e76a2e0 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -0,0 +1,118 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.client.openai; + +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.Slime; +import com.yahoo.slime.SlimeUtils; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.List; + +/** + * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. + * Currently only completions are implemented. + * + * @author bratseth + */ +@Beta +public class OpenAiClient implements LanguageModel { + + private final String token; + private final String model; + private final double temperature; + private final HttpClient httpClient; + + private OpenAiClient(Builder builder) { + this.token = builder.token; + this.model = builder.model; + this.temperature = builder.temperature; + this.httpClient = HttpClient.newBuilder().build(); + } + + @Override + public List<Completion> complete(Prompt prompt) { + try { + HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray()); + var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); + if ( httpResponse.statusCode() != 200) + throw new IllegalArgumentException(SlimeUtils.toJson(response)); + return toCompletions(response); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { + var slime = new Slime(); + var root = slime.setObject(); + root.setString("model", model); + root.setDouble("temperature", temperature); + root.setString("prompt", prompt.asString()); + return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions")) + .header("Content-Type", "application/json") + .header("Authorization", "Bearer " + token) + .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) + .build(); + } + + private List<Completion> toCompletions(Inspector response) { + List<Completion> completions = new ArrayList<>(); + response.field("choices") + .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice))); + return completions; + } + + private Completion toCompletion(Inspector choice) { + return new Completion(choice.field("text").asString(), + toFinishReason(choice.field("finish_reason").asString())); + } + + private Completion.FinishReason toFinishReason(String finishReasonString) { + return switch(finishReasonString) { + case "length" -> Completion.FinishReason.length; + case "stop" -> Completion.FinishReason.stop; + default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); + }; + } + + public static class Builder { + + private final String token; + private String model = "text-davinci-003"; + private double temperature = 0; + + public Builder(String token) { + this.token = token; + } + + /** One of the language models listed at https://platform.openai.com/docs/models */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** A value between 0 and 2 - higher gives more random/creative output. */ + public Builder temperature(double temperature) { + this.temperature = temperature; + return this; + } + + public OpenAiClient build() { + return new OpenAiClient(this); + } + + } + +} diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java new file mode 100644 index 00000000000..8b8b99308b0 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java @@ -0,0 +1,11 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.llm.client.openai; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; + +/** + * Client to OpenAi's large language models. + */
\ No newline at end of file diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java new file mode 100644 index 00000000000..961a02afea3 --- /dev/null +++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java @@ -0,0 +1,29 @@ +package ai.vespa.llm.client.openai; + +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.StringPrompt; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +/** + * @author bratseth + */ +public class OpenAiClientCompletionTest { + + @Test + @Disabled + public void testClient() { + var client = new OpenAiClient.Builder("your token here").build(); + String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"; + StringPrompt prompt = StringPrompt.from(input); + System.out.print(prompt); + for (int i = 0; i < 10; i++) { + var completion = client.complete(prompt).get(0); + System.out.print(completion.text()); + if (completion.finishReason() == Completion.FinishReason.stop) break; + prompt = prompt.append(completion.text()); + } + } + +} diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java index 1c794c64d1a..26508228ab6 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java @@ -1,8 +1,5 @@ package ai.vespa.llm.completion; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; import ai.vespa.llm.test.MockLanguageModel; import org.junit.jupiter.api.Test; |