From 98ccab8d442bc3b13de47746bbd265a08b319add Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 2 Apr 2024 09:23:54 +0200 Subject: Move LLM classes in vespajlib from ai.vespa.llm to ai.vespa.languagemodels --- vespajlib/abi-spec.json | 78 +++++------ .../vespa/languagemodels/InferenceParameters.java | 76 ++++++++++ .../ai/vespa/languagemodels/LanguageModel.java | 26 ++++ .../languagemodels/LanguageModelException.java | 19 +++ .../languagemodels/client/openai/OpenAiClient.java | 155 +++++++++++++++++++++ .../languagemodels/client/openai/package-info.java | 11 ++ .../languagemodels/completion/Completion.java | 50 +++++++ .../ai/vespa/languagemodels/completion/Prompt.java | 23 +++ .../languagemodels/completion/StringPrompt.java | 44 ++++++ .../languagemodels/completion/package-info.java | 11 ++ .../java/ai/vespa/languagemodels/package-info.java | 11 ++ .../languagemodels/test/MockLanguageModel.java | 54 +++++++ .../ai/vespa/languagemodels/test/package-info.java | 11 ++ .../java/ai/vespa/llm/InferenceParameters.java | 76 ---------- .../src/main/java/ai/vespa/llm/LanguageModel.java | 26 ---- .../java/ai/vespa/llm/LanguageModelException.java | 19 --- .../ai/vespa/llm/client/openai/OpenAiClient.java | 155 --------------------- .../ai/vespa/llm/client/openai/package-info.java | 11 -- .../java/ai/vespa/llm/completion/Completion.java | 50 ------- .../main/java/ai/vespa/llm/completion/Prompt.java | 23 --- .../java/ai/vespa/llm/completion/StringPrompt.java | 44 ------ .../java/ai/vespa/llm/completion/package-info.java | 11 -- .../src/main/java/ai/vespa/llm/package-info.java | 11 -- .../java/ai/vespa/llm/test/MockLanguageModel.java | 54 ------- .../main/java/ai/vespa/llm/test/package-info.java | 11 -- .../client/openai/OpenAiClientCompletionTest.java | 46 ++++++ .../languagemodels/completion/CompletionTest.java | 40 ++++++ .../client/openai/OpenAiClientCompletionTest.java | 47 ------- .../ai/vespa/llm/completion/CompletionTest.java | 40 ------ 29 files changed, 616 insertions(+), 617 deletions(-) create mode 100755 vespajlib/src/main/java/ai/vespa/languagemodels/InferenceParameters.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModel.java create mode 100755 vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModelException.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/OpenAiClient.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/package-info.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/completion/Completion.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/completion/Prompt.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/completion/StringPrompt.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/completion/package-info.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/package-info.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/test/MockLanguageModel.java create mode 100644 vespajlib/src/main/java/ai/vespa/languagemodels/test/package-info.java delete mode 100755 vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java delete mode 100755 vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/package-info.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java delete mode 100644 vespajlib/src/main/java/ai/vespa/llm/test/package-info.java create mode 100644 vespajlib/src/test/java/ai/vespa/languagemodels/client/openai/OpenAiClientCompletionTest.java create mode 100644 vespajlib/src/test/java/ai/vespa/languagemodels/completion/CompletionTest.java delete mode 100644 vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java delete mode 100644 vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java (limited to 'vespajlib') diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 45e88ac2e94..c92d41edfd9 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -4077,7 +4077,7 @@ ], "fields" : [ ] }, - "ai.vespa.llm.InferenceParameters" : { + "ai.vespa.languagemodels.InferenceParameters" : { "superClass" : "java.lang.Object", "interfaces" : [ ], "attributes" : [ @@ -4097,7 +4097,7 @@ ], "fields" : [ ] }, - "ai.vespa.llm.LanguageModel" : { + "ai.vespa.languagemodels.LanguageModel" : { "superClass" : "java.lang.Object", "interfaces" : [ ], "attributes" : [ @@ -4106,12 +4106,12 @@ "abstract" ], "methods" : [ - "public abstract java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + "public abstract java.util.List complete(ai.vespa.languagemodels.completion.Prompt, ai.vespa.languagemodels.InferenceParameters)", + "public abstract java.util.concurrent.CompletableFuture completeAsync(ai.vespa.languagemodels.completion.Prompt, ai.vespa.languagemodels.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] }, - "ai.vespa.llm.LanguageModelException" : { + "ai.vespa.languagemodels.LanguageModelException" : { "superClass" : "java.lang.RuntimeException", "interfaces" : [ ], "attributes" : [ @@ -4123,22 +4123,22 @@ ], "fields" : [ ] }, - "ai.vespa.llm.client.openai.OpenAiClient" : { + "ai.vespa.languagemodels.client.openai.OpenAiClient" : { "superClass" : "java.lang.Object", "interfaces" : [ - "ai.vespa.llm.LanguageModel" + "ai.vespa.languagemodels.LanguageModel" ], "attributes" : [ "public" ], "methods" : [ "public void ()", - "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + "public java.util.List complete(ai.vespa.languagemodels.completion.Prompt, ai.vespa.languagemodels.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.languagemodels.completion.Prompt, ai.vespa.languagemodels.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] }, - "ai.vespa.llm.completion.Completion$FinishReason" : { + "ai.vespa.languagemodels.completion.Completion$FinishReason" : { "superClass" : "java.lang.Enum", "interfaces" : [ ], "attributes" : [ @@ -4147,17 +4147,17 @@ "enum" ], "methods" : [ - "public static ai.vespa.llm.completion.Completion$FinishReason[] values()", - "public static ai.vespa.llm.completion.Completion$FinishReason valueOf(java.lang.String)" + "public static ai.vespa.languagemodels.completion.Completion$FinishReason[] values()", + "public static ai.vespa.languagemodels.completion.Completion$FinishReason valueOf(java.lang.String)" ], "fields" : [ - "public static final enum ai.vespa.llm.completion.Completion$FinishReason length", - "public static final enum ai.vespa.llm.completion.Completion$FinishReason stop", - "public static final enum ai.vespa.llm.completion.Completion$FinishReason none", - "public static final enum ai.vespa.llm.completion.Completion$FinishReason error" + "public static final enum ai.vespa.languagemodels.completion.Completion$FinishReason length", + "public static final enum ai.vespa.languagemodels.completion.Completion$FinishReason stop", + "public static final enum ai.vespa.languagemodels.completion.Completion$FinishReason none", + "public static final enum ai.vespa.languagemodels.completion.Completion$FinishReason error" ] }, - "ai.vespa.llm.completion.Completion" : { + "ai.vespa.languagemodels.completion.Completion" : { "superClass" : "java.lang.Record", "interfaces" : [ ], "attributes" : [ @@ -4166,18 +4166,18 @@ "record" ], "methods" : [ - "public void (java.lang.String, ai.vespa.llm.completion.Completion$FinishReason)", + "public void (java.lang.String, ai.vespa.languagemodels.completion.Completion$FinishReason)", "public java.lang.String text()", - "public ai.vespa.llm.completion.Completion$FinishReason finishReason()", - "public static ai.vespa.llm.completion.Completion from(java.lang.String)", - "public static ai.vespa.llm.completion.Completion from(java.lang.String, ai.vespa.llm.completion.Completion$FinishReason)", + "public ai.vespa.languagemodels.completion.Completion$FinishReason finishReason()", + "public static ai.vespa.languagemodels.completion.Completion from(java.lang.String)", + "public static ai.vespa.languagemodels.completion.Completion from(java.lang.String, ai.vespa.languagemodels.completion.Completion$FinishReason)", "public final java.lang.String toString()", "public final int hashCode()", "public final boolean equals(java.lang.Object)" ], "fields" : [ ] }, - "ai.vespa.llm.completion.Prompt" : { + "ai.vespa.languagemodels.completion.Prompt" : { "superClass" : "java.lang.Object", "interfaces" : [ ], "attributes" : [ @@ -4187,53 +4187,53 @@ "methods" : [ "public void ()", "public abstract java.lang.String asString()", - "public ai.vespa.llm.completion.Prompt append(ai.vespa.llm.completion.Completion)", - "public abstract ai.vespa.llm.completion.Prompt append(java.lang.String)" + "public ai.vespa.languagemodels.completion.Prompt append(ai.vespa.languagemodels.completion.Completion)", + "public abstract ai.vespa.languagemodels.completion.Prompt append(java.lang.String)" ], "fields" : [ ] }, - "ai.vespa.llm.completion.StringPrompt" : { - "superClass" : "ai.vespa.llm.completion.Prompt", + "ai.vespa.languagemodels.completion.StringPrompt" : { + "superClass" : "ai.vespa.languagemodels.completion.Prompt", "interfaces" : [ ], "attributes" : [ "public" ], "methods" : [ "public java.lang.String asString()", - "public ai.vespa.llm.completion.StringPrompt append(java.lang.String)", - "public ai.vespa.llm.completion.StringPrompt append(ai.vespa.llm.completion.Completion)", + "public ai.vespa.languagemodels.completion.StringPrompt append(java.lang.String)", + "public ai.vespa.languagemodels.completion.StringPrompt append(ai.vespa.languagemodels.completion.Completion)", "public java.lang.String toString()", - "public static ai.vespa.llm.completion.StringPrompt from(java.lang.String)", - "public bridge synthetic ai.vespa.llm.completion.Prompt append(java.lang.String)", - "public bridge synthetic ai.vespa.llm.completion.Prompt append(ai.vespa.llm.completion.Completion)" + "public static ai.vespa.languagemodels.completion.StringPrompt from(java.lang.String)", + "public bridge synthetic ai.vespa.languagemodels.completion.Prompt append(java.lang.String)", + "public bridge synthetic ai.vespa.languagemodels.completion.Prompt append(ai.vespa.languagemodels.completion.Completion)" ], "fields" : [ ] }, - "ai.vespa.llm.test.MockLanguageModel$Builder" : { + "ai.vespa.languagemodels.test.MockLanguageModel$Builder" : { "superClass" : "java.lang.Object", "interfaces" : [ ], "attributes" : [ "public" ], "methods" : [ - "public ai.vespa.llm.test.MockLanguageModel$Builder completer(java.util.function.Function)", + "public ai.vespa.languagemodels.test.MockLanguageModel$Builder completer(java.util.function.Function)", "public void ()", - "public ai.vespa.llm.test.MockLanguageModel build()" + "public ai.vespa.languagemodels.test.MockLanguageModel build()" ], "fields" : [ ] }, - "ai.vespa.llm.test.MockLanguageModel" : { + "ai.vespa.languagemodels.test.MockLanguageModel" : { "superClass" : "java.lang.Object", "interfaces" : [ - "ai.vespa.llm.LanguageModel" + "ai.vespa.languagemodels.LanguageModel" ], "attributes" : [ "public" ], "methods" : [ - "public void (ai.vespa.llm.test.MockLanguageModel$Builder)", - "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)", - "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)" + "public void (ai.vespa.languagemodels.test.MockLanguageModel$Builder)", + "public java.util.List complete(ai.vespa.languagemodels.completion.Prompt, ai.vespa.languagemodels.InferenceParameters)", + "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.languagemodels.completion.Prompt, ai.vespa.languagemodels.InferenceParameters, java.util.function.Consumer)" ], "fields" : [ ] } diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/languagemodels/InferenceParameters.java new file mode 100755 index 00000000000..dd86db039dc --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/InferenceParameters.java @@ -0,0 +1,76 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Parameters for inference to language models. Parameters are typically + * supplied from searchers or processors and comes from query strings, + * headers, or other sources. Which parameters are available depends on + * the language model used. + * + * author lesters + */ +@Beta +public class InferenceParameters { + + private String apiKey; + private String endpoint; + private final Function options; + + public InferenceParameters(Function options) { + this(null, options); + } + + public InferenceParameters(String apiKey, Function options) { + this.apiKey = apiKey; + this.options = Objects.requireNonNull(options); + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public Optional getApiKey() { + return Optional.ofNullable(apiKey); + } + + public void setEndpoint(String endpoint) { + this.endpoint = endpoint; + } + + public Optional getEndpoint() { + return Optional.ofNullable(endpoint); + } + + public Optional get(String option) { + return Optional.ofNullable(options.apply(option)); + } + + public Optional getDouble(String option) { + try { + return Optional.of(Double.parseDouble(options.apply(option))); + } catch (Exception e) { + return Optional.empty(); + } + } + + public Optional getInt(String option) { + try { + return Optional.of(Integer.parseInt(options.apply(option))); + } catch (Exception e) { + return Optional.empty(); + } + } + + public void ifPresent(String option, Consumer func) { + get(option).ifPresent(func); + } + +} + diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModel.java new file mode 100644 index 00000000000..115a6e21bd6 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModel.java @@ -0,0 +1,26 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels; + +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.completion.Prompt; +import com.yahoo.api.annotations.Beta; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; + +/** + * Interface to language models. + * + * @author bratseth + */ +@Beta +public interface LanguageModel { + + List complete(Prompt prompt, InferenceParameters options); + + CompletableFuture completeAsync(Prompt prompt, + InferenceParameters options, + Consumer consumer); + +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModelException.java b/vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModelException.java new file mode 100755 index 00000000000..f7ae7472d13 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/LanguageModelException.java @@ -0,0 +1,19 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels; + +import com.yahoo.api.annotations.Beta; + +@Beta +public class LanguageModelException extends RuntimeException { + + private final int code; + + public LanguageModelException(int code, String message) { + super(message); + this.code = code; + } + + public int code() { + return code; + } +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/OpenAiClient.java new file mode 100644 index 00000000000..c83fa4799c8 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/OpenAiClient.java @@ -0,0 +1,155 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.client.openai; + +import ai.vespa.languagemodels.LanguageModelException; +import ai.vespa.languagemodels.InferenceParameters; +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.LanguageModel; +import ai.vespa.languagemodels.completion.Prompt; +import com.yahoo.api.annotations.Beta; +import com.yahoo.slime.ArrayTraverser; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.Slime; +import com.yahoo.slime.SlimeUtils; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. + * Currently, only completions are implemented. + * + * @author bratseth + * @author lesters + */ +@Beta +public class OpenAiClient implements LanguageModel { + + private static final String DEFAULT_MODEL = "gpt-3.5-turbo"; + private static final String DATA_FIELD = "data: "; + + private static final String OPTION_MODEL = "model"; + private static final String OPTION_TEMPERATURE = "temperature"; + private static final String OPTION_MAX_TOKENS = "maxTokens"; + + private final HttpClient httpClient; + + public OpenAiClient() { + this.httpClient = HttpClient.newBuilder().build(); + } + + @Override + public List complete(Prompt prompt, InferenceParameters options) { + try { + HttpResponse httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray()); + var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); + if ( httpResponse.statusCode() != 200) + throw new IllegalArgumentException(SlimeUtils.toJson(response)); + return toCompletions(response); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters options, + Consumer consumer) { + try { + var request = toRequest(prompt, options, true); + var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); + var completionFuture = new CompletableFuture(); + + futureResponse.thenAcceptAsync(response -> { + try { + int responseCode = response.statusCode(); + if (responseCode != 200) { + throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining())); + } + + Stream lines = response.body(); + lines.forEach(line -> { + if (line.startsWith(DATA_FIELD)) { + var root = SlimeUtils.jsonToSlime(line.substring(DATA_FIELD.length())).get(); + var completion = toCompletions(root, "delta").get(0); + consumer.accept(completion); + if (!completion.finishReason().equals(Completion.FinishReason.none)) { + completionFuture.complete(completion.finishReason()); + } + } + }); + } catch (Exception e) { + completionFuture.completeExceptionally(e); + } + }); + return completionFuture; + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException { + var slime = new Slime(); + var root = slime.setObject(); + root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL)); + root.setBool("stream", stream); + root.setLong("n", 1); + + if (options.getDouble(OPTION_TEMPERATURE).isPresent()) + root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get()); + if (options.getInt(OPTION_MAX_TOKENS).isPresent()) + root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get()); + // Others? + + var messagesArray = root.setArray("messages"); + var messagesObject = messagesArray.addObject(); + messagesObject.setString("role", "user"); + messagesObject.setString("content", prompt.asString()); + + var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions"); + return HttpRequest.newBuilder(new URI(endpoint)) + .header("Content-Type", "application/json") + .header("Authorization", "Bearer " + options.getApiKey().orElse("")) + .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) + .build(); + } + + private List toCompletions(Inspector response) { + return toCompletions(response, "message"); + } + + private List toCompletions(Inspector response, String field) { + List completions = new ArrayList<>(); + response.field("choices") + .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice, field))); + return completions; + } + + private Completion toCompletion(Inspector choice, String field) { + var content = choice.field(field).field("content").asString(); + var finishReason = toFinishReason(choice.field("finish_reason").asString()); + return new Completion(content, finishReason); + } + + private Completion.FinishReason toFinishReason(String finishReasonString) { + return switch(finishReasonString) { + case "length" -> Completion.FinishReason.length; + case "stop" -> Completion.FinishReason.stop; + case "", "null" -> Completion.FinishReason.none; + default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); + }; + } + +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/package-info.java new file mode 100644 index 00000000000..d3f6d9042d8 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/client/openai/package-info.java @@ -0,0 +1,11 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.languagemodels.client.openai; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; + +/** + * Client to OpenAi's large language models. + */ diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/Completion.java new file mode 100644 index 00000000000..68b6dc47bfd --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/Completion.java @@ -0,0 +1,50 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.completion; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; + +/** + * A completion from a language model. + * + * @author bratseth + */ +@Beta +public record Completion(String text, FinishReason finishReason) { + + public enum FinishReason { + + /** The maximum length of a completion was reached. */ + length, + + /** The completion is the predicted ending of the prompt. */ + stop, + + /** The completion is not finished yet, more tokens are incoming. */ + none, + + /** An error occurred while generating the completion */ + error + } + + public Completion(String text, FinishReason finishReason) { + this.text = Objects.requireNonNull(text); + this.finishReason = Objects.requireNonNull(finishReason); + } + + /** Returns the generated text completion. */ + public String text() { return text; } + + /** Returns the reason this completion ended. */ + public FinishReason finishReason() { return finishReason; } + + public static Completion from(String text) { + return from(text, FinishReason.stop); + } + + public static Completion from(String text, FinishReason reason) { + return new Completion(text, reason); + } + +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/completion/Prompt.java b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/Prompt.java new file mode 100644 index 00000000000..f10a1768b8c --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/Prompt.java @@ -0,0 +1,23 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.completion; + +import com.yahoo.api.annotations.Beta; + +/** + * A prompt that can be given to a large language model to generate a completion. + * + * @author bratseth + */ +@Beta +public abstract class Prompt { + + public abstract String asString(); + + /** Returns a new prompt with the text of the given completion appended. */ + public Prompt append(Completion completion) { + return append(completion.text()); + } + + public abstract Prompt append(String text); + +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/completion/StringPrompt.java b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/StringPrompt.java new file mode 100644 index 00000000000..d4bb387cd6c --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/StringPrompt.java @@ -0,0 +1,44 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.completion; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; + +/** + * A prompt which just consists of a string. + * + * @author bratseth + */ +@Beta +public class StringPrompt extends Prompt { + + private final String string; + + private StringPrompt(String string) { + this.string = Objects.requireNonNull(string); + } + + @Override + public String asString() { return string; } + + @Override + public StringPrompt append(String text) { + return StringPrompt.from(string + text); + } + + @Override + public StringPrompt append(Completion completion) { + return append(completion.text()); + } + + @Override + public String toString() { + return string; + } + + public static StringPrompt from(String string) { + return new StringPrompt(string); + } + +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/completion/package-info.java b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/package-info.java new file mode 100644 index 00000000000..fbae8d3a0e9 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/completion/package-info.java @@ -0,0 +1,11 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.languagemodels.completion; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; + +/** + * Classes for generating text completions with language models. + */ diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/package-info.java b/vespajlib/src/main/java/ai/vespa/languagemodels/package-info.java new file mode 100644 index 00000000000..9ec6a7773e5 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/package-info.java @@ -0,0 +1,11 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.languagemodels; + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; + +/** + * API for working with large language models. + */ diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/languagemodels/test/MockLanguageModel.java new file mode 100644 index 00000000000..42bfb8b3e93 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/test/MockLanguageModel.java @@ -0,0 +1,54 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.test; + +import ai.vespa.languagemodels.LanguageModel; +import ai.vespa.languagemodels.InferenceParameters; +import ai.vespa.languagemodels.completion.Completion; +import ai.vespa.languagemodels.completion.Prompt; +import com.yahoo.api.annotations.Beta; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * @author bratseth + */ +@Beta +public class MockLanguageModel implements LanguageModel { + + private final Function> completer; + + public MockLanguageModel(Builder builder) { + completer = builder.completer; + } + + @Override + public List complete(Prompt prompt, InferenceParameters options) { + return completer.apply(prompt); + } + + @Override + public CompletableFuture completeAsync(Prompt prompt, + InferenceParameters options, + Consumer action) { + throw new RuntimeException("Not implemented"); + } + + public static class Builder { + + private Function> completer = prompt -> List.of(Completion.from("")); + + public Builder completer(Function> completer) { + this.completer = completer; + return this; + } + + public Builder() {} + + public MockLanguageModel build() { return new MockLanguageModel(this); } + + } + +} diff --git a/vespajlib/src/main/java/ai/vespa/languagemodels/test/package-info.java b/vespajlib/src/main/java/ai/vespa/languagemodels/test/package-info.java new file mode 100644 index 00000000000..ba5cf265408 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/languagemodels/test/package-info.java @@ -0,0 +1,11 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +@PublicApi +package ai.vespa.languagemodels.test; + +/** + * Tools for writing tests when working with large language models. + */ + +import com.yahoo.api.annotations.PublicApi; +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java deleted file mode 100755 index a942e5090e5..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm; - -import com.yahoo.api.annotations.Beta; - -import java.util.Objects; -import java.util.Optional; -import java.util.function.Consumer; -import java.util.function.Function; - -/** - * Parameters for inference to language models. Parameters are typically - * supplied from searchers or processors and comes from query strings, - * headers, or other sources. Which parameters are available depends on - * the language model used. - * - * author lesters - */ -@Beta -public class InferenceParameters { - - private String apiKey; - private String endpoint; - private final Function options; - - public InferenceParameters(Function options) { - this(null, options); - } - - public InferenceParameters(String apiKey, Function options) { - this.apiKey = apiKey; - this.options = Objects.requireNonNull(options); - } - - public void setApiKey(String apiKey) { - this.apiKey = apiKey; - } - - public Optional getApiKey() { - return Optional.ofNullable(apiKey); - } - - public void setEndpoint(String endpoint) { - this.endpoint = endpoint; - } - - public Optional getEndpoint() { - return Optional.ofNullable(endpoint); - } - - public Optional get(String option) { - return Optional.ofNullable(options.apply(option)); - } - - public Optional getDouble(String option) { - try { - return Optional.of(Double.parseDouble(options.apply(option))); - } catch (Exception e) { - return Optional.empty(); - } - } - - public Optional getInt(String option) { - try { - return Optional.of(Integer.parseInt(options.apply(option))); - } catch (Exception e) { - return Optional.empty(); - } - } - - public void ifPresent(String option, Consumer func) { - get(option).ifPresent(func); - } - -} - diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java deleted file mode 100644 index 059f25fadb4..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm; - -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; - -/** - * Interface to language models. - * - * @author bratseth - */ -@Beta -public interface LanguageModel { - - List complete(Prompt prompt, InferenceParameters options); - - CompletableFuture completeAsync(Prompt prompt, - InferenceParameters options, - Consumer consumer); - -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java deleted file mode 100755 index b5dbf615c08..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm; - -import com.yahoo.api.annotations.Beta; - -@Beta -public class LanguageModelException extends RuntimeException { - - private final int code; - - public LanguageModelException(int code, String message) { - super(message); - this.code = code; - } - - public int code() { - return code; - } -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java deleted file mode 100644 index 75308a84faa..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.client.openai; - -import ai.vespa.llm.LanguageModelException; -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; -import com.yahoo.slime.ArrayTraverser; -import com.yahoo.slime.Inspector; -import com.yahoo.slime.Slime; -import com.yahoo.slime.SlimeUtils; - -import java.io.IOException; -import java.net.URI; -import java.net.URISyntaxException; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/** - * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. - * Currently, only completions are implemented. - * - * @author bratseth - * @author lesters - */ -@Beta -public class OpenAiClient implements LanguageModel { - - private static final String DEFAULT_MODEL = "gpt-3.5-turbo"; - private static final String DATA_FIELD = "data: "; - - private static final String OPTION_MODEL = "model"; - private static final String OPTION_TEMPERATURE = "temperature"; - private static final String OPTION_MAX_TOKENS = "maxTokens"; - - private final HttpClient httpClient; - - public OpenAiClient() { - this.httpClient = HttpClient.newBuilder().build(); - } - - @Override - public List complete(Prompt prompt, InferenceParameters options) { - try { - HttpResponse httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray()); - var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); - if ( httpResponse.statusCode() != 200) - throw new IllegalArgumentException(SlimeUtils.toJson(response)); - return toCompletions(response); - } - catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, - InferenceParameters options, - Consumer consumer) { - try { - var request = toRequest(prompt, options, true); - var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); - var completionFuture = new CompletableFuture(); - - futureResponse.thenAcceptAsync(response -> { - try { - int responseCode = response.statusCode(); - if (responseCode != 200) { - throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining())); - } - - Stream lines = response.body(); - lines.forEach(line -> { - if (line.startsWith(DATA_FIELD)) { - var root = SlimeUtils.jsonToSlime(line.substring(DATA_FIELD.length())).get(); - var completion = toCompletions(root, "delta").get(0); - consumer.accept(completion); - if (!completion.finishReason().equals(Completion.FinishReason.none)) { - completionFuture.complete(completion.finishReason()); - } - } - }); - } catch (Exception e) { - completionFuture.completeExceptionally(e); - } - }); - return completionFuture; - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException { - var slime = new Slime(); - var root = slime.setObject(); - root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL)); - root.setBool("stream", stream); - root.setLong("n", 1); - - if (options.getDouble(OPTION_TEMPERATURE).isPresent()) - root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get()); - if (options.getInt(OPTION_MAX_TOKENS).isPresent()) - root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get()); - // Others? - - var messagesArray = root.setArray("messages"); - var messagesObject = messagesArray.addObject(); - messagesObject.setString("role", "user"); - messagesObject.setString("content", prompt.asString()); - - var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions"); - return HttpRequest.newBuilder(new URI(endpoint)) - .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + options.getApiKey().orElse("")) - .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) - .build(); - } - - private List toCompletions(Inspector response) { - return toCompletions(response, "message"); - } - - private List toCompletions(Inspector response, String field) { - List completions = new ArrayList<>(); - response.field("choices") - .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice, field))); - return completions; - } - - private Completion toCompletion(Inspector choice, String field) { - var content = choice.field(field).field("content").asString(); - var finishReason = toFinishReason(choice.field("finish_reason").asString()); - return new Completion(content, finishReason); - } - - private Completion.FinishReason toFinishReason(String finishReasonString) { - return switch(finishReasonString) { - case "length" -> Completion.FinishReason.length; - case "stop" -> Completion.FinishReason.stop; - case "", "null" -> Completion.FinishReason.none; - default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); - }; - } - -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java deleted file mode 100644 index 2593d919499..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.client.openai; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; - -/** - * Client to OpenAi's large language models. - */ diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java deleted file mode 100644 index 91d0ad9bd02..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.completion; - -import com.yahoo.api.annotations.Beta; - -import java.util.Objects; - -/** - * A completion from a language model. - * - * @author bratseth - */ -@Beta -public record Completion(String text, FinishReason finishReason) { - - public enum FinishReason { - - /** The maximum length of a completion was reached. */ - length, - - /** The completion is the predicted ending of the prompt. */ - stop, - - /** The completion is not finished yet, more tokens are incoming. */ - none, - - /** An error occurred while generating the completion */ - error - } - - public Completion(String text, FinishReason finishReason) { - this.text = Objects.requireNonNull(text); - this.finishReason = Objects.requireNonNull(finishReason); - } - - /** Returns the generated text completion. */ - public String text() { return text; } - - /** Returns the reason this completion ended. */ - public FinishReason finishReason() { return finishReason; } - - public static Completion from(String text) { - return from(text, FinishReason.stop); - } - - public static Completion from(String text, FinishReason reason) { - return new Completion(text, reason); - } - -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java deleted file mode 100644 index 44dfb8499a8..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.completion; - -import com.yahoo.api.annotations.Beta; - -/** - * A prompt that can be given to a large language model to generate a completion. - * - * @author bratseth - */ -@Beta -public abstract class Prompt { - - public abstract String asString(); - - /** Returns a new prompt with the text of the given completion appended. */ - public Prompt append(Completion completion) { - return append(completion.text()); - } - - public abstract Prompt append(String text); - -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java b/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java deleted file mode 100644 index 9e702c79a7a..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.completion; - -import com.yahoo.api.annotations.Beta; - -import java.util.Objects; - -/** - * A prompt which just consists of a string. - * - * @author bratseth - */ -@Beta -public class StringPrompt extends Prompt { - - private final String string; - - private StringPrompt(String string) { - this.string = Objects.requireNonNull(string); - } - - @Override - public String asString() { return string; } - - @Override - public StringPrompt append(String text) { - return StringPrompt.from(string + text); - } - - @Override - public StringPrompt append(Completion completion) { - return append(completion.text()); - } - - @Override - public String toString() { - return string; - } - - public static StringPrompt from(String string) { - return new StringPrompt(string); - } - -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java deleted file mode 100644 index 57c2b3f3364..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.completion; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; - -/** - * Classes for generating text completions with language models. - */ diff --git a/vespajlib/src/main/java/ai/vespa/llm/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/package-info.java deleted file mode 100644 index 8640f652ad4..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm; - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; - -/** - * API for working with large language models. - */ diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java deleted file mode 100644 index 0e757a1f1e7..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.test; - -import ai.vespa.llm.LanguageModel; -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.Prompt; -import com.yahoo.api.annotations.Beta; - -import java.util.List; -import java.util.concurrent.CompletableFuture; -import java.util.function.Consumer; -import java.util.function.Function; - -/** - * @author bratseth - */ -@Beta -public class MockLanguageModel implements LanguageModel { - - private final Function> completer; - - public MockLanguageModel(Builder builder) { - completer = builder.completer; - } - - @Override - public List complete(Prompt prompt, InferenceParameters options) { - return completer.apply(prompt); - } - - @Override - public CompletableFuture completeAsync(Prompt prompt, - InferenceParameters options, - Consumer action) { - throw new RuntimeException("Not implemented"); - } - - public static class Builder { - - private Function> completer = prompt -> List.of(Completion.from("")); - - public Builder completer(Function> completer) { - this.completer = completer; - return this; - } - - public Builder() {} - - public MockLanguageModel build() { return new MockLanguageModel(this); } - - } - -} diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java deleted file mode 100644 index ab3b7acc657..00000000000 --- a/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage -@PublicApi -package ai.vespa.llm.test; - -/** - * Tools for writing tests when working with large language models. - */ - -import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/vespajlib/src/test/java/ai/vespa/languagemodels/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/languagemodels/client/openai/OpenAiClientCompletionTest.java new file mode 100644 index 00000000000..d6cb4b13f50 --- /dev/null +++ b/vespajlib/src/test/java/ai/vespa/languagemodels/client/openai/OpenAiClientCompletionTest.java @@ -0,0 +1,46 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.client.openai; + +import ai.vespa.languagemodels.InferenceParameters; +import ai.vespa.languagemodels.completion.StringPrompt; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +/** + * @author bratseth + */ +public class OpenAiClientCompletionTest { + + private static final String apiKey = ""; + + @Test + @Disabled + public void testClient() { + var client = new OpenAiClient(); + var options = Map.of("maxTokens", "10"); + var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"); + + System.out.print(prompt); + var completion = client.complete(prompt, new InferenceParameters(apiKey, options::get)).get(0); + System.out.print(completion.text()); + } + + @Test + @Disabled + public void testAsyncClient() { + var client = new OpenAiClient(); + var options = Map.of("maxTokens", "10"); + var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"); + System.out.print(prompt); + var future = client.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { + System.out.print(completion.text()); + }); + System.out.println("\nWaiting for completion...\n\n"); + System.out.println("\nFinished streaming because of " + future.join()); + } + +} diff --git a/vespajlib/src/test/java/ai/vespa/languagemodels/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/languagemodels/completion/CompletionTest.java new file mode 100644 index 00000000000..304334a0ea2 --- /dev/null +++ b/vespajlib/src/test/java/ai/vespa/languagemodels/completion/CompletionTest.java @@ -0,0 +1,40 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.languagemodels.completion; + +import ai.vespa.languagemodels.test.MockLanguageModel; +import ai.vespa.languagemodels.InferenceParameters; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.function.Function; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Tests completion with a mock completer. + * + * @author bratseth + */ +public class CompletionTest { + + @Test + public void testCompletion() { + Function> completer = in -> + switch (in.asString()) { + case "Complete this: " -> List.of(Completion.from("The completion")); + default -> throw new RuntimeException("Cannot complete '" + in + "'"); + }; + var llm = new MockLanguageModel.Builder().completer(completer).build(); + + String input = "Complete this: "; + StringPrompt prompt = StringPrompt.from(input); + InferenceParameters options = new InferenceParameters(s -> ""); + for (int i = 0; i < 10; i++) { + var completion = llm.complete(prompt, options).get(0); + prompt = prompt.append(completion); + if (completion.finishReason() == Completion.FinishReason.stop) break; + } + assertEquals("Complete this: The completion", prompt.asString()); + } + +} diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java deleted file mode 100644 index 1baab26f496..00000000000 --- a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.client.openai; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.completion.Completion; -import ai.vespa.llm.completion.StringPrompt; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; - -import java.util.Map; - -/** - * @author bratseth - */ -public class OpenAiClientCompletionTest { - - private static final String apiKey = ""; - - @Test - @Disabled - public void testClient() { - var client = new OpenAiClient(); - var options = Map.of("maxTokens", "10"); - var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?"); - - System.out.print(prompt); - var completion = client.complete(prompt, new InferenceParameters(apiKey, options::get)).get(0); - System.out.print(completion.text()); - } - - @Test - @Disabled - public void testAsyncClient() { - var client = new OpenAiClient(); - var options = Map.of("maxTokens", "10"); - var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?"); - System.out.print(prompt); - var future = client.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { - System.out.print(completion.text()); - }); - System.out.println("\nWaiting for completion...\n\n"); - System.out.println("\nFinished streaming because of " + future.join()); - } - -} diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java deleted file mode 100644 index 24c496a3d2c..00000000000 --- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java +++ /dev/null @@ -1,40 +0,0 @@ -// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.llm.completion; - -import ai.vespa.llm.InferenceParameters; -import ai.vespa.llm.test.MockLanguageModel; -import org.junit.jupiter.api.Test; - -import java.util.List; -import java.util.function.Function; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** - * Tests completion with a mock completer. - * - * @author bratseth - */ -public class CompletionTest { - - @Test - public void testCompletion() { - Function> completer = in -> - switch (in.asString()) { - case "Complete this: " -> List.of(Completion.from("The completion")); - default -> throw new RuntimeException("Cannot complete '" + in + "'"); - }; - var llm = new MockLanguageModel.Builder().completer(completer).build(); - - String input = "Complete this: "; - StringPrompt prompt = StringPrompt.from(input); - InferenceParameters options = new InferenceParameters(s -> ""); - for (int i = 0; i < 10; i++) { - var completion = llm.complete(prompt, options).get(0); - prompt = prompt.append(completion); - if (completion.finishReason() == Completion.FinishReason.stop) break; - } - assertEquals("Complete this: The completion", prompt.asString()); - } - -} -- cgit v1.2.3