aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java')
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java89
1 files changed, 27 insertions, 62 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
index d7334b40963..75308a84faa 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.llm.client.openai;
+import ai.vespa.llm.LanguageModelException;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Prompt;
@@ -28,31 +30,28 @@ import java.util.stream.Stream;
* Currently, only completions are implemented.
*
* @author bratseth
+ * @author lesters
*/
@Beta
public class OpenAiClient implements LanguageModel {
+ private static final String DEFAULT_MODEL = "gpt-3.5-turbo";
private static final String DATA_FIELD = "data: ";
- private final String token;
- private final String model;
- private final double temperature;
- private final long maxTokens;
+ private static final String OPTION_MODEL = "model";
+ private static final String OPTION_TEMPERATURE = "temperature";
+ private static final String OPTION_MAX_TOKENS = "maxTokens";
private final HttpClient httpClient;
- private OpenAiClient(Builder builder) {
- this.token = builder.token;
- this.model = builder.model;
- this.temperature = builder.temperature;
- this.maxTokens = builder.maxTokens;
+ public OpenAiClient() {
this.httpClient = HttpClient.newBuilder().build();
}
@Override
- public List<Completion> complete(Prompt prompt) {
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
try {
- HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray());
+ HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray());
var response = SlimeUtils.jsonToSlime(httpResponse.body()).get();
if ( httpResponse.statusCode() != 200)
throw new IllegalArgumentException(SlimeUtils.toJson(response));
@@ -64,9 +63,11 @@ public class OpenAiClient implements LanguageModel {
}
@Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) {
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> consumer) {
try {
- var request = toRequest(prompt, true);
+ var request = toRequest(prompt, options, true);
var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines());
var completionFuture = new CompletableFuture<Completion.FinishReason>();
@@ -74,8 +75,7 @@ public class OpenAiClient implements LanguageModel {
try {
int responseCode = response.statusCode();
if (responseCode != 200) {
- throw new IllegalArgumentException("Received code " + responseCode + ": " +
- response.body().collect(Collectors.joining()));
+ throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining()));
}
Stream<String> lines = response.body();
@@ -100,28 +100,28 @@ public class OpenAiClient implements LanguageModel {
}
}
- private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException {
- return toRequest(prompt, false);
- }
-
- private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException {
+ private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException {
var slime = new Slime();
var root = slime.setObject();
- root.setString("model", model);
- root.setDouble("temperature", temperature);
+ root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL));
root.setBool("stream", stream);
root.setLong("n", 1);
- if (maxTokens > 0) {
- root.setLong("max_tokens", maxTokens);
- }
+
+ if (options.getDouble(OPTION_TEMPERATURE).isPresent())
+ root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get());
+ if (options.getInt(OPTION_MAX_TOKENS).isPresent())
+ root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get());
+ // Others?
+
var messagesArray = root.setArray("messages");
var messagesObject = messagesArray.addObject();
messagesObject.setString("role", "user");
messagesObject.setString("content", prompt.asString());
- return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions"))
+ var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions");
+ return HttpRequest.newBuilder(new URI(endpoint))
.header("Content-Type", "application/json")
- .header("Authorization", "Bearer " + token)
+ .header("Authorization", "Bearer " + options.getApiKey().orElse(""))
.POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime)))
.build();
}
@@ -152,39 +152,4 @@ public class OpenAiClient implements LanguageModel {
};
}
- public static class Builder {
-
- private final String token;
- private String model = "gpt-3.5-turbo";
- private double temperature = 0.0;
- private long maxTokens = 0;
-
- public Builder(String token) {
- this.token = token;
- }
-
- /** One of the language models listed at https://platform.openai.com/docs/models */
- public Builder model(String model) {
- this.model = model;
- return this;
- }
-
- /** A value between 0 and 2 - higher gives more random/creative output. */
- public Builder temperature(double temperature) {
- this.temperature = temperature;
- return this;
- }
-
- /** Maximum number of tokens to generate */
- public Builder maxTokens(long maxTokens) {
- this.maxTokens = maxTokens;
- return this;
- }
-
- public OpenAiClient build() {
- return new OpenAiClient(this);
- }
-
- }
-
}