summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2023-06-13 23:40:39 +0200
committerJon Bratseth <bratseth@vespa.ai>2023-06-13 23:40:39 +0200
commit3567995f6b857b677a6e7dbf82f952a3dfc388cd (patch)
tree72d68e16bf6ba449537a8604ffcbc3be5783341e /vespajlib
parent50d7555bfe7bdaec86f8b31c4d316c9ba66bb976 (diff)
Get rid of third party openai client
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java118
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java11
-rw-r--r--vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java29
-rw-r--r--vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java3
4 files changed, 158 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
new file mode 100644
index 00000000000..9145e76a2e0
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
@@ -0,0 +1,118 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.client.openai;
+
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.slime.ArrayTraverser;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.Slime;
+import com.yahoo.slime.SlimeUtils;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/.
+ * Currently only completions are implemented.
+ *
+ * @author bratseth
+ */
+@Beta
+public class OpenAiClient implements LanguageModel {
+
+ private final String token;
+ private final String model;
+ private final double temperature;
+ private final HttpClient httpClient;
+
+ private OpenAiClient(Builder builder) {
+ this.token = builder.token;
+ this.model = builder.model;
+ this.temperature = builder.temperature;
+ this.httpClient = HttpClient.newBuilder().build();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt) {
+ try {
+ HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray());
+ var response = SlimeUtils.jsonToSlime(httpResponse.body()).get();
+ if ( httpResponse.statusCode() != 200)
+ throw new IllegalArgumentException(SlimeUtils.toJson(response));
+ return toCompletions(response);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException {
+ var slime = new Slime();
+ var root = slime.setObject();
+ root.setString("model", model);
+ root.setDouble("temperature", temperature);
+ root.setString("prompt", prompt.asString());
+ return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", "Bearer " + token)
+ .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime)))
+ .build();
+ }
+
+ private List<Completion> toCompletions(Inspector response) {
+ List<Completion> completions = new ArrayList<>();
+ response.field("choices")
+ .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice)));
+ return completions;
+ }
+
+ private Completion toCompletion(Inspector choice) {
+ return new Completion(choice.field("text").asString(),
+ toFinishReason(choice.field("finish_reason").asString()));
+ }
+
+ private Completion.FinishReason toFinishReason(String finishReasonString) {
+ return switch(finishReasonString) {
+ case "length" -> Completion.FinishReason.length;
+ case "stop" -> Completion.FinishReason.stop;
+ default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'");
+ };
+ }
+
+ public static class Builder {
+
+ private final String token;
+ private String model = "text-davinci-003";
+ private double temperature = 0;
+
+ public Builder(String token) {
+ this.token = token;
+ }
+
+ /** One of the language models listed at https://platform.openai.com/docs/models */
+ public Builder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ /** A value between 0 and 2 - higher gives more random/creative output. */
+ public Builder temperature(double temperature) {
+ this.temperature = temperature;
+ return this;
+ }
+
+ public OpenAiClient build() {
+ return new OpenAiClient(this);
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java
new file mode 100644
index 00000000000..8b8b99308b0
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java
@@ -0,0 +1,11 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+@PublicApi
+package ai.vespa.llm.client.openai;
+
+import com.yahoo.api.annotations.PublicApi;
+import com.yahoo.osgi.annotation.ExportPackage;
+
+/**
+ * Client to OpenAi's large language models.
+ */ \ No newline at end of file
diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
new file mode 100644
index 00000000000..961a02afea3
--- /dev/null
+++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
@@ -0,0 +1,29 @@
+package ai.vespa.llm.client.openai;
+
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.StringPrompt;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+/**
+ * @author bratseth
+ */
+public class OpenAiClientCompletionTest {
+
+ @Test
+ @Disabled
+ public void testClient() {
+ var client = new OpenAiClient.Builder("your token here").build();
+ String input = "You are an unhelpful assistant who never answers questions straightforwardly. " +
+ "Be as long-winded as possible. Are humans smarter than cats?";
+ StringPrompt prompt = StringPrompt.from(input);
+ System.out.print(prompt);
+ for (int i = 0; i < 10; i++) {
+ var completion = client.complete(prompt).get(0);
+ System.out.print(completion.text());
+ if (completion.finishReason() == Completion.FinishReason.stop) break;
+ prompt = prompt.append(completion.text());
+ }
+ }
+
+}
diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
index 1c794c64d1a..26508228ab6 100644
--- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
+++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
@@ -1,8 +1,5 @@
package ai.vespa.llm.completion;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import ai.vespa.llm.completion.StringPrompt;
import ai.vespa.llm.test.MockLanguageModel;
import org.junit.jupiter.api.Test;