diff options
Diffstat (limited to 'vespajlib/src/main/java/ai/vespa/llm/client/openai')
-rw-r--r-- | vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java | 89 |
1 files changed, 27 insertions, 62 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index d7334b40963..75308a84faa 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.client.openai; +import ai.vespa.llm.LanguageModelException; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.LanguageModel; import ai.vespa.llm.completion.Prompt; @@ -28,31 +30,28 @@ import java.util.stream.Stream; * Currently, only completions are implemented. * * @author bratseth + * @author lesters */ @Beta public class OpenAiClient implements LanguageModel { + private static final String DEFAULT_MODEL = "gpt-3.5-turbo"; private static final String DATA_FIELD = "data: "; - private final String token; - private final String model; - private final double temperature; - private final long maxTokens; + private static final String OPTION_MODEL = "model"; + private static final String OPTION_TEMPERATURE = "temperature"; + private static final String OPTION_MAX_TOKENS = "maxTokens"; private final HttpClient httpClient; - private OpenAiClient(Builder builder) { - this.token = builder.token; - this.model = builder.model; - this.temperature = builder.temperature; - this.maxTokens = builder.maxTokens; + public OpenAiClient() { this.httpClient = HttpClient.newBuilder().build(); } @Override - public List<Completion> complete(Prompt prompt) { + public List<Completion> complete(Prompt prompt, InferenceParameters options) { try { - HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray()); + HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray()); var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); if ( httpResponse.statusCode() != 200) throw new IllegalArgumentException(SlimeUtils.toJson(response)); @@ -64,9 +63,11 @@ public class OpenAiClient implements LanguageModel { } @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) { + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> consumer) { try { - var request = toRequest(prompt, true); + var request = toRequest(prompt, options, true); var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); var completionFuture = new CompletableFuture<Completion.FinishReason>(); @@ -74,8 +75,7 @@ public class OpenAiClient implements LanguageModel { try { int responseCode = response.statusCode(); if (responseCode != 200) { - throw new IllegalArgumentException("Received code " + responseCode + ": " + - response.body().collect(Collectors.joining())); + throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining())); } Stream<String> lines = response.body(); @@ -100,28 +100,28 @@ public class OpenAiClient implements LanguageModel { } } - private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { - return toRequest(prompt, false); - } - - private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException { + private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException { var slime = new Slime(); var root = slime.setObject(); - root.setString("model", model); - root.setDouble("temperature", temperature); + root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL)); root.setBool("stream", stream); root.setLong("n", 1); - if (maxTokens > 0) { - root.setLong("max_tokens", maxTokens); - } + + if (options.getDouble(OPTION_TEMPERATURE).isPresent()) + root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get()); + if (options.getInt(OPTION_MAX_TOKENS).isPresent()) + root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get()); + // Others? + var messagesArray = root.setArray("messages"); var messagesObject = messagesArray.addObject(); messagesObject.setString("role", "user"); messagesObject.setString("content", prompt.asString()); - return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions")) + var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions"); + return HttpRequest.newBuilder(new URI(endpoint)) .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + token) + .header("Authorization", "Bearer " + options.getApiKey().orElse("")) .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) .build(); } @@ -152,39 +152,4 @@ public class OpenAiClient implements LanguageModel { }; } - public static class Builder { - - private final String token; - private String model = "gpt-3.5-turbo"; - private double temperature = 0.0; - private long maxTokens = 0; - - public Builder(String token) { - this.token = token; - } - - /** One of the language models listed at https://platform.openai.com/docs/models */ - public Builder model(String model) { - this.model = model; - return this; - } - - /** A value between 0 and 2 - higher gives more random/creative output. */ - public Builder temperature(double temperature) { - this.temperature = temperature; - return this; - } - - /** Maximum number of tokens to generate */ - public Builder maxTokens(long maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public OpenAiClient build() { - return new OpenAiClient(this); - } - - } - } |