diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2023-04-19 10:58:01 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2023-04-19 10:58:01 +0200 |
commit | 2d9fd2e6c78e9ae8580cd4a21d20d0febb8f9c93 (patch) | |
tree | b42f58bd734f4db2e2e45dfad298f96ade13573e /model-integration/src/main | |
parent | 4dd12fde1043fb42eeac2917d40e77e2682403e4 (diff) |
Llm completion abstraction and OpenAi implementation
Diffstat (limited to 'model-integration/src/main')
7 files changed, 222 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/llm/Completion.java b/model-integration/src/main/java/ai/vespa/llm/Completion.java new file mode 100644 index 00000000000..13e2ae9f731 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/Completion.java @@ -0,0 +1,49 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; + +/** + * A completion form a language model. + * + * @author bratseth + */ +@Beta +public class Completion { + + public enum FinishReason { + + /** The maximum length of a completion was reached. */ + length, + + /** The completion is the predicted ending of the prompt. */ + stop + + } + + private final String text; + private final FinishReason finishReason; + + public Completion(String text, FinishReason finishReason) { + this.text = Objects.requireNonNull(text); + this.finishReason = Objects.requireNonNull(finishReason); + } + + /** Returns the generated text completion. */ + public String text() { return text; } + + /** Returns the reason this completion ended. */ + public FinishReason finishReason() { return finishReason; } + + public static Completion from(String text) { + return new Completion(text, FinishReason.stop); + } + + @Override + public String toString() { + return text; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/Generator.java b/model-integration/src/main/java/ai/vespa/llm/Generator.java index 973b5ac2899..6b60041947b 100644 --- a/model-integration/src/main/java/ai/vespa/llm/Generator.java +++ b/model-integration/src/main/java/ai/vespa/llm/Generator.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.api.annotations.Beta; import java.util.ArrayList; import java.util.List; @@ -27,6 +28,7 @@ import java.util.Map; * * @author lesters */ +@Beta public class Generator extends AbstractComponent { private final static int TOKEN_EOS = 1; // end of sequence diff --git a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java index 743bb7c2f27..8b490a733dd 100644 --- a/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java +++ b/model-integration/src/main/java/ai/vespa/llm/GeneratorOptions.java @@ -1,5 +1,8 @@ package ai.vespa.llm; +import com.yahoo.api.annotations.Beta; + +@Beta public class GeneratorOptions { public enum SearchMethod { diff --git a/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java new file mode 100644 index 00000000000..0739162c5ee --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/LanguageModel.java @@ -0,0 +1,18 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.List; + +/** + * Interface to language models. + * + * @author bratseth + */ +@Beta +public interface LanguageModel { + + List<Completion> complete(Prompt prompt); + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/Prompt.java b/model-integration/src/main/java/ai/vespa/llm/Prompt.java new file mode 100644 index 00000000000..77093d5e21b --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/Prompt.java @@ -0,0 +1,23 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +/** + * A prompt that can be given to a large language model to generate a completion. + * + * @author bratseth + */ +@Beta +public abstract class Prompt { + + public abstract String asString(); + + /** Returns a new prompt with the text of the given completion appended. */ + public Prompt append(Completion completion) { + return append(completion.text()); + } + + public abstract Prompt append(String text); + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java new file mode 100644 index 00000000000..0af8388dfb1 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/StringPrompt.java @@ -0,0 +1,43 @@ +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; + +/** + * A prompt which just consists of a string. + * + * @author bratseth + */ +@Beta +public class StringPrompt extends Prompt { + + private final String string; + + private StringPrompt(String string) { + this.string = Objects.requireNonNull(string); + } + + @Override + public String asString() { return string; } + + @Override + public StringPrompt append(String text) { + return StringPrompt.from(string + text); + } + + @Override + public StringPrompt append(Completion completion) { + return append(completion.text()); + } + + @Override + public String toString() { + return string; + } + + public static StringPrompt from(String string) { + return new StringPrompt(string); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java new file mode 100644 index 00000000000..3f4475b2482 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/llm/client/OpenAiClient.java @@ -0,0 +1,84 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm.client; + +import ai.vespa.llm.Completion; +import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.Prompt; +import com.theokanning.openai.OpenAiHttpException; +import com.theokanning.openai.completion.CompletionRequest; +import com.theokanning.openai.service.OpenAiService; +import com.yahoo.api.annotations.Beta; +import com.yahoo.yolean.Exceptions; + +import java.util.List; + +/** + * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. + * + * @author bratseth + */ +@Beta +public class OpenAiClient implements LanguageModel { + + private final OpenAiService openAiService; + private final String model; + private final boolean echo; + + private OpenAiClient(Builder builder) { + openAiService = new OpenAiService(builder.token); + this.model = builder.model; + this.echo = builder.echo; + } + + @Override + public List<Completion> complete(Prompt prompt) { + try { + CompletionRequest completionRequest = CompletionRequest.builder() + .prompt(prompt.asString()) + .model(model) + .echo(echo) + .build(); + return openAiService.createCompletion(completionRequest).getChoices().stream() + .map(c -> new Completion(c.getText(), toFinishReason(c.getFinish_reason()))).toList(); + } + catch (OpenAiHttpException e) { + throw new RuntimeException(Exceptions.toMessageString(e)); + } + } + + private Completion.FinishReason toFinishReason(String finishReasonString) { + return switch(finishReasonString) { + case "length" -> Completion.FinishReason.length; + case "stop" -> Completion.FinishReason.stop; + default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); + }; + } + + public static class Builder { + + private final String token; + private String model = "text-davinci-003"; + private boolean echo = false; + + public Builder(String token) { + this.token = token; + } + + /** One of the language models listed at https://platform.openai.com/docs/models */ + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder echo(boolean echo) { + this.echo = echo; + return this; + } + + public OpenAiClient build() { + return new OpenAiClient(this); + } + + } + +} |