diff options
Diffstat (limited to 'vespajlib/src')
8 files changed, 160 insertions, 86 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java new file mode 100755 index 00000000000..a942e5090e5 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/InferenceParameters.java @@ -0,0 +1,76 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +import java.util.Objects; +import java.util.Optional; +import java.util.function.Consumer; +import java.util.function.Function; + +/** + * Parameters for inference to language models. Parameters are typically + * supplied from searchers or processors and comes from query strings, + * headers, or other sources. Which parameters are available depends on + * the language model used. + * + * author lesters + */ +@Beta +public class InferenceParameters { + + private String apiKey; + private String endpoint; + private final Function<String, String> options; + + public InferenceParameters(Function<String, String> options) { + this(null, options); + } + + public InferenceParameters(String apiKey, Function<String, String> options) { + this.apiKey = apiKey; + this.options = Objects.requireNonNull(options); + } + + public void setApiKey(String apiKey) { + this.apiKey = apiKey; + } + + public Optional<String> getApiKey() { + return Optional.ofNullable(apiKey); + } + + public void setEndpoint(String endpoint) { + this.endpoint = endpoint; + } + + public Optional<String> getEndpoint() { + return Optional.ofNullable(endpoint); + } + + public Optional<String> get(String option) { + return Optional.ofNullable(options.apply(option)); + } + + public Optional<Double> getDouble(String option) { + try { + return Optional.of(Double.parseDouble(options.apply(option))); + } catch (Exception e) { + return Optional.empty(); + } + } + + public Optional<Integer> getInt(String option) { + try { + return Optional.of(Integer.parseInt(options.apply(option))); + } catch (Exception e) { + return Optional.empty(); + } + } + + public void ifPresent(String option, Consumer<String> func) { + get(option).ifPresent(func); + } + +} + diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java index f4b8938934b..059f25fadb4 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java @@ -17,8 +17,10 @@ import java.util.function.Consumer; @Beta public interface LanguageModel { - List<Completion> complete(Prompt prompt); + List<Completion> complete(Prompt prompt, InferenceParameters options); - CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action); + CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> consumer); } diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java new file mode 100755 index 00000000000..b5dbf615c08 --- /dev/null +++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModelException.java @@ -0,0 +1,19 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.llm; + +import com.yahoo.api.annotations.Beta; + +@Beta +public class LanguageModelException extends RuntimeException { + + private final int code; + + public LanguageModelException(int code, String message) { + super(message); + this.code = code; + } + + public int code() { + return code; + } +} diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index d7334b40963..75308a84faa 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.client.openai; +import ai.vespa.llm.LanguageModelException; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.LanguageModel; import ai.vespa.llm.completion.Prompt; @@ -28,31 +30,28 @@ import java.util.stream.Stream; * Currently, only completions are implemented. * * @author bratseth + * @author lesters */ @Beta public class OpenAiClient implements LanguageModel { + private static final String DEFAULT_MODEL = "gpt-3.5-turbo"; private static final String DATA_FIELD = "data: "; - private final String token; - private final String model; - private final double temperature; - private final long maxTokens; + private static final String OPTION_MODEL = "model"; + private static final String OPTION_TEMPERATURE = "temperature"; + private static final String OPTION_MAX_TOKENS = "maxTokens"; private final HttpClient httpClient; - private OpenAiClient(Builder builder) { - this.token = builder.token; - this.model = builder.model; - this.temperature = builder.temperature; - this.maxTokens = builder.maxTokens; + public OpenAiClient() { this.httpClient = HttpClient.newBuilder().build(); } @Override - public List<Completion> complete(Prompt prompt) { + public List<Completion> complete(Prompt prompt, InferenceParameters options) { try { - HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt), HttpResponse.BodyHandlers.ofByteArray()); + HttpResponse<byte[]> httpResponse = httpClient.send(toRequest(prompt, options, false), HttpResponse.BodyHandlers.ofByteArray()); var response = SlimeUtils.jsonToSlime(httpResponse.body()).get(); if ( httpResponse.statusCode() != 200) throw new IllegalArgumentException(SlimeUtils.toJson(response)); @@ -64,9 +63,11 @@ public class OpenAiClient implements LanguageModel { } @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) { + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> consumer) { try { - var request = toRequest(prompt, true); + var request = toRequest(prompt, options, true); var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); var completionFuture = new CompletableFuture<Completion.FinishReason>(); @@ -74,8 +75,7 @@ public class OpenAiClient implements LanguageModel { try { int responseCode = response.statusCode(); if (responseCode != 200) { - throw new IllegalArgumentException("Received code " + responseCode + ": " + - response.body().collect(Collectors.joining())); + throw new LanguageModelException(responseCode, response.body().collect(Collectors.joining())); } Stream<String> lines = response.body(); @@ -100,28 +100,28 @@ public class OpenAiClient implements LanguageModel { } } - private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { - return toRequest(prompt, false); - } - - private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException { + private HttpRequest toRequest(Prompt prompt, InferenceParameters options, boolean stream) throws IOException, URISyntaxException { var slime = new Slime(); var root = slime.setObject(); - root.setString("model", model); - root.setDouble("temperature", temperature); + root.setString("model", options.get(OPTION_MODEL).orElse(DEFAULT_MODEL)); root.setBool("stream", stream); root.setLong("n", 1); - if (maxTokens > 0) { - root.setLong("max_tokens", maxTokens); - } + + if (options.getDouble(OPTION_TEMPERATURE).isPresent()) + root.setDouble("temperature", options.getDouble(OPTION_TEMPERATURE).get()); + if (options.getInt(OPTION_MAX_TOKENS).isPresent()) + root.setLong("max_tokens", options.getInt(OPTION_MAX_TOKENS).get()); + // Others? + var messagesArray = root.setArray("messages"); var messagesObject = messagesArray.addObject(); messagesObject.setString("role", "user"); messagesObject.setString("content", prompt.asString()); - return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions")) + var endpoint = options.getEndpoint().orElse("https://api.openai.com/v1/chat/completions"); + return HttpRequest.newBuilder(new URI(endpoint)) .header("Content-Type", "application/json") - .header("Authorization", "Bearer " + token) + .header("Authorization", "Bearer " + options.getApiKey().orElse("")) .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) .build(); } @@ -152,39 +152,4 @@ public class OpenAiClient implements LanguageModel { }; } - public static class Builder { - - private final String token; - private String model = "gpt-3.5-turbo"; - private double temperature = 0.0; - private long maxTokens = 0; - - public Builder(String token) { - this.token = token; - } - - /** One of the language models listed at https://platform.openai.com/docs/models */ - public Builder model(String model) { - this.model = model; - return this; - } - - /** A value between 0 and 2 - higher gives more random/creative output. */ - public Builder temperature(double temperature) { - this.temperature = temperature; - return this; - } - - /** Maximum number of tokens to generate */ - public Builder maxTokens(long maxTokens) { - this.maxTokens = maxTokens; - return this; - } - - public OpenAiClient build() { - return new OpenAiClient(this); - } - - } - } diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java index ea784013812..91d0ad9bd02 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java @@ -22,7 +22,10 @@ public record Completion(String text, FinishReason finishReason) { stop, /** The completion is not finished yet, more tokens are incoming. */ - none + none, + + /** An error occurred while generating the completion */ + error } public Completion(String text, FinishReason finishReason) { @@ -37,7 +40,11 @@ public record Completion(String text, FinishReason finishReason) { public FinishReason finishReason() { return finishReason; } public static Completion from(String text) { - return new Completion(text, FinishReason.stop); + return from(text, FinishReason.stop); + } + + public static Completion from(String text, FinishReason reason) { + return new Completion(text, reason); } } diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java index db1b42fbbac..0e757a1f1e7 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java @@ -2,6 +2,7 @@ package ai.vespa.llm.test; import ai.vespa.llm.LanguageModel; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.Prompt; import com.yahoo.api.annotations.Beta; @@ -24,12 +25,14 @@ public class MockLanguageModel implements LanguageModel { } @Override - public List<Completion> complete(Prompt prompt) { + public List<Completion> complete(Prompt prompt, InferenceParameters options) { return completer.apply(prompt); } @Override - public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action) { + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, + InferenceParameters options, + Consumer<Completion> action) { throw new RuntimeException("Not implemented"); } diff --git a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java index 45ef7e270aa..1baab26f496 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/client/openai/OpenAiClientCompletionTest.java @@ -1,46 +1,46 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.client.openai; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.completion.Completion; import ai.vespa.llm.completion.StringPrompt; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import java.util.Map; + /** * @author bratseth */ public class OpenAiClientCompletionTest { - private static final String apiKey = "your-api-key-here"; + private static final String apiKey = "<your-api-key-here>"; @Test @Disabled public void testClient() { - var client = new OpenAiClient.Builder(apiKey).maxTokens(10).build(); - String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?\n\n"; - StringPrompt prompt = StringPrompt.from(input); + var client = new OpenAiClient(); + var options = Map.of("maxTokens", "10"); + var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"); + System.out.print(prompt); - for (int i = 0; i < 10; i++) { - var completion = client.complete(prompt).get(0); - System.out.print(completion.text()); - if (completion.finishReason() == Completion.FinishReason.stop) break; - prompt = prompt.append(completion.text()); - } + var completion = client.complete(prompt, new InferenceParameters(apiKey, options::get)).get(0); + System.out.print(completion.text()); } @Test @Disabled public void testAsyncClient() { - var client = new OpenAiClient.Builder(apiKey).build(); - String input = "You are an unhelpful assistant who never answers questions straightforwardly. " + - "Be as long-winded as possible. Are humans smarter than cats?\n\n"; - StringPrompt prompt = StringPrompt.from(input); + var client = new OpenAiClient(); + var options = Map.of("maxTokens", "10"); + var prompt = StringPrompt.from("You are an unhelpful assistant who never answers questions straightforwardly. " + + "Be as long-winded as possible. Are humans smarter than cats?"); System.out.print(prompt); - var future = client.completeAsync(prompt, completion -> { + var future = client.completeAsync(prompt, new InferenceParameters(apiKey, options::get), completion -> { System.out.print(completion.text()); }); - System.out.println("Waiting for completion..."); + System.out.println("\nWaiting for completion...\n\n"); System.out.println("\nFinished streaming because of " + future.join()); } diff --git a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java index 7407eb526e7..24c496a3d2c 100644 --- a/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java +++ b/vespajlib/src/test/java/ai/vespa/llm/completion/CompletionTest.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.completion; +import ai.vespa.llm.InferenceParameters; import ai.vespa.llm.test.MockLanguageModel; import org.junit.jupiter.api.Test; @@ -27,8 +28,9 @@ public class CompletionTest { String input = "Complete this: "; StringPrompt prompt = StringPrompt.from(input); + InferenceParameters options = new InferenceParameters(s -> ""); for (int i = 0; i < 10; i++) { - var completion = llm.complete(prompt).get(0); + var completion = llm.complete(prompt, options).get(0); prompt = prompt.append(completion); if (completion.finishReason() == Completion.FinishReason.stop) break; } |