diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-04-19 14:46:47 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-04-19 14:46:47 +0200 |
commit | 7406547ddead0bf97c30a912b7873d2c5fbd9e3a (patch) | |
tree | f7a9cb49c67b725a4a58909d0e770e6936fc7f75 | |
parent | dd2f89843c95423c243ee858ce07c8f9554bab54 (diff) | |
parent | bb387e070c81bb45cb31f47c7962bdf885ca522b (diff) |
Merge pull request #26777 from vespa-engine/bratseth/openai-client
Llm completion abstraction and OpenAi implementation
14 files changed, 352 insertions, 1 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml index 4d5d801e0e3..69e1a94a813 100644 --- a/cloud-tenant-base-dependencies-enforcer/pom.xml +++ b/cloud-tenant-base-dependencies-enforcer/pom.xml @@ -30,11 +30,12 @@ <httpcore.version>4.4.16</httpcore.version> <junit5.version>5.8.1</junit5.version> <!-- TODO: in parent this is named 'junit.version' --> <onnxruntime.version>1.13.1</onnxruntime.version> + <openai-gpt3.version>0.12.0</openai-gpt3.version> <!-- END parent/pom.xml --> <!-- ALL BELOW MUST BE KEPT IN SYNC WITH container-dependency-versions pom - Copied here because vz-tenant-base does not have a parent. --> + Copied here because cloud-tenant-base does not have a parent. --> <aopalliance.version>1.0</aopalliance.version> <guava.version>27.1-jre</guava.version> <guice.version>4.2.3</guice.version> @@ -234,6 +235,18 @@ <include>org.osgi:org.osgi.compendium:4.1.0:test</include> <include>org.osgi:org.osgi.core:4.1.0:test</include> <include>xerces:xercesImpl:2.12.2:test</include> + + <include>com.squareup.okhttp3:okhttp:3.14.9:test</include> + <include>com.squareup.okio:okio:1.17.2:test</include> + <include>com.squareup.retrofit2:adapter-rxjava2:2.9.0:test</include> + <include>com.squareup.retrofit2:converter-jackson:2.9.0:test</include> + <include>com.squareup.retrofit2:retrofit:2.9.0:test</include> + <include>com.theokanning.openai-gpt3-java:api:${openai-gpt3.version}:test</include> + <include>com.theokanning.openai-gpt3-java:client:${openai-gpt3.version}:test</include> + <include>com.theokanning.openai-gpt3-java:service:${openai-gpt3.version}:test</include> + <include>io.reactivex.rxjava2:rxjava:2.0.0:test</include> + <include>org.reactivestreams:reactive-streams:1.0.3:test</include> + </allowed> </enforceDependencies> </rules> diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml index b58533b32e9..0ae2c68e6a8 100644 --- a/fat-model-dependencies/pom.xml +++ b/fat-model-dependencies/pom.xml @@ -96,6 +96,10 @@ <groupId>ai.djl.huggingface</groupId> <artifactId>tokenizers</artifactId> </exclusion> + <exclusion> + <groupId>com.theokanning.openai-gpt3-java</groupId> + <artifactId>service</artifactId> + </exclusion> </exclusions> </dependency> <dependency> diff --git a/model-integration/pom.xml b/model-integration/pom.xml index c27ed9d2c31..c96441f11a7 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -111,6 +111,11 @@ </dependency> <dependency> + <groupId>com.theokanning.openai-gpt3-java</groupId> + <artifactId>service</artifactId> + </dependency> + + <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> @@ -146,6 +151,18 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> + <configuration> + <!-- + openai-gpt3-java depends on a different Jackson version than the one we provide, + which leads to warnings, so we must disable error on warnings. + --> + <compilerArgs> + <arg>-Xlint:all</arg> + <arg>-Xlint:-rawtypes</arg> + <arg>-Xlint:-unchecked</arg> + <arg>-Xlint:-serial</arg> + </compilerArgs> + </configuration> </plugin> <plugin> <groupId>com.github.os72</groupId> diff --git a/model-integration/src/main/java/ai/vespa/llm/Completion.java b/model-integration/src/main/java/ai/vespa/llm/Completion.java new file mode 100644 index 00000000000..5f483a65186 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/Completion.java @@ -0,0 +1,41 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; + +/** + * A completion from a language model. + * + * @author bratseth + */ +@Beta +public record Completion(String text, FinishReason finishReason) { + + public enum FinishReason { + + /** The maximum length of a completion was reached. */ + length, + + /** The completion is the predicted ending of the prompt. */ + stop + + } + + public Completion(String text, FinishReason finishReason) { + this.text = Objects.requireNonNull(text); + this.finishReason = Objects.requireNonNull(finishReason); + } + + /** Returns the generated text completion. */ + public String text() { return text; } + + /** Returns the reason this completion ended. */ + public FinishReason finishReason() { return finishReason; } + + public static Completion from(String text) { + return new Completion(text, FinishReason.stop); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java index 973b5ac2899..6b60041947b 100644 --- a/model-integration/src/main/java/ai/vespa/llm/Generator.java +++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.api.annotations.Beta; import java.util.ArrayList; import java.util.List; @@ -27,6 +28,7 @@ import java.util.Map; * * @author lesters */ +@Beta public class Generator extends AbstractComponent { private final static int TOKEN_EOS = 1; // end of sequence diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java index 743bb7c2f27..8b490a733dd 100644 --- a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java +++ b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java @@ -1,5 +1,8 @@ package ai.vespa.llm; +import com.yahoo.api.annotations.Beta; + +@Beta public class GeneratorOptions { public enum SearchMethod { diff --git a/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java new file mode 100644 index 00000000000..0739162c5ee --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java @@ -0,0 +1,18 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.List; + +/** + * Interface to language models. + * + * @author bratseth + */ +@Beta +public interface LanguageModel { + + List<Completion> complete(Prompt prompt); + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/Prompt.java b/model-integration/src/main/java/ai/vespa/llm/Prompt.java new file mode 100644 index 00000000000..77093d5e21b --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/Prompt.java @@ -0,0 +1,23 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +/** + * A prompt that can be given to a large language model to generate a completion. + * + * @author bratseth + */ +@Beta +public abstract class Prompt { + + public abstract String asString(); + + /** Returns a new prompt with the text of the given completion appended. */ + public Prompt append(Completion completion) { + return append(completion.text()); + } + + public abstract Prompt append(String text); + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java new file mode 100644 index 00000000000..0af8388dfb1 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java @@ -0,0 +1,43 @@ +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; + +/** + * A prompt which just consists of a string. + * + * @author bratseth + */ +@Beta +public class StringPrompt extends Prompt { + + private final String string; + + private StringPrompt(String string) { + this.string = Objects.requireNonNull(string); + } + + @Override + public String asString() { return string; } + + @Override + public StringPrompt append(String text) { + return StringPrompt.from(string + text); + } + + @Override + public StringPrompt append(Completion completion) { + return append(completion.text()); + } + + @Override + public String toString() { + return string; + } + + public static StringPrompt from(String string) { + return new StringPrompt(string); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java new file mode 100644 index 00000000000..3f4475b2482 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java @@ -0,0 +1,84 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.client; + +import ai.vespa.llm.Completion; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.Prompt; +import com.theokanning.openai.OpenAiHttpException; +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.service.OpenAiService; +import com.yahoo.api.annotations.Beta; +import com.yahoo.yolean.Exceptions; + +import java.util.List; + +/** + * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. + * + * @author bratseth + */ +@Beta +public class OpenAiClient implements LanguageModel { + + private final OpenAiService openAiService; + private final String model; + private final boolean echo; + + private OpenAiClient(Builder builder) { + openAiService = new OpenAiService(builder.token); + this.model = builder.model; + this.echo = builder.echo; + } + + @Override + public List<Completion> complete(Prompt prompt) { + try { + CompletionRequest completionRequest = CompletionRequest.builder() + .prompt(prompt.asString()) + .model(model) + .echo(echo) + .build(); + return openAiService.createCompletion(completionRequest).getChoices().stream() + .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList(); + } + catch (OpenAiHttpException e) { + throw new RuntimeException(Exceptions.toMessageString(e)); + } + } + + private Completion.FinishReason toFinishReason(String finishReasonString) { + return switch(finishReasonString) { + case "length" -> Completion.FinishReason.length; + case "stop" -> Completion.FinishReason.stop; + default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); + }; + } + + public static class Builder { + + private final String token; + private String model = "text-davinci-003"; + private boolean echo = false; + + public Builder(String token) { + this.token = token; + } + + /** One of the language models listed at https://platform.openai.com/docs/models */ + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder echo(boolean echo) { + this.echo = echo; + return this; + } + + public OpenAiClient build() { + return new OpenAiClient(this); + } + + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java new file mode 100644 index 00000000000..54b085a451c --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java @@ -0,0 +1,44 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.test; + +import ai.vespa.llm.Completion; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.Prompt; +import com.yahoo.api.annotations.Beta; + +import java.util.List; +import java.util.function.Function; + +/** + * @author bratseth + */ +@Beta +public class MockLanguageModel implements LanguageModel { + + private final Function<Prompt, List<Completion>> completer; + + public MockLanguageModel(Builder builder) { + completer = builder.completer; + } + + @Override + public List<Completion> complete(Prompt prompt) { + return completer.apply(prompt); + } + + public static class Builder { + + private Function<Prompt, List<Completion>> completer = prompt -> List.of(Completion.from("")); + + public Builder completer(Function<Prompt, List<Completion>> completer) { + this.completer = completer; + return this; + } + + public Builder() {} + + public MockLanguageModel build() { return new MockLanguageModel(this); } + + } + +} diff --git a/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java new file mode 100644 index 00000000000..30b1c8c2fb1 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java @@ -0,0 +1,37 @@ +package ai.vespa.llm; + +import ai.vespa.llm.test.MockLanguageModel; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.function.Function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Tests completion with a mock completer. + * + * @author bratseth + */ +public class CompletionTest { + + @Test + public void testCompletion() { + Function<Prompt, List<Completion>> completer = in -> + switch (in.asString()) { + case "Complete this: " -> List.of(Completion.from("The completion")); + default -> throw new RuntimeException("Cannot complete '" + in + "'"); + }; + var llm = new MockLanguageModel.Builder().completer(completer).build(); + + String input = "Complete this: "; + StringPrompt prompt = StringPrompt.from(input); + for (int i = 0; i < 10; i++) { + var completion = llm.complete(prompt).get(0); + prompt = prompt.append(completion); + if (completion.finishReason() == Completion.FinishReason.stop) break; + } + assertEquals("Complete this: The completion", prompt.asString()); + } + +} diff --git a/parent/pom.xml b/parent/pom.xml index ffd8c596277..76f4ef30dda 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -565,6 +565,17 @@ <version>${onnxruntime.version}</version> </dependency> <dependency> + <groupId>com.theokanning.openai-gpt3-java</groupId> + <artifactId>service</artifactId> + <version>${openai-gpt3.version}</version> + <exclusions> + <exclusion> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + </exclusion> + </exclusions> + </dependency> + <dependency> <groupId>com.yahoo.athenz</groupId> <artifactId>athenz-cert-refresher</artifactId> <version>${athenz.version}</version> @@ -1171,6 +1182,7 @@ <netty.version>4.1.86.Final</netty.version> <netty-tcnative.version>2.0.54.Final</netty-tcnative.version> <onnxruntime.version>1.13.1</onnxruntime.version> <!-- WARNING: sync cloud-tenant-base-dependencies-enforcer/pom.xml --> + <openai-gpt3.version>0.12.0</openai-gpt3.version> <org.json.version>20230227</org.json.version> <org.lz4.version>1.8.0</org.lz4.version> <prometheus.client.version>0.6.0</prometheus.client.version> diff --git a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt index b5841d1c9e4..0d007097fa2 100644 --- a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt +++ b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt @@ -34,12 +34,20 @@ com.google.protobuf:protobuf-java:3.21.7 com.ibm.icu:icu4j:70.1 com.intellij:annotations:9.0.4 com.microsoft.onnxruntime:onnxruntime:1.13.1 +com.squareup.okhttp3:okhttp:3.14.9 +com.squareup.okio:okio:1.17.2 +com.squareup.retrofit2:adapter-rxjava2:2.9.0 +com.squareup.retrofit2:converter-jackson:2.9.0 +com.squareup.retrofit2:retrofit:2.9.0 com.sun.activation:javax.activation:1.2.0 com.sun.istack:istack-commons-runtime:3.0.8 com.sun.xml.bind:jaxb-core:2.3.0 com.sun.xml.bind:jaxb-impl:2.3.0 com.sun.xml.fastinfoset:FastInfoset:1.2.16 com.thaiopensource:jing:20091111 +com.theokanning.openai-gpt3-java:api:0.12.0 +com.theokanning.openai-gpt3-java:client:0.12.0 +com.theokanning.openai-gpt3-java:service:0.12.0 com.yahoo.athenz:athenz-auth-core:1.10.54 com.yahoo.athenz:athenz-client-common:1.10.54 com.yahoo.athenz:athenz-zms-core:1.10.54 @@ -69,6 +77,7 @@ io.netty:netty-transport-native-epoll:4.1.86.Final io.netty:netty-transport-native-unix-common:4.1.86.Final io.prometheus:simpleclient:0.6.0 io.prometheus:simpleclient_common:0.6.0 +io.reactivex.rxjava2:rxjava:2.0.0 javax.annotation:javax.annotation-api:1.2 javax.inject:javax.inject:1 javax.servlet:javax.servlet-api:3.1.0 @@ -201,6 +210,7 @@ org.ow2.asm:asm-commons:9.3 org.ow2.asm:asm-tree:9.3 org.ow2.asm:asm-util:9.3 org.questdb:questdb:6.2 +org.reactivestreams:reactive-streams:1.0.3 org.slf4j:jcl-over-slf4j:1.7.32 org.slf4j:log4j-over-slf4j:1.7.32 org.slf4j:slf4j-api:1.7.32 |