diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2023-06-13 23:40:39 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2023-06-13 23:40:39 +0200 |
commit | 3567995f6b857b677a6e7dbf82f952a3dfc388cd (patch) | |
tree | 72d68e16bf6ba449537a8604ffcbc3be5783341e | |
parent | 50d7555bfe7bdaec86f8b31c4d316c9ba66bb976 (diff) |
Get rid of third party openai client
-rw-r--r-- | openai-client/abi-spec.json | 29 | ||||
-rw-r--r-- | openai-client/pom.xml | 75 | ||||
-rw-r--r-- | openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java | 84 | ||||
-rw-r--r-- | pom.xml | 1 | ||||
-rw-r--r-- | vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java | 118 | ||||
-rw-r--r-- | vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java (renamed from openai-client/src/main/java/ai/vespa/llm/client/openai/package-info.java) | 0 | ||||
-rw-r--r-- | vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java | 29 | ||||
-rw-r--r-- | vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java | 3 |
8 files changed, 147 insertions, 192 deletions
diff --git a/openai-client/abi-spec.json b/openai-client/abi-spec.json deleted file mode 100644 index 039ca57fc64..00000000000 --- a/openai-client/abi-spec.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "ai.vespa.llm.client.openai.OpenAiClient$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void <init>(java.lang.String)", - "public ai.vespa.llm.client.openai.OpenAiClient$Builder model(java.lang.String)", - "public ai.vespa.llm.client.openai.OpenAiClient$Builder echo(boolean)", - "public ai.vespa.llm.client.openai.OpenAiClient build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.client.openai.OpenAiClient" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public" - ], - "methods" : [ - "public java.util.List complete(ai.vespa.llm.completion.Prompt)" - ], - "fields" : [ ] - } -}
\ No newline at end of file diff --git a/openai-client/pom.xml b/openai-client/pom.xml deleted file mode 100644 index 71a31a7b859..00000000000 --- a/openai-client/pom.xml +++ /dev/null @@ -1,75 +0,0 @@ -<?xml version="1.0" encoding="UTF-8"?> -<project xmlns="http://maven.apache.org/POM/4.0.0" - xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" - xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> - <modelVersion>4.0.0</modelVersion> - <parent> - <groupId>com.yahoo.vespa</groupId> - <artifactId>parent</artifactId> - <version>8-SNAPSHOT</version> - <relativePath>../parent/pom.xml</relativePath> - </parent> - <artifactId>openai-client</artifactId> - <packaging>container-plugin</packaging> - <version>8-SNAPSHOT</version> - - <properties> - <openai-gpt3.version>0.12.0</openai-gpt3.version> - </properties> - - <dependencies> - <dependency> - <groupId>com.theokanning.openai-gpt3-java</groupId> - <artifactId>service</artifactId> - <version>${openai-gpt3.version}</version> - </dependency> - <dependency> <!-- Missing dependency of openai-gpt3 --> - <groupId>com.squareup.retrofit2</groupId> - <artifactId>converter-jackson</artifactId> - <version>2.9.0</version> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>annotations</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <groupId>com.yahoo.vespa</groupId> - <artifactId>vespajlib</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - </dependencies> - - <build> - <plugins> - <plugin> - <groupId>com.yahoo.vespa</groupId> - <artifactId>bundle-plugin</artifactId> - <extensions>true</extensions> - <configuration> - <suppressWarningMissingImportPackages>true</suppressWarningMissingImportPackages> - </configuration> - </plugin> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <!-- openai-gpt3-java produces warnings --> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-rawtypes</arg> - <arg>-Xlint:-unchecked</arg> - <arg>-Xlint:-serial</arg> - </compilerArgs> - </configuration> - </plugin> - <plugin> - <groupId>com.yahoo.vespa</groupId> - <artifactId>abi-check-plugin</artifactId> - </plugin> - </plugins> - </build> - -</project> diff --git a/openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java deleted file mode 100644 index 66be5ff1f69..00000000000 --- a/openai-client/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.client.openai; - -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.completion.Prompt; -import com.theokanning.openai.OpenAiHttpException; -import com.theokanning.openai.completion.CompletionRequest; -import com.theokanning.openai.service.OpenAiService; -import com.yahoo.api.annotations.Beta; -import com.yahoo.yolean.Exceptions; - -import java.util.List; - -/** - * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. - * - * @author bratseth - */ -@Beta -public class OpenAiClient implements LanguageModel { - - private final OpenAiService openAiService; - private final String model; - private final boolean echo; - - private OpenAiClient(Builder builder) { - openAiService = new OpenAiService(builder.token); - this.model = builder.model; - this.echo = builder.echo; - } - - @Override - public List<Completion> complete(Prompt prompt) { - try { - CompletionRequest completionRequest = CompletionRequest.builder() - .prompt(prompt.asString()) - .model(model) - .echo(echo) - .build(); - return openAiService.createCompletion(completionRequest).getChoices().stream() - .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList(); - } - catch (OpenAiHttpException e) { - throw new RuntimeException(Exceptions.toMessageString(e)); - } - } - - private Completion.FinishReason toFinishReason(String finishReasonString) { - return switch(finishReasonString) { - case "length" -> Completion.FinishReason.length; - case "stop" -> Completion.FinishReason.stop; - default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); - }; - } - - public static class Builder { - - private final String token; - private String model = "text-davinci-003"; - private boolean echo = false; - - public Builder(String token) { - this.token = token; - } - - /** One of the language models listed at https://platform.openai.com/docs/models */ - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder echo(boolean echo) { - this.echo = echo; - return this; - } - - public OpenAiClient build() { - return new OpenAiClient(this); - } - - } - -} @@ -101,7 +101,6 @@ <module>model-integration</module> <module>node-repository</module> <module>node-admin</module> - <module>openai-client</module> <module>opennlp-linguistics</module> <module>orchestrator-restapi</module> <module>orchestrator</module> diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java new file mode 100644 index 00000000000..9145e76a2e0 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -0,0 +1,118 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.client.openai; + +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.Slime; +import com.yahoo.slime.SlimeUtils; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.List; + +/** + * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. + * Currently only completions are implemented. + * + * @author bratseth + */ +@Beta +public class OpenAiClient implements LanguageModel { + + private final String token; + private final String model; + private final double temperature; + private final HttpClient httpClient; + + private OpenAiClient(Builder builder) { + this.token = builder.token; + this.model = builder.model; + this.temperature = builder.temperature; + this.httpClient = HttpClient.newBuilder().build(); + } + + @Override + public List<Completion> complete(Prompt prompt) { + try { + HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray()); + var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); + if ( httpResponse.statusCode() != 200) + throw new IllegalArgumentException(SlimeUtils.toJson(response)); + return toCompletions(response); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { + var slime = new Slime(); + var root = slime.setObject(); + root.setString("model", model); + root.setDouble("temperature", temperature); + root.setString("prompt", prompt.asString()); + return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions")) + .header("Content-Type", "application/json") + .header("Authorization", "Bearer " + token) + .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) + .build(); + } + + private List<Completion> toCompletions(Inspector response) { + List<Completion> completions = new ArrayList<>(); + response.field("choices") + .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice))); + return completions; + } + + private Completion toCompletion(Inspector choice) { + return new Completion(choice.field("text").asString(), + toFinishReason(choice.field("finish_reason").asString())); + } + + private Completion.FinishReason toFinishReason(String finishReasonString) { + return switch(finishReasonString) { + case "length" -> Completion.FinishReason.length; + case "stop" -> Completion.FinishReason.stop; + default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); + }; + } + + public static class Builder { + + private final String token; + private String model = "text-davinci-003"; + private double temperature = 0; + + public Builder(String token) { + this.token = token; + } + + /** One of the language models listed at https://platform.openai.com/docs/models */ + public Builder model(String model) { + this.model = model; + return this; + } + + /** A value between 0 and 2 - higher gives more random/creative output. */ + public Builder temperature(double temperature) { + this.temperature = temperature; + return this; + } + + public OpenAiClient build() { + return new OpenAiClient(this); + } + + } + +} diff --git a/openai-client/src/main/java/ai/vespa/llm/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java index 8b8b99308b0..8b8b99308b0 100644 --- a/openai-client/src/main/java/ai/vespa/llm/client/openai/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java new file mode 100644 index 00000000000..961a02afea3 --- /dev/null +++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java @@ -0,0 +1,29 @@ +package ai.vespa.llm.client.openai; + +import ai.vespa.llm.completion.Completion; +import ai.vespa.llm.completion.StringPrompt; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +/** + * @author bratseth + */ +public class OpenAiClientCompletionTest { + + @Test + @Disabled + public void testClient() { + var client = new OpenAiClient.Builder("your token here").build(); + String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"; + StringPrompt prompt = StringPrompt.from(input); + System.out.print(prompt); + for (int i = 0; i < 10; i++) { + var completion = client.complete(prompt).get(0); + System.out.print(completion.text()); + if (completion.finishReason() == Completion.FinishReason.stop) break; + prompt = prompt.append(completion.text()); + } + } + +} diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java index 1c794c64d1a..26508228ab6 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java @@ -1,8 +1,5 @@ package ai.vespa.llm.completion; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import ai.vespa.llm.completion.StringPrompt; import ai.vespa.llm.test.MockLanguageModel; import org.junit.jupiter.api.Test; |