diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2023-04-25 20:06:15 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2023-04-25 20:06:15 +0200 |
commit | b1c12e25e9698501440b46bca37ada23c5116239 (patch) | |
tree | b620b6da03bd81f8805733f626ebf375a5a25357 /model-integration | |
parent | d5f17d23f377776e85aa687be17b211b54423c59 (diff) |
Put the openai client in a separate component
Diffstat (limited to 'model-integration')
15 files changed, 20 insertions, 482 deletions
diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json index 0c76ba38660..d3c472778e6 100644 --- a/model-integration/abi-spec.json +++ b/model-integration/abi-spec.json @@ -1,41 +1,5 @@ { - "ai.vespa.llm.Completion$FinishReason" : { - "superClass" : "java.lang.Enum", - "interfaces" : [ ], - "attributes" : [ - "public", - "final", - "enum" - ], - "methods" : [ - "public static ai.vespa.llm.Completion$FinishReason[] values()", - "public static ai.vespa.llm.Completion$FinishReason valueOf(java.lang.String)" - ], - "fields" : [ - "public static final enum ai.vespa.llm.Completion$FinishReason length", - "public static final enum ai.vespa.llm.Completion$FinishReason stop" - ] - }, - "ai.vespa.llm.Completion" : { - "superClass" : "java.lang.Record", - "interfaces" : [ ], - "attributes" : [ - "public", - "final", - "record" - ], - "methods" : [ - "public void <init>(java.lang.String, ai.vespa.llm.Completion$FinishReason)", - "public java.lang.String text()", - "public ai.vespa.llm.Completion$FinishReason finishReason()", - "public static ai.vespa.llm.Completion from(java.lang.String)", - "public final java.lang.String toString()", - "public final int hashCode()", - "public final boolean equals(java.lang.Object)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.Generator" : { + "ai.vespa.llm.generation.Generator" : { "superClass" : "com.yahoo.component.AbstractComponent", "interfaces" : [ ], "attributes" : [ @@ -43,13 +7,13 @@ ], "methods" : [ "public void <init>(ai.vespa.modelintegration.evaluator.OnnxRuntime, com.yahoo.llm.GeneratorConfig)", - "public java.lang.String generate(java.lang.String, ai.vespa.llm.GeneratorOptions)", + "public java.lang.String generate(java.lang.String, ai.vespa.llm.generation.GeneratorOptions)", "public java.lang.String generate(java.lang.String)", "public void deconstruct()" ], "fields" : [ ] }, - "ai.vespa.llm.GeneratorOptions$SearchMethod" : { + "ai.vespa.llm.generation.GeneratorOptions$SearchMethod" : { "superClass" : "java.lang.Enum", "interfaces" : [ ], "attributes" : [ @@ -58,17 +22,17 @@ "enum" ], "methods" : [ - "public static ai.vespa.llm.GeneratorOptions$SearchMethod[] values()", - "public static ai.vespa.llm.GeneratorOptions$SearchMethod valueOf(java.lang.String)" + "public static ai.vespa.llm.generation.GeneratorOptions$SearchMethod[] values()", + "public static ai.vespa.llm.generation.GeneratorOptions$SearchMethod valueOf(java.lang.String)" ], "fields" : [ - "public static final enum ai.vespa.llm.GeneratorOptions$SearchMethod GREEDY", - "public static final enum ai.vespa.llm.GeneratorOptions$SearchMethod CONTRASTIVE", - "public static final enum ai.vespa.llm.GeneratorOptions$SearchMethod BEAM", - "public static final enum ai.vespa.llm.GeneratorOptions$SearchMethod SAMPLE" + "public static final enum ai.vespa.llm.generation.GeneratorOptions$SearchMethod GREEDY", + "public static final enum ai.vespa.llm.generation.GeneratorOptions$SearchMethod CONTRASTIVE", + "public static final enum ai.vespa.llm.generation.GeneratorOptions$SearchMethod BEAM", + "public static final enum ai.vespa.llm.generation.GeneratorOptions$SearchMethod SAMPLE" ] }, - "ai.vespa.llm.GeneratorOptions" : { + "ai.vespa.llm.generation.GeneratorOptions" : { "superClass" : "java.lang.Object", "interfaces" : [ ], "attributes" : [ @@ -76,109 +40,10 @@ ], "methods" : [ "public void <init>()", - "public ai.vespa.llm.GeneratorOptions$SearchMethod getSearchMethod()", - "public ai.vespa.llm.GeneratorOptions setSearchMethod(ai.vespa.llm.GeneratorOptions$SearchMethod)", + "public ai.vespa.llm.generation.GeneratorOptions$SearchMethod getSearchMethod()", + "public ai.vespa.llm.generation.GeneratorOptions setSearchMethod(ai.vespa.llm.generation.GeneratorOptions$SearchMethod)", "public int getMaxLength()", - "public ai.vespa.llm.GeneratorOptions setMaxLength(int)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.LanguageModel" : { - "superClass" : "java.lang.Object", - "interfaces" : [ ], - "attributes" : [ - "public", - "interface", - "abstract" - ], - "methods" : [ - "public abstract java.util.List complete(ai.vespa.llm.Prompt)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.Prompt" : { - "superClass" : "java.lang.Object", - "interfaces" : [ ], - "attributes" : [ - "public", - "abstract" - ], - "methods" : [ - "public void <init>()", - "public abstract java.lang.String asString()", - "public ai.vespa.llm.Prompt append(ai.vespa.llm.Completion)", - "public abstract ai.vespa.llm.Prompt append(java.lang.String)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.StringPrompt" : { - "superClass" : "ai.vespa.llm.Prompt", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public java.lang.String asString()", - "public ai.vespa.llm.StringPrompt append(java.lang.String)", - "public ai.vespa.llm.StringPrompt append(ai.vespa.llm.Completion)", - "public java.lang.String toString()", - "public static ai.vespa.llm.StringPrompt from(java.lang.String)", - "public bridge synthetic ai.vespa.llm.Prompt append(java.lang.String)", - "public bridge synthetic ai.vespa.llm.Prompt append(ai.vespa.llm.Completion)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.client.OpenAiClient$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void <init>(java.lang.String)", - "public ai.vespa.llm.client.OpenAiClient$Builder model(java.lang.String)", - "public ai.vespa.llm.client.OpenAiClient$Builder echo(boolean)", - "public ai.vespa.llm.client.OpenAiClient build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.client.OpenAiClient" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public" - ], - "methods" : [ - "public java.util.List complete(ai.vespa.llm.Prompt)" - ], - "fields" : [ ] - }, - "ai.vespa.llm.test.MockLanguageModel$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ ], - "attributes" : [ - "public" - ], - "methods" : [ - "public ai.vespa.llm.test.MockLanguageModel$Builder completer(java.util.function.Function)", - "public void <init>()", - "public ai.vespa.llm.test.MockLanguageModel build()" - ], - "fields" : [ ] - }, - "ai.vespa.llm.test.MockLanguageModel" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "ai.vespa.llm.LanguageModel" - ], - "attributes" : [ - "public" - ], - "methods" : [ - "public void <init>(ai.vespa.llm.test.MockLanguageModel$Builder)", - "public java.util.List complete(ai.vespa.llm.Prompt)" + "public ai.vespa.llm.generation.GeneratorOptions setMaxLength(int)" ], "fields" : [ ] } diff --git a/model-integration/pom.xml b/model-integration/pom.xml index c96441f11a7..c27ed9d2c31 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -111,11 +111,6 @@ </dependency> <dependency> - <groupId>com.theokanning.openai-gpt3-java</groupId> - <artifactId>service</artifactId> - </dependency> - - <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> @@ -151,18 +146,6 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> - <configuration> - <!-- - openai-gpt3-java depends on a different Jackson version than the one we provide, - which leads to warnings, so we must disable error on warnings. - --> - <compilerArgs> - <arg>-Xlint:all</arg> - <arg>-Xlint:-rawtypes</arg> - <arg>-Xlint:-unchecked</arg> - <arg>-Xlint:-serial</arg> - </compilerArgs> - </configuration> </plugin> <plugin> <groupId>com.github.os72</groupId> diff --git a/model-integration/src/main/java/ai/vespa/llm/Completion.java b/model-integration/src/main/java/ai/vespa/llm/Completion.java deleted file mode 100644 index 5f483a65186..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/Completion.java +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm; - -import com.yahoo.api.annotations.Beta; - -import java.util.Objects; - -/** - * A completion from a language model. - * - * @author bratseth - */ -@Beta -public record Completion(String text, FinishReason finishReason) { - - public enum FinishReason { - - /** The maximum length of a completion was reached. */ - length, - - /** The completion is the predicted ending of the prompt. */ - stop - - } - - public Completion(String text, FinishReason finishReason) { - this.text = Objects.requireNonNull(text); - this.finishReason = Objects.requireNonNull(finishReason); - } - - /** Returns the generated text completion. */ - public String text() { return text; } - - /** Returns the reason this completion ended. */ - public FinishReason finishReason() { return finishReason; } - - public static Completion from(String text) { - return new Completion(text, FinishReason.stop); - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java deleted file mode 100644 index 0739162c5ee..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm; - -import com.yahoo.api.annotations.Beta; - -import java.util.List; - -/** - * Interface to language models. - * - * @author bratseth - */ -@Beta -public interface LanguageModel { - - List<Completion> complete(Prompt prompt); - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/Prompt.java b/model-integration/src/main/java/ai/vespa/llm/Prompt.java deleted file mode 100644 index 77093d5e21b..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/Prompt.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm; - -import com.yahoo.api.annotations.Beta; - -/** - * A prompt that can be given to a large language model to generate a completion. - * - * @author bratseth - */ -@Beta -public abstract class Prompt { - - public abstract String asString(); - - /** Returns a new prompt with the text of the given completion appended. */ - public Prompt append(Completion completion) { - return append(completion.text()); - } - - public abstract Prompt append(String text); - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java deleted file mode 100644 index 0af8388dfb1..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java +++ /dev/null @@ -1,43 +0,0 @@ -package ai.vespa.llm; - -import com.yahoo.api.annotations.Beta; - -import java.util.Objects; - -/** - * A prompt which just consists of a string. - * - * @author bratseth - */ -@Beta -public class StringPrompt extends Prompt { - - private final String string; - - private StringPrompt(String string) { - this.string = Objects.requireNonNull(string); - } - - @Override - public String asString() { return string; } - - @Override - public StringPrompt append(String text) { - return StringPrompt.from(string + text); - } - - @Override - public StringPrompt append(Completion completion) { - return append(completion.text()); - } - - @Override - public String toString() { - return string; - } - - public static StringPrompt from(String string) { - return new StringPrompt(string); - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java deleted file mode 100644 index 3f4475b2482..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.client; - -import ai.vespa.llm.Completion; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.Prompt; -import com.theokanning.openai.OpenAiHttpException; -import com.theokanning.openai.completion.CompletionRequest; -import com.theokanning.openai.service.OpenAiService; -import com.yahoo.api.annotations.Beta; -import com.yahoo.yolean.Exceptions; - -import java.util.List; - -/** - * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. - * - * @author bratseth - */ -@Beta -public class OpenAiClient implements LanguageModel { - - private final OpenAiService openAiService; - private final String model; - private final boolean echo; - - private OpenAiClient(Builder builder) { - openAiService = new OpenAiService(builder.token); - this.model = builder.model; - this.echo = builder.echo; - } - - @Override - public List<Completion> complete(Prompt prompt) { - try { - CompletionRequest completionRequest = CompletionRequest.builder() - .prompt(prompt.asString()) - .model(model) - .echo(echo) - .build(); - return openAiService.createCompletion(completionRequest).getChoices().stream() - .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList(); - } - catch (OpenAiHttpException e) { - throw new RuntimeException(Exceptions.toMessageString(e)); - } - } - - private Completion.FinishReason toFinishReason(String finishReasonString) { - return switch(finishReasonString) { - case "length" -> Completion.FinishReason.length; - case "stop" -> Completion.FinishReason.stop; - default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); - }; - } - - public static class Builder { - - private final String token; - private String model = "text-davinci-003"; - private boolean echo = false; - - public Builder(String token) { - this.token = token; - } - - /** One of the language models listed at https://platform.openai.com/docs/models */ - public Builder model(String model) { - this.model = model; - return this; - } - - public Builder echo(boolean echo) { - this.echo = echo; - return this; - } - - public OpenAiClient build() { - return new OpenAiClient(this); - } - - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/client/package-info.java b/model-integration/src/main/java/ai/vespa/llm/client/package-info.java deleted file mode 100644 index c95f87eec3c..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/client/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.client; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; - -/** - * Clients to externally hosted large language models. - */
\ No newline at end of file diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/generation/Generator.java index 6b60041947b..f20925b86ee 100644 --- a/model-integration/src/main/java/ai/vespa/llm/Generator.java +++ b/model-integration/src/main/java/ai/vespa/llm/generation/Generator.java @@ -1,4 +1,4 @@ -package ai.vespa.llm; +package ai.vespa.llm.generation; import ai.vespa.modelintegration.evaluator.OnnxEvaluator; import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/generation/GeneratorOptions.java index 8b490a733dd..79b466e5a74 100644 --- a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java +++ b/model-integration/src/main/java/ai/vespa/llm/generation/GeneratorOptions.java @@ -1,4 +1,4 @@ -package ai.vespa.llm; +package ai.vespa.llm.generation; import com.yahoo.api.annotations.Beta; diff --git a/model-integration/src/main/java/ai/vespa/llm/package-info.java b/model-integration/src/main/java/ai/vespa/llm/generation/package-info.java index 04fc24c51ee..ed3adb2f59e 100644 --- a/model-integration/src/main/java/ai/vespa/llm/package-info.java +++ b/model-integration/src/main/java/ai/vespa/llm/generation/package-info.java @@ -1,11 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi -package ai.vespa.llm; +package ai.vespa.llm.generation; import com.yahoo.api.annotations.PublicApi; import com.yahoo.osgi.annotation.ExportPackage; /** - * API for working with large language models. + * API for generating text with language models. */
\ No newline at end of file diff --git a/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java deleted file mode 100644 index 54b085a451c..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/test/MockLanguageModel.java +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.test; - -import ai.vespa.llm.Completion; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.Prompt; -import com.yahoo.api.annotations.Beta; - -import java.util.List; -import java.util.function.Function; - -/** - * @author bratseth - */ -@Beta -public class MockLanguageModel implements LanguageModel { - - private final Function<Prompt, List<Completion>> completer; - - public MockLanguageModel(Builder builder) { - completer = builder.completer; - } - - @Override - public List<Completion> complete(Prompt prompt) { - return completer.apply(prompt); - } - - public static class Builder { - - private Function<Prompt, List<Completion>> completer = prompt -> List.of(Completion.from("")); - - public Builder completer(Function<Prompt, List<Completion>> completer) { - this.completer = completer; - return this; - } - - public Builder() {} - - public MockLanguageModel build() { return new MockLanguageModel(this); } - - } - -} diff --git a/model-integration/src/main/java/ai/vespa/llm/test/package-info.java b/model-integration/src/main/java/ai/vespa/llm/test/package-info.java deleted file mode 100644 index 0d51815fd6d..00000000000 --- a/model-integration/src/main/java/ai/vespa/llm/test/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.test; - -/** - * Tools for writing tests when working with large language models. - */ - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java b/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java deleted file mode 100644 index 30b1c8c2fb1..00000000000 --- a/model-integration/src/test/java/ai/vespa/llm/CompletionTest.java +++ /dev/null @@ -1,37 +0,0 @@ -package ai.vespa.llm; - -import ai.vespa.llm.test.MockLanguageModel; -import org.junit.jupiter.api.Test; - -import java.util.List; -import java.util.function.Function; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * Tests completion with a mock completer. - * - * @author bratseth - */ -public class CompletionTest { - - @Test - public void testCompletion() { - Function<Prompt, List<Completion>> completer = in -> - switch (in.asString()) { - case "Complete this: " -> List.of(Completion.from("The completion")); - default -> throw new RuntimeException("Cannot complete '" + in + "'"); - }; - var llm = new MockLanguageModel.Builder().completer(completer).build(); - - String input = "Complete this: "; - StringPrompt prompt = StringPrompt.from(input); - for (int i = 0; i < 10; i++) { - var completion = llm.complete(prompt).get(0); - prompt = prompt.append(completion); - if (completion.finishReason() == Completion.FinishReason.stop) break; - } - assertEquals("Complete this: The completion", prompt.asString()); - } - -} diff --git a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java b/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java index c22902b344f..8c9b961f4a8 100644 --- a/model-integration/src/test/java/ai/vespa/llm/GeneratorTest.java +++ b/model-integration/src/test/java/ai/vespa/llm/generation/GeneratorTest.java @@ -1,5 +1,7 @@ -package ai.vespa.llm; +package ai.vespa.llm.generation; +import ai.vespa.llm.generation.Generator; +import ai.vespa.llm.generation.GeneratorOptions; import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.llm.GeneratorConfig; |