aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-04-19 14:46:47 +0200
committerGitHub <noreply@github.com>2023-04-19 14:46:47 +0200
commit7406547ddead0bf97c30a912b7873d2c5fbd9e3a (patch)
treef7a9cb49c67b725a4a58909d0e770e6936fc7f75
parentdd2f89843c95423c243ee858ce07c8f9554bab54 (diff)
parentbb387e070c81bb45cb31f47c7962bdf885ca522b (diff)
Merge pull request #26777 from vespa-engine/bratseth/openai-client
Llm completion abstraction and OpenAi implementation
-rw-r--r--cloud-tenant-base-dependencies-enforcer/pom.xml15
-rw-r--r--fat-model-dependencies/pom.xml4
-rw-r--r--model-integration/pom.xml17
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Completion.java41
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Generator.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/LanguageModel.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/Prompt.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/StringPrompt.java43
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java84
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java44
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/CompletionTest.java37
-rw-r--r--parent/pom.xml12
-rw-r--r--vespa-dependencies-enforcer/allowed-maven-dependencies.txt10
14 files changed, 352 insertions, 1 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml
index 4d5d801e0e3..69e1a94a813 100644
--- a/cloud-tenant-base-dependencies-enforcer/pom.xml
+++ b/cloud-tenant-base-dependencies-enforcer/pom.xml
@@ -30,11 +30,12 @@
<httpcore.version>4.4.16</httpcore.version>
<junit5.version>5.8.1</junit5.version> <!-- TODO: in parent this is named 'junit.version' -->
<onnxruntime.version>1.13.1</onnxruntime.version>
+ <openai-gpt3.version>0.12.0</openai-gpt3.version>
<!-- END parent/pom.xml -->
<!-- ALL BELOW MUST BE KEPT IN SYNC WITH container-dependency-versions pom
- Copied here because vz-tenant-base does not have a parent. -->
+ Copied here because cloud-tenant-base does not have a parent. -->
<aopalliance.version>1.0</aopalliance.version>
<guava.version>27.1-jre</guava.version>
<guice.version>4.2.3</guice.version>
@@ -234,6 +235,18 @@
<include>org.osgi:org.osgi.compendium:4.1.0:test</include>
<include>org.osgi:org.osgi.core:4.1.0:test</include>
<include>xerces:xercesImpl:2.12.2:test</include>
+
+ <include>com.squareup.okhttp3:okhttp:3.14.9:test</include>
+ <include>com.squareup.okio:okio:1.17.2:test</include>
+ <include>com.squareup.retrofit2:adapter-rxjava2:2.9.0:test</include>
+ <include>com.squareup.retrofit2:converter-jackson:2.9.0:test</include>
+ <include>com.squareup.retrofit2:retrofit:2.9.0:test</include>
+ <include>com.theokanning.openai-gpt3-java:api:${openai-gpt3.version}:test</include>
+ <include>com.theokanning.openai-gpt3-java:client:${openai-gpt3.version}:test</include>
+ <include>com.theokanning.openai-gpt3-java:service:${openai-gpt3.version}:test</include>
+ <include>io.reactivex.rxjava2:rxjava:2.0.0:test</include>
+ <include>org.reactivestreams:reactive-streams:1.0.3:test</include>
+
</allowed>
</enforceDependencies>
</rules>
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index b58533b32e9..0ae2c68e6a8 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -96,6 +96,10 @@
<groupId>ai.djl.huggingface</groupId>
<artifactId>tokenizers</artifactId>
</exclusion>
+ <exclusion>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ </exclusion>
</exclusions>
</dependency>
<dependency>
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index c27ed9d2c31..c96441f11a7 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -111,6 +111,11 @@
</dependency>
<dependency>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
@@ -146,6 +151,18 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
+ <configuration>
+ <!--
+ openai-gpt3-java depends on a different Jackson version than the one we provide,
+ which leads to warnings, so we must disable error on warnings.
+ -->
+ <compilerArgs>
+ <arg>-Xlint:all</arg>
+ <arg>-Xlint:-rawtypes</arg>
+ <arg>-Xlint:-unchecked</arg>
+ <arg>-Xlint:-serial</arg>
+ </compilerArgs>
+ </configuration>
</plugin>
<plugin>
<groupId>com.github.os72</groupId>
diff --git a/model-integration/src/main/java/ai/vespa/llm/Completion.java b/model-integration/src/main/java/ai/vespa/llm/Completion.java
new file mode 100644
index 00000000000..5f483a65186
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/Completion.java
@@ -0,0 +1,41 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.Objects;
+
+/**
+ * A completion from a language model.
+ *
+ * @author bratseth
+ */
+@Beta
+public record Completion(String text, FinishReason finishReason) {
+
+ public enum FinishReason {
+
+ /** The maximum length of a completion was reached. */
+ length,
+
+ /** The completion is the predicted ending of the prompt. */
+ stop
+
+ }
+
+ public Completion(String text, FinishReason finishReason) {
+ this.text = Objects.requireNonNull(text);
+ this.finishReason = Objects.requireNonNull(finishReason);
+ }
+
+ /** Returns the generated text completion. */
+ public String text() { return text; }
+
+ /** Returns the reason this completion ended. */
+ public FinishReason finishReason() { return finishReason; }
+
+ public static Completion from(String text) {
+ return new Completion(text, FinishReason.stop);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java
index 973b5ac2899..6b60041947b 100644
--- a/model-integration/src/main/java/ai/vespa/llm/Generator.java
+++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.api.annotations.Beta;
import java.util.ArrayList;
import java.util.List;
@@ -27,6 +28,7 @@ import java.util.Map;
*
* @author lesters
*/
+@Beta
public class Generator extends AbstractComponent {
private final static int TOKEN_EOS = 1; // end of sequence
diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
index 743bb7c2f27..8b490a733dd 100644
--- a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
+++ b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java
@@ -1,5 +1,8 @@
package ai.vespa.llm;
+import com.yahoo.api.annotations.Beta;
+
+@Beta
public class GeneratorOptions {
public enum SearchMethod {
diff --git a/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java
new file mode 100644
index 00000000000..0739162c5ee
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java
@@ -0,0 +1,18 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.List;
+
+/**
+ * Interface to language models.
+ *
+ * @author bratseth
+ */
+@Beta
+public interface LanguageModel {
+
+ List<Completion> complete(Prompt prompt);
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/Prompt.java b/model-integration/src/main/java/ai/vespa/llm/Prompt.java
new file mode 100644
index 00000000000..77093d5e21b
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/Prompt.java
@@ -0,0 +1,23 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+/**
+ * A prompt that can be given to a large language model to generate a completion.
+ *
+ * @author bratseth
+ */
+@Beta
+public abstract class Prompt {
+
+ public abstract String asString();
+
+ /** Returns a new prompt with the text of the given completion appended. */
+ public Prompt append(Completion completion) {
+ return append(completion.text());
+ }
+
+ public abstract Prompt append(String text);
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java
new file mode 100644
index 00000000000..0af8388dfb1
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java
@@ -0,0 +1,43 @@
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.Objects;
+
+/**
+ * A prompt which just consists of a string.
+ *
+ * @author bratseth
+ */
+@Beta
+public class StringPrompt extends Prompt {
+
+ private final String string;
+
+ private StringPrompt(String string) {
+ this.string = Objects.requireNonNull(string);
+ }
+
+ @Override
+ public String asString() { return string; }
+
+ @Override
+ public StringPrompt append(String text) {
+ return StringPrompt.from(string + text);
+ }
+
+ @Override
+ public StringPrompt append(Completion completion) {
+ return append(completion.text());
+ }
+
+ @Override
+ public String toString() {
+ return string;
+ }
+
+ public static StringPrompt from(String string) {
+ return new StringPrompt(string);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java
new file mode 100644
index 00000000000..3f4475b2482
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java
@@ -0,0 +1,84 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.client;
+
+import ai.vespa.llm.Completion;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.Prompt;
+import com.theokanning.openai.OpenAiHttpException;
+import com.theokanning.openai.completion.CompletionRequest;
+import com.theokanning.openai.service.OpenAiService;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.yolean.Exceptions;
+
+import java.util.List;
+
+/**
+ * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/.
+ *
+ * @author bratseth
+ */
+@Beta
+public class OpenAiClient implements LanguageModel {
+
+ private final OpenAiService openAiService;
+ private final String model;
+ private final boolean echo;
+
+ private OpenAiClient(Builder builder) {
+ openAiService = new OpenAiService(builder.token);
+ this.model = builder.model;
+ this.echo = builder.echo;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt) {
+ try {
+ CompletionRequest completionRequest = CompletionRequest.builder()
+ .prompt(prompt.asString())
+ .model(model)
+ .echo(echo)
+ .build();
+ return openAiService.createCompletion(completionRequest).getChoices().stream()
+ .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList();
+ }
+ catch (OpenAiHttpException e) {
+ throw new RuntimeException(Exceptions.toMessageString(e));
+ }
+ }
+
+ private Completion.FinishReason toFinishReason(String finishReasonString) {
+ return switch(finishReasonString) {
+ case "length" -> Completion.FinishReason.length;
+ case "stop" -> Completion.FinishReason.stop;
+ default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'");
+ };
+ }
+
+ public static class Builder {
+
+ private final String token;
+ private String model = "text-davinci-003";
+ private boolean echo = false;
+
+ public Builder(String token) {
+ this.token = token;
+ }
+
+ /** One of the language models listed at https://platform.openai.com/docs/models */
+ public Builder model(String model) {
+ this.model = model;
+ return this;
+ }
+
+ public Builder echo(boolean echo) {
+ this.echo = echo;
+ return this;
+ }
+
+ public OpenAiClient build() {
+ return new OpenAiClient(this);
+ }
+
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
new file mode 100644
index 00000000000..54b085a451c
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
@@ -0,0 +1,44 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.test;
+
+import ai.vespa.llm.Completion;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.Prompt;
+import com.yahoo.api.annotations.Beta;
+
+import java.util.List;
+import java.util.function.Function;
+
+/**
+ * @author bratseth
+ */
+@Beta
+public class MockLanguageModel implements LanguageModel {
+
+ private final Function<Prompt, List<Completion>> completer;
+
+ public MockLanguageModel(Builder builder) {
+ completer = builder.completer;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt) {
+ return completer.apply(prompt);
+ }
+
+ public static class Builder {
+
+ private Function<Prompt, List<Completion>> completer = prompt -> List.of(Completion.from(""));
+
+ public Builder completer(Function<Prompt, List<Completion>> completer) {
+ this.completer = completer;
+ return this;
+ }
+
+ public Builder() {}
+
+ public MockLanguageModel build() { return new MockLanguageModel(this); }
+
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
new file mode 100644
index 00000000000..30b1c8c2fb1
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java
@@ -0,0 +1,37 @@
+package ai.vespa.llm;
+
+import ai.vespa.llm.test.MockLanguageModel;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+import java.util.function.Function;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+/**
+ * Tests completion with a mock completer.
+ *
+ * @author bratseth
+ */
+public class CompletionTest {
+
+ @Test
+ public void testCompletion() {
+ Function<Prompt, List<Completion>> completer = in ->
+ switch (in.asString()) {
+ case "Complete this: " -> List.of(Completion.from("The completion"));
+ default -> throw new RuntimeException("Cannot complete '" + in + "'");
+ };
+ var llm = new MockLanguageModel.Builder().completer(completer).build();
+
+ String input = "Complete this: ";
+ StringPrompt prompt = StringPrompt.from(input);
+ for (int i = 0; i < 10; i++) {
+ var completion = llm.complete(prompt).get(0);
+ prompt = prompt.append(completion);
+ if (completion.finishReason() == Completion.FinishReason.stop) break;
+ }
+ assertEquals("Complete this: The completion", prompt.asString());
+ }
+
+}
diff --git a/parent/pom.xml b/parent/pom.xml
index ffd8c596277..76f4ef30dda 100644
--- a/parent/pom.xml
+++ b/parent/pom.xml
@@ -565,6 +565,17 @@
<version>${onnxruntime.version}</version>
</dependency>
<dependency>
+ <groupId>com.theokanning.openai-gpt3-java</groupId>
+ <artifactId>service</artifactId>
+ <version>${openai-gpt3.version}</version>
+ <exclusions>
+ <exclusion>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
<groupId>com.yahoo.athenz</groupId>
<artifactId>athenz-cert-refresher</artifactId>
<version>${athenz.version}</version>
@@ -1171,6 +1182,7 @@
<netty.version>4.1.86.Final</netty.version>
<netty-tcnative.version>2.0.54.Final</netty-tcnative.version>
<onnxruntime.version>1.13.1</onnxruntime.version> <!-- WARNING: sync cloud-tenant-base-dependencies-enforcer/pom.xml -->
+ <openai-gpt3.version>0.12.0</openai-gpt3.version>
<org.json.version>20230227</org.json.version>
<org.lz4.version>1.8.0</org.lz4.version>
<prometheus.client.version>0.6.0</prometheus.client.version>
diff --git a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
index b5841d1c9e4..0d007097fa2 100644
--- a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
+++ b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt
@@ -34,12 +34,20 @@ com.google.protobuf:protobuf-java:3.21.7
com.ibm.icu:icu4j:70.1
com.intellij:annotations:9.0.4
com.microsoft.onnxruntime:onnxruntime:1.13.1
+com.squareup.okhttp3:okhttp:3.14.9
+com.squareup.okio:okio:1.17.2
+com.squareup.retrofit2:adapter-rxjava2:2.9.0
+com.squareup.retrofit2:converter-jackson:2.9.0
+com.squareup.retrofit2:retrofit:2.9.0
com.sun.activation:javax.activation:1.2.0
com.sun.istack:istack-commons-runtime:3.0.8
com.sun.xml.bind:jaxb-core:2.3.0
com.sun.xml.bind:jaxb-impl:2.3.0
com.sun.xml.fastinfoset:FastInfoset:1.2.16
com.thaiopensource:jing:20091111
+com.theokanning.openai-gpt3-java:api:0.12.0
+com.theokanning.openai-gpt3-java:client:0.12.0
+com.theokanning.openai-gpt3-java:service:0.12.0
com.yahoo.athenz:athenz-auth-core:1.10.54
com.yahoo.athenz:athenz-client-common:1.10.54
com.yahoo.athenz:athenz-zms-core:1.10.54
@@ -69,6 +77,7 @@ io.netty:netty-transport-native-epoll:4.1.86.Final
io.netty:netty-transport-native-unix-common:4.1.86.Final
io.prometheus:simpleclient:0.6.0
io.prometheus:simpleclient_common:0.6.0
+io.reactivex.rxjava2:rxjava:2.0.0
javax.annotation:javax.annotation-api:1.2
javax.inject:javax.inject:1
javax.servlet:javax.servlet-api:3.1.0
@@ -201,6 +210,7 @@ org.ow2.asm:asm-commons:9.3
org.ow2.asm:asm-tree:9.3
org.ow2.asm:asm-util:9.3
org.questdb:questdb:6.2
+org.reactivestreams:reactive-streams:1.0.3
org.slf4j:jcl-over-slf4j:1.7.32
org.slf4j:log4j-over-slf4j:1.7.32
org.slf4j:slf4j-api:1.7.32