aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/ai
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/ai')
-rwxr-xr-xvespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java76
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java6
-rwxr-xr-xvespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java19
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java89
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java11
-rw-r--r--vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java7
6 files changed, 140 insertions, 68 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
new file mode 100755
index 00000000000..a942e5090e5
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java
@@ -0,0 +1,76 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+import java.util.Objects;
+import java.util.Optional;
+import java.util.function.Consumer;
+import java.util.function.Function;
+
+/**
+ * Parameters for inference to language models. Parameters are typically
+ * supplied from searchers or processors and comes from query strings,
+ * headers, or other sources. Which parameters are available depends on
+ * the language model used.
+ *
+ * author lesters
+ */
+@Beta
+public class InferenceParameters {
+
+ private String apiKey;
+ private String endpoint;
+ private final Function<String, String> options;
+
+ public InferenceParameters(Function<String, String> options) {
+ this(null, options);
+ }
+
+ public InferenceParameters(String apiKey, Function<String, String> options) {
+ this.apiKey = apiKey;
+ this.options = Objects.requireNonNull(options);
+ }
+
+ public void setApiKey(String apiKey) {
+ this.apiKey = apiKey;
+ }
+
+ public Optional<String> getApiKey() {
+ return Optional.ofNullable(apiKey);
+ }
+
+ public void setEndpoint(String endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ public Optional<String> getEndpoint() {
+ return Optional.ofNullable(endpoint);
+ }
+
+ public Optional<String> get(String option) {
+ return Optional.ofNullable(options.apply(option));
+ }
+
+ public Optional<Double> getDouble(String option) {
+ try {
+ return Optional.of(Double.parseDouble(options.apply(option)));
+ } catch (Exception e) {
+ return Optional.empty();
+ }
+ }
+
+ public Optional<Integer> getInt(String option) {
+ try {
+ return Optional.of(Integer.parseInt(options.apply(option)));
+ } catch (Exception e) {
+ return Optional.empty();
+ }
+ }
+
+ public void ifPresent(String option, Consumer<String> func) {
+ get(option).ifPresent(func);
+ }
+
+}
+
diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java
index f4b8938934b..059f25fadb4 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java
@@ -17,8 +17,10 @@ import java.util.function.Consumer;
@Beta
public interface LanguageModel {
- List<Completion> complete(Prompt prompt);
+ List<Completion> complete(Prompt prompt, InferenceParameters options);
- CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action);
+ CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> consumer);
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java
new file mode 100755
index 00000000000..b5dbf615c08
--- /dev/null
+++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java
@@ -0,0 +1,19 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm;
+
+import com.yahoo.api.annotations.Beta;
+
+@Beta
+public class LanguageModelException extends RuntimeException {
+
+ private final int code;
+
+ public LanguageModelException(int code, String message) {
+ super(message);
+ this.code = code;
+ }
+
+ public int code() {
+ return code;
+ }
+}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
index d7334b40963..75308a84faa 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.llm.client.openai;
+import ai.vespa.llm.LanguageModelException;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Prompt;
@@ -28,31 +30,28 @@ import java.util.stream.Stream;
* Currently, only completions are implemented.
*
* @author bratseth
+ * @author lesters
*/
@Beta
public class OpenAiClient implements LanguageModel {
+ private static final String DEFAULT_MODEL = "gpt-3.5-turbo";
private static final String DATA_FIELD = "data: ";
- private final String token;
- private final String model;
- private final double temperature;
- private final long maxTokens;
+ private static final String OPTION_MODEL = "model";
+ private static final String OPTION_TEMPERATURE = "temperature";
+ private static final String OPTION_MAX_TOKENS = "maxTokens";
private final HttpClient httpClient;
- private OpenAiClient(Builder builder) {
- this.token = builder.token;
- this.model = builder.model;
- this.temperature = builder.temperature;
- this.maxTokens = builder.maxTokens;
+ public OpenAiClient() {
this.httpClient = HttpClient.newBuilder().build();
}
@Override
- public List<Completion> complete(Prompt prompt) {
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
try {
- HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray());
+ HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray());
var response = SlimeUtils.jsonToSlime(httpResponse.body()).get();
if ( httpResponse.statusCode() != 200)
throw new IllegalArgumentException(SlimeUtils.toJson(response));
@@ -64,9 +63,11 @@ public class OpenAiClient implements LanguageModel {
}
@Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) {
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> consumer) {
try {
- var request = toRequest(prompt, true);
+ var request = toRequest(prompt, options, true);
var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines());
var completionFuture = new CompletableFuture<Completion.FinishReason>();
@@ -74,8 +75,7 @@ public class OpenAiClient implements LanguageModel {
try {
int responseCode = response.statusCode();
if (responseCode != 200) {
- throw new IllegalArgumentException("Received code " + responseCode + ": " +
- response.body().collect(Collectors.joining()));
+ throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining()));
}
Stream<String> lines = response.body();
@@ -100,28 +100,28 @@ public class OpenAiClient implements LanguageModel {
}
}
- private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException {
- return toRequest(prompt, false);
- }
-
- private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException {
+ private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException {
var slime = new Slime();
var root = slime.setObject();
- root.setString("model", model);
- root.setDouble("temperature", temperature);
+ root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL));
root.setBool("stream", stream);
root.setLong("n", 1);
- if (maxTokens > 0) {
- root.setLong("max_tokens", maxTokens);
- }
+
+ if (options.getDouble(OPTION_TEMPERATURE).isPresent())
+ root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get());
+ if (options.getInt(OPTION_MAX_TOKENS).isPresent())
+ root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get());
+ // Others?
+
var messagesArray = root.setArray("messages");
var messagesObject = messagesArray.addObject();
messagesObject.setString("role", "user");
messagesObject.setString("content", prompt.asString());
- return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions"))
+ var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions");
+ return HttpRequest.newBuilder(new URI(endpoint))
.header("Content-Type", "application/json")
- .header("Authorization", "Bearer " + token)
+ .header("Authorization", "Bearer " + options.getApiKey().orElse(""))
.POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime)))
.build();
}
@@ -152,39 +152,4 @@ public class OpenAiClient implements LanguageModel {
};
}
- public static class Builder {
-
- private final String token;
- private String model = "gpt-3.5-turbo";
- private double temperature = 0.0;
- private long maxTokens = 0;
-
- public Builder(String token) {
- this.token = token;
- }
-
- /** One of the language models listed at https://platform.openai.com/docs/models */
- public Builder model(String model) {
- this.model = model;
- return this;
- }
-
- /** A value between 0 and 2 - higher gives more random/creative output. */
- public Builder temperature(double temperature) {
- this.temperature = temperature;
- return this;
- }
-
- /** Maximum number of tokens to generate */
- public Builder maxTokens(long maxTokens) {
- this.maxTokens = maxTokens;
- return this;
- }
-
- public OpenAiClient build() {
- return new OpenAiClient(this);
- }
-
- }
-
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java
index ea784013812..91d0ad9bd02 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java
@@ -22,7 +22,10 @@ public record Completion(String text, FinishReason finishReason) {
stop,
/** The completion is not finished yet, more tokens are incoming. */
- none
+ none,
+
+ /** An error occurred while generating the completion */
+ error
}
public Completion(String text, FinishReason finishReason) {
@@ -37,7 +40,11 @@ public record Completion(String text, FinishReason finishReason) {
public FinishReason finishReason() { return finishReason; }
public static Completion from(String text) {
- return new Completion(text, FinishReason.stop);
+ return from(text, FinishReason.stop);
+ }
+
+ public static Completion from(String text, FinishReason reason) {
+ return new Completion(text, reason);
}
}
diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
index db1b42fbbac..0e757a1f1e7 100644
--- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
+++ b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java
@@ -2,6 +2,7 @@
package ai.vespa.llm.test;
import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.api.annotations.Beta;
@@ -24,12 +25,14 @@ public class MockLanguageModel implements LanguageModel {
}
@Override
- public List<Completion> complete(Prompt prompt) {
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
return completer.apply(prompt);
}
@Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action) {
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters options,
+ Consumer<Completion> action) {
throw new RuntimeException("Not implemented");
}