summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2023-06-13 23:40:39 +0200
committerJon Bratseth <bratseth@vespa.ai>2023-06-13 23:40:39 +0200
commit3567995f6b857b677a6e7dbf82f952a3dfc388cd (patch)
tree72d68e16bf6ba449537a8604ffcbc3be5783341e
parent50d7555bfe7bdaec86f8b31c4d316c9ba66bb976 (diff)
Get rid of third party openai client
-rw-r--r--openai-client/abi-spec.json29
-rw-r--r--openai-client/pom.xml75
-rw-r--r--openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java84
-rw-r--r--pom.xml1
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java118
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java (renamed from openai-client/src/main/java/ai/vespa/llm/client/openai/package-info.java)0
-rw-r--r--vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java29
-rw-r--r--vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java3
8 files changed, 147 insertions, 192 deletions
diff --git a/openai-client/abi-spec.json b/openai-client/abi-spec.json
deleted file mode 100644
index 039ca57fc64..00000000000
--- a/openai-client/abi-spec.json
+++ /dev/null
@@ -1,29 +0,0 @@
-{
- "ai.vespa.llm.client.openai.OpenAiClient$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [ ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(java.lang.String)",
- "public ai.vespa.llm.client.openai.OpenAiClient$Builder model(java.lang.String)",
- "public ai.vespa.llm.client.openai.OpenAiClient$Builder echo(boolean)",
- "public ai.vespa.llm.client.openai.OpenAiClient build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.client.openai.OpenAiClient" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public java.util.List complete(ai.vespa.llm.completion.Prompt)"
- ],
- "fields" : [ ]
- }
-} \ No newline at end of file
diff --git a/openai-client/pom.xml b/openai-client/pom.xml
deleted file mode 100644
index 71a31a7b859..00000000000
--- a/openai-client/pom.xml
+++ /dev/null
@@ -1,75 +0,0 @@
-<?xml version="1.0" encoding="UTF-8"?>
-<project xmlns="http://maven.apache.org/POM/4.0.0"
- xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
- xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
- <modelVersion>4.0.0</modelVersion>
- <parent>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>parent</artifactId>
- <version>8-SNAPSHOT</version>
- <relativePath>../parent/pom.xml</relativePath>
- </parent>
- <artifactId>openai-client</artifactId>
- <packaging>container-plugin</packaging>
- <version>8-SNAPSHOT</version>
-
- <properties>
- <openai-gpt3.version>0.12.0</openai-gpt3.version>
- </properties>
-
- <dependencies>
- <dependency>
- <groupId>com.theokanning.openai-gpt3-java</groupId>
- <artifactId>service</artifactId>
- <version>${openai-gpt3.version}</version>
- </dependency>
- <dependency> <!-- Missing dependency of openai-gpt3 -->
- <groupId>com.squareup.retrofit2</groupId>
- <artifactId>converter-jackson</artifactId>
- <version>2.9.0</version>
- </dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>annotations</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>vespajlib</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
- </dependencies>
-
- <build>
- <plugins>
- <plugin>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>bundle-plugin</artifactId>
- <extensions>true</extensions>
- <configuration>
- <suppressWarningMissingImportPackages>true</suppressWarningMissingImportPackages>
- </configuration>
- </plugin>
- <plugin>
- <groupId>org.apache.maven.plugins</groupId>
- <artifactId>maven-compiler-plugin</artifactId>
- <configuration>
- <!-- openai-gpt3-java produces warnings -->
- <compilerArgs>
- <arg>-Xlint:all</arg>
- <arg>-Xlint:-rawtypes</arg>
- <arg>-Xlint:-unchecked</arg>
- <arg>-Xlint:-serial</arg>
- </compilerArgs>
- </configuration>
- </plugin>
- <plugin>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>abi-check-plugin</artifactId>
- </plugin>
- </plugins>
- </build>
-
-</project>
diff --git a/openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
deleted file mode 100644
index 66be5ff1f69..00000000000
--- a/openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
+++ /dev/null
@@ -1,84 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.client.openai;
-
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.completion.Prompt;
-import com.theokanning.openai.OpenAiHttpException;
-import com.theokanning.openai.completion.CompletionRequest;
-import com.theokanning.openai.service.OpenAiService;
-import com.yahoo.api.annotations.Beta;
-import com.yahoo.yolean.Exceptions;
-
-import java.util.List;
-
-/**
- * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/.
- *
- * @author bratseth
- */
-@Beta
-public class OpenAiClient implements LanguageModel {
-
- private final OpenAiService openAiService;
- private final String model;
- private final boolean echo;
-
- private OpenAiClient(Builder builder) {
- openAiService = new OpenAiService(builder.token);
- this.model = builder.model;
- this.echo = builder.echo;
- }
-
- @Override
- public List<Completion> complete(Prompt prompt) {
- try {
- CompletionRequest completionRequest = CompletionRequest.builder()
- .prompt(prompt.asString())
- .model(model)
- .echo(echo)
- .build();
- return openAiService.createCompletion(completionRequest).getChoices().stream()
- .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList();
- }
- catch (OpenAiHttpException e) {
- throw new RuntimeException(Exceptions.toMessageString(e));
- }
- }
-
- private Completion.FinishReason toFinishReason(String finishReasonString) {
- return switch(finishReasonString) {
- case "length" -> Completion.FinishReason.length;
- case "stop" -> Completion.FinishReason.stop;
- default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'");
- };
- }
-
- public static class Builder {
-
- private final String token;
- private String model = "text-davinci-003";
- private boolean echo = false;
-
- public Builder(String token) {
- this.token = token;
- }
-
- /** One of the language models listed at https://platform.openai.com/docs/models */
- public Builder model(String model) {
- this.model = model;
- return this;
- }
-
- public Builder echo(boolean echo) {
- this.echo = echo;
- return this;
- }
-
- public OpenAiClient build() {
- return new OpenAiClient(this);
- }
-
- }
-
-}
diff --git a/pom.xml b/pom.xml
index a601e847e2b..7e54b97383e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -101,7 +101,6 @@
<module>model-integration</module>
<module>node-repository</module>
<module>node-admin</module>
- <module>openai-client</module>
<module>opennlp-linguistics</module>
<module>orchestrator-restapi</module>
<module>orchestrator</module>
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
new file mode 100644
index 00000000000..9145e76a2e0
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
@@ -0,0 +1,118 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.client.openai;
+
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.slime.ArrayTraverser;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.Slime;
+import com.yahoo.slime.SlimeUtils;
+
+import java.io.IOException;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.net.http.HttpClient;
+import java.net.http.HttpRequest;
+import java.net.http.HttpResponse;
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/.
+ * Currently only completions are implemented.
+ *
+ * @author bratseth
+ */
+@Beta
+public class OpenAiClient implements LanguageModel {
+
+ private final String token;
+ private final String model;
+ private final double temperature;
+ private final HttpClient httpClient;
+
+ private OpenAiClient(Builder builder) {
+ this.token = builder.token;
+ this.model = builder.model;
+ this.temperature = builder.temperature;
+ this.httpClient = HttpClient.newBuilder().build();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt) {
+ try {
+ HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray());
+ var response = SlimeUtils.jsonToSlime(httpResponse.body()).get();
+ if ( httpResponse.statusCode() != 200)
+ throw new IllegalArgumentException(SlimeUtils.toJson(response));
+ return toCompletions(response);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException {
+ var slime = new Slime();
+ var root = slime.setObject();
+ root.setString("model", model);
+ root.setDouble("temperature", temperature);
+ root.setString("prompt", prompt.asString());
+ return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions"))
+ .header("Content-Type", "application/json")
+ .header("Authorization", "Bearer " + token)
+ .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime)))
+ .build();
+ }
+
+ private List<Completion> toCompletions(Inspector response) {
+ List<Completion> completions = new ArrayList<>();
+ response.field("choices")
+ .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice)));
+ return completions;
+ }
+
+ private Completion toCompletion(Inspector choice) {
+ return new Completion(choice.field("text").asString(),
+ toFinishReason(choice.field("finish_reason").asString()));
+ }
+
+ private Completion.FinishReason toFinishReason(String finishReasonString) {
+ return switch(finishReasonString) {
+ case "length" -> Completion.FinishReason.length;
+ case "stop" -> Completion.FinishReason.stop;
+ default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'");
+ };
+ }
+
+ public static class Builder {
+
+ private final String token;
+ private String model = "text-davinci-003";
+ private double temperature = 0;
+
+ public Builder(String token) {
+ this.token = token;
+ }
+
+ /** One of the language models listed at https://platform.openai.com/docs/models */
+ public Builder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ /** A value between 0 and 2 - higher gives more random/creative output. */
+ public Builder temperature(double temperature) {
+ this.temperature = temperature;
+ return this;
+ }
+
+ public OpenAiClient build() {
+ return new OpenAiClient(this);
+ }
+
+ }
+
+}
diff --git a/openai-client/src/main/java/ai/vespa/llm/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java
index 8b8b99308b0..8b8b99308b0 100644
--- a/openai-client/src/main/java/ai/vespa/llm/client/openai/package-info.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java
diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
new file mode 100644
index 00000000000..961a02afea3
--- /dev/null
+++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java
@@ -0,0 +1,29 @@
+package ai.vespa.llm.client.openai;
+
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.StringPrompt;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+/**
+ * @author bratseth
+ */
+public class OpenAiClientCompletionTest {
+
+ @Test
+ @Disabled
+ public void testClient() {
+ var client = new OpenAiClient.Builder("your token here").build();
+ String input = "You are an unhelpful assistant who never answers questions straightforwardly. " +
+ "Be as long-winded as possible. Are humans smarter than cats?";
+ StringPrompt prompt = StringPrompt.from(input);
+ System.out.print(prompt);
+ for (int i = 0; i < 10; i++) {
+ var completion = client.complete(prompt).get(0);
+ System.out.print(completion.text());
+ if (completion.finishReason() == Completion.FinishReason.stop) break;
+ prompt = prompt.append(completion.text());
+ }
+ }
+
+}
diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
index 1c794c64d1a..26508228ab6 100644
--- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
+++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java
@@ -1,8 +1,5 @@
package ai.vespa.llm.completion;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import ai.vespa.llm.completion.StringPrompt;
import ai.vespa.llm.test.MockLanguageModel;
import org.junit.jupiter.api.Test;