aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--client/go/internal/admin/vespa-wrapper/standalone/start.go1
-rw-r--r--cloud-tenant-base-dependencies-enforcer/pom.xml1
-rw-r--r--config-model-api/abi-spec.json2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java18
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java9
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java6
-rw-r--r--container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java1
-rw-r--r--container-dependencies-enforcer/pom.xml1
-rw-r--r--container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java7
-rw-r--r--container-llama/src/main/java/de/kherud/llama/package-info.java8
-rw-r--r--container-search-and-docproc/pom.xml12
-rw-r--r--container-search/abi-spec.json108
-rwxr-xr-xcontainer-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java84
-rw-r--r--container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java25
-rw-r--r--container-search/src/main/java/com/yahoo/search/result/EventStream.java36
-rwxr-xr-xcontainer-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java150
-rw-r--r--container-test/pom.xml10
-rw-r--r--dependency-versions/pom.xml2
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Dimension.java9
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java14
-rw-r--r--metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java9
-rw-r--r--metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java18
-rw-r--r--metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java10
-rw-r--r--metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java9
-rw-r--r--metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java4
-rw-r--r--metrics/src/main/java/ai/vespa/metrics/ControllerMetrics.java4
-rw-r--r--model-integration/abi-spec.json182
-rw-r--r--model-integration/pom.xml12
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java (renamed from container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java)0
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java126
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java (renamed from container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java)0
-rw-r--r--model-integration/src/main/java/ai/vespa/llm/clients/package-info.java (renamed from container-search/src/main/java/ai/vespa/llm/clients/package-info.java)0
-rwxr-xr-xmodel-integration/src/main/resources/configdefinitions/llm-client.def (renamed from container-search/src/main/resources/configdefinitions/llm-client.def)0
-rwxr-xr-xmodel-integration/src/main/resources/configdefinitions/llm-local-client.def29
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java (renamed from container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java)0
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java186
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java (renamed from container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java)0
-rw-r--r--model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java (renamed from container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java)0
-rw-r--r--model-integration/src/test/models/llm/tinyllm.ggufbin0 -> 1185376 bytes
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java19
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java8
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java4
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java1
-rw-r--r--searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp10
-rw-r--r--searchlib/src/tests/diskindex/pagedict4/.gitignore1
-rw-r--r--searchlib/src/tests/diskindex/pagedict4/CMakeLists.txt8
-rw-r--r--searchlib/src/tests/diskindex/pagedict4/pagedict4_long_words_test.cpp131
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/CMakeLists.txt1
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.cpp3
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.h1
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/common.cpp39
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/common.h2
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.cpp78
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.h39
-rw-r--r--searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp271
-rw-r--r--searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h23
-rw-r--r--searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp30
-rw-r--r--searchlib/src/vespa/searchlib/bitcompression/compression.cpp18
-rw-r--r--searchlib/src/vespa/searchlib/bitcompression/compression.h10
-rw-r--r--searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/flow_tuning.h29
-rw-r--r--searchlib/src/vespa/searchlib/util/comprfile.cpp39
-rw-r--r--searchlib/src/vespa/searchlib/util/comprfile.h6
-rw-r--r--standalone-container/pom.xml1
66 files changed, 1572 insertions, 316 deletions
diff --git a/client/go/internal/admin/vespa-wrapper/standalone/start.go b/client/go/internal/admin/vespa-wrapper/standalone/start.go
index a3703ce930c..16e76562b99 100644
--- a/client/go/internal/admin/vespa-wrapper/standalone/start.go
+++ b/client/go/internal/admin/vespa-wrapper/standalone/start.go
@@ -41,6 +41,7 @@ func StartStandaloneContainer(extraArgs []string) int {
c := jvm.NewStandaloneContainer(serviceName)
jvmOpts := c.JvmOptions()
jvmOpts.AddOption("-DOnnxBundleActivator.skip=true")
+ jvmOpts.AddOption("-DLlamaBundleActivator.skip=true")
for _, extra := range extraArgs {
jvmOpts.AddOption(extra)
}
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml
index eff4e4125e9..98bef7df402 100644
--- a/cloud-tenant-base-dependencies-enforcer/pom.xml
+++ b/cloud-tenant-base-dependencies-enforcer/pom.xml
@@ -141,6 +141,7 @@
<include>com.microsoft.onnxruntime:onnxruntime:jar:${onnxruntime.vespa.version}:test</include>
<include>com.thaiopensource:jing:20091111:test</include>
<include>commons-codec:commons-codec:${commons-codec.vespa.version}:test</include>
+ <include>de.kherud:llama:${kherud.llama.vespa.version}:test</include>
<include>io.airlift:aircompressor:${aircompressor.vespa.version}:test</include>
<include>io.airlift:airline:${airline.vespa.version}:test</include>
<include>io.prometheus:simpleclient:${prometheus.client.vespa.version}:test</include>
diff --git a/config-model-api/abi-spec.json b/config-model-api/abi-spec.json
index c416a5e3a0b..42e7e23dfcc 100644
--- a/config-model-api/abi-spec.json
+++ b/config-model-api/abi-spec.json
@@ -1862,4 +1862,4 @@
"public final java.lang.String serviceName"
]
}
-}
+} \ No newline at end of file
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 62979404025..5be1690f0dc 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -31,6 +31,7 @@ public class ContainerModelEvaluation implements
public final static String EVALUATION_BUNDLE_NAME = "model-evaluation";
public final static String INTEGRATION_BUNDLE_NAME = "model-integration";
public final static String ONNXRUNTIME_BUNDLE_NAME = "container-onnxruntime.jar";
+ public final static String LLAMA_BUNDLE_NAME = "container-llama.jar";
public final static String ONNX_RUNTIME_CLASS = "ai.vespa.modelintegration.evaluator.OnnxRuntime";
private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
index 9f91f6bf5e1..468cf8dd961 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/PlatformBundles.java
@@ -14,6 +14,7 @@ import java.util.stream.Stream;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.EVALUATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LINGUISTICS_BUNDLE_NAME;
+import static com.yahoo.vespa.model.container.ContainerModelEvaluation.LLAMA_BUNDLE_NAME;
import static com.yahoo.vespa.model.container.ContainerModelEvaluation.ONNXRUNTIME_BUNDLE_NAME;
/**
@@ -37,6 +38,7 @@ public class PlatformBundles {
public static final Path LIBRARY_PATH = Paths.get(Defaults.getDefaults().underVespaHome("lib/jars"));
public static final String SEARCH_AND_DOCPROC_BUNDLE = BundleInstantiationSpecification.CONTAINER_SEARCH_AND_DOCPROC;
+ public static final String MODEL_INTEGRATION_BUNDLE = BundleInstantiationSpecification.MODEL_INTEGRATION;
// Bundles that must be loaded for all container types.
public static final Set<Path> COMMON_VESPA_BUNDLES = toBundlePaths(
@@ -63,7 +65,8 @@ public class PlatformBundles {
"lucene-linguistics",
EVALUATION_BUNDLE_NAME,
INTEGRATION_BUNDLE_NAME,
- ONNXRUNTIME_BUNDLE_NAME
+ ONNXRUNTIME_BUNDLE_NAME,
+ LLAMA_BUNDLE_NAME
);
private static Set<Path> toBundlePaths(String... bundleNames) {
@@ -86,6 +89,10 @@ public class PlatformBundles {
return searchAndDocprocComponents.contains(className);
}
+ public static boolean isModelIntegrationClass(String className) {
+ return modelIntegrationComponents.contains(className);
+ }
+
// This is a hack to allow users to declare components from the search-and-docproc bundle without naming the bundle.
private static final Set<String> searchAndDocprocComponents = Set.of(
com.yahoo.docproc.AbstractConcreteDocumentFactory.class.getName(),
@@ -147,8 +154,13 @@ public class PlatformBundles {
com.yahoo.vespa.streamingvisitors.MetricsSearcher.class.getName(),
com.yahoo.vespa.streamingvisitors.StreamingBackend.class.getName(),
ai.vespa.search.llm.LLMSearcher.class.getName(),
- ai.vespa.search.llm.RAGSearcher.class.getName(),
- ai.vespa.llm.clients.OpenAI.class.getName()
+ ai.vespa.search.llm.RAGSearcher.class.getName()
+ );
+
+ // This is a hack to allow users to declare components from the model-integration bundle without naming the bundle.
+ private static final Set<String> modelIntegrationComponents = Set.of(
+ ai.vespa.llm.clients.OpenAI.class.getName(),
+ ai.vespa.llm.clients.LocalLLM.class.getName()
);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java
index 7e14eafc2ee..1323506eaeb 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/BundleInstantiationSpecificationBuilder.java
@@ -26,17 +26,18 @@ public class BundleInstantiationSpecificationBuilder {
BundleInstantiationSpecification instSpec = new BundleInstantiationSpecification(id, classId, bundle);
validate(instSpec);
- return bundle == null ? setBundleForSearchAndDocprocComponents(instSpec) : instSpec;
+ return bundle == null ? setBundleForComponent(instSpec) : instSpec;
}
- private static BundleInstantiationSpecification setBundleForSearchAndDocprocComponents(BundleInstantiationSpecification spec) {
+ private static BundleInstantiationSpecification setBundleForComponent(BundleInstantiationSpecification spec) {
if (PlatformBundles.isSearchAndDocprocClass(spec.getClassName()))
return spec.inBundle(PlatformBundles.SEARCH_AND_DOCPROC_BUNDLE);
+ else if (PlatformBundles.isModelIntegrationClass(spec.getClassName()))
+ return spec.inBundle(PlatformBundles.MODEL_INTEGRATION_BUNDLE);
else
return spec;
}
-
private static void validate(BundleInstantiationSpecification instSpec) {
List<String> forbiddenClasses = List.of(SearchHandler.HANDLER_CLASSNAME, PROCESSING_HANDLER_CLASS);
@@ -47,7 +48,7 @@ public class BundleInstantiationSpecificationBuilder {
}
}
- //null if missing
+ // null if missing
private static ComponentSpecification getComponentSpecification(Element spec, String attributeName) {
return (spec.hasAttribute(attributeName)) ?
new ComponentSpecification(spec.getAttribute(attributeName)) :
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
index f9993b770e5..867ac86f8d5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -26,6 +26,7 @@ public class ModelIdResolver {
public static final String ONNX_MODEL = "onnx-model";
public static final String BERT_VOCAB = "bert-vocabulary";
public static final String SIGNIFICANCE_MODEL = "significance-model";
+ public static final String GGUF_MODEL = "gguf-model";
private static Map<String, ProvidedModel> setupProvidedModels() {
var m = new HashMap<String, ProvidedModel>();
@@ -60,6 +61,9 @@ public class ModelIdResolver {
register(m, "e5-large-v2", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/model.onnx", Set.of(ONNX_MODEL));
register(m, "e5-large-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-large-v2/tokenizer.json", Set.of(HF_TOKENIZER));
+
+ register(m, "mistral-7b", "https://data.vespa.oath.cloud/gguf_models/mistral-7b-instruct-v0.1.Q6_K.gguf", Set.of(GGUF_MODEL));
+ register(m, "mistral-7b-q8", "https://data.vespa.oath.cloud/gguf_models/mistral-7b-instruct-v0.1.Q8_0.gguf", Set.of(GGUF_MODEL));
return Map.copyOf(m);
}
@@ -124,7 +128,7 @@ public class ModelIdResolver {
throw new IllegalArgumentException("Unknown model id '" + modelId + "' on '" + valueName + "'. Available models are [" +
providedModels.keySet().stream().sorted().collect(Collectors.joining(", ")) + "]");
var providedModel = providedModels.get(modelId);
- if (!providedModel.tags().containsAll(requiredTags)) {
+ if ( ! providedModel.tags().containsAll(requiredTags)) {
throw new IllegalArgumentException(
"Model '%s' on '%s' has tags %s but are missing required tags %s"
.formatted(modelId, valueName, providedModel.tags(), requiredTags));
diff --git a/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java b/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java
index bd35d257813..b49f519906f 100644
--- a/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java
+++ b/container-core/src/main/java/com/yahoo/container/bundle/BundleInstantiationSpecification.java
@@ -15,6 +15,7 @@ import com.yahoo.component.ComponentSpecification;
public final class BundleInstantiationSpecification {
public static final String CONTAINER_SEARCH_AND_DOCPROC = "container-search-and-docproc";
+ public static final String MODEL_INTEGRATION = "model-integration";
public final ComponentId id;
public final ComponentSpecification classId;
diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml
index a06365abbeb..f67f33a3b05 100644
--- a/container-dependencies-enforcer/pom.xml
+++ b/container-dependencies-enforcer/pom.xml
@@ -154,6 +154,7 @@
<include>com.microsoft.onnxruntime:onnxruntime:${onnxruntime.vespa.version}:test</include>
<include>com.thaiopensource:jing:20091111:test</include>
<include>commons-codec:commons-codec:${commons-codec.vespa.version}:test</include>
+ <include>de.kherud:llama:${kherud.llama.vespa.version}:test</include>
<include>io.airlift:aircompressor:${aircompressor.vespa.version}:test</include>
<include>io.airlift:airline:${airline.vespa.version}:test</include>
<include>io.prometheus:simpleclient:${prometheus.client.vespa.version}:test</include>
diff --git a/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java b/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java
index 11ba05e363d..846a2008858 100644
--- a/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java
+++ b/container-llama/src/main/java/ai/vespa/llama/LlamaBundleActivator.java
@@ -13,12 +13,19 @@ import java.util.logging.Logger;
**/
public class LlamaBundleActivator implements BundleActivator {
+ private static final String SKIP_SUFFIX = ".skip";
+ private static final String SKIP_VALUE = "true";
private static final String PATH_PROPNAME = "de.kherud.llama.lib.path";
private static final Logger log = Logger.getLogger(LlamaBundleActivator.class.getName());
@Override
public void start(BundleContext ctx) {
log.fine("start bundle");
+ String skipAll = LlamaBundleActivator.class.getSimpleName() + SKIP_SUFFIX;
+ if (SKIP_VALUE.equals(System.getProperty(skipAll))) {
+ log.info("skip loading of native libraries");
+ return;
+ }
if (checkFilenames(
"/dev/nvidia0",
"/opt/vespa-deps/lib64/cuda/libllama.so",
diff --git a/container-llama/src/main/java/de/kherud/llama/package-info.java b/container-llama/src/main/java/de/kherud/llama/package-info.java
new file mode 100644
index 00000000000..3c9773762b4
--- /dev/null
+++ b/container-llama/src/main/java/de/kherud/llama/package-info.java
@@ -0,0 +1,8 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+/**
+ * @author lesters
+ */
+@ExportPackage
+package de.kherud.llama;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search-and-docproc/pom.xml b/container-search-and-docproc/pom.xml
index 9554b517586..e2afa0e91f4 100644
--- a/container-search-and-docproc/pom.xml
+++ b/container-search-and-docproc/pom.xml
@@ -210,6 +210,18 @@
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>model-integration</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>container-llama</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
<!-- TEST scope -->
<dependency>
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index e74fe22c588..07f0449e61a 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -7842,6 +7842,21 @@
"public static final int emptyDocsumsCode"
]
},
+ "com.yahoo.search.result.EventStream$ErrorEvent" : {
+ "superClass" : "com.yahoo.search.result.EventStream$Event",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(int, java.lang.String, com.yahoo.search.result.ErrorMessage)",
+ "public java.lang.String source()",
+ "public int code()",
+ "public java.lang.String message()",
+ "public com.yahoo.search.result.Hit asHit()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.search.result.EventStream$Event" : {
"superClass" : "com.yahoo.component.provider.ListenableFreezableClass",
"interfaces" : [
@@ -9149,99 +9164,6 @@
],
"fields" : [ ]
},
- "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "ai.vespa.llm.LanguageModel"
- ],
- "attributes" : [
- "public",
- "abstract"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
- "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
- "protected java.lang.String getEndpoint()",
- "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Builder"
- ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public void <init>()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
- "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
- "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
- "public final java.lang.String getDefMd5()",
- "public final java.lang.String getDefName()",
- "public final java.lang.String getDefNamespace()",
- "public final boolean getApplyOnRestart()",
- "public final void setApplyOnRestart(boolean)",
- "public ai.vespa.llm.clients.LlmClientConfig build()"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
- "superClass" : "java.lang.Object",
- "interfaces" : [
- "com.yahoo.config.ConfigInstance$Producer"
- ],
- "attributes" : [
- "public",
- "interface",
- "abstract"
- ],
- "methods" : [
- "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
- ],
- "fields" : [ ]
- },
- "ai.vespa.llm.clients.LlmClientConfig" : {
- "superClass" : "com.yahoo.config.ConfigInstance",
- "interfaces" : [ ],
- "attributes" : [
- "public",
- "final"
- ],
- "methods" : [
- "public static java.lang.String getDefMd5()",
- "public static java.lang.String getDefName()",
- "public static java.lang.String getDefNamespace()",
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
- "public java.lang.String apiKeySecretName()",
- "public java.lang.String endpoint()"
- ],
- "fields" : [
- "public static final java.lang.String CONFIG_DEF_MD5",
- "public static final java.lang.String CONFIG_DEF_NAME",
- "public static final java.lang.String CONFIG_DEF_NAMESPACE",
- "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
- ]
- },
- "ai.vespa.llm.clients.OpenAI" : {
- "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
- "interfaces" : [ ],
- "attributes" : [
- "public"
- ],
- "methods" : [
- "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
- "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
- "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
- ],
- "fields" : [ ]
- },
"ai.vespa.search.llm.LLMSearcher" : {
"superClass" : "com.yahoo.search.Searcher",
"interfaces" : [ ],
diff --git a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
index 860fc69af91..f565315b775 100755
--- a/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
+++ b/container-search/src/main/java/ai/vespa/search/llm/LLMSearcher.java
@@ -20,6 +20,7 @@ import com.yahoo.search.result.HitGroup;
import com.yahoo.search.searchchain.Execution;
import java.util.List;
+import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -83,27 +84,41 @@ public class LLMSearcher extends Searcher {
protected Result complete(Query query, Prompt prompt) {
var options = new InferenceParameters(getApiKeyHeader(query), s -> lookupProperty(s, query));
var stream = lookupPropertyBool(STREAM_PROPERTY, query, this.stream); // query value overwrites config
- return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ try {
+ return stream ? completeAsync(query, prompt, options) : completeSync(query, prompt, options);
+ } catch (RejectedExecutionException e) {
+ return new Result(query, new ErrorMessage(429, e.getMessage()));
+ }
+ }
+
+ private boolean shouldAddPrompt(Query query) {
+ return query.getTrace().getLevel() >= 1;
+ }
+
+ private boolean shouldAddTokenStats(Query query) {
+ return query.getTrace().getLevel() >= 1;
}
private Result completeAsync(Query query, Prompt prompt, InferenceParameters options) {
- EventStream eventStream = new EventStream();
+ final EventStream eventStream = new EventStream();
- if (query.getTrace().getLevel() >= 1) {
+ if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
- languageModel.completeAsync(prompt, options, token -> {
- eventStream.add(token.text());
+ final TokenStats tokenStats = new TokenStats();
+ languageModel.completeAsync(prompt, options, completion -> {
+ tokenStats.onToken();
+ handleCompletion(eventStream, completion);
}).exceptionally(exception -> {
- int errorCode = 400;
- if (exception instanceof LanguageModelException languageModelException) {
- errorCode = languageModelException.code();
- }
- eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ handleException(eventStream, exception);
eventStream.markComplete();
return Completion.FinishReason.error;
}).thenAccept(finishReason -> {
+ tokenStats.onCompletion();
+ if (shouldAddTokenStats(query)) {
+ eventStream.add(tokenStats.report(), "stats");
+ }
eventStream.markComplete();
});
@@ -112,10 +127,26 @@ public class LLMSearcher extends Searcher {
return new Result(query, hitGroup);
}
+ private void handleCompletion(EventStream eventStream, Completion completion) {
+ if (completion.finishReason() == Completion.FinishReason.error) {
+ eventStream.add(completion.text(), "error");
+ } else {
+ eventStream.add(completion.text());
+ }
+ }
+
+ private void handleException(EventStream eventStream, Throwable exception) {
+ int errorCode = 400;
+ if (exception instanceof LanguageModelException languageModelException) {
+ errorCode = languageModelException.code();
+ }
+ eventStream.error(languageModelId, new ErrorMessage(errorCode, exception.getMessage()));
+ }
+
private Result completeSync(Query query, Prompt prompt, InferenceParameters options) {
EventStream eventStream = new EventStream();
- if (query.getTrace().getLevel() >= 1) {
+ if (shouldAddPrompt(query)) {
eventStream.add(prompt.asString(), "prompt");
}
@@ -169,4 +200,35 @@ public class LLMSearcher extends Searcher {
return lookupPropertyWithOrWithoutPrefix(API_KEY_HEADER, p -> query.getHttpRequest().getHeader(p));
}
+ private static class TokenStats {
+
+ private long start;
+ private long timeToFirstToken;
+ private long timeToLastToken;
+ private long tokens = 0;
+
+ TokenStats() {
+ start = System.currentTimeMillis();
+ }
+
+ void onToken() {
+ if (tokens == 0) {
+ timeToFirstToken = System.currentTimeMillis() - start;
+ }
+ tokens++;
+ }
+
+ void onCompletion() {
+ timeToLastToken = System.currentTimeMillis() - start;
+ }
+
+ String report() {
+ return "Time to first token: " + timeToFirstToken + " ms, " +
+ "Generation time: " + timeToLastToken + " ms, " +
+ "Generated tokens: " + tokens + " " +
+ String.format("(%.2f tokens/sec)", tokens / (timeToLastToken / 1000.0));
+ }
+
+ }
+
}
diff --git a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
index 83ae349f5a0..88a1e6c1485 100644
--- a/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
+++ b/container-search/src/main/java/com/yahoo/search/rendering/EventRenderer.java
@@ -64,7 +64,17 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
@Override
public void data(Data data) throws IOException {
- if (data instanceof EventStream.Event event) {
+ if (data instanceof EventStream.ErrorEvent error) {
+ generator.writeRaw("event: error\n");
+ generator.writeRaw("data: ");
+ generator.writeStartObject();
+ generator.writeStringField("source", error.source());
+ generator.writeNumberField("error", error.code());
+ generator.writeStringField("message", error.message());
+ generator.writeEndObject();
+ generator.writeRaw("\n\n");
+ generator.flush();
+ } else if (data instanceof EventStream.Event event) {
if (RENDER_EVENT_HEADER) {
generator.writeRaw("event: " + event.type() + "\n");
}
@@ -75,19 +85,6 @@ public class EventRenderer extends AsynchronousSectionedRenderer<Result> {
generator.writeRaw("\n\n");
generator.flush();
}
- else if (data instanceof ErrorHit) {
- for (ErrorMessage error : ((ErrorHit) data).errors()) {
- generator.writeRaw("event: error\n");
- generator.writeRaw("data: ");
- generator.writeStartObject();
- generator.writeStringField("source", error.getSource());
- generator.writeNumberField("error", error.getCode());
- generator.writeStringField("message", error.getMessage());
- generator.writeEndObject();
- generator.writeRaw("\n\n");
- generator.flush();
- }
- }
// Todo: support other types of data such as search results (hits), timing and trace
}
diff --git a/container-search/src/main/java/com/yahoo/search/result/EventStream.java b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
index b393a91e6d0..8e6f7977d55 100644
--- a/container-search/src/main/java/com/yahoo/search/result/EventStream.java
+++ b/container-search/src/main/java/com/yahoo/search/result/EventStream.java
@@ -41,7 +41,7 @@ public class EventStream extends Hit implements DataList<Data> {
}
public void error(String source, ErrorMessage message) {
- incoming().add(new DefaultErrorHit(source, message));
+ incoming().add(new ErrorEvent(eventCount.incrementAndGet(), source, message));
}
public void markComplete() {
@@ -117,4 +117,38 @@ public class EventStream extends Hit implements DataList<Data> {
}
+ public static class ErrorEvent extends Event {
+
+ private final String source;
+ private final ErrorMessage message;
+
+ public ErrorEvent(int eventNumber, String source, ErrorMessage message) {
+ super(eventNumber, message.getMessage(), "error");
+ this.source = source;
+ this.message = message;
+ }
+
+ public String source() {
+ return source;
+ }
+
+ public int code() {
+ return message.getCode();
+ }
+
+ public String message() {
+ return message.getMessage();
+ }
+
+ @Override
+ public Hit asHit() {
+ Hit hit = super.asHit();
+ hit.setField("source", source);
+ hit.setField("code", message.getCode());
+ return hit;
+ }
+
+
+ }
+
}
diff --git a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
index 1efcf1c736a..3baa9715c34 100755
--- a/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
+++ b/container-search/src/test/java/ai/vespa/search/llm/LLMSearcherTest.java
@@ -3,14 +3,11 @@ package ai.vespa.search.llm;
import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
-import ai.vespa.llm.clients.ConfigurableLanguageModelTest;
-import ai.vespa.llm.clients.LlmClientConfig;
-import ai.vespa.llm.clients.MockLLMClient;
+import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import com.yahoo.component.ComponentId;
import com.yahoo.component.chain.Chain;
import com.yahoo.component.provider.ComponentRegistry;
-import com.yahoo.container.jdisc.SecretStoreProvider;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
@@ -20,10 +17,14 @@ import org.junit.jupiter.api.Test;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
+import java.util.Arrays;
+import java.util.List;
import java.util.Map;
+import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BiFunction;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
@@ -36,10 +37,10 @@ public class LLMSearcherTest {
@Test
public void testLLMSelection() {
- var llm1 = createLLMClient("mock1");
- var llm2 = createLLMClient("mock2");
+ var client1 = createLLMClient("mock1");
+ var client2 = createLLMClient("mock2");
var config = new LlmSearcherConfig.Builder().stream(false).providerId("mock2").build();
- var searcher = createLLMSearcher(config, Map.of("mock1", llm1, "mock2", llm2));
+ var searcher = createLLMSearcher(config, Map.of("mock1", client1, "mock2", client2));
var result = runMockSearch(searcher, Map.of("prompt", "what is your id?"));
assertEquals(1, result.getHitCount());
assertEquals("My id is mock2", getCompletion(result));
@@ -47,14 +48,16 @@ public class LLMSearcherTest {
@Test
public void testGeneration() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of("prompt", "why are ducks better than cats");
assertEquals("Ducks have adorable waddling walks.", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testPrompting() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
// Prompt with prefix
assertEquals("Ducks have adorable waddling walks.",
@@ -71,7 +74,8 @@ public class LLMSearcherTest {
@Test
public void testPromptEvent() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of(
"prompt", "why are ducks better than cats",
"traceLevel", "1");
@@ -90,7 +94,8 @@ public class LLMSearcherTest {
@Test
public void testParameters() {
- var searcher = createLLMSearcher(Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(client);
var params = Map.of(
"llm.prompt", "why are ducks better than cats",
"llm.temperature", "1.0",
@@ -107,16 +112,18 @@ public class LLMSearcherTest {
"foo.maxTokens", "5"
);
var config = new LlmSearcherConfig.Builder().stream(false).propertyPrefix(prefix).providerId("mock").build();
- var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient()));
+ var client = createLLMClient();
+ var searcher = createLLMSearcher(config, client);
assertEquals("I have no opinion on", getCompletion(runMockSearch(searcher, params)));
}
@Test
public void testApiKeyFromHeader() {
var properties = Map.of("prompt", "why are ducks better than cats");
- var searcher = createLLMSearcher(Map.of("mock", createLLMClientWithoutSecretStore()));
- assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key", "llm"));
- assertDoesNotThrow(() -> runMockSearch(searcher, properties, MockLLMClient.ACCEPTED_API_KEY, "llm"));
+ var client = createLLMClient(createApiKeyGenerator("a_valid_key"));
+ var searcher = createLLMSearcher(client);
+ assertThrows(IllegalArgumentException.class, () -> runMockSearch(searcher, properties, "invalid_key"));
+ assertDoesNotThrow(() -> runMockSearch(searcher, properties, "a_valid_key"));
}
@Test
@@ -129,7 +136,8 @@ public class LLMSearcherTest {
"llm.stream", "true", // ... but inference parameters says do it anyway
"llm.prompt", "why are ducks better than cats?"
);
- var searcher = createLLMSearcher(config, Map.of("mock", createLLMClient(executor)));
+ var client = createLLMClient(executor);
+ var searcher = createLLMSearcher(config, client);
Result result = runMockSearch(searcher, params);
assertEquals(1, result.getHitCount());
@@ -162,6 +170,10 @@ public class LLMSearcherTest {
return runMockSearch(searcher, parameters, null, "");
}
+ static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey) {
+ return runMockSearch(searcher, parameters, apiKey, "llm");
+ }
+
static Result runMockSearch(Searcher searcher, Map<String, String> parameters, String apiKey, String prefix) {
Chain<Searcher> chain = new Chain<>(searcher);
Execution execution = new Execution(chain, Execution.Context.createContextStub());
@@ -191,43 +203,59 @@ public class LLMSearcherTest {
}
private static BiFunction<Prompt, InferenceParameters, String> createGenerator() {
- return ConfigurableLanguageModelTest.createGenerator();
+ return (prompt, options) -> {
+ String answer = "I have no opinion on the matter";
+ if (prompt.asString().contains("ducks")) {
+ answer = "Ducks have adorable waddling walks.";
+ var temperature = options.getDouble("temperature");
+ if (temperature.isPresent() && temperature.get() > 0.5) {
+ answer = "Random text about ducks vs cats that makes no sense whatsoever.";
+ }
+ }
+ var maxTokens = options.getInt("maxTokens");
+ if (maxTokens.isPresent()) {
+ return Arrays.stream(answer.split(" ")).limit(maxTokens.get()).collect(Collectors.joining(" "));
+ }
+ return answer;
+ };
}
- static MockLLMClient createLLMClient() {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore, generator, null);
+ private static BiFunction<Prompt, InferenceParameters, String> createApiKeyGenerator(String validApiKey) {
+ return (prompt, options) -> {
+ if (options.getApiKey().isEmpty() || ! options.getApiKey().get().equals(validApiKey)) {
+ throw new IllegalArgumentException("Invalid API key");
+ }
+ return "Ok";
+ };
+ }
+
+ static MockLLM createLLMClient() {
+ return new MockLLM(createGenerator(), null);
}
- static MockLLMClient createLLMClient(String id) {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createIdGenerator(id);
- return new MockLLMClient(config, secretStore, generator, null);
+ static MockLLM createLLMClient(String id) {
+ return new MockLLM(createIdGenerator(id), null);
}
- static MockLLMClient createLLMClient(ExecutorService executor) {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = ConfigurableLanguageModelTest.createSecretStore(Map.of("api-key", MockLLMClient.ACCEPTED_API_KEY));
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore, generator, executor);
+ static MockLLM createLLMClient(BiFunction<Prompt, InferenceParameters, String> generator) {
+ return new MockLLM(generator, null);
}
- static MockLLMClient createLLMClientWithoutSecretStore() {
- var config = new LlmClientConfig.Builder().apiKeySecretName("api-key").build();
- var secretStore = new SecretStoreProvider();
- var generator = createGenerator();
- return new MockLLMClient(config, secretStore.get(), generator, null);
+ static MockLLM createLLMClient(ExecutorService executor) {
+ return new MockLLM(createGenerator(), executor);
+ }
+
+ private static Searcher createLLMSearcher(LanguageModel llm) {
+ return createLLMSearcher(Map.of("mock", llm));
}
private static Searcher createLLMSearcher(Map<String, LanguageModel> llms) {
var config = new LlmSearcherConfig.Builder().stream(false).build();
- ComponentRegistry<LanguageModel> models = new ComponentRegistry<>();
- llms.forEach((key, value) -> models.register(ComponentId.fromString(key), value));
- models.freeze();
- return new LLMSearcher(config, models);
+ return createLLMSearcher(config, llms);
+ }
+
+ private static Searcher createLLMSearcher(LlmSearcherConfig config, LanguageModel llm) {
+ return createLLMSearcher(config, Map.of("mock", llm));
}
private static Searcher createLLMSearcher(LlmSearcherConfig config, Map<String, LanguageModel> llms) {
@@ -237,4 +265,44 @@ public class LLMSearcherTest {
return new LLMSearcher(config, models);
}
+ private static class MockLLM implements LanguageModel {
+
+ private final ExecutorService executor;
+ private final BiFunction<Prompt, InferenceParameters, String> generator;
+
+ public MockLLM(BiFunction<Prompt, InferenceParameters, String> generator, ExecutorService executor) {
+ this.executor = executor;
+ this.generator = generator;
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters params) {
+ return List.of(Completion.from(this.generator.apply(prompt, params)));
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt,
+ InferenceParameters params,
+ Consumer<Completion> consumer) {
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ var completions = this.generator.apply(prompt, params).split(" "); // Simple tokenization
+
+ long sleep = 1;
+ executor.submit(() -> {
+ try {
+ for (int i = 0; i < completions.length; ++i) {
+ String completion = (i > 0 ? " " : "") + completions[i];
+ consumer.accept(Completion.from(completion, Completion.FinishReason.none));
+ Thread.sleep(sleep);
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ } catch (InterruptedException e) {
+ // Do nothing
+ }
+ });
+ return completionFuture;
+ }
+
+ }
+
}
diff --git a/container-test/pom.xml b/container-test/pom.xml
index 8e1b4870665..d6be6946208 100644
--- a/container-test/pom.xml
+++ b/container-test/pom.xml
@@ -61,6 +61,16 @@
<artifactId>onnxruntime</artifactId>
</dependency>
<dependency>
+ <groupId>de.kherud</groupId>
+ <artifactId>llama</artifactId>
+ <exclusions>
+ <exclusion>
+ <groupId>org.jetbrains</groupId>
+ <artifactId>annotations</artifactId>
+ </exclusion>
+ </exclusions>
+ </dependency>
+ <dependency>
<groupId>io.airlift</groupId>
<artifactId>airline</artifactId>
<exclusions>
diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml
index fe3982c4e34..1c87e1da589 100644
--- a/dependency-versions/pom.xml
+++ b/dependency-versions/pom.xml
@@ -126,7 +126,7 @@
<mimepull.vespa.version>1.10.0</mimepull.vespa.version>
<mockito.vespa.version>5.11.0</mockito.vespa.version>
<mojo-executor.vespa.version>2.4.0</mojo-executor.vespa.version>
- <netty.vespa.version>4.1.108.Final</netty.vespa.version>
+ <netty.vespa.version>4.1.109.Final</netty.vespa.version>
<netty-tcnative.vespa.version>2.0.65.Final</netty-tcnative.vespa.version>
<onnxruntime.vespa.version>1.17.1</onnxruntime.vespa.version>
<opennlp.vespa.version>2.3.2</opennlp.vespa.version>
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Dimension.java b/flags/src/main/java/com/yahoo/vespa/flags/Dimension.java
index 328d581aed3..b02fa949dbb 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Dimension.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Dimension.java
@@ -44,8 +44,8 @@ public enum Dimension {
/**
* Cloud from com.yahoo.config.provision.CloudName::value, e.g. yahoo, aws, gcp.
*
- * <p><em>Eager resolution</em>: This dimension is resolved before putting the flag data to the config server
- * or controller, unless controller and the flag has declared this dimension.
+ * <p><em>Eager resolution</em>: This dimension is resolved before storing the flag data in the config server and
+ * controller ZooKeeper, UNLESS it is the controller and the flag was defined with this dimension in [Permanent]Flags.
*/
CLOUD("cloud"),
@@ -61,7 +61,10 @@ public enum Dimension {
/** Email address of user - provided by auth0 in console. */
CONSOLE_USER_EMAIL("console-user-email"),
- /** Hosted Vespa environment from com.yahoo.config.provision.Environment::value, e.g. prod, staging, test. */
+ /**
+ * Hosted Vespa environment from com.yahoo.config.provision.Environment::value, e.g. prod, staging, test.
+ * <em>Eager resolution</em>, see {@link #CLOUD}.
+ */
ENVIRONMENT("environment"),
/**
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index ba7775b1790..e9de8cdca20 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -231,6 +231,13 @@ public class Flags {
"Takes effect on next tick.",
NODE_TYPE);
+ public static final UnboundStringFlag DIST_HOST = defineStringFlag(
+ "dist-host", "",
+ List.of("freva"), "2024-04-15", "2024-05-31",
+ "Sets dist_host YUM variable, empty means old behavior. Only effective in Public.",
+ "Provisioning of instance or next host-admin tick",
+ HOSTNAME, NODE_TYPE, CLOUD_ACCOUNT);
+
public static final UnboundBooleanFlag ENABLED_HORIZON_DASHBOARD = defineFeatureFlag(
"enabled-horizon-dashboard", false,
List.of("olaa"), "2021-09-13", "2024-09-01",
@@ -370,13 +377,6 @@ public class Flags {
"Takes effect at redeployment",
INSTANCE_ID);
- public static final UnboundBooleanFlag DYNAMIC_HEAP_SIZE = defineFeatureFlag(
- "dynamic-heap-size", true,
- List.of("bjorncs"), "2023-09-21", "2024-04-15",
- "Whether to calculate JVM heap size based on predicted Onnx model memory requirements",
- "Takes effect at redeployment",
- INSTANCE_ID);
-
public static final UnboundStringFlag UNKNOWN_CONFIG_DEFINITION = defineStringFlag(
"unknown-config-definition", "warn",
List.of("hmusum"), "2023-09-25", "2024-09-01",
diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java
index a6b09ddefd8..43cc8fda3c9 100644
--- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java
+++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/DimensionId.java
@@ -2,6 +2,7 @@
package ai.vespa.metricsproxy.metric.model;
import com.yahoo.concurrent.CopyOnWriteHashMap;
+import io.prometheus.client.Collector;
import java.util.Map;
import java.util.Objects;
@@ -13,12 +14,18 @@ public final class DimensionId {
private static final Map<String, DimensionId> dictionary = new CopyOnWriteHashMap<>();
public final String id;
- private DimensionId(String id) { this.id = id; }
+ private final String idForPrometheus;
+ private DimensionId(String id) {
+ this.id = id;
+ idForPrometheus = Collector.sanitizeMetricName(id);
+ }
public static DimensionId toDimensionId(String id) {
return dictionary.computeIfAbsent(id, key -> new DimensionId(key));
}
+ public String getIdForPrometheus() { return idForPrometheus; }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java
index 9014e818eab..829eb06101f 100644
--- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java
+++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/MetricId.java
@@ -2,6 +2,7 @@
package ai.vespa.metricsproxy.metric.model;
import com.yahoo.concurrent.CopyOnWriteHashMap;
+import io.prometheus.client.Collector;
import java.util.Map;
import java.util.Objects;
@@ -14,11 +15,16 @@ public class MetricId {
private static final Map<String, MetricId> dictionary = new CopyOnWriteHashMap<>();
public static final MetricId empty = toMetricId("");
public final String id;
- private MetricId(String id) { this.id = id; }
+ private final String idForPrometheus;
+ private MetricId(String id) {
+ this.id = id;
+ idForPrometheus = Collector.sanitizeMetricName(id);
+ }
public static MetricId toMetricId(String id) {
- return dictionary.computeIfAbsent(id, key -> new MetricId(key));
+ return dictionary.computeIfAbsent(id, MetricId::new);
}
+ public String getIdForPrometheus() { return idForPrometheus; }
@Override
public boolean equals(Object o) {
@@ -29,13 +35,9 @@ public class MetricId {
}
@Override
- public int hashCode() {
- return Objects.hash(id);
- }
+ public int hashCode() { return Objects.hash(id); }
@Override
- public String toString() {
- return id;
- }
+ public String toString() { return id; }
}
diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java
index 96ee2fa00e2..28c64b012c1 100644
--- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java
+++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/ServiceId.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.metricsproxy.metric.model;
+import io.prometheus.client.Collector;
+
import java.util.Objects;
/**
@@ -9,10 +11,16 @@ import java.util.Objects;
public class ServiceId {
public final String id;
- private ServiceId(String id) { this.id = id; }
+ private final String idForPrometheus;
+ private ServiceId(String id) {
+ this.id = id;
+ idForPrometheus = Collector.sanitizeMetricName(id);
+ }
public static ServiceId toServiceId(String id) { return new ServiceId(id); }
+ public String getIdForPrometheus() { return idForPrometheus; }
+
@Override
public boolean equals(Object o) {
if (this == o) return true;
diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java
index d7436ccf404..2b0db5381bc 100644
--- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java
+++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/model/prometheus/PrometheusUtil.java
@@ -28,13 +28,14 @@ public class PrometheusUtil {
Map<String, List<Sample>> samples = new HashMap<>();
packetsByService.forEach(((serviceId, packets) -> {
- var serviceName = Collector.sanitizeMetricName(serviceId.id);
+ var serviceName = serviceId.getIdForPrometheus();
for (var packet : packets) {
+ Long timeStamp = packet.timestamp * 1000;
var dimensions = packet.dimensions();
List<String> labels = new ArrayList<>(dimensions.size());
List<String> labelValues = new ArrayList<>(dimensions.size());
for (var entry : dimensions.entrySet()) {
- var labelName = Collector.sanitizeMetricName(entry.getKey().id);
+ var labelName = entry.getKey().getIdForPrometheus();
labels.add(labelName);
labelValues.add(entry.getValue());
}
@@ -42,7 +43,7 @@ public class PrometheusUtil {
labelValues.add(serviceName);
for (var metric : packet.metrics().entrySet()) {
- var metricName = Collector.sanitizeMetricName(metric.getKey().id);
+ var metricName = metric.getKey().getIdForPrometheus();
List<Sample> sampleList;
if (samples.containsKey(metricName)) {
sampleList = samples.get(metricName);
@@ -51,7 +52,7 @@ public class PrometheusUtil {
samples.put(metricName, sampleList);
metricFamilySamples.add(new MetricFamilySamples(metricName, Collector.Type.UNKNOWN, "", sampleList));
}
- sampleList.add(new Sample(metricName, labels, labelValues, metric.getValue().doubleValue(), packet.timestamp * 1000));
+ sampleList.add(new Sample(metricName, labels, labelValues, metric.getValue().doubleValue(), timeStamp));
}
}
if (!packets.isEmpty()) {
diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java
index 6c3b759e97b..0e33d7dbf2f 100644
--- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java
+++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java
@@ -28,6 +28,7 @@ import static ai.vespa.metricsproxy.metric.model.DimensionId.toDimensionId;
* @author Jo Kristian Bergum
*/
public class MetricsParser {
+ private static final Double ZERO_DOUBLE = 0d;
public interface Collector {
void accept(Metric metric);
}
@@ -186,7 +187,8 @@ public class MetricsParser {
if (token == JsonToken.VALUE_NUMBER_INT) {
metrics.add(Map.entry(metricName, parser.getLongValue()));
} else if (token == JsonToken.VALUE_NUMBER_FLOAT) {
- metrics.add(Map.entry(metricName, parser.getValueAsDouble()));
+ double value = parser.getValueAsDouble();
+ metrics.add(Map.entry(metricName, value == ZERO_DOUBLE ? ZERO_DOUBLE : value));
} else {
throw new IllegalArgumentException("Value for aggregator '" + fieldName + "' is not a number");
}
diff --git a/metrics/src/main/java/ai/vespa/metrics/ControllerMetrics.java b/metrics/src/main/java/ai/vespa/metrics/ControllerMetrics.java
index 05d51967166..dc461fbdbfa 100644
--- a/metrics/src/main/java/ai/vespa/metrics/ControllerMetrics.java
+++ b/metrics/src/main/java/ai/vespa/metrics/ControllerMetrics.java
@@ -45,8 +45,8 @@ public enum ControllerMetrics implements VespaMetrics {
AUTH0_EXCEPTIONS("auth0.exceptions", Unit.FAILURE, "Controller: Auth0 exceptions"),
CERTIFICATE_POOL_AVAILABLE("certificate_pool_available", Unit.FRACTION, "Available certificates in the pool, fraction of configured size"),
BILLING_EXCEPTIONS("billing.exceptions", Unit.FAILURE, "Controller: Billing related exceptions"),
- BILLING_WEBHOOK_FILTER_FAILURES("billing.webhook.failures", Unit.FAILURE, "Controller: webhook filter failures"),
- BILLING_WEBHOOK_FILTER_REQUESTS("billing.webhook.requests", Unit.REQUEST, "Controller: webhook filter requests"),
+ BILLING_WEBHOOK_FAILURES("billing.webhook.failures", Unit.FAILURE, "Controller: webhook failures"),
+ BILLING_WEBHOOK_REQUESTS("billing.webhook.requests", Unit.REQUEST, "Controller: webhook requests"),
// Metrics per API, metrics names generated in ControllerMaintainer/MetricsReporter
OPERATION_APPLICATION("operation.application", Unit.REQUEST, "Controller: Requests for /application API"),
diff --git a/model-integration/abi-spec.json b/model-integration/abi-spec.json
index d3c472778e6..e7130d9c777 100644
--- a/model-integration/abi-spec.json
+++ b/model-integration/abi-spec.json
@@ -1,4 +1,186 @@
{
+ "ai.vespa.llm.clients.ConfigurableLanguageModel" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public",
+ "abstract"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "protected java.lang.String getApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected void setApiKey(ai.vespa.llm.InferenceParameters)",
+ "protected java.lang.String getEndpoint()",
+ "protected void setEndpoint(ai.vespa.llm.InferenceParameters)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig$Builder" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig)",
+ "public ai.vespa.llm.clients.LlmClientConfig$Builder apiKeySecretName(java.lang.String)",
+ "public ai.vespa.llm.clients.LlmClientConfig$Builder endpoint(java.lang.String)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public ai.vespa.llm.clients.LlmClientConfig build()"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig$Producer" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes" : [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods" : [
+ "public abstract void getConfig(ai.vespa.llm.clients.LlmClientConfig$Builder)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmClientConfig" : {
+ "superClass" : "com.yahoo.config.ConfigInstance",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig$Builder)",
+ "public java.lang.String apiKeySecretName()",
+ "public java.lang.String endpoint()"
+ ],
+ "fields" : [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
+ "ai.vespa.llm.clients.LlmLocalClientConfig$Builder" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Builder"
+ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public void <init>()",
+ "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder model(com.yahoo.config.ModelReference)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder parallelRequests(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxQueueSize(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder useGpu(boolean)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder gpuLayers(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder threads(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder contextSize(int)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig$Builder maxTokens(int)",
+ "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)",
+ "public final java.lang.String getDefMd5()",
+ "public final java.lang.String getDefName()",
+ "public final java.lang.String getDefNamespace()",
+ "public final boolean getApplyOnRestart()",
+ "public final void setApplyOnRestart(boolean)",
+ "public ai.vespa.llm.clients.LlmLocalClientConfig build()"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmLocalClientConfig$Producer" : {
+ "superClass" : "java.lang.Object",
+ "interfaces" : [
+ "com.yahoo.config.ConfigInstance$Producer"
+ ],
+ "attributes" : [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods" : [
+ "public abstract void getConfig(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.LlmLocalClientConfig" : {
+ "superClass" : "com.yahoo.config.ConfigInstance",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public",
+ "final"
+ ],
+ "methods" : [
+ "public static java.lang.String getDefMd5()",
+ "public static java.lang.String getDefName()",
+ "public static java.lang.String getDefNamespace()",
+ "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig$Builder)",
+ "public java.nio.file.Path model()",
+ "public int parallelRequests()",
+ "public int maxQueueSize()",
+ "public boolean useGpu()",
+ "public int gpuLayers()",
+ "public int threads()",
+ "public int contextSize()",
+ "public int maxTokens()"
+ ],
+ "fields" : [
+ "public static final java.lang.String CONFIG_DEF_MD5",
+ "public static final java.lang.String CONFIG_DEF_NAME",
+ "public static final java.lang.String CONFIG_DEF_NAMESPACE",
+ "public static final java.lang.String[] CONFIG_DEF_SCHEMA"
+ ]
+ },
+ "ai.vespa.llm.clients.LocalLLM" : {
+ "superClass" : "com.yahoo.component.AbstractComponent",
+ "interfaces" : [
+ "ai.vespa.llm.LanguageModel"
+ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.clients.LlmLocalClientConfig)",
+ "public void deconstruct()",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
+ "ai.vespa.llm.clients.OpenAI" : {
+ "superClass" : "ai.vespa.llm.clients.ConfigurableLanguageModel",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(ai.vespa.llm.clients.LlmClientConfig, com.yahoo.container.jdisc.secretstore.SecretStore)",
+ "public java.util.List complete(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters)",
+ "public java.util.concurrent.CompletableFuture completeAsync(ai.vespa.llm.completion.Prompt, ai.vespa.llm.InferenceParameters, java.util.function.Consumer)"
+ ],
+ "fields" : [ ]
+ },
"ai.vespa.llm.generation.Generator" : {
"superClass" : "com.yahoo.component.AbstractComponent",
"interfaces" : [ ],
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 0bab30e1453..d92fa319251 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -40,6 +40,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>container-disc</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>searchcore</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
@@ -76,6 +82,12 @@
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
+ <artifactId>container-llama</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
<artifactId>component</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
index 761fdf0af93..761fdf0af93 100644
--- a/container-search/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/ConfigurableLanguageModel.java
diff --git a/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
new file mode 100644
index 00000000000..fd1b8b700c8
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/LocalLLM.java
@@ -0,0 +1,126 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.LanguageModel;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import de.kherud.llama.LlamaModel;
+import de.kherud.llama.ModelParameters;
+import de.kherud.llama.args.LogFormat;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Consumer;
+import java.util.logging.Logger;
+
+/**
+ * A language model running locally on the container node.
+ *
+ * @author lesters
+ */
+public class LocalLLM extends AbstractComponent implements LanguageModel {
+
+ private final static Logger logger = Logger.getLogger(LocalLLM.class.getName());
+ private final LlamaModel model;
+ private final ThreadPoolExecutor executor;
+ private final int contextSize;
+ private final int maxTokens;
+
+ @Inject
+ public LocalLLM(LlmLocalClientConfig config) {
+ executor = createExecutor(config);
+
+ // Maximum number of tokens to generate - need this since some models can just generate infinitely
+ maxTokens = config.maxTokens();
+
+ // Only used if GPU is not used
+ var defaultThreadCount = Runtime.getRuntime().availableProcessors() - 2;
+
+ var modelFile = config.model().toFile().getAbsolutePath();
+ var modelParams = new ModelParameters()
+ .setModelFilePath(modelFile)
+ .setContinuousBatching(true)
+ .setNParallel(config.parallelRequests())
+ .setNThreads(config.threads() <= 0 ? defaultThreadCount : config.threads())
+ .setNCtx(config.contextSize())
+ .setNGpuLayers(config.useGpu() ? config.gpuLayers() : 0);
+
+ long startLoad = System.nanoTime();
+ model = new LlamaModel(modelParams);
+ long loadTime = System.nanoTime() - startLoad;
+ logger.info(String.format("Loaded model %s in %.2f sec", modelFile, (loadTime*1.0/1000000000)));
+
+ // Todo: handle prompt context size - such as give a warning when prompt exceeds context size
+ contextSize = config.contextSize();
+ }
+
+ private ThreadPoolExecutor createExecutor(LlmLocalClientConfig config) {
+ return new ThreadPoolExecutor(config.parallelRequests(), config.parallelRequests(),
+ 0L, TimeUnit.MILLISECONDS,
+ config.maxQueueSize() > 0 ? new ArrayBlockingQueue<>(config.maxQueueSize()) : new SynchronousQueue<>(),
+ new ThreadPoolExecutor.AbortPolicy());
+ }
+
+ @Override
+ public void deconstruct() {
+ logger.info("Closing LLM model...");
+ model.close();
+ executor.shutdownNow();
+ }
+
+ @Override
+ public List<Completion> complete(Prompt prompt, InferenceParameters options) {
+ StringBuilder result = new StringBuilder();
+ var future = completeAsync(prompt, options, completion -> {
+ result.append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ var reason = future.join();
+
+ List<Completion> completions = new ArrayList<>();
+ completions.add(new Completion(result.toString(), reason));
+ return completions;
+ }
+
+ @Override
+ public CompletableFuture<Completion.FinishReason> completeAsync(Prompt prompt, InferenceParameters options, Consumer<Completion> consumer) {
+ var inferParams = new de.kherud.llama.InferenceParameters(prompt.asString().stripLeading());
+
+ // We always set this to some value to avoid infinite token generation
+ inferParams.setNPredict(maxTokens);
+
+ options.ifPresent("temperature", (v) -> inferParams.setTemperature(Float.parseFloat(v)));
+ options.ifPresent("topk", (v) -> inferParams.setTopK(Integer.parseInt(v)));
+ options.ifPresent("topp", (v) -> inferParams.setTopP(Integer.parseInt(v)));
+ options.ifPresent("npredict", (v) -> inferParams.setNPredict(Integer.parseInt(v)));
+ options.ifPresent("repeatpenalty", (v) -> inferParams.setRepeatPenalty(Float.parseFloat(v)));
+ // Todo: more options?
+
+ var completionFuture = new CompletableFuture<Completion.FinishReason>();
+ try {
+ executor.submit(() -> {
+ for (LlamaModel.Output output : model.generate(inferParams)) {
+ consumer.accept(Completion.from(output.text, Completion.FinishReason.none));
+ }
+ completionFuture.complete(Completion.FinishReason.stop);
+ });
+ } catch (RejectedExecutionException e) {
+ // If we have too many requests (active + any waiting in queue), we reject the completion
+ int activeCount = executor.getActiveCount();
+ int queueSize = executor.getQueue().size();
+ String error = String.format("Rejected completion due to too many requests, " +
+ "%d active, %d in queue", activeCount, queueSize);
+ throw new RejectedExecutionException(error);
+ }
+ return completionFuture;
+ }
+
+}
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
index 82e19d47c92..82e19d47c92 100644
--- a/container-search/src/main/java/ai/vespa/llm/clients/OpenAI.java
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/OpenAI.java
diff --git a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
index c360245901c..c360245901c 100644
--- a/container-search/src/main/java/ai/vespa/llm/clients/package-info.java
+++ b/model-integration/src/main/java/ai/vespa/llm/clients/package-info.java
diff --git a/container-search/src/main/resources/configdefinitions/llm-client.def b/model-integration/src/main/resources/configdefinitions/llm-client.def
index 0866459166a..0866459166a 100755
--- a/container-search/src/main/resources/configdefinitions/llm-client.def
+++ b/model-integration/src/main/resources/configdefinitions/llm-client.def
diff --git a/model-integration/src/main/resources/configdefinitions/llm-local-client.def b/model-integration/src/main/resources/configdefinitions/llm-local-client.def
new file mode 100755
index 00000000000..c06c24b33e5
--- /dev/null
+++ b/model-integration/src/main/resources/configdefinitions/llm-local-client.def
@@ -0,0 +1,29 @@
+# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package=ai.vespa.llm.clients
+
+# The LLM model to use
+model model
+
+# Maximum number of requests to handle in parallel pr container node
+parallelRequests int default=10
+
+# Additional number of requests to put in queue for processing before starting to reject new requests
+maxQueueSize int default=10
+
+# Use GPU
+useGpu bool default=false
+
+# Maximum number of model layers to run on GPU
+gpuLayers int default=1000000
+
+# Number of threads to use for CPU processing - -1 means use all available cores
+# Not used for GPU processing
+threads int default=-1
+
+# Context size for the model
+# Context is divided between parallel requests. So for 10 parallel requests, each "slot" gets 1/10 of the context
+contextSize int default=512
+
+# Maximum number of tokens to process in one request - overriden by inference parameters
+maxTokens int default=512
+
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
index 35d5cfd3855..35d5cfd3855 100644
--- a/container-search/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/ConfigurableLanguageModelTest.java
diff --git a/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
new file mode 100644
index 00000000000..a3b260f3fb5
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/LocalLLMTest.java
@@ -0,0 +1,186 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.llm.clients;
+
+import ai.vespa.llm.InferenceParameters;
+import ai.vespa.llm.completion.Completion;
+import ai.vespa.llm.completion.Prompt;
+import ai.vespa.llm.completion.StringPrompt;
+import com.yahoo.config.ModelReference;
+import org.junit.jupiter.api.Disabled;
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+/**
+ * Tests for LocalLLM.
+ *
+ * @author lesters
+ */
+public class LocalLLMTest {
+
+ private static String model = "src/test/models/llm/tinyllm.gguf";
+ private static Prompt prompt = StringPrompt.from("A random prompt");
+
+ @Test
+ @Disabled
+ public void testGeneration() {
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(1)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ try {
+ var result = llm.complete(prompt, defaultOptions());
+ assertEquals(Completion.FinishReason.stop, result.get(0).finishReason());
+ assertTrue(result.get(0).text().length() > 10);
+ } finally {
+ llm.deconstruct();
+ }
+ }
+
+ @Test
+ @Disabled
+ public void testAsyncGeneration() {
+ var sb = new StringBuilder();
+ var tokenCount = new AtomicInteger(0);
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(1)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ try {
+ var future = llm.completeAsync(prompt, defaultOptions(), completion -> {
+ sb.append(completion.text());
+ tokenCount.incrementAndGet();
+ }).exceptionally(exception -> Completion.FinishReason.error);
+
+ assertFalse(future.isDone());
+ var reason = future.join();
+ assertTrue(future.isDone());
+ assertNotEquals(reason, Completion.FinishReason.error);
+
+ } finally {
+ llm.deconstruct();
+ }
+ assertTrue(tokenCount.get() > 0);
+ System.out.println(sb);
+ }
+
+ @Test
+ @Disabled
+ public void testParallelGeneration() {
+ var prompts = testPrompts();
+ var promptsToUse = prompts.size();
+ var parallelRequests = 10;
+
+ var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
+ var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
+ var tokenCounts = new ArrayList<>(Collections.nCopies(promptsToUse, 0));
+
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(parallelRequests)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ try {
+ for (int i = 0; i < promptsToUse; i++) {
+ final var seq = i;
+
+ completions.set(seq, new StringBuilder());
+ futures.set(seq, llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
+ completions.get(seq).append(completion.text());
+ tokenCounts.set(seq, tokenCounts.get(seq) + 1);
+ }).exceptionally(exception -> Completion.FinishReason.error));
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ var reason = futures.get(i).join();
+ assertNotEquals(reason, Completion.FinishReason.error);
+ }
+ } finally {
+ llm.deconstruct();
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ assertFalse(completions.get(i).isEmpty());
+ assertTrue(tokenCounts.get(i) > 0);
+ }
+ }
+
+ @Test
+ @Disabled
+ public void testRejection() {
+ var prompts = testPrompts();
+ var promptsToUse = prompts.size();
+ var parallelRequests = 2;
+ var additionalQueue = 1;
+ // 7 should be rejected
+
+ var futures = new ArrayList<CompletableFuture<Completion.FinishReason>>(Collections.nCopies(promptsToUse, null));
+ var completions = new ArrayList<StringBuilder>(Collections.nCopies(promptsToUse, null));
+
+ var config = new LlmLocalClientConfig.Builder()
+ .parallelRequests(parallelRequests)
+ .maxQueueSize(additionalQueue)
+ .model(ModelReference.valueOf(model));
+ var llm = new LocalLLM(config.build());
+
+ var rejected = new AtomicInteger(0);
+ try {
+ for (int i = 0; i < promptsToUse; i++) {
+ final var seq = i;
+
+ completions.set(seq, new StringBuilder());
+ try {
+ var future = llm.completeAsync(StringPrompt.from(prompts.get(seq)), defaultOptions(), completion -> {
+ completions.get(seq).append(completion.text());
+ }).exceptionally(exception -> Completion.FinishReason.error);
+ futures.set(seq, future);
+ } catch (RejectedExecutionException e) {
+ rejected.incrementAndGet();
+ }
+ }
+ for (int i = 0; i < promptsToUse; i++) {
+ if (futures.get(i) != null) {
+ assertNotEquals(futures.get(i).join(), Completion.FinishReason.error);
+ }
+ }
+ } finally {
+ llm.deconstruct();
+ }
+ assertEquals(7, rejected.get());
+ }
+
+ private static InferenceParameters defaultOptions() {
+ final Map<String, String> options = Map.of(
+ "temperature", "0.1",
+ "npredict", "100"
+ );
+ return new InferenceParameters(options::get);
+ }
+
+ private List<String> testPrompts() {
+ List<String> prompts = new ArrayList<>();
+ prompts.add("Write a short story about a time-traveling detective who must solve a mystery that spans multiple centuries.");
+ prompts.add("Explain the concept of blockchain technology and its implications for data security in layman's terms.");
+ prompts.add("Discuss the socio-economic impacts of the Industrial Revolution in 19th century Europe.");
+ prompts.add("Describe a future where humans have colonized Mars, focusing on daily life and societal structure.");
+ prompts.add("Analyze the statement 'If a tree falls in a forest and no one is around to hear it, does it make a sound?' from both a philosophical and a physics perspective.");
+ prompts.add("Translate the following sentence into French: 'The quick brown fox jumps over the lazy dog.'");
+ prompts.add("Explain what the following Python code does: `print([x for x in range(10) if x % 2 == 0])`.");
+ prompts.add("Provide general guidelines for maintaining a healthy lifestyle to reduce the risk of developing heart disease.");
+ prompts.add("Create a detailed description of a fictional planet, including its ecosystem, dominant species, and technology level.");
+ prompts.add("Discuss the impact of social media on interpersonal communication in the 21st century.");
+ return prompts;
+ }
+
+}
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
index 4d0073f1cbe..4d0073f1cbe 100644
--- a/container-search/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/MockLLMClient.java
diff --git a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java
index 57339f6ad49..57339f6ad49 100644
--- a/container-search/src/test/java/ai/vespa/llm/clients/OpenAITest.java
+++ b/model-integration/src/test/java/ai/vespa/llm/clients/OpenAITest.java
diff --git a/model-integration/src/test/models/llm/tinyllm.gguf b/model-integration/src/test/models/llm/tinyllm.gguf
new file mode 100644
index 00000000000..34367b6b57b
--- /dev/null
+++ b/model-integration/src/test/models/llm/tinyllm.gguf
Binary files differ
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java
index 8646121bd4b..f00414aa654 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.provision.applications;
+import com.yahoo.config.provision.CloudAccount;
import com.yahoo.config.provision.ClusterInfo;
import com.yahoo.config.provision.IntRange;
import com.yahoo.config.provision.Capacity;
@@ -33,6 +34,7 @@ public class Cluster {
private final ClusterResources min, max;
private final IntRange groupSize;
private final boolean required;
+ private final Optional<CloudAccount> cloudAccount;
private final List<Autoscaling> suggestions;
private final Autoscaling target;
private final ClusterInfo clusterInfo;
@@ -47,6 +49,7 @@ public class Cluster {
ClusterResources maxResources,
IntRange groupSize,
boolean required,
+ Optional<CloudAccount> cloudAccount,
List<Autoscaling> suggestions,
Autoscaling target,
ClusterInfo clusterInfo,
@@ -58,6 +61,7 @@ public class Cluster {
this.max = Objects.requireNonNull(maxResources);
this.groupSize = Objects.requireNonNull(groupSize);
this.required = required;
+ this.cloudAccount = Objects.requireNonNull(cloudAccount);
this.suggestions = Objects.requireNonNull(suggestions);
Objects.requireNonNull(target);
if (target.resources().isPresent() && ! target.resources().get().isWithin(minResources, maxResources))
@@ -89,6 +93,9 @@ public class Cluster {
*/
public boolean required() { return required; }
+ /** Returns the enclave cloud account of this cluster, or empty if not enclave. */
+ public Optional<CloudAccount> cloudAccount() { return cloudAccount; }
+
/**
* Returns the computed resources (between min and max, inclusive) this cluster should
* have allocated at the moment (whether or not it actually has it),
@@ -134,19 +141,19 @@ public class Cluster {
public Cluster withConfiguration(boolean exclusive, Capacity capacity) {
return new Cluster(id, exclusive,
capacity.minResources(), capacity.maxResources(), capacity.groupSize(), capacity.isRequired(),
- suggestions, target, capacity.clusterInfo(), bcpGroupInfo, scalingEvents);
+ capacity.cloudAccount(), suggestions, target, capacity.clusterInfo(), bcpGroupInfo, scalingEvents);
}
public Cluster withSuggestions(List<Autoscaling> suggestions) {
- return new Cluster(id, exclusive, min, max, groupSize, required, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
+ return new Cluster(id, exclusive, min, max, groupSize, required, cloudAccount, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
}
public Cluster withTarget(Autoscaling target) {
- return new Cluster(id, exclusive, min, max, groupSize, required, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
+ return new Cluster(id, exclusive, min, max, groupSize, required, cloudAccount, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
}
public Cluster with(BcpGroupInfo bcpGroupInfo) {
- return new Cluster(id, exclusive, min, max, groupSize, required, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
+ return new Cluster(id, exclusive, min, max, groupSize, required, cloudAccount, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
}
/** Add or update (based on "at" time) a scaling event */
@@ -160,7 +167,7 @@ public class Cluster {
scalingEvents.add(scalingEvent);
prune(scalingEvents);
- return new Cluster(id, exclusive, min, max, groupSize, required, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
+ return new Cluster(id, exclusive, min, max, groupSize, required, cloudAccount, suggestions, target, clusterInfo, bcpGroupInfo, scalingEvents);
}
@Override
@@ -192,7 +199,7 @@ public class Cluster {
public static Cluster create(ClusterSpec.Id id, boolean exclusive, Capacity requested) {
return new Cluster(id, exclusive,
requested.minResources(), requested.maxResources(), requested.groupSize(), requested.isRequired(),
- List.of(), Autoscaling.empty(), requested.clusterInfo(), BcpGroupInfo.empty(), List.of());
+ requested.cloudAccount(), List.of(), Autoscaling.empty(), requested.clusterInfo(), BcpGroupInfo.empty(), List.of());
}
/** The predicted time it will take to rescale this cluster. */
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java
index 0c3a1df0f27..8a5780496b1 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializer.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.provision.persistence;
+import com.yahoo.config.provision.CloudAccount;
import com.yahoo.config.provision.ClusterInfo;
import com.yahoo.config.provision.IntRange;
import com.yahoo.config.provision.ApplicationId;
@@ -56,6 +57,7 @@ public class ApplicationSerializer {
private static final String maxResourcesKey = "max";
private static final String groupSizeKey = "groupSize";
private static final String requiredKey = "required";
+ private static final String cloudAccountKey = "cloudAccount";
private static final String suggestionsKey = "suggestionsKey";
private static final String clusterInfoKey = "clusterInfo";
private static final String bcpDeadlineKey = "bcpDeadline";
@@ -140,6 +142,7 @@ public class ApplicationSerializer {
toSlime(cluster.maxResources(), clusterObject.setObject(maxResourcesKey));
toSlime(cluster.groupSize(), clusterObject.setObject(groupSizeKey));
clusterObject.setBool(requiredKey, cluster.required());
+ cluster.cloudAccount().ifPresent(cloudAccount -> clusterObject.setString(cloudAccountKey, cloudAccount.value()));
toSlime(cluster.suggestions(), clusterObject.setArray(suggestionsKey));
toSlime(cluster.target(), clusterObject.setObject(targetKey));
if (! cluster.clusterInfo().isEmpty())
@@ -156,6 +159,7 @@ public class ApplicationSerializer {
clusterResourcesFromSlime(clusterObject.field(maxResourcesKey)),
intRangeFromSlime(clusterObject.field(groupSizeKey)),
clusterObject.field(requiredKey).asBool(),
+ optionalCloudAccount(clusterObject.field(cloudAccountKey)),
suggestionsFromSlime(clusterObject.field(suggestionsKey)),
autoscalingFromSlime(clusterObject.field(targetKey)),
clusterInfoFromSlime(clusterObject.field(clusterInfoKey)),
@@ -326,6 +330,10 @@ public class ApplicationSerializer {
};
}
+ private static Optional<CloudAccount> optionalCloudAccount(Inspector inspector) {
+ return inspector.valid() ? Optional.of(CloudAccount.from(inspector.asString())) : Optional.empty();
+ }
+
private static Optional<Instant> optionalInstant(Inspector inspector) {
return inspector.valid() ? Optional.of(Instant.ofEpochMilli(inspector.asLong())) : Optional.empty();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java
index f25d4cc3c30..72f402ca997 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/ApplicationSerializerTest.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.provision.persistence;
+import com.yahoo.config.provision.CloudAccount;
import com.yahoo.config.provision.ClusterInfo;
import com.yahoo.config.provision.IntRange;
import com.yahoo.config.provision.ApplicationId;
@@ -40,6 +41,7 @@ public class ApplicationSerializerTest {
new ClusterResources(12, 6, new NodeResources(3, 6, 21, 24)),
IntRange.empty(),
true,
+ Optional.empty(),
List.of(),
Autoscaling.empty(),
ClusterInfo.empty(),
@@ -52,6 +54,7 @@ public class ApplicationSerializerTest {
new ClusterResources(14, 7, new NodeResources(3, 6, 21, 24)),
IntRange.of(3, 5),
false,
+ Optional.of(CloudAccount.from("aws:123456789012")),
List.of(new Autoscaling(Autoscaling.Status.unavailable,
"",
Optional.of(new ClusterResources(20, 10,
@@ -97,6 +100,7 @@ public class ApplicationSerializerTest {
assertEquals(originalCluster.maxResources(), serializedCluster.maxResources());
assertEquals(originalCluster.groupSize(), serializedCluster.groupSize());
assertEquals(originalCluster.required(), serializedCluster.required());
+ assertEquals(originalCluster.cloudAccount(), serializedCluster.cloudAccount());
assertEquals(originalCluster.suggestions(), serializedCluster.suggestions());
assertEquals(originalCluster.target(), serializedCluster.target());
assertEquals(originalCluster.clusterInfo(), serializedCluster.clusterInfo());
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java
index 1f8178dff6a..ca6dc5a0044 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicProvisioningTester.java
@@ -143,6 +143,7 @@ public class DynamicProvisioningTester {
cluster.maxResources(),
cluster.groupSize(),
cluster.required(),
+ cluster.cloudAccount(),
cluster.suggestions(),
cluster.target(),
cluster.clusterInfo(),
diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp
index 87004d7e5f2..0c986422be6 100644
--- a/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp
+++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/lid_allocator.cpp
@@ -2,10 +2,11 @@
#include "lid_allocator.h"
#include <vespa/searchlib/common/bitvectoriterator.h>
-#include <vespa/searchlib/fef/termfieldmatchdataarray.h>
#include <vespa/searchlib/fef/matchdata.h>
-#include <vespa/searchlib/queryeval/full_search.h>
+#include <vespa/searchlib/fef/termfieldmatchdataarray.h>
#include <vespa/searchlib/queryeval/blueprint.h>
+#include <vespa/searchlib/queryeval/flow_tuning.h>
+#include <vespa/searchlib/queryeval/full_search.h>
#include <mutex>
#include <vespa/log/log.h>
@@ -19,6 +20,8 @@ using search::queryeval::SearchIterator;
using search::queryeval::SimpleLeafBlueprint;
using vespalib::GenerationHolder;
+using namespace search::queryeval::flow;
+
namespace proton::documentmetastore {
LidAllocator::LidAllocator(uint32_t size,
@@ -206,7 +209,8 @@ private:
return search::BitVectorIterator::create(&_activeLids, get_docid_limit(), *tfmd, strict);
}
FlowStats calculate_flow_stats(uint32_t docid_limit) const override {
- return default_flow_stats(docid_limit, _activeLids.size(), 0);
+ double rel_est = abs_to_rel_est(_activeLids.size(), docid_limit);
+ return {rel_est, bitvector_cost(), bitvector_strict_cost(rel_est)};
}
SearchIterator::UP
createLeafSearch(const TermFieldMatchDataArray &tfmda) const override
diff --git a/searchlib/src/tests/diskindex/pagedict4/.gitignore b/searchlib/src/tests/diskindex/pagedict4/.gitignore
index 2381ed57229..8aa95a29b63 100644
--- a/searchlib/src/tests/diskindex/pagedict4/.gitignore
+++ b/searchlib/src/tests/diskindex/pagedict4/.gitignore
@@ -3,3 +3,4 @@ Makefile
pagedict4_test
fakedict.*
searchlib_pagedict4_test_app
+/long_words_dir/
diff --git a/searchlib/src/tests/diskindex/pagedict4/CMakeLists.txt b/searchlib/src/tests/diskindex/pagedict4/CMakeLists.txt
index 6be544db829..34114e195bc 100644
--- a/searchlib/src/tests/diskindex/pagedict4/CMakeLists.txt
+++ b/searchlib/src/tests/diskindex/pagedict4/CMakeLists.txt
@@ -16,3 +16,11 @@ vespa_add_executable(searchlib_pagedict4_hugeword_cornercase_test_app TEST
searchlib
)
vespa_add_test(NAME searchlib_pagedict4_hugeword_cornercase_test_app COMMAND searchlib_pagedict4_hugeword_cornercase_test_app)
+
+vespa_add_executable(searchlib_pagedict4_long_words_test_app TEST
+ SOURCES
+ pagedict4_long_words_test.cpp
+ DEPENDS
+ searchlib_test
+ searchlib
+)
diff --git a/searchlib/src/tests/diskindex/pagedict4/pagedict4_long_words_test.cpp b/searchlib/src/tests/diskindex/pagedict4/pagedict4_long_words_test.cpp
new file mode 100644
index 00000000000..dba7980a4b7
--- /dev/null
+++ b/searchlib/src/tests/diskindex/pagedict4/pagedict4_long_words_test.cpp
@@ -0,0 +1,131 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/searchlib/common/tunefileinfo.h>
+#include <vespa/searchlib/diskindex/pagedict4file.h>
+#include <vespa/searchlib/diskindex/pagedict4randread.h>
+#include <vespa/searchlib/index/dummyfileheadercontext.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/stllike/asciistream.h>
+#include <filesystem>
+
+using search::diskindex::PageDict4FileSeqRead;
+using search::diskindex::PageDict4FileSeqWrite;
+using search::diskindex::PageDict4RandRead;
+using search::index::DummyFileHeaderContext;
+using search::index::PostingListCounts;
+using search::index::PostingListOffsetAndCounts;
+using search::index::PostingListParams;
+
+
+namespace {
+
+vespalib::string test_dir("long_words_dir");
+vespalib::string dict(test_dir + "/dict");
+
+PostingListCounts make_counts()
+{
+ PostingListCounts counts;
+ counts._bitLength = 100;
+ counts._numDocs = 1;
+ counts._segments.clear();
+ return counts;
+}
+
+vespalib::string
+make_word(int i)
+{
+ vespalib::asciistream os;
+ vespalib::string word(5_Ki, 'a');
+ os << vespalib::setfill('0') << vespalib::setw(8) << i;
+ word.append(os.str());
+ return word;
+}
+
+}
+
+/*
+ * A long word that don't fit into a 4 KiB 'page' causes a fallback to
+ * overflow handling where the word is put in the .ssdat file.
+ *
+ * Many long words causes excessive growth of the .ssdat file, with
+ * overflow potentials when the whole file is read into a buffer.
+ *
+ * 4 GiB size: Overflow in ComprFileReadBase::ReadComprBuffer for expression
+ * readUnits * cbuf.getUnitSize() when both are 32-bits.
+ * Testable by setting num_words to 900_Ki
+ *
+ * 16 GiB size: Overflow in ComprFileReadBase::ReadComprBuffer when
+ * readUnits is 32-bit signed.
+ * Some overflows in ComprFileDecodeContext API.
+ * Overflow in DecodeContext64Base::getBitPos
+ * Testable by setting num_words to 4_Mi
+ *
+ * 32 GiB size: Overflow when calling ComprFileReadContext::allocComprBuf when
+ * comprBufSize is 32-bit unsigned.
+ * Overflow in DecodeContext64Base::setEnd.
+ * Testable by setting num_words to 9_Mi
+ *
+ * These overflows are fixed.
+ */
+TEST(PageDict4LongWordsTest, test_many_long_words)
+{
+ int num_words = 9_Mi;
+ auto counts = make_counts();
+ std::filesystem::remove_all(std::filesystem::path(test_dir));
+ std::filesystem::create_directories(std::filesystem::path(test_dir));
+
+ auto dw = std::make_unique<PageDict4FileSeqWrite>();
+ DummyFileHeaderContext file_header_context;
+ PostingListParams params;
+ search::TuneFileSeqWrite tune_file_write;
+ params.set("numWordIds", num_words);
+ params.set("minChunkDocs", 256_Ki);
+ dw->setParams(params);
+ EXPECT_TRUE(dw->open(dict, tune_file_write, file_header_context));
+ for (int i = 0; i < num_words; ++i) {
+ auto word = make_word(i);
+ dw->writeWord(word, counts);
+ }
+ EXPECT_TRUE(dw->close());
+ dw.reset();
+
+ auto drr = std::make_unique<PageDict4RandRead>();
+ search::TuneFileRandRead tune_file_rand_read;
+ EXPECT_TRUE(drr->open(dict, tune_file_rand_read));
+ PostingListOffsetAndCounts offset_and_counts;
+ uint64_t exp_offset = 0;
+ uint64_t exp_acc_num_docs = 0;
+ for (int i = 0; i < num_words; ++i) {
+ auto word = make_word(i);
+ uint64_t check_word_num = 0;
+ EXPECT_TRUE(drr->lookup(word, check_word_num, offset_and_counts));
+ EXPECT_EQ(i + 1, (int) check_word_num);
+ EXPECT_EQ(exp_offset, offset_and_counts._offset);
+ EXPECT_EQ(exp_acc_num_docs, offset_and_counts._accNumDocs);
+ EXPECT_EQ(counts, offset_and_counts._counts);
+ exp_offset += offset_and_counts._counts._bitLength;
+ exp_acc_num_docs += offset_and_counts._counts._numDocs;
+ }
+ EXPECT_TRUE(drr->close());
+ drr.reset();
+
+ auto dr = std::make_unique<PageDict4FileSeqRead>();
+ search::TuneFileSeqRead tune_file_read;
+ EXPECT_TRUE(dr->open(dict, tune_file_read));
+ vespalib::string check_word;
+ PostingListCounts check_counts;
+ for (int i = 0; i < num_words; ++i) {
+ uint64_t check_word_num = 0;
+ check_word.clear();
+ dr->readWord(check_word, check_word_num, check_counts);
+ EXPECT_EQ(i + 1, (int) check_word_num);
+ EXPECT_EQ(make_word(i), check_word);
+ EXPECT_EQ(counts, check_counts);
+ }
+ EXPECT_TRUE(dr->close());
+ dr.reset();
+
+ std::filesystem::remove_all(std::filesystem::path(test_dir));
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/CMakeLists.txt b/searchlib/src/tests/queryeval/iterator_benchmark/CMakeLists.txt
index 872fb4ca6ca..dadd06ee7cd 100644
--- a/searchlib/src/tests/queryeval/iterator_benchmark/CMakeLists.txt
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/CMakeLists.txt
@@ -1,6 +1,7 @@
# Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
vespa_add_executable(searchlib_iterator_benchmark_test_app TEST
SOURCES
+ intermediate_blueprint_factory.cpp
attribute_ctx_builder.cpp
benchmark_blueprint_factory.cpp
common.cpp
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.cpp
index 0496a0e6dc8..504ad057d3d 100644
--- a/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.cpp
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.cpp
@@ -162,6 +162,9 @@ public:
double op_hit_ratio, uint32_t children, bool disjunct_children);
std::unique_ptr<Blueprint> make_blueprint() override;
+ vespalib::string get_name(Blueprint& blueprint) const override {
+ return get_class_name(blueprint);
+ }
};
MyFactory::MyFactory(const FieldConfig& field_cfg, QueryOperator query_op,
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.h b/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.h
index 423f517ffb0..2a90dbbbef8 100644
--- a/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.h
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/benchmark_blueprint_factory.h
@@ -16,6 +16,7 @@ class BenchmarkBlueprintFactory {
public:
virtual ~BenchmarkBlueprintFactory() = default;
virtual std::unique_ptr<Blueprint> make_blueprint() = 0;
+ virtual vespalib::string get_name(Blueprint& blueprint) const = 0;
};
std::unique_ptr<BenchmarkBlueprintFactory>
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp
index c67a5ee1074..1db9cd58d46 100644
--- a/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/common.cpp
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "common.h"
+#include <vespa/searchlib/queryeval/blueprint.h>
#include <sstream>
using search::attribute::CollectionType;
@@ -19,7 +20,11 @@ to_string(const Config& attr_config)
oss << col_type.asString() << "<" << basic_type.asString() << ">";
}
if (attr_config.fastSearch()) {
- oss << "(fs)";
+ oss << "(fs";
+ if (attr_config.getIsFilter()) {
+ oss << ",rf";
+ }
+ oss << ")";
}
return oss.str();
}
@@ -42,6 +47,38 @@ to_string(QueryOperator query_op)
namespace {
+std::string
+delete_substr_from(const std::string& source, const std::string& substr)
+{
+ std::string res = source;
+ auto i = res.find(substr);
+ while (i != std::string::npos) {
+ res.erase(i, substr.length());
+ i = res.find(substr, i);
+ }
+ return res;
+}
+
+}
+
+vespalib::string
+get_class_name(const auto& obj)
+{
+ auto res = obj.getClassName();
+ res = delete_substr_from(res, "search::attribute::");
+ res = delete_substr_from(res, "search::queryeval::");
+ res = delete_substr_from(res, "vespalib::btree::");
+ res = delete_substr_from(res, "search::");
+ res = delete_substr_from(res, "vespalib::");
+ res = delete_substr_from(res, "anonymous namespace");
+ return res;
+}
+
+template vespalib::string get_class_name<Blueprint>(const Blueprint& obj);
+template vespalib::string get_class_name<SearchIterator>(const SearchIterator& obj);
+
+namespace {
+
// TODO: Make seed configurable.
constexpr uint32_t default_seed = 1234;
std::mt19937 gen(default_seed);
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/common.h b/searchlib/src/tests/queryeval/iterator_benchmark/common.h
index bf16e6f51d7..6341b16e96a 100644
--- a/searchlib/src/tests/queryeval/iterator_benchmark/common.h
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/common.h
@@ -79,6 +79,8 @@ public:
auto end() const { return _specs.end(); }
};
+vespalib::string get_class_name(const auto& obj);
+
std::mt19937& get_gen();
BitVector::UP random_docids(uint32_t docid_limit, uint32_t count);
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.cpp
new file mode 100644
index 00000000000..ad8c0b1008e
--- /dev/null
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.cpp
@@ -0,0 +1,78 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "intermediate_blueprint_factory.h"
+#include <vespa/searchlib/queryeval/intermediate_blueprints.h>
+#include <sstream>
+
+namespace search::queryeval::test {
+
+template <typename BlueprintType>
+char
+IntermediateBlueprintFactory<BlueprintType>::child_name(void* blueprint) const
+{
+ auto itr = _child_names.find(blueprint);
+ if (itr != _child_names.end()) {
+ return itr->second;
+ }
+ return '?';
+}
+
+template <typename BlueprintType>
+IntermediateBlueprintFactory<BlueprintType>::IntermediateBlueprintFactory(vespalib::stringref name)
+ : _name(name),
+ _children(),
+ _child_names()
+{
+}
+
+template <typename BlueprintType>
+IntermediateBlueprintFactory<BlueprintType>::~IntermediateBlueprintFactory() = default;
+
+template <typename BlueprintType>
+std::unique_ptr<Blueprint>
+IntermediateBlueprintFactory<BlueprintType>::make_blueprint()
+{
+ auto res = std::make_unique<BlueprintType>();
+ _child_names.clear();
+ char name = 'A';
+ for (const auto& factory : _children) {
+ auto child = factory->make_blueprint();
+ _child_names[child.get()] = name++;
+ res->addChild(std::move(child));
+ }
+ return res;
+}
+
+template <typename BlueprintType>
+vespalib::string
+IntermediateBlueprintFactory<BlueprintType>::get_name(Blueprint& blueprint) const
+{
+ auto* intermediate = blueprint.asIntermediate();
+ if (intermediate != nullptr) {
+ std::ostringstream oss;
+ bool first = true;
+ oss << _name << "[";
+ for (size_t i = 0; i < intermediate->childCnt(); ++i) {
+ auto* child = &intermediate->getChild(i);
+ oss << (first ? "" : ",") << child_name(child) << ".";
+ if (child->strict()) {
+ oss << "s(" << std::setw(6) << std::setprecision(3) << child->strict_cost() << ")";
+ } else {
+ oss << "n(" << std::setw(6) << std::setprecision(3) << child->cost() << ")";
+ }
+ first = false;
+ }
+ oss << "]";
+ return oss.str();
+ }
+ return get_class_name(blueprint);
+}
+
+template class IntermediateBlueprintFactory<AndBlueprint>;
+
+AndBlueprintFactory::AndBlueprintFactory()
+ : IntermediateBlueprintFactory<AndBlueprint>("AND")
+{}
+
+}
+
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.h b/searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.h
new file mode 100644
index 00000000000..6f7fe4f9ee7
--- /dev/null
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/intermediate_blueprint_factory.h
@@ -0,0 +1,39 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "benchmark_blueprint_factory.h"
+#include <vespa/searchlib/queryeval/intermediate_blueprints.h>
+#include <unordered_map>
+
+namespace search::queryeval::test {
+
+/**
+ * Factory that creates an IntermediateBlueprint (of the given type) with children created by the given factories.
+ */
+template <typename BlueprintType>
+class IntermediateBlueprintFactory : public BenchmarkBlueprintFactory {
+private:
+ vespalib::string _name;
+ std::vector<std::shared_ptr<BenchmarkBlueprintFactory>> _children;
+ std::unordered_map<void*, char> _child_names;
+
+ char child_name(void* blueprint) const;
+
+public:
+ IntermediateBlueprintFactory(vespalib::stringref name);
+ ~IntermediateBlueprintFactory();
+ void add_child(std::shared_ptr<BenchmarkBlueprintFactory> child) {
+ _children.push_back(std::move(child));
+ }
+ std::unique_ptr<Blueprint> make_blueprint() override;
+ vespalib::string get_name(Blueprint& blueprint) const override;
+};
+
+class AndBlueprintFactory : public IntermediateBlueprintFactory<AndBlueprint> {
+public:
+ AndBlueprintFactory();
+};
+
+}
+
diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp
index c6dae52fd69..f7a358efb26 100644
--- a/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp
+++ b/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp
@@ -1,5 +1,6 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include "intermediate_blueprint_factory.h"
#include "benchmark_blueprint_factory.h"
#include "common.h"
#include <vespa/searchlib/fef/matchdata.h>
@@ -25,6 +26,25 @@ using vespalib::make_string_short::fmt;
const vespalib::string field_name = "myfield";
double budget_sec = 1.0;
+enum class PlanningAlgo {
+ Order,
+ Estimate,
+ Cost,
+ CostForceStrict
+};
+
+vespalib::string
+to_string(PlanningAlgo algo)
+{
+ switch (algo) {
+ case PlanningAlgo::Order: return "ordr";
+ case PlanningAlgo::Estimate: return "esti";
+ case PlanningAlgo::Cost: return "cost";
+ case PlanningAlgo::CostForceStrict: return "forc";
+ }
+ return "unknown";
+}
+
struct BenchmarkResult {
double time_ms;
uint32_t seeks;
@@ -127,31 +147,6 @@ public:
}
};
-std::string
-delete_substr_from(const std::string& source, const std::string& substr)
-{
- std::string res = source;
- auto i = res.find(substr);
- while (i != std::string::npos) {
- res.erase(i, substr.length());
- i = res.find(substr, i);
- }
- return res;
-}
-
-vespalib::string
-get_class_name(const auto& obj)
-{
- auto res = obj.getClassName();
- res = delete_substr_from(res, "search::attribute::");
- res = delete_substr_from(res, "search::queryeval::");
- res = delete_substr_from(res, "vespalib::btree::");
- res = delete_substr_from(res, "search::");
- res = delete_substr_from(res, "vespalib::");
- res = delete_substr_from(res, "anonymous namespace");
- return res;
-}
-
struct MatchLoopContext {
Blueprint::UP blueprint;
MatchData::UP match_data;
@@ -174,12 +169,37 @@ struct MatchLoopContext {
MatchLoopContext::~MatchLoopContext() = default;
+Blueprint::Options
+to_sort_options(PlanningAlgo algo)
+{
+ Blueprint::Options opts;
+ if (algo == PlanningAlgo::Order) {
+ opts.keep_order(true);
+ } else if (algo == PlanningAlgo::Cost) {
+ opts.sort_by_cost(true);
+ } else if (algo == PlanningAlgo::CostForceStrict) {
+ opts.sort_by_cost(true).allow_force_strict(true);
+ }
+ return opts;
+}
+
+void
+sort_blueprint(Blueprint& blueprint, InFlow in_flow, uint32_t docid_limit, Blueprint::Options opts)
+{
+ auto opts_guard = blueprint.bind_opts(opts);
+ blueprint.setDocIdLimit(docid_limit);
+ blueprint.each_node_post_order([docid_limit](Blueprint &bp){
+ bp.update_flow_stats(docid_limit);
+ });
+ blueprint.sort(in_flow);
+}
+
MatchLoopContext
-make_match_loop_context(BenchmarkBlueprintFactory& factory, InFlow in_flow, uint32_t docid_limit)
+make_match_loop_context(BenchmarkBlueprintFactory& factory, InFlow in_flow, uint32_t docid_limit, PlanningAlgo algo)
{
auto blueprint = factory.make_blueprint();
assert(blueprint);
- blueprint->basic_plan(in_flow, docid_limit);
+ sort_blueprint(*blueprint, in_flow, docid_limit, to_sort_options(algo));
blueprint->fetchPostings(ExecuteInfo::FULL);
// Note: All blueprints get the same TermFieldMatchData instance.
// This is OK as long as we don't do unpacking and only use 1 thread.
@@ -191,13 +211,13 @@ make_match_loop_context(BenchmarkBlueprintFactory& factory, InFlow in_flow, uint
template <bool do_unpack>
BenchmarkResult
-strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit)
+strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, PlanningAlgo algo)
{
BenchmarkTimer timer(budget_sec);
uint32_t hits = 0;
MatchLoopContext ctx;
while (timer.has_budget()) {
- ctx = make_match_loop_context(factory, true, docid_limit);
+ ctx = make_match_loop_context(factory, true, docid_limit, algo);
auto* itr = ctx.iterator.get();
timer.before();
hits = 0;
@@ -216,12 +236,12 @@ strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit)
timer.after();
}
FlowStats flow(ctx.blueprint->estimate(), ctx.blueprint->cost(), ctx.blueprint->strict_cost());
- return {timer.min_time() * 1000.0, hits + 1, hits, flow, flow.strict_cost, get_class_name(*ctx.iterator), get_class_name(*ctx.blueprint)};
+ return {timer.min_time() * 1000.0, hits + 1, hits, flow, flow.strict_cost, get_class_name(*ctx.iterator), factory.get_name(*ctx.blueprint)};
}
template <bool do_unpack>
BenchmarkResult
-non_strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, double filter_hit_ratio, bool force_strict)
+non_strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, double filter_hit_ratio, bool force_strict, PlanningAlgo algo)
{
BenchmarkTimer timer(budget_sec);
uint32_t seeks = 0;
@@ -231,7 +251,7 @@ non_strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, doub
uint32_t docid_skip = 1.0 / filter_hit_ratio;
MatchLoopContext ctx;
while (timer.has_budget()) {
- ctx = make_match_loop_context(factory, InFlow(force_strict, filter_hit_ratio), docid_limit);
+ ctx = make_match_loop_context(factory, InFlow(force_strict, filter_hit_ratio), docid_limit, algo);
auto* itr = ctx.iterator.get();
timer.before();
seeks = 0;
@@ -250,27 +270,31 @@ non_strict_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, doub
}
FlowStats flow(ctx.blueprint->estimate(), ctx.blueprint->cost(), ctx.blueprint->strict_cost());
double actual_cost = flow.cost * filter_hit_ratio;
- return {timer.min_time() * 1000.0, seeks, hits, flow, actual_cost, get_class_name(*ctx.iterator), get_class_name(*ctx.blueprint)};
+ return {timer.min_time() * 1000.0, seeks, hits, flow, actual_cost, get_class_name(*ctx.iterator), factory.get_name(*ctx.blueprint)};
}
BenchmarkResult
-benchmark_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, bool strict_context, bool force_strict, bool unpack_iterator, double filter_hit_ratio)
+benchmark_search(BenchmarkBlueprintFactory& factory, uint32_t docid_limit, bool strict_context, bool force_strict, bool unpack_iterator, double filter_hit_ratio, PlanningAlgo algo)
{
if (strict_context) {
if (unpack_iterator) {
- return strict_search<true>(factory, docid_limit);
+ return strict_search<true>(factory, docid_limit, algo);
} else {
- return strict_search<false>(factory, docid_limit);
+ return strict_search<false>(factory, docid_limit, algo);
}
} else {
if (unpack_iterator) {
- return non_strict_search<true>(factory, docid_limit, filter_hit_ratio, force_strict);
+ return non_strict_search<true>(factory, docid_limit, filter_hit_ratio, force_strict, algo);
} else {
- return non_strict_search<false>(factory, docid_limit, filter_hit_ratio, force_strict);
+ return non_strict_search<false>(factory, docid_limit, filter_hit_ratio, force_strict, algo);
}
}
}
+
+
+
+
//-----------------------------------------------------------------------------
double est_forced_strict_cost(double estimate, double strict_cost, double rate) {
@@ -343,12 +367,12 @@ void analyze_crossover(BenchmarkBlueprintFactory &fixed, std::function<std::uniq
auto a = first.make_blueprint();
a->basic_plan(true, docid_limit);
double est_a = a->estimate();
- double a_ms = benchmark_search(first, docid_limit, true, false, false, 1.0).time_ms;
- double b_ms = benchmark_search(last, docid_limit, false, false, false, est_a).time_ms;
+ double a_ms = benchmark_search(first, docid_limit, true, false, false, 1.0, PlanningAlgo::Cost).time_ms;
+ double b_ms = benchmark_search(last, docid_limit, false, false, false, est_a, PlanningAlgo::Cost).time_ms;
if (!allow_force_strict) {
return Sample(a_ms + b_ms);
}
- double c_ms = benchmark_search(last, docid_limit, false, true, false, est_a).time_ms;
+ double c_ms = benchmark_search(last, docid_limit, false, true, false, est_a, PlanningAlgo::Cost).time_ms;
if (c_ms < b_ms) {
return Sample(a_ms + c_ms, a_ms + b_ms, true);
}
@@ -615,7 +639,7 @@ run_benchmark_case(const BenchmarkCaseSetup& setup)
for (double filter_hit_ratio : setup.filter_hit_ratios) {
if (filter_hit_ratio * setup.filter_crossover_factor <= op_hit_ratio) {
auto res = benchmark_search(*factory, setup.num_docs + 1,
- setup.bcase.strict_context, setup.bcase.force_strict, setup.bcase.unpack_iterator, filter_hit_ratio);
+ setup.bcase.strict_context, setup.bcase.force_strict, setup.bcase.unpack_iterator, filter_hit_ratio, PlanningAlgo::Cost);
print_result(res, children, op_hit_ratio, filter_hit_ratio, setup.num_docs);
result.add(res);
}
@@ -650,11 +674,129 @@ run_benchmarks(const BenchmarkSetup& setup)
print_summary(summary);
}
+//---------------------------------------------------------------------------------------
+// Tools for benchmarking root intermediate blueprints with configurable children setups.
+//---------------------------------------------------------------------------------------
+
+void
+print_intermediate_blueprint_result_header(size_t children)
+{
+ // This matches the naming scheme in IntermediateBlueprintFactory.
+ char name = 'A';
+ for (size_t i = 0; i < children; ++i) {
+ std::cout << "| " << name++ << ".ratio ";
+ }
+ std::cout << "| flow.cost | flow.scost | flow.est | ratio | hits | seeks | ms_per_cost | time_ms | algo | blueprint |" << std::endl;
+}
+
+void
+print_intermediate_blueprint_result(const BenchmarkResult& res, const std::vector<double>& children_ratios, PlanningAlgo algo, uint32_t num_docs)
+{
+ std::cout << std::fixed << std::setprecision(5);
+ for (auto ratio : children_ratios) {
+ std::cout << "| " << std::setw(7) << ratio << " ";
+ }
+ std::cout << std::setprecision(5)
+ << "| " << std::setw(10) << res.flow.cost
+ << " | " << std::setw(10) << res.flow.strict_cost
+ << " | " << std::setw(8) << res.flow.estimate
+ << " | " << std::setw(7) << ((double) res.hits / (double) num_docs)
+ << std::setprecision(4)
+ << " | " << std::setw(8) << res.hits
+ << " | " << std::setw(8) << res.seeks
+ << std::setprecision(3)
+ << " | " << std::setw(11) << res.ms_per_actual_cost()
+ << " | " << std::setw(8) << res.time_ms
+ << " | " << to_string(algo)
+ << " | " << res.blueprint_name << " |" << std::endl;
+}
+
+struct BlueprintFactorySetup {
+ FieldConfig field_cfg;
+ QueryOperator query_op;
+ std::vector<double> op_hit_ratios;
+ uint32_t children;
+ bool disjunct_children;
+ uint32_t default_values_per_document;
+
+ BlueprintFactorySetup(const FieldConfig& field_cfg_in, QueryOperator query_op_in, const std::vector<double>& op_hit_ratios_in)
+ : BlueprintFactorySetup(field_cfg_in, query_op_in, op_hit_ratios_in, 1, false)
+ {}
+ BlueprintFactorySetup(const FieldConfig& field_cfg_in, QueryOperator query_op_in, const std::vector<double>& op_hit_ratios_in,
+ uint32_t children_in, bool disjunct_children_in)
+ : field_cfg(field_cfg_in),
+ query_op(query_op_in),
+ op_hit_ratios(op_hit_ratios_in),
+ children(children_in),
+ disjunct_children(disjunct_children_in),
+ default_values_per_document(0)
+ {}
+ ~BlueprintFactorySetup();
+ std::unique_ptr<BenchmarkBlueprintFactory> make_factory(size_t num_docs, double op_hit_ratio) const {
+ return make_blueprint_factory(field_cfg, query_op, num_docs, default_values_per_document, op_hit_ratio, children, disjunct_children);
+ }
+ std::shared_ptr<BenchmarkBlueprintFactory> make_factory_shared(size_t num_docs, double op_hit_ratio) const {
+ return std::shared_ptr<BenchmarkBlueprintFactory>(make_factory(num_docs, op_hit_ratio));
+ }
+ vespalib::string to_string() const {
+ return "field=" + field_cfg.to_string() + ", query=" + test::to_string(query_op) + ", children=" + std::to_string(children);
+ }
+};
+
+BlueprintFactorySetup::~BlueprintFactorySetup() = default;
+
+template <typename IntermediateBlueprintFactoryType>
+void
+run_intermediate_blueprint_benchmark(const BlueprintFactorySetup& a, const BlueprintFactorySetup& b, size_t num_docs)
+{
+ print_intermediate_blueprint_result_header(2);
+ for (double b_hit_ratio: b.op_hit_ratios) {
+ auto b_factory = b.make_factory_shared(num_docs, b_hit_ratio);
+ for (double a_hit_ratio : a.op_hit_ratios) {
+ IntermediateBlueprintFactoryType factory;
+ factory.add_child(a.make_factory(num_docs, a_hit_ratio));
+ factory.add_child(b_factory);
+ for (auto algo: {PlanningAlgo::Order, PlanningAlgo::Estimate, PlanningAlgo::Cost, PlanningAlgo::CostForceStrict}) {
+ auto res = benchmark_search(factory, num_docs + 1, true, false, false, 1.0, algo);
+ print_intermediate_blueprint_result(res, {a_hit_ratio, b_hit_ratio}, algo, num_docs);
+ }
+ std::cout << std::endl;
+ }
+ }
+}
+
+void
+run_and_benchmark(const BlueprintFactorySetup& a, const BlueprintFactorySetup& b, size_t num_docs)
+{
+ std::cout << "AND[A={" << a.to_string() << "},B={" << b.to_string() << "}]" << std::endl;
+ run_intermediate_blueprint_benchmark<AndBlueprintFactory>(a, b, num_docs);
+}
+
+//-------------------------------------------------------------------------------------
+
+std::vector<double>
+gen_ratios(double middle, double range_multiplier, size_t num_samples)
+{
+ double lower = middle / range_multiplier;
+ double upper = middle * range_multiplier;
+ // Solve the following equation:
+ // lower * (factor ^ (num_samples - 1)) = upper;
+ double factor = std::pow(upper / lower, 1.0 / (num_samples - 1));
+ std::vector<double> res;
+ double ratio = lower;
+ for (size_t i = 0; i < num_samples; ++i) {
+ res.push_back(ratio);
+ ratio *= factor;
+ }
+ return res;
+}
+
FieldConfig
-make_attr_config(BasicType basic_type, CollectionType col_type, bool fast_search)
+make_attr_config(BasicType basic_type, CollectionType col_type, bool fast_search, bool rank_filter = false)
{
Config cfg(basic_type, col_type);
cfg.setFastSearch(fast_search);
+ cfg.setIsFilter(rank_filter);
return FieldConfig(cfg);
}
@@ -671,6 +813,7 @@ const std::vector<double> base_hit_ratios = {0.0001, 0.001, 0.01, 0.1, 0.5, 1.0}
const std::vector<double> filter_hit_ratios = {0.00001, 0.00005, 0.0001, 0.0005, 0.001, 0.005, 0.01, 0.05, 0.1, 0.2, 0.5, 1.0};
const auto int32 = make_attr_config(BasicType::INT32, CollectionType::SINGLE, false);
const auto int32_fs = make_attr_config(BasicType::INT32, CollectionType::SINGLE, true);
+const auto int32_fs_rf = make_attr_config(BasicType::INT32, CollectionType::SINGLE, true, true);
const auto int32_array = make_attr_config(BasicType::INT32, CollectionType::ARRAY, false);
const auto int32_array_fs = make_attr_config(BasicType::INT32, CollectionType::ARRAY, true);
const auto int32_wset = make_attr_config(BasicType::INT32, CollectionType::WSET, false);
@@ -790,6 +933,48 @@ TEST(IteratorBenchmark, or_vs_filter_crossover_with_allow_force_strict)
analyze_crossover(*fixed_or, variable_term, num_docs + 1, true, 0.0001);
}
+TEST(IteratorBenchmark, analyze_and_with_filter_vs_in)
+{
+ for (uint32_t children: {10, 100, 1000}) {
+ run_and_benchmark({int32_fs, QueryOperator::Term, gen_ratios(0.1, 8.0, 15)},
+ {int32_fs, QueryOperator::In, {0.1}, children, false},
+ num_docs);
+ }
+}
+
+TEST(IteratorBenchmark, analyze_and_with_bitvector_vs_in)
+{
+ for (uint32_t children: {10, 100, 1000, 10000}) {
+ run_and_benchmark({int32_fs, QueryOperator::In, {0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.40, 0.45, 0.50, 0.55, 0.60}, children, true},
+ {int32_fs_rf, QueryOperator::Term, {1.0}, 1, true}, // this setup returns a bitvector matching all documents.
+ num_docs);
+ }
+}
+
+TEST(IteratorBenchmark, analyze_and_with_filter_vs_in_array)
+{
+ for (uint32_t children: {10, 100, 1000}) {
+ run_and_benchmark({int32_fs, QueryOperator::Term, gen_ratios(0.1, 8.0, 15)},
+ {int32_array_fs, QueryOperator::In, {0.1}, children, false},
+ num_docs);
+ }
+}
+
+TEST(IteratorBenchmark, analyze_and_with_filter_vs_or)
+{
+ for (uint32_t children: {10, 100, 1000}) {
+ run_and_benchmark({int32_fs, QueryOperator::Term, gen_ratios(0.1, 8.0, 15)},
+ {int32_fs, QueryOperator::Or, {0.1}, children, false},
+ num_docs);
+ }
+}
+
+TEST(IteratorBenchmark, analyze_btree_vs_bitvector_iterators_strict)
+{
+ BenchmarkSetup setup(num_docs, {int32_fs, int32_fs_rf}, {QueryOperator::Term}, {true}, {0.1, 0.2, 0.4, 0.5, 0.6, 0.8, 1.0}, {1});
+ run_benchmarks(setup);
+}
+
int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
int res = RUN_ALL_TESTS();
diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h
index 413d0dd0bf4..51321a56885 100644
--- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h
+++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.h
@@ -76,28 +76,7 @@ public:
resolve_strict(in_flow);
}
- queryeval::FlowStats calculate_flow_stats(uint32_t docid_limit) const override {
- using OrFlow = search::queryeval::OrFlow;
- struct MyAdapter {
- uint32_t docid_limit;
- MyAdapter(uint32_t docid_limit_in) noexcept : docid_limit(docid_limit_in) {}
- double estimate(const IDirectPostingStore::LookupResult &term) const noexcept {
- return abs_to_rel_est(term.posting_size, docid_limit);
- }
- double cost(const IDirectPostingStore::LookupResult &) const noexcept {
- return search::queryeval::flow::btree_cost();
- }
- double strict_cost(const IDirectPostingStore::LookupResult &term) const noexcept {
- double rel_est = abs_to_rel_est(term.posting_size, docid_limit);
- return search::queryeval::flow::btree_strict_cost(rel_est);
- }
- };
- double est = OrFlow::estimate_of(MyAdapter(docid_limit), _terms);
- // Iterator benchmarking has shown that non-strict cost should be 1.0.
- // Program: searchlib/src/tests/queryeval/iterator_benchmark
- // TODO: Add more details, and consider moving constant to flow_tuning.h
- return {est, 1.0, OrFlow::cost_of(MyAdapter(docid_limit), _terms, true) + queryeval::flow::heap_cost(est, _terms.size())};
- }
+ queryeval::FlowStats calculate_flow_stats(uint32_t docid_limit) const override;
std::unique_ptr<queryeval::SearchIterator> createLeafSearch(const fef::TermFieldMatchDataArray &tfmda) const override;
diff --git a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp
index 817eab3e070..160bb199fb8 100644
--- a/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/direct_multi_term_blueprint.hpp
@@ -181,4 +181,34 @@ DirectMultiTermBlueprint<PostingStoreType, SearchType>::createFilterSearch(Filte
return wrapper;
}
+template <typename PostingStoreType, typename SearchType>
+queryeval::FlowStats
+DirectMultiTermBlueprint<PostingStoreType, SearchType>::calculate_flow_stats(uint32_t docid_limit) const
+{
+ using OrFlow = search::queryeval::OrFlow;
+ struct MyAdapter {
+ uint32_t docid_limit;
+ MyAdapter(uint32_t docid_limit_in) noexcept : docid_limit(docid_limit_in) {}
+ double estimate(const IDirectPostingStore::LookupResult &term) const noexcept {
+ return abs_to_rel_est(term.posting_size, docid_limit);
+ }
+ double cost(const IDirectPostingStore::LookupResult &) const noexcept {
+ return search::queryeval::flow::btree_cost();
+ }
+ double strict_cost(const IDirectPostingStore::LookupResult &term) const noexcept {
+ double rel_est = abs_to_rel_est(term.posting_size, docid_limit);
+ return search::queryeval::flow::btree_strict_cost(rel_est);
+ }
+ };
+ double est = OrFlow::estimate_of(MyAdapter(docid_limit), _terms);
+ // Iterator benchmarking has shown that non-strict cost is different for attributes
+ // that support using a reverse hash filter (see use_hash_filter()).
+ // Program used: searchlib/src/tests/queryeval/iterator_benchmark
+ // Tests: analyze_and_with_filter_vs_in(), analyze_and_with_filter_vs_in_array()
+ double non_strict_cost = (SearchType::supports_hash_filter && !_iattr.hasMultiValue())
+ ? queryeval::flow::reverse_hash_lookup()
+ : OrFlow::cost_of(MyAdapter(docid_limit), _terms, false);
+ return {est, non_strict_cost, OrFlow::cost_of(MyAdapter(docid_limit), _terms, true) + queryeval::flow::heap_cost(est, _terms.size())};
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/bitcompression/compression.cpp b/searchlib/src/vespa/searchlib/bitcompression/compression.cpp
index 0f089c60e4b..f3fc31ac8b1 100644
--- a/searchlib/src/vespa/searchlib/bitcompression/compression.cpp
+++ b/searchlib/src/vespa/searchlib/bitcompression/compression.cpp
@@ -359,6 +359,24 @@ getParams(PostingListParams &params) const
params.clear();
}
+template <bool bigEndian>
+void
+FeatureEncodeContext<bigEndian>::pad_for_memory_map_and_flush()
+{
+ // Write some pad bits to avoid decompression readahead going past
+ // memory mapped file during search and into SIGSEGV territory.
+
+ // First pad to 64 bits alignment.
+ this->smallAlign(64);
+ writeComprBufferIfNeeded();
+
+ // Then write 128 more bits. This allows for 64-bit decoding
+ // with a readbits that always leaves a nonzero preRead
+ padBits(128);
+ this->alignDirectIO();
+ this->flush();
+ writeComprBuffer(); // Also flushes slack
+}
template <bool bigEndian>
void
diff --git a/searchlib/src/vespa/searchlib/bitcompression/compression.h b/searchlib/src/vespa/searchlib/bitcompression/compression.h
index 232572a6314..4124f1f659f 100644
--- a/searchlib/src/vespa/searchlib/bitcompression/compression.h
+++ b/searchlib/src/vespa/searchlib/bitcompression/compression.h
@@ -1167,7 +1167,7 @@ public:
* Get remaining units in buffer (e.g. _realValE - _valI)
*/
- int32_t remainingUnits() const override { return _realValE - _valI; }
+ int64_t remainingUnits() const override { return _realValE - _valI; }
/**
* Get unit ptr (e.g. _valI) from decode context.
@@ -1181,7 +1181,7 @@ public:
}
uint64_t getBitPos(int bitOffset, uint64_t bufferEndFilePos) const override {
- int intOffset = _realValE - _valI;
+ int64_t intOffset = _realValE - _valI;
if (bitOffset == -1) {
bitOffset = -64 - _preRead;
}
@@ -1200,7 +1200,7 @@ public:
uint64_t getBitPosV() const override { return getReadOffset(); }
- void adjUnitPtr(int newRemainingUnits) override {
+ void adjUnitPtr(int64_t newRemainingUnits) override {
_valI = _realValE - newRemainingUnits;
}
@@ -1219,7 +1219,7 @@ public:
* @param unitCount Number of bytes in buffer
* @param moreData Set if there is more data available
*/
- void setEnd(unsigned int unitCount, bool moreData) {
+ void setEnd(uint64_t unitCount, bool moreData) {
_valE = _realValE = _valI + unitCount;
if (moreData) {
_valE -= END_BUFFER_SAFETY;
@@ -1595,6 +1595,8 @@ public:
writeComprBufferIfNeeded();
}
+ void pad_for_memory_map_and_flush();
+
virtual void readHeader(const vespalib::GenericHeader &header, const vespalib::string &prefix);
virtual void writeHeader(vespalib::GenericHeader &header, const vespalib::string &prefix) const;
virtual const vespalib::string &getIdentifier() const;
diff --git a/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp b/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp
index 387d95bce66..bceeb1e7bc1 100644
--- a/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp
+++ b/searchlib/src/vespa/searchlib/diskindex/pagedict4file.cpp
@@ -269,11 +269,9 @@ PageDict4FileSeqWrite::DictFileContext::DictFileContext(bool extended, vespalib:
}
bool
-PageDict4FileSeqWrite::DictFileContext::DictFileContext::close() {
- //uint64_t usedPBits = _ec.getWriteOffset();
- _ec.flush();
- _writeContext.writeComprBuffer(true);
-
+PageDict4FileSeqWrite::DictFileContext::DictFileContext::close()
+{
+ _ec.pad_for_memory_map_and_flush();
_writeContext.dropComprBuf();
bool success = _file.Sync();
success &= _file.Close();
diff --git a/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp b/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp
index c7480633e21..f2b7911ba55 100644
--- a/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp
+++ b/searchlib/src/vespa/searchlib/diskindex/zc4_posting_writer.cpp
@@ -247,19 +247,7 @@ template <bool bigEndian>
void
Zc4PostingWriter<bigEndian>::on_close()
{
- // Write some pad bits to avoid decompression readahead going past
- // memory mapped file during search and into SIGSEGV territory.
-
- // First pad to 64 bits alignment.
- _encode_context.smallAlign(64);
- _encode_context.writeComprBufferIfNeeded();
-
- // Then write 128 more bits. This allows for 64-bit decoding
- // with a readbits that always leaves a nonzero preRead
- _encode_context.padBits(128);
- _encode_context.alignDirectIO();
- _encode_context.flush();
- _encode_context.writeComprBuffer(); // Also flushes slack
+ _encode_context.pad_for_memory_map_and_flush();
}
template class Zc4PostingWriter<false>;
diff --git a/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h b/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h
index 51e544b2e30..356ecd4c992 100644
--- a/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h
+++ b/searchlib/src/vespa/searchlib/queryeval/flow_tuning.h
@@ -6,8 +6,19 @@
namespace search::queryeval::flow {
+/**
+ * This function is used when calculating the strict cost of
+ * intermediate and complex leaf blueprints that use a heap for their strict iterator implementation.
+ *
+ * Iterator benchmarking has shown the need to increase the strict cost
+ * of complex blueprints, to avoid that they are forced strict too early.
+ * The 5.0 multiplier reflects this.
+ *
+ * Program used: searchlib/src/tests/queryeval/iterator_benchmark
+ * Tests used: analyze_and_with_filter_vs_*
+ */
inline double heap_cost(double my_est, size_t num_children) {
- return my_est * std::log2(std::max(size_t(1),num_children));
+ return 5.0 * my_est * std::log2(std::max(size_t(1), num_children));
}
/**
@@ -30,6 +41,11 @@ inline double lookup_cost(size_t num_indirections) {
return 1.0 + (num_indirections * 1.0);
}
+// Non-strict cost of reverse lookup into a hash table (containing terms from a multi-term operator).
+inline double reverse_hash_lookup() {
+ return 5.0;
+}
+
// Strict cost of lookup based matching in an attribute (not fast-search).
inline double lookup_strict_cost(size_t num_indirections) {
return lookup_cost(num_indirections);
@@ -45,6 +61,17 @@ inline double btree_strict_cost(double my_est) {
return my_est;
}
+// Non-strict cost of matching in a bitvector.
+inline double bitvector_cost() {
+ return 1.0;
+}
+
+// Strict cost of matching in a bitvector.
+// Test used: IteratorBenchmark::analyze_btree_vs_bitvector_iterators_strict
+inline double bitvector_strict_cost(double my_est) {
+ return 1.5 * my_est;
+}
+
// Non-strict cost of matching in a disk index posting list.
inline double disk_index_cost() {
return 1.5;
diff --git a/searchlib/src/vespa/searchlib/util/comprfile.cpp b/searchlib/src/vespa/searchlib/util/comprfile.cpp
index ff74dc7a0e0..db8fe14d658 100644
--- a/searchlib/src/vespa/searchlib/util/comprfile.cpp
+++ b/searchlib/src/vespa/searchlib/util/comprfile.cpp
@@ -23,14 +23,17 @@ ComprFileReadBase::ReadComprBuffer(uint64_t stopOffset,
bool isretryread = false;
retry:
- if (decodeContext.lastChunk())
+ if (decodeContext.lastChunk()) {
return; // Already reached end of file.
- int remainingUnits = decodeContext.remainingUnits();
+ }
+ int64_t remainingUnits = decodeContext.remainingUnits();
+ assert(remainingUnits >= 0);
// There's a good amount of data here already.
if (remainingUnits >
- static_cast<ssize_t>(ComprBuffer::minimumPadding())) //FIX! Tune
+ static_cast<ssize_t>(ComprBuffer::minimumPadding())) { //FIX! Tune
return;
+ }
// Assert that file read offset is aligned on unit boundary
assert((static_cast<size_t>(fileReadByteOffset) &
@@ -47,9 +50,9 @@ ComprFileReadBase::ReadComprBuffer(uint64_t stopOffset,
// Continuation reads starts at aligned boundary.
assert(remainingUnits == 0 || padBeforeUnits == 0);
- if (readAll)
+ if (readAll) {
stopOffset = fileSize << 3;
- else if (!isretryread) {
+ } else if (!isretryread) {
stopOffset += 8 * cbuf.getUnitBitSize(); // XXX: Magic integer
// Realign stop offset to direct IO alignment boundary
uint64_t fileDirectIOBitAlign =
@@ -93,20 +96,19 @@ ComprFileReadBase::ReadComprBuffer(uint64_t stopOffset,
fileReadByteOffset -= padBeforeUnits * cbuf.getUnitSize();
file.SetPosition(fileReadByteOffset);
}
- int readUnits0 = 0;
- if (readBits > 0)
- readUnits0 = static_cast<int>((readBits + cbuf.getUnitBitSize() - 1) /
- cbuf.getUnitBitSize());
+ size_t readUnits0 = 0;
+ if (readBits > 0) {
+ readUnits0 = (readBits + cbuf.getUnitBitSize() - 1) / cbuf.getUnitBitSize();
+ }
// Try to align end of read to an alignment boundary
- int readUnits = cbuf.getAligner().adjustElements(fileReadByteOffset /
- cbuf.getUnitSize(), readUnits0);
- if (readUnits < readUnits0)
+ size_t readUnits = cbuf.getAligner().adjustElements(fileReadByteOffset / cbuf.getUnitSize(), readUnits0);
+ if (readUnits < readUnits0) {
isMore = true;
+ }
if (readUnits > 0) {
- int64_t padBytes = fileReadByteOffset +
- static_cast<int64_t>(readUnits) * cbuf.getUnitSize() - fileSize;
+ int64_t padBytes = fileReadByteOffset + readUnits * cbuf.getUnitSize() - fileSize;
if (!isMore && padBytes > 0) {
// Pad reading of file written with smaller unit size with
// NUL bytes.
@@ -115,17 +117,18 @@ ComprFileReadBase::ReadComprBuffer(uint64_t stopOffset,
readUnits * cbuf.getUnitSize() - padBytes,
0,
padBytes);
- } else
+ } else {
file.ReadBuf(cbuf.getComprBuf(), readUnits * cbuf.getUnitSize());
+ }
}
// If at end of file then add units of zero bits as padding
- if (!isMore)
+ if (!isMore) {
memset(reinterpret_cast<char *>(cbuf.getComprBuf()) +
readUnits * cbuf.getUnitSize(),
0,
cbuf.getUnitSize() * ComprBuffer::minimumPadding());
+ }
- assert(remainingUnits + readUnits >= 0);
decodeContext.afterRead(reinterpret_cast<char *>(cbuf.getComprBuf()) +
(padBeforeUnits - remainingUnits) *
static_cast<int32_t>(cbuf.getUnitSize()),
@@ -343,7 +346,7 @@ ComprFileReadContext::setPosition(uint64_t newPosition)
}
void
-ComprFileReadContext::allocComprBuf(unsigned int comprBufSize, size_t preferredFileAlignment)
+ComprFileReadContext::allocComprBuf(size_t comprBufSize, size_t preferredFileAlignment)
{
ComprBuffer::allocComprBuf(comprBufSize, preferredFileAlignment, _file, true);
}
diff --git a/searchlib/src/vespa/searchlib/util/comprfile.h b/searchlib/src/vespa/searchlib/util/comprfile.h
index 8f8cffaffd6..6a0b72ce7e5 100644
--- a/searchlib/src/vespa/searchlib/util/comprfile.h
+++ b/searchlib/src/vespa/searchlib/util/comprfile.h
@@ -32,7 +32,7 @@ public:
* Get remaining units in buffer (e.g. _realValE - _valI)
*/
- virtual int32_t remainingUnits() const = 0;
+ virtual int64_t remainingUnits() const = 0;
/**
* Get unit ptr (e.g. _valI) from decode context.
@@ -51,7 +51,7 @@ public:
virtual uint64_t getBitPos(int bitOffset, uint64_t bufferEndFilePos) const = 0;
virtual uint64_t getBitPosV() const = 0;
virtual void skipBits(int bits) = 0;
- virtual void adjUnitPtr(int newRemainingUnits) = 0;
+ virtual void adjUnitPtr(int64_t newRemainingUnits) = 0;
virtual void emptyBuffer(uint64_t newBitPosition) = 0;
/**
@@ -105,7 +105,7 @@ public:
void readComprBuffer(uint64_t stopOffset, bool readAll);
void readComprBuffer();
void setPosition(uint64_t newPosition);
- void allocComprBuf(unsigned int comprBufSize, size_t preferredFileAlignment);
+ void allocComprBuf(size_t comprBufSize, size_t preferredFileAlignment);
void setDecodeContext(ComprFileDecodeContext *decodeContext) { _decodeContext = decodeContext; }
ComprFileDecodeContext *getDecodeContext() const { return _decodeContext; }
void setFile(FastOS_FileInterface *file) { _file = file; }
diff --git a/standalone-container/pom.xml b/standalone-container/pom.xml
index 844b912543c..92faa1ae670 100644
--- a/standalone-container/pom.xml
+++ b/standalone-container/pom.xml
@@ -113,6 +113,7 @@
model-evaluation-jar-with-dependencies.jar,
model-integration-jar-with-dependencies.jar,
container-onnxruntime.jar,
+ container-llama.jar,
<!-- END config-model dependencies -->
</discPreInstallBundle>
</configuration>