diff options
Diffstat (limited to 'vespajlib/src/main/java/ai/vespa')
21 files changed, 121 insertions, 34 deletions
diff --git a/vespajlib/src/main/java/ai/vespa/http/DomainName.java b/vespajlib/src/main/java/ai/vespa/http/DomainName.java index 86242a1af0c..fa6964002bc 100644 --- a/vespajlib/src/main/java/ai/vespa/http/DomainName.java +++ b/vespajlib/src/main/java/ai/vespa/http/DomainName.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.http; import ai.vespa.validation.PatternedStringWrapper; diff --git a/vespajlib/src/main/java/ai/vespa/http/HttpURL.java b/vespajlib/src/main/java/ai/vespa/http/HttpURL.java index ba1a8e08740..94d99fa7aab 100644 --- a/vespajlib/src/main/java/ai/vespa/http/HttpURL.java +++ b/vespajlib/src/main/java/ai/vespa/http/HttpURL.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.http; import ai.vespa.validation.StringWrapper; diff --git a/vespajlib/src/main/java/ai/vespa/http/package-info.java b/vespajlib/src/main/java/ai/vespa/http/package-info.java index e5600c6f82d..ab62a1ec4dd 100644 --- a/vespajlib/src/main/java/ai/vespa/http/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/http/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage package ai.vespa.http; diff --git a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java index 829b74f7bf4..f4b8938934b 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/LanguageModel.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm; import ai.vespa.llm.completion.Completion; @@ -6,6 +6,8 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.api.annotations.Beta; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; /** * Interface to language models. @@ -17,4 +19,6 @@ public interface LanguageModel { List<Completion> complete(Prompt prompt); + CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action); + } diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java index 9145e76a2e0..d7334b40963 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/OpenAiClient.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.client.openai; import ai.vespa.llm.completion.Completion; @@ -18,25 +18,34 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * A client to the OpenAI language model API. Refer to https://platform.openai.com/docs/api-reference/. - * Currently only completions are implemented. + * Currently, only completions are implemented. * * @author bratseth */ @Beta public class OpenAiClient implements LanguageModel { + private static final String DATA_FIELD = "data: "; + private final String token; private final String model; private final double temperature; + private final long maxTokens; + private final HttpClient httpClient; private OpenAiClient(Builder builder) { this.token = builder.token; this.model = builder.model; this.temperature = builder.temperature; + this.maxTokens = builder.maxTokens; this.httpClient = HttpClient.newBuilder().build(); } @@ -54,13 +63,63 @@ public class OpenAiClient implements LanguageModel { } } + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> consumer) { + try { + var request = toRequest(prompt, true); + var futureResponse = httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofLines()); + var completionFuture = new CompletableFuture<Completion.FinishReason>(); + + futureResponse.thenAcceptAsync(response -> { + try { + int responseCode = response.statusCode(); + if (responseCode != 200) { + throw new IllegalArgumentException("Received code " + responseCode + ": " + + response.body().collect(Collectors.joining())); + } + + Stream<String> lines = response.body(); + lines.forEach(line -> { + if (line.startsWith(DATA_FIELD)) { + var root = SlimeUtils.jsonToSlime(line.substring(DATA_FIELD.length())).get(); + var completion = toCompletions(root, "delta").get(0); + consumer.accept(completion); + if (!completion.finishReason().equals(Completion.FinishReason.none)) { + completionFuture.complete(completion.finishReason()); + } + } + }); + } catch (Exception e) { + completionFuture.completeExceptionally(e); + } + }); + return completionFuture; + + } catch (Exception e) { + throw new RuntimeException(e); + } + } + private HttpRequest toRequest(Prompt prompt) throws IOException, URISyntaxException { + return toRequest(prompt, false); + } + + private HttpRequest toRequest(Prompt prompt, boolean stream) throws IOException, URISyntaxException { var slime = new Slime(); var root = slime.setObject(); root.setString("model", model); root.setDouble("temperature", temperature); - root.setString("prompt", prompt.asString()); - return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/completions")) + root.setBool("stream", stream); + root.setLong("n", 1); + if (maxTokens > 0) { + root.setLong("max_tokens", maxTokens); + } + var messagesArray = root.setArray("messages"); + var messagesObject = messagesArray.addObject(); + messagesObject.setString("role", "user"); + messagesObject.setString("content", prompt.asString()); + + return HttpRequest.newBuilder(new URI("https://api.openai.com/v1/chat/completions")) .header("Content-Type", "application/json") .header("Authorization", "Bearer " + token) .POST(HttpRequest.BodyPublishers.ofByteArray(SlimeUtils.toJsonBytes(slime))) @@ -68,21 +127,27 @@ public class OpenAiClient implements LanguageModel { } private List<Completion> toCompletions(Inspector response) { + return toCompletions(response, "message"); + } + + private List<Completion> toCompletions(Inspector response, String field) { List<Completion> completions = new ArrayList<>(); response.field("choices") - .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice))); + .traverse((ArrayTraverser) (__, choice) -> completions.add(toCompletion(choice, field))); return completions; } - private Completion toCompletion(Inspector choice) { - return new Completion(choice.field("text").asString(), - toFinishReason(choice.field("finish_reason").asString())); + private Completion toCompletion(Inspector choice, String field) { + var content = choice.field(field).field("content").asString(); + var finishReason = toFinishReason(choice.field("finish_reason").asString()); + return new Completion(content, finishReason); } private Completion.FinishReason toFinishReason(String finishReasonString) { return switch(finishReasonString) { case "length" -> Completion.FinishReason.length; case "stop" -> Completion.FinishReason.stop; + case "", "null" -> Completion.FinishReason.none; default -> throw new IllegalStateException("Unknown OpenAi completion finish reason '" + finishReasonString + "'"); }; } @@ -90,8 +155,9 @@ public class OpenAiClient implements LanguageModel { public static class Builder { private final String token; - private String model = "text-davinci-003"; - private double temperature = 0; + private String model = "gpt-3.5-turbo"; + private double temperature = 0.0; + private long maxTokens = 0; public Builder(String token) { this.token = token; @@ -109,6 +175,12 @@ public class OpenAiClient implements LanguageModel { return this; } + /** Maximum number of tokens to generate */ + public Builder maxTokens(long maxTokens) { + this.maxTokens = maxTokens; + return this; + } + public OpenAiClient build() { return new OpenAiClient(this); } diff --git a/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java index 8b8b99308b0..2593d919499 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/llm/client/openai/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi package ai.vespa.llm.client.openai; @@ -8,4 +8,4 @@ import com.yahoo.osgi.annotation.ExportPackage; /** * Client to OpenAi's large language models. - */
\ No newline at end of file + */ diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java index 30645b5151f..ea784013812 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Completion.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.completion; import com.yahoo.api.annotations.Beta; @@ -19,8 +19,10 @@ public record Completion(String text, FinishReason finishReason) { length, /** The completion is the predicted ending of the prompt. */ - stop + stop, + /** The completion is not finished yet, more tokens are incoming. */ + none } public Completion(String text, FinishReason finishReason) { diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java b/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java index d5d0247d6b0..44dfb8499a8 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/Prompt.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.completion; import com.yahoo.api.annotations.Beta; diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java b/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java index e8392ca992e..9e702c79a7a 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/StringPrompt.java @@ -1,3 +1,4 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.completion; import com.yahoo.api.annotations.Beta; diff --git a/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java index 79898c694ca..57c2b3f3364 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/llm/completion/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi package ai.vespa.llm.completion; @@ -8,4 +8,4 @@ import com.yahoo.osgi.annotation.ExportPackage; /** * Classes for generating text completions with language models. - */
\ No newline at end of file + */ diff --git a/vespajlib/src/main/java/ai/vespa/llm/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/package-info.java index 04fc24c51ee..8640f652ad4 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/llm/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi package ai.vespa.llm; @@ -8,4 +8,4 @@ import com.yahoo.osgi.annotation.ExportPackage; /** * API for working with large language models. - */
\ No newline at end of file + */ diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java index d47f43c55b2..db1b42fbbac 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java +++ b/vespajlib/src/main/java/ai/vespa/llm/test/MockLanguageModel.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.llm.test; import ai.vespa.llm.LanguageModel; @@ -7,6 +7,8 @@ import ai.vespa.llm.completion.Prompt; import com.yahoo.api.annotations.Beta; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; import java.util.function.Function; /** @@ -26,6 +28,11 @@ public class MockLanguageModel implements LanguageModel { return completer.apply(prompt); } + @Override + public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, Consumer<Completion> action) { + throw new RuntimeException("Not implemented"); + } + public static class Builder { private Function<Prompt, List<Completion>> completer = prompt -> List.of(Completion.from("")); diff --git a/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java b/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java index 0d51815fd6d..ab3b7acc657 100644 --- a/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/llm/test/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage @PublicApi package ai.vespa.llm.test; @@ -8,4 +8,4 @@ package ai.vespa.llm.test; */ import com.yahoo.api.annotations.PublicApi; -import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/vespajlib/src/main/java/ai/vespa/net/CidrBlock.java b/vespajlib/src/main/java/ai/vespa/net/CidrBlock.java index 7bf1970663e..751f3ef8f32 100644 --- a/vespajlib/src/main/java/ai/vespa/net/CidrBlock.java +++ b/vespajlib/src/main/java/ai/vespa/net/CidrBlock.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.net; import com.google.common.net.InetAddresses; diff --git a/vespajlib/src/main/java/ai/vespa/net/package-info.java b/vespajlib/src/main/java/ai/vespa/net/package-info.java index 5d5bb613870..0240bb8ffa0 100644 --- a/vespajlib/src/main/java/ai/vespa/net/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/net/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage package ai.vespa.net; diff --git a/vespajlib/src/main/java/ai/vespa/validation/Name.java b/vespajlib/src/main/java/ai/vespa/validation/Name.java index a6ab456c285..0e30557b89b 100644 --- a/vespajlib/src/main/java/ai/vespa/validation/Name.java +++ b/vespajlib/src/main/java/ai/vespa/validation/Name.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.validation; import java.util.regex.Pattern; diff --git a/vespajlib/src/main/java/ai/vespa/validation/PathValidator.java b/vespajlib/src/main/java/ai/vespa/validation/PathValidator.java index 0ae81e2315d..8f96cac76ac 100644 --- a/vespajlib/src/main/java/ai/vespa/validation/PathValidator.java +++ b/vespajlib/src/main/java/ai/vespa/validation/PathValidator.java @@ -1,3 +1,4 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.validation; import java.nio.file.Path; diff --git a/vespajlib/src/main/java/ai/vespa/validation/PatternedStringWrapper.java b/vespajlib/src/main/java/ai/vespa/validation/PatternedStringWrapper.java index b97a7ed9cc1..9a3fb17fa90 100644 --- a/vespajlib/src/main/java/ai/vespa/validation/PatternedStringWrapper.java +++ b/vespajlib/src/main/java/ai/vespa/validation/PatternedStringWrapper.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.validation; import java.util.regex.Pattern; diff --git a/vespajlib/src/main/java/ai/vespa/validation/StringWrapper.java b/vespajlib/src/main/java/ai/vespa/validation/StringWrapper.java index 45241f97ce9..95476be2ad8 100644 --- a/vespajlib/src/main/java/ai/vespa/validation/StringWrapper.java +++ b/vespajlib/src/main/java/ai/vespa/validation/StringWrapper.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.validation; import static java.util.Objects.requireNonNull; diff --git a/vespajlib/src/main/java/ai/vespa/validation/Validation.java b/vespajlib/src/main/java/ai/vespa/validation/Validation.java index 292cb2f0aa5..c03aa71c1bb 100644 --- a/vespajlib/src/main/java/ai/vespa/validation/Validation.java +++ b/vespajlib/src/main/java/ai/vespa/validation/Validation.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.validation; import com.yahoo.yolean.Exceptions; @@ -69,4 +69,4 @@ public class Validation { throw new IllegalArgumentException(description + ", but got: '" + value + "'"); } -}
\ No newline at end of file +} diff --git a/vespajlib/src/main/java/ai/vespa/validation/package-info.java b/vespajlib/src/main/java/ai/vespa/validation/package-info.java index edbab3a6fd1..1612537004c 100644 --- a/vespajlib/src/main/java/ai/vespa/validation/package-info.java +++ b/vespajlib/src/main/java/ai/vespa/validation/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage package ai.vespa.validation; |