summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2024-04-15 21:18:16 +0200
committerGitHub <noreply@github.com>2024-04-15 21:18:16 +0200
commit4212a0b72398dbd5c555f23f3501d265b9a8724e (patch)
tree3f94e22d38d63c456a3bb202b4f3787ecff5ded6
parent44a866c0d648543c04567503990c03c36403d86d (diff)
parented62b750494822cc67a328390178754512baf032 (diff)
Merge pull request #30925 from vespa-engine/revert-30916-lesters/add-local-llms-2
Revert "Lesters/add local llms 2"
-rw-r--r--application/pom.xml10
-rw-r--r--client/go/internal/admin/vespa-wrapper/standalone/start.go1
-rw-r--r--cloud-tenant-base-dependencies-enforcer/pom.xml3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java7
-rw-r--r--container-dependencies-enforcer/pom.xml3
-rw-r--r--container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java7
-rw-r--r--container-llama/src/main/java/de/kherud/llama/package-info.java8
-rw-r--r--container-search-and-docproc/pom.xml22
-rw-r--r--container-search/abi-spec.json108
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java (renamed from model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java)0
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java (renamed from model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java)0
-rw-r--r--container-search/src/main/java/ai/vespa/llm/clients/package-info.java (renamed from model-integration/src/main/java/ai/vespa/llm/clients/package-info.java)0
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java84
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java36
-rwxr-xr-xcontainer-search/src/main/resources/configdefinitions/llm-client.def (renamed from model-integration/src/main/resources/configdefinitions/llm-client.def)0
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java (renamed from model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java)0
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java (renamed from model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java)0
-rw-r--r--container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java (renamed from model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java)0
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java150
-rw-r--r--container-test/pom.xml10
-rw-r--r--fat-model-dependencies/pom.xml14
-rw-r--r--model-integration/abi-spec.json182
-rw-r--r--model-integration/pom.xml12
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java126
-rwxr-xr-xmodel-integration/src/main/resources/configdefinitions/llm-local-client.def29
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java181
-rw-r--r--model-integration/src/test/models/llm/tinyllm.ggufbin1185376 -> 0 bytes
-rw-r--r--standalone-container/pom.xml1
30 files changed, 164 insertions, 856 deletions
diff --git a/application/pom.xml b/application/pom.xml
index 7c34cd966eb..f5704541308 100644
--- a/application/pom.xml
+++ b/application/pom.xml
@@ -43,16 +43,6 @@
<groupId>com.yahoo.vespa</groupId>
<artifactId>model-integration</artifactId>
<version>${project.version}</version>
- <exclusions>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.lz4</groupId>
- <artifactId>lz4-java</artifactId>
- </exclusion>
- </exclusions>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
diff --git a/client/go/internal/admin/vespa-wrapper/standalone/start.go b/client/go/internal/admin/vespa-wrapper/standalone/start.go
index 16e76562b99..a3703ce930c 100644
--- a/client/go/internal/admin/vespa-wrapper/standalone/start.go
+++ b/client/go/internal/admin/vespa-wrapper/standalone/start.go
@@ -41,7 +41,6 @@ func StartStandaloneContainer(extraArgs []string) int {
c := jvm.NewStandaloneContainer(serviceName)
jvmOpts := c.JvmOptions()
jvmOpts.AddOption("-DOnnxBundleActivator.skip=true")
- jvmOpts.AddOption("-DLlamaBundleActivator.skip=true")
for _, extra := range extraArgs {
jvmOpts.AddOption(extra)
}
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml
index 6d6ac4a81e7..eff4e4125e9 100644
--- a/cloud-tenant-base-dependencies-enforcer/pom.xml
+++ b/cloud-tenant-base-dependencies-enforcer/pom.xml
@@ -97,7 +97,6 @@
<include>com.yahoo.vespa:messagebus:*:provided</include>
<include>com.yahoo.vespa:metrics:*:provided</include>
<include>com.yahoo.vespa:model-evaluation:*:provided</include>
- <include>com.yahoo.vespa:model-integration:*:provided</include>
<include>com.yahoo.vespa:opennlp-linguistics:*:provided</include>
<include>com.yahoo.vespa:predicate-search-core:*:provided</include>
<include>com.yahoo.vespa:provided-dependencies:*:provided</include>
@@ -123,6 +122,7 @@
<include>com.yahoo.vespa:indexinglanguage:*:test</include>
<include>com.yahoo.vespa:logd:*:test</include>
<include>com.yahoo.vespa:metrics-proxy:*:test</include>
+ <include>com.yahoo.vespa:model-integration:*:test</include>
<include>com.yahoo.vespa:searchsummary:*:test</include>
<include>com.yahoo.vespa:standalone-container:*:test</include>
<include>com.yahoo.vespa:storage:*:test</include>
@@ -141,7 +141,6 @@
<include>com.microsoft.onnxruntime:onnxruntime:jar:${onnxruntime.vespa.version}:test</include>
<include>com.thaiopensource:jing:20091111:test</include>
<include>commons-codec:commons-codec:${commons-codec.vespa.version}:test</include>
- <include>de.kherud:llama:${kherud.llama.vespa.version}:test</include>
<include>io.airlift:aircompressor:${aircompressor.vespa.version}:test</include>
<include>io.airlift:airline:${airline.vespa.version}:test</include>
<include>io.prometheus:simpleclient:${prometheus.client.vespa.version}:test</include>
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 5be1690f0dc..62979404025 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -31,7 +31,6 @@ public class ContainerModelEvaluation implements
public final static String EVALUATION_BUNDLE_NAME = "model-evaluation";
public final static String INTEGRATION_BUNDLE_NAME = "model-integration";
public final static String ONNXRUNTIME_BUNDLE_NAME = "container-onnxruntime.jar";
- public final static String LLAMA_BUNDLE_NAME = "container-llama.jar";
public final static String ONNX_RUNTIME_CLASS = "ai.vespa.modelintegration.evaluator.OnnxRuntime";
private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
index e801884a73a..9f91f6bf5e1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
@@ -14,7 +14,6 @@ import java.util.stream.Stream;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.EVALUATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
-import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LLAMA_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.ONNXRUNTIME_BUNDLE_NAME;
/**
@@ -64,8 +63,7 @@ public class PlatformBundles {
"lucene-linguistics",
EVALUATION_BUNDLE_NAME,
INTEGRATION_BUNDLE_NAME,
- ONNXRUNTIME_BUNDLE_NAME,
- LLAMA_BUNDLE_NAME
+ ONNXRUNTIME_BUNDLE_NAME
);
private static Set<Path> toBundlePaths(String... bundleNames) {
@@ -150,8 +148,7 @@ public class PlatformBundles {
com.yahoo.vespa.streamingvisitors.StreamingBackend.class.getName(),
ai.vespa.search.llm.LLMSearcher.class.getName(),
ai.vespa.search.llm.RAGSearcher.class.getName(),
- ai.vespa.llm.clients.OpenAI.class.getName(),
- ai.vespa.llm.clients.LocalLLM.class.getName()
+ ai.vespa.llm.clients.OpenAI.class.getName()
);
}
diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml
index 6451e32941c..a06365abbeb 100644
--- a/container-dependencies-enforcer/pom.xml
+++ b/container-dependencies-enforcer/pom.xml
@@ -116,7 +116,6 @@
<include>com.yahoo.vespa:messagebus:*:provided</include>
<include>com.yahoo.vespa:metrics:*:provided</include>
<include>com.yahoo.vespa:model-evaluation:*:provided</include>
- <include>com.yahoo.vespa:model-integration:*:provided</include>
<include>com.yahoo.vespa:opennlp-linguistics:*:provided</include>
<include>com.yahoo.vespa:predicate-search-core:*:provided</include>
<include>com.yahoo.vespa:provided-dependencies:*:provided</include>
@@ -140,6 +139,7 @@
<include>com.yahoo.vespa:indexinglanguage:*:test</include>
<include>com.yahoo.vespa:logd:*:test</include>
<include>com.yahoo.vespa:metrics-proxy:*:test</include>
+ <include>com.yahoo.vespa:model-integration:*:test</include>
<include>com.yahoo.vespa:searchsummary:*:test</include>
<include>com.yahoo.vespa:standalone-container:*:test</include>
<include>com.yahoo.vespa:storage:*:test</include>
@@ -152,7 +152,6 @@
<include>com.google.protobuf:protobuf-java:${protobuf.vespa.version}:test</include>
<include>com.ibm.icu:icu4j:${icu4j.vespa.version}:test</include>
<include>com.microsoft.onnxruntime:onnxruntime:${onnxruntime.vespa.version}:test</include>
- <include>de.kherud:llama:${kherud.llama.vespa.version}:test</include>
<include>com.thaiopensource:jing:20091111:test</include>
<include>commons-codec:commons-codec:${commons-codec.vespa.version}:test</include>
<include>io.airlift:aircompressor:${aircompressor.vespa.version}:test</include>
diff --git a/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java b/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java
index 846a2008858..11ba05e363d 100644
--- a/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java
+++ b/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java
@@ -13,19 +13,12 @@ import java.util.logging.Logger;
**/
public class LlamaBundleActivator implements BundleActivator {
- private static final String SKIP_SUFFIX = ".skip";
- private static final String SKIP_VALUE = "true";
private static final String PATH_PROPNAME = "de.kherud.llama.lib.path";
private static final Logger log = Logger.getLogger(LlamaBundleActivator.class.getName());
@Override
public void start(BundleContext ctx) {
log.fine("start bundle");
- String skipAll = LlamaBundleActivator.class.getSimpleName() + SKIP_SUFFIX;
- if (SKIP_VALUE.equals(System.getProperty(skipAll))) {
- log.info("skip loading of native libraries");
- return;
- }
if (checkFilenames(
"/dev/nvidia0",
"/opt/vespa-deps/lib64/cuda/libllama.so",
diff --git a/container-llama/src/main/java/de/kherud/llama/package-info.java b/container-llama/src/main/java/de/kherud/llama/package-info.java
deleted file mode 100644
index 3c9773762b4..00000000000
--- a/container-llama/src/main/java/de/kherud/llama/package-info.java
+++ /dev/null
@@ -1,8 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * @author lesters
- */
-@ExportPackage
-package de.kherud.llama;
-
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search-and-docproc/pom.xml b/container-search-and-docproc/pom.xml
index ee74b86dbec..9554b517586 100644
--- a/container-search-and-docproc/pom.xml
+++ b/container-search-and-docproc/pom.xml
@@ -96,22 +96,6 @@
</exclusion>
</exclusions>
</dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>model-integration</artifactId>
- <version>${project.version}</version>
- <scope>compile</scope>
- <exclusions>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.lz4</groupId>
- <artifactId>lz4-java</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
<!-- PROVIDED scope -->
<dependency>
@@ -226,12 +210,6 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>container-llama</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
<!-- TEST scope -->
<dependency>
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 07f0449e61a..e74fe22c588 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -7842,21 +7842,6 @@
"public static final int emptyDocsumsCode"
]
},
- "com.yahoo.search.result.EventStream$ErrorEvent" : {
- "superClass" : "com.yahoo.search.result.EventStream$Event",
- "interfaces" : [ ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)",
- "public java.lang.String source()",
- "public int code()",
- "public java.lang.String message()",
- "public com.yahoo.search.result.Hit asHit()"
- ],
- "fields" : [ ]
- },
"com.yahoo.search.result.EventStream$Event" : {
"superClass" : "com.yahoo.component.provider.ListenableFreezableClass",
"interfaces" : [
@@ -9164,6 +9149,99 @@
],
"fields" : [ ]
},
+ "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected java.lang.String getEndpoint()",
+ "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
+ "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
+ "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public ai.vespa.llm.clients.LlmClientConfig build()"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes" : [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods" : [
+ "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig" : {
+ "superClass" : "com.yahoo.config.ConfigInstance",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
+ "public java.lang.String apiKeySecretName()",
+ "public java.lang.String endpoint()"
+ ],
+ "fields" : [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
+ "ai.vespa.llm.clients.OpenAI" : {
+ "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
"ai.vespa.search.llm.LLMSearcher" : {
"superClass" : "com.yahoo.search.Searcher",
"interfaces" : [ ],
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
index 761fdf0af93..761fdf0af93 100644
--- a/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
+++ b/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
index 82e19d47c92..82e19d47c92 100644
--- a/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
+++ b/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
index c360245901c..c360245901c 100644
--- a/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
+++ b/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index f565315b775..860fc69af91 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -20,7 +20,6 @@ import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.List;
-import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -84,41 +83,27 @@ public class LLMSearcher extends Searcher {
protected Result complete(Query query, Prompt prompt) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
- try {
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
- } catch (RejectedExecutionException e) {
- return new Result(query, new ErrorMessage(429, e.getMessage()));
- }
- }
-
- private boolean shouldAddPrompt(Query query) {
- return query.getTrace().getLevel() >= 1;
- }
-
- private boolean shouldAddTokenStats(Query query) {
- return query.getTrace().getLevel() >= 1;
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- final EventStream eventStream = new EventStream();
+ EventStream eventStream = new EventStream();
- if (shouldAddPrompt(query)) {
+ if (query.getTrace().getLevel() >= 1) {
eventStream.add(prompt.asString(), "prompt");
}
- final TokenStats tokenStats = new TokenStats();
- languageModel.completeAsync(prompt, options, completion -> {
- tokenStats.onToken();
- handleCompletion(eventStream, completion);
+ languageModel.completeAsync(prompt, options, token -> {
+ eventStream.add(token.text());
}).exceptionally(exception -> {
- handleException(eventStream, exception);
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
- tokenStats.onCompletion();
- if (shouldAddTokenStats(query)) {
- eventStream.add(tokenStats.report(), "stats");
- }
eventStream.markComplete();
});
@@ -127,26 +112,10 @@ public class LLMSearcher extends Searcher {
return new Result(query, hitGroup);
}
- private void handleCompletion(EventStream eventStream, Completion completion) {
- if (completion.finishReason() == Completion.FinishReason.error) {
- eventStream.add(completion.text(), "error");
- } else {
- eventStream.add(completion.text());
- }
- }
-
- private void handleException(EventStream eventStream, Throwable exception) {
- int errorCode = 400;
- if (exception instanceof LanguageModelException languageModelException) {
- errorCode = languageModelException.code();
- }
- eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
- }
-
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
EventStream eventStream = new EventStream();
- if (shouldAddPrompt(query)) {
+ if (query.getTrace().getLevel() >= 1) {
eventStream.add(prompt.asString(), "prompt");
}
@@ -200,35 +169,4 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
- private static class TokenStats {
-
- private long start;
- private long timeToFirstToken;
- private long timeToLastToken;
- private long tokens = 0;
-
- TokenStats() {
- start = System.currentTimeMillis();
- }
-
- void onToken() {
- if (tokens == 0) {
- timeToFirstToken = System.currentTimeMillis() - start;
- }
- tokens++;
- }
-
- void onCompletion() {
- timeToLastToken = System.currentTimeMillis() - start;
- }
-
- String report() {
- return "Time to first token: " + timeToFirstToken + " ms, " +
- "Generation time: " + timeToLastToken + " ms, " +
- "Generated tokens: " + tokens + " " +
- String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0));
- }
-
- }
-
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 88a1e6c1485..83ae349f5a0 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -64,17 +64,7 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- if (data instanceof EventStream.ErrorEvent error) {
- generator.writeRaw("event: error\n");
- generator.writeRaw("data: ");
- generator.writeStartObject();
- generator.writeStringField("source", error.source());
- generator.writeNumberField("error", error.code());
- generator.writeStringField("message", error.message());
- generator.writeEndObject();
- generator.writeRaw("\n\n");
- generator.flush();
- } else if (data instanceof EventStream.Event event) {
+ if (data instanceof EventStream.Event event) {
if (RENDER_EVENT_HEADER) {
generator.writeRaw("event: " + event.type() + "\n");
}
@@ -85,6 +75,19 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.writeRaw("\n\n");
generator.flush();
}
+ else if (data instanceof ErrorHit) {
+ for (ErrorMessage error : ((ErrorHit) data).errors()) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.getSource());
+ generator.writeNumberField("error", error.getCode());
+ generator.writeStringField("message", error.getMessage());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ }
+ }
// Todo: support other types of data such as search results (hits), timing and trace
}
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
index 8e6f7977d55..b393a91e6d0 100644
--- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> {
}
public void error(String source, ErrorMessage message) {
- incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message));
+ incoming().add(new DefaultErrorHit(source, message));
}
public void markComplete() {
@@ -117,38 +117,4 @@ public class EventStream extends Hit implements DataList<Data> {
}
- public static class ErrorEvent extends Event {
-
- private final String source;
- private final ErrorMessage message;
-
- public ErrorEvent(int eventNumber, String source, ErrorMessage message) {
- super(eventNumber, message.getMessage(), "error");
- this.source = source;
- this.message = message;
- }
-
- public String source() {
- return source;
- }
-
- public int code() {
- return message.getCode();
- }
-
- public String message() {
- return message.getMessage();
- }
-
- @Override
- public Hit asHit() {
- Hit hit = super.asHit();
- hit.setField("source", source);
- hit.setField("code", message.getCode());
- return hit;
- }
-
-
- }
-
}
diff --git a/model-integration/src/main/resources/configdefinitions/llm-client.def b/container-search/src/main/resources/configdefinitions/llm-client.def
index 0866459166a..0866459166a 100755
--- a/model-integration/src/main/resources/configdefinitions/llm-client.def
+++ b/container-search/src/main/resources/configdefinitions/llm-client.def
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
index 35d5cfd3855..35d5cfd3855 100644
--- a/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
+++ b/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
index 4d0073f1cbe..4d0073f1cbe 100644
--- a/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
+++ b/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
index 57339f6ad49..57339f6ad49 100644
--- a/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java
+++ b/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
index 3baa9715c34..1efcf1c736a 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
@@ -3,11 +3,14 @@ package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
+import ai.vespa.llm.clients.LlmClientConfig;
+import ai.vespa.llm.clients.MockLLMClient;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.ComponentId;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.provider.ComponentRegistry;
+import com.yahoo.container.jdisc.SecretStoreProvider;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
@@ -17,14 +20,10 @@ import org.junit.jupiter.api.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
-import java.util.Arrays;
-import java.util.List;
import java.util.Map;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
-import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -37,10 +36,10 @@ public class LLMSearcherTest {
@Test
public void testLLMSelection() {
- var client1 = createLLMClient("mock1");
- var client2 = createLLMClient("mock2");
+ var llm1 = createLLMClient("mock1");
+ var llm2 = createLLMClient("mock2");
var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
- var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2));
+ var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
assertEquals(1, result.getHitCount());
assertEquals("My id is mock2", getCompletion(result));
@@ -48,16 +47,14 @@ public class LLMSearcherTest {
@Test
public void testGeneration() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
var params = Map.of("prompt", "why are ducks better than cats");
assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testPrompting() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
// Prompt with prefix
assertEquals("Ducks have adorable waddling walks.",
@@ -74,8 +71,7 @@ public class LLMSearcherTest {
@Test
public void testPromptEvent() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
var params = Map.of(
"prompt", "why are ducks better than cats",
"traceLevel", "1");
@@ -94,8 +90,7 @@ public class LLMSearcherTest {
@Test
public void testParameters() {
- var client = createLLMClient();
- var searcher = createLLMSearcher(client);
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
var params = Map.of(
"llm.prompt", "why are ducks better than cats",
"llm.temperature", "1.0",
@@ -112,18 +107,16 @@ public class LLMSearcherTest {
"foo.maxTokens", "5"
);
var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
- var client = createLLMClient();
- var searcher = createLLMSearcher(config, client);
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testApiKeyFromHeader() {
var properties = Map.of("prompt", "why are ducks better than cats");
- var client = createLLMClient(createApiKeyGenerator("a_valid_key"));
- var searcher = createLLMSearcher(client);
- assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key"));
- assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key"));
+ var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
}
@Test
@@ -136,8 +129,7 @@ public class LLMSearcherTest {
"llm.stream", "true", // ... but inference parameters says do it anyway
"llm.prompt", "why are ducks better than cats?"
);
- var client = createLLMClient(executor);
- var searcher = createLLMSearcher(config, client);
+ var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
Result result = runMockSearch(searcher, params);
assertEquals(1, result.getHitCount());
@@ -170,10 +162,6 @@ public class LLMSearcherTest {
return runMockSearch(searcher, parameters, null, "");
}
- static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) {
- return runMockSearch(searcher, parameters, apiKey, "llm");
- }
-
static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
Chain<Searcher> chain = new Chain<>(searcher);
Execution execution = new Execution(chain, Execution.Context.createContextStub());
@@ -203,59 +191,43 @@ public class LLMSearcherTest {
}
private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return (prompt, options) -> {
- String answer = "I have no opinion on the matter";
- if (prompt.asString().contains("ducks")) {
- answer = "Ducks have adorable waddling walks.";
- var temperature = options.getDouble("temperature");
- if (temperature.isPresent() && temperature.get() > 0.5) {
- answer = "Random text about ducks vs cats that makes no sense whatsoever.";
- }
- }
- var maxTokens = options.getInt("maxTokens");
- if (maxTokens.isPresent()) {
- return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
- }
- return answer;
- };
+ return ConfigurableLanguageModelTest.createGenerator();
}
- private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) {
- return (prompt, options) -> {
- if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) {
- throw new IllegalArgumentException("Invalid API key");
- }
- return "Ok";
- };
- }
-
- static MockLLM createLLMClient() {
- return new MockLLM(createGenerator(), null);
+ static MockLLMClient createLLMClient() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, null);
}
- static MockLLM createLLMClient(String id) {
- return new MockLLM(createIdGenerator(id), null);
+ static MockLLMClient createLLMClient(String id) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createIdGenerator(id);
+ return new MockLLMClient(config, secretStore, generator, null);
}
- static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) {
- return new MockLLM(generator, null);
+ static MockLLMClient createLLMClient(ExecutorService executor) {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore, generator, executor);
}
- static MockLLM createLLMClient(ExecutorService executor) {
- return new MockLLM(createGenerator(), executor);
- }
-
- private static Searcher createLLMSearcher(LanguageModel llm) {
- return createLLMSearcher(Map.of("mock", llm));
+ static MockLLMClient createLLMClientWithoutSecretStore() {
+ var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
+ var secretStore = new SecretStoreProvider();
+ var generator = createGenerator();
+ return new MockLLMClient(config, secretStore.get(), generator, null);
}
private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
- return createLLMSearcher(config, llms);
- }
-
- private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) {
- return createLLMSearcher(config, Map.of("mock", llm));
+ ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
+ llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
+ models.freeze();
+ return new LLMSearcher(config, models);
}
private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
@@ -265,44 +237,4 @@ public class LLMSearcherTest {
return new LLMSearcher(config, models);
}
- private static class MockLLM implements LanguageModel {
-
- private final ExecutorService executor;
- private final BiFunction<Prompt, InferenceParameters, String> generator;
-
- public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) {
- this.executor = executor;
- this.generator = generator;
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters params) {
- return List.of(Completion.from(this.generator.apply(prompt, params)));
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
- InferenceParameters params,
- Consumer<Completion> consumer) {
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
-
- long sleep = 1;
- executor.submit(() -> {
- try {
- for (int i = 0; i < completions.length; ++i) {
- String completion = (i > 0 ? " " : "") + completions[i];
- consumer.accept(Completion.from(completion, Completion.FinishReason.none));
- Thread.sleep(sleep);
- }
- completionFuture.complete(Completion.FinishReason.stop);
- } catch (InterruptedException e) {
- // Do nothing
- }
- });
- return completionFuture;
- }
-
- }
-
}
diff --git a/container-test/pom.xml b/container-test/pom.xml
index d6be6946208..8e1b4870665 100644
--- a/container-test/pom.xml
+++ b/container-test/pom.xml
@@ -61,16 +61,6 @@
<artifactId>onnxruntime</artifactId>
</dependency>
<dependency>
- <groupId>de.kherud</groupId>
- <artifactId>llama</artifactId>
- <exclusions>
- <exclusion>
- <groupId>org.jetbrains</groupId>
- <artifactId>annotations</artifactId>
- </exclusion>
- </exclusions>
- </dependency>
- <dependency>
<groupId>io.airlift</groupId>
<artifactId>airline</artifactId>
<exclusions>
diff --git a/fat-model-dependencies/pom.xml b/fat-model-dependencies/pom.xml
index c4ecd3d2e1c..069b53a0dcd 100644
--- a/fat-model-dependencies/pom.xml
+++ b/fat-model-dependencies/pom.xml
@@ -96,14 +96,6 @@
<groupId>com.theokanning.openai-gpt3-java</groupId>
<artifactId>service</artifactId>
</exclusion>
- <exclusion>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </exclusion>
- <exclusion>
- <groupId>org.lz4</groupId>
- <artifactId>lz4-java</artifactId>
- </exclusion>
</exclusions>
</dependency>
<dependency>
@@ -238,12 +230,6 @@
<groupId>com.yahoo.vespa</groupId>
<artifactId>container-search-and-docproc</artifactId>
<version>${project.version}</version>
- <exclusions>
- <exclusion>
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>model-integration</artifactId>
- </exclusion>
- </exclusions>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json
index e7130d9c777..d3c472778e6 100644
--- a/model-integration/abi-spec.json
+++ b/model-integration/abi-spec.json
@@ -1,186 +1,4 @@
{
- "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public",
- "abstract"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
- "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
- "protected java.lang.String getEndpoint()",
- "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Builder"
- ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
- "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
- "public final java.lang.String getDefMd5()",
- "public final java.lang.String getDefName()",
- "public final java.lang.String getDefNamespace()",
- "public final boolean getApplyOnRestart()",
- "public final void setApplyOnRestart(boolean)",
- "public ai.vespa.llm.clients.LlmClientConfig build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Producer"
- ],
- "attributes" : [
- "public",
- "interface",
- "abstract"
- ],
- "methods" : [
- "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig" : {
- "superClass" : "com.yahoo.config.ConfigInstance",
- "interfaces" : [ ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public static java.lang.String getDefMd5()",
- "public static java.lang.String getDefName()",
- "public static java.lang.String getDefNamespace()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
- "public java.lang.String apiKeySecretName()",
- "public java.lang.String endpoint()"
- ],
- "fields" : [
- "public static final java.lang.String CONFIG_DEF_MD5",
- "public static final java.lang.String CONFIG_DEF_NAME",
- "public static final java.lang.String CONFIG_DEF_NAMESPACE",
- "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
- ]
- },
- "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Builder"
- ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder model(com.yahoo.config.ModelReference)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)",
- "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
- "public final java.lang.String getDefMd5()",
- "public final java.lang.String getDefName()",
- "public final java.lang.String getDefNamespace()",
- "public final boolean getApplyOnRestart()",
- "public final void setApplyOnRestart(boolean)",
- "public ai.vespa.llm.clients.LlmLocalClientConfig build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Producer"
- ],
- "attributes" : [
- "public",
- "interface",
- "abstract"
- ],
- "methods" : [
- "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmLocalClientConfig" : {
- "superClass" : "com.yahoo.config.ConfigInstance",
- "interfaces" : [ ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public static java.lang.String getDefMd5()",
- "public static java.lang.String getDefName()",
- "public static java.lang.String getDefNamespace()",
- "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)",
- "public java.nio.file.Path model()",
- "public int parallelRequests()",
- "public int maxQueueSize()",
- "public boolean useGpu()",
- "public int gpuLayers()",
- "public int threads()",
- "public int contextSize()",
- "public int maxTokens()"
- ],
- "fields" : [
- "public static final java.lang.String CONFIG_DEF_MD5",
- "public static final java.lang.String CONFIG_DEF_NAME",
- "public static final java.lang.String CONFIG_DEF_NAMESPACE",
- "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
- ]
- },
- "ai.vespa.llm.clients.LocalLLM" : {
- "superClass" : "com.yahoo.component.AbstractComponent",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
- "public void deconstruct()",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.OpenAI" : {
- "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
- "interfaces" : [ ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
- ],
- "fields" : [ ]
- },
"ai.vespa.llm.generation.Generator" : {
"superClass" : "com.yahoo.component.AbstractComponent",
"interfaces" : [ ],
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index d92fa319251..0bab30e1453 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -40,12 +40,6 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>container-disc</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
<artifactId>searchcore</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
@@ -82,12 +76,6 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>container-llama</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <groupId>com.yahoo.vespa</groupId>
<artifactId>component</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
deleted file mode 100644
index fd1b8b700c8..00000000000
--- a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
+++ /dev/null
@@ -1,126 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import com.yahoo.component.AbstractComponent;
-import com.yahoo.component.annotation.Inject;
-import de.kherud.llama.LlamaModel;
-import de.kherud.llama.ModelParameters;
-import de.kherud.llama.args.LogFormat;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.ArrayBlockingQueue;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.SynchronousQueue;
-import java.util.concurrent.ThreadPoolExecutor;
-import java.util.concurrent.TimeUnit;
-import java.util.function.Consumer;
-import java.util.logging.Logger;
-
-/**
- * A language model running locally on the container node.
- *
- * @author lesters
- */
-public class LocalLLM extends AbstractComponent implements LanguageModel {
-
- private final static Logger logger = Logger.getLogger(LocalLLM.class.getName());
- private final LlamaModel model;
- private final ThreadPoolExecutor executor;
- private final int contextSize;
- private final int maxTokens;
-
- @Inject
- public LocalLLM(LlmLocalClientConfig config) {
- executor = createExecutor(config);
-
- // Maximum number of tokens to generate - need this since some models can just generate infinitely
- maxTokens = config.maxTokens();
-
- // Only used if GPU is not used
- var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
-
- var modelFile = config.model().toFile().getAbsolutePath();
- var modelParams = new ModelParameters()
- .setModelFilePath(modelFile)
- .setContinuousBatching(true)
- .setNParallel(config.parallelRequests())
- .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads())
- .setNCtx(config.contextSize())
- .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0);
-
- long startLoad = System.nanoTime();
- model = new LlamaModel(modelParams);
- long loadTime = System.nanoTime() - startLoad;
- logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000)));
-
- // Todo: handle prompt context size - such as give a warning when prompt exceeds context size
- contextSize = config.contextSize();
- }
-
- private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) {
- return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
- 0L, TimeUnit.MILLISECONDS,
- config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
- new ThreadPoolExecutor.AbortPolicy());
- }
-
- @Override
- public void deconstruct() {
- logger.info("Closing LLM model...");
- model.close();
- executor.shutdownNow();
- }
-
- @Override
- public List<Completion> complete(Prompt prompt, InferenceParameters options) {
- StringBuilder result = new StringBuilder();
- var future = completeAsync(prompt, options, completion -> {
- result.append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
- var reason = future.join();
-
- List<Completion> completions = new ArrayList<>();
- completions.add(new Completion(result.toString(), reason));
- return completions;
- }
-
- @Override
- public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
- var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading());
-
- // We always set this to some value to avoid infinite token generation
- inferParams.setNPredict(maxTokens);
-
- options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
- options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
- options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
- options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
- options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
- // Todo: more options?
-
- var completionFuture = new CompletableFuture<Completion.FinishReason>();
- try {
- executor.submit(() -> {
- for (LlamaModel.Output output : model.generate(inferParams)) {
- consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
- }
- completionFuture.complete(Completion.FinishReason.stop);
- });
- } catch (RejectedExecutionException e) {
- // If we have too many requests (active + any waiting in queue), we reject the completion
- int activeCount = executor.getActiveCount();
- int queueSize = executor.getQueue().size();
- String error = String.format("Rejected completion due to too many requests, " +
- "%d active, %d in queue", activeCount, queueSize);
- throw new RejectedExecutionException(error);
- }
- return completionFuture;
- }
-
-}
diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def
deleted file mode 100755
index c06c24b33e5..00000000000
--- a/model-integration/src/main/resources/configdefinitions/llm-local-client.def
+++ /dev/null
@@ -1,29 +0,0 @@
-# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package=ai.vespa.llm.clients
-
-# The LLM model to use
-model model
-
-# Maximum number of requests to handle in parallel pr container node
-parallelRequests int default=10
-
-# Additional number of requests to put in queue for processing before starting to reject new requests
-maxQueueSize int default=10
-
-# Use GPU
-useGpu bool default=false
-
-# Maximum number of model layers to run on GPU
-gpuLayers int default=1000000
-
-# Number of threads to use for CPU processing - -1 means use all available cores
-# Not used for GPU processing
-threads int default=-1
-
-# Context size for the model
-# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context
-contextSize int default=512
-
-# Maximum number of tokens to process in one request - overriden by inference parameters
-maxTokens int default=512
-
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
deleted file mode 100644
index e85e397b7ff..00000000000
--- a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
+++ /dev/null
@@ -1,181 +0,0 @@
-// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.llm.clients;
-
-import ai.vespa.llm.InferenceParameters;
-import ai.vespa.llm.completion.Completion;
-import ai.vespa.llm.completion.Prompt;
-import ai.vespa.llm.completion.StringPrompt;
-import com.yahoo.config.ModelReference;
-import org.junit.jupiter.api.Test;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.RejectedExecutionException;
-import java.util.concurrent.atomic.AtomicInteger;
-
-import static org.junit.jupiter.api.Assertions.assertEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-import static org.junit.jupiter.api.Assertions.assertNotEquals;
-import static org.junit.jupiter.api.Assertions.assertTrue;
-
-/**
- * Tests for LocalLLM.
- *
- * @author lesters
- */
-public class LocalLLMTest {
-
- private static String model = "src/test/models/llm/tinyllm.gguf";
- private static Prompt prompt = StringPrompt.from("A random prompt");
-
- @Test
- public void testGeneration() {
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(1)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- try {
- var result = llm.complete(prompt, defaultOptions());
- assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
- assertTrue(result.get(0).text().length() > 10);
- } finally {
- llm.deconstruct();
- }
- }
-
- @Test
- public void testAsyncGeneration() {
- var sb = new StringBuilder();
- var tokenCount = new AtomicInteger(0);
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(1)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- try {
- var future = llm.completeAsync(prompt, defaultOptions(), completion -> {
- sb.append(completion.text());
- tokenCount.incrementAndGet();
- }).exceptionally(exception -> Completion.FinishReason.error);
-
- assertFalse(future.isDone());
- var reason = future.join();
- assertTrue(future.isDone());
- assertNotEquals(reason, Completion.FinishReason.error);
-
- } finally {
- llm.deconstruct();
- }
- assertTrue(tokenCount.get() > 0);
-// System.out.println(sb);
- }
-
- @Test
- public void testParallelGeneration() {
- var prompts = testPrompts();
- var promptsToUse = prompts.size();
- var parallelRequests = 10;
-
- var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
- var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
- var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0));
-
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(parallelRequests)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- try {
- for (int i = 0; i < promptsToUse; i++) {
- final var seq = i;
-
- completions.set(seq, new StringBuilder());
- futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
- completions.get(seq).append(completion.text());
- tokenCounts.set(seq, tokenCounts.get(seq) + 1);
- }).exceptionally(exception -> Completion.FinishReason.error));
- }
- for (int i = 0; i < promptsToUse; i++) {
- var reason = futures.get(i).join();
- assertNotEquals(reason, Completion.FinishReason.error);
- }
- } finally {
- llm.deconstruct();
- }
- for (int i = 0; i < promptsToUse; i++) {
- assertFalse(completions.get(i).isEmpty());
- assertTrue(tokenCounts.get(i) > 0);
- }
- }
-
- @Test
- public void testRejection() {
- var prompts = testPrompts();
- var promptsToUse = prompts.size();
- var parallelRequests = 2;
- var additionalQueue = 1;
- // 7 should be rejected
-
- var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
- var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
-
- var config = new LlmLocalClientConfig.Builder()
- .parallelRequests(parallelRequests)
- .maxQueueSize(additionalQueue)
- .model(ModelReference.valueOf(model));
- var llm = new LocalLLM(config.build());
-
- var rejected = new AtomicInteger(0);
- try {
- for (int i = 0; i < promptsToUse; i++) {
- final var seq = i;
-
- completions.set(seq, new StringBuilder());
- try {
- var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
- completions.get(seq).append(completion.text());
- }).exceptionally(exception -> Completion.FinishReason.error);
- futures.set(seq, future);
- } catch (RejectedExecutionException e) {
- rejected.incrementAndGet();
- }
- }
- for (int i = 0; i < promptsToUse; i++) {
- if (futures.get(i) != null) {
- assertNotEquals(futures.get(i).join(), Completion.FinishReason.error);
- }
- }
- } finally {
- llm.deconstruct();
- }
- assertEquals(7, rejected.get());
- }
-
- private static InferenceParameters defaultOptions() {
- final Map<String, String> options = Map.of(
- "temperature", "0.1",
- "npredict", "100"
- );
- return new InferenceParameters(options::get);
- }
-
- private List<String> testPrompts() {
- List<String> prompts = new ArrayList<>();
- prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries.");
- prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms.");
- prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe.");
- prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure.");
- prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective.");
- prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'");
- prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`.");
- prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease.");
- prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level.");
- prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century.");
- return prompts;
- }
-
-}
diff --git a/model-integration/src/test/models/llm/tinyllm.gguf b/model-integration/src/test/models/llm/tinyllm.gguf
deleted file mode 100644
index 34367b6b57b..00000000000
--- a/model-integration/src/test/models/llm/tinyllm.gguf
+++ /dev/null
Binary files differ
diff --git a/standalone-container/pom.xml b/standalone-container/pom.xml
index 92faa1ae670..844b912543c 100644
--- a/standalone-container/pom.xml
+++ b/standalone-container/pom.xml
@@ -113,7 +113,6 @@
model-evaluation-jar-with-dependencies.jar,
model-integration-jar-with-dependencies.jar,
container-onnxruntime.jar,
- container-llama.jar,
<!-- END config-model dependencies -->
</discPreInstallBundle>
</configuration>