summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--clustercontroller-core/pom.xml9
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java3
-rw-r--r--clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java7
-rw-r--r--clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/Container.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java19
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java15
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java18
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java5
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java34
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java25
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java3
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java4
-rw-r--r--config-model/src/main/protobuf/onnx.proto517
-rw-r--r--config-model/src/main/resources/schema/content.rnc3
-rw-r--r--config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java36
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java6
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java119
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java10
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java2
-rw-r--r--container-core/abi-spec.json2
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java5
-rw-r--r--container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java37
-rw-r--r--container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def3
-rw-r--r--container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java14
-rw-r--r--container-dependencies-enforcer/pom.xml117
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java45
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java3
-rw-r--r--container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java34
-rw-r--r--container-search/abi-spec.json3
-rw-r--r--container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/Ranking.java10
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java3
-rwxr-xr-xcontainer-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java3
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java3
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java6
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java10
-rw-r--r--dist/vespa.spec4
-rw-r--r--metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java7
-rw-r--r--model-integration/src/main/protobuf/onnx.proto517
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java4
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java7
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java9
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java10
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java77
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java63
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java24
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java8
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java6
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java435
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java14
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java43
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java4
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java23
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java1
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java2
-rw-r--r--node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java1
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java7
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java3
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java307
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java8
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java3
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java6
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java2
-rw-r--r--node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java2
-rw-r--r--searchcore/src/vespa/searchcore/proton/matching/matcher.cpp8
-rw-r--r--searchlib/abi-spec.json4
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj32
-rw-r--r--searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp4
-rw-r--r--searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp72
-rw-r--r--searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp22
-rw-r--r--searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp2
-rw-r--r--searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h12
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp18
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/docstore/visitcache.cpp23
-rw-r--r--searchlib/src/vespa/searchlib/fef/indexproperties.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/fef/indexproperties.h1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp9
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp19
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h4
-rw-r--r--storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp1
-rw-r--r--vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java1
-rw-r--r--vespajlib/abi-spec.json36
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java93
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java89
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java66
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java54
-rw-r--r--vespalib/src/tests/datastore/array_store/array_store_test.cpp15
-rw-r--r--vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp74
-rw-r--r--vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp42
-rw-r--r--vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp60
-rw-r--r--vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp123
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store.h4
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store.hpp14
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store_config.cpp20
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store_config.h4
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h4
-rw-r--r--vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp9
-rw-r--r--vespalib/src/vespa/vespalib/datastore/buffer_type.cpp11
-rw-r--r--vespalib/src/vespa/vespalib/datastore/buffer_type.h16
-rw-r--r--vespalib/src/vespa/vespalib/datastore/buffer_type.hpp4
-rw-r--r--vespalib/src/vespa/vespalib/datastore/bufferstate.cpp42
-rw-r--r--vespalib/src/vespa/vespalib/datastore/bufferstate.h16
-rw-r--r--vespalib/src/vespa/vespalib/datastore/datastorebase.cpp9
-rw-r--r--vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h10
-rw-r--r--vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp17
-rw-r--r--vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h4
-rw-r--r--vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h2
-rw-r--r--vespalib/src/vespa/vespalib/stllike/cache.h2
-rw-r--r--vespalib/src/vespa/vespalib/stllike/cache.hpp2
-rw-r--r--vespalib/src/vespa/vespalib/test/nexus.h1
-rw-r--r--vespalib/src/vespa/vespalib/test/thread_meets.h46
-rw-r--r--vespamalloc/src/vespamalloc/malloc/threadlist.hpp13
-rw-r--r--vespamalloc/src/vespamalloc/malloc/threadpool.h4
-rw-r--r--vespamalloc/src/vespamalloc/malloc/threadpool.hpp9
135 files changed, 2984 insertions, 990 deletions
diff --git a/clustercontroller-core/pom.xml b/clustercontroller-core/pom.xml
index 647d8ca4e64..61176c178c6 100644
--- a/clustercontroller-core/pom.xml
+++ b/clustercontroller-core/pom.xml
@@ -15,14 +15,7 @@
<dependencies>
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>annotations</artifactId>
- <version>${project.version}</version>
- <scope>provided</scope>
- </dependency>
- <dependency>
- <!-- required for bundle-plugin to generate import-package statements for Java's standard library -->
- <groupId>com.yahoo.vespa</groupId>
- <artifactId>jdisc_core</artifactId>
+ <artifactId>container-dev</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java
index c823c94afd1..60671c96474 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.clustercontroller.core;
+import ai.vespa.metrics.StorageMetrics;
import com.yahoo.lang.MutableBoolean;
import com.yahoo.lang.SettableOptional;
import com.yahoo.vdslib.distribution.ConfiguredNode;
@@ -44,7 +45,7 @@ import static java.util.logging.Level.FINE;
public class NodeStateChangeChecker {
private static final Logger log = Logger.getLogger(NodeStateChangeChecker.class.getName());
- private static final String BUCKETS_METRIC_NAME = "vds.datastored.bucket_space.buckets_total";
+ private static final String BUCKETS_METRIC_NAME = StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_BUCKETS_TOTAL.baseName();
private static final Map<String, String> BUCKETS_METRIC_DIMENSIONS = Map.of("bucketSpace", "default");
private final int requiredRedundancy;
diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java
index 65c08e67850..09aac786b2f 100644
--- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java
+++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.clustercontroller.core.restapiv2.requests;
+import ai.vespa.metrics.StorageMetrics;
import com.yahoo.vespa.clustercontroller.core.NodeInfo;
import com.yahoo.vespa.clustercontroller.core.RemoteClusterControllerTask;
import com.yahoo.vespa.clustercontroller.core.hostinfo.Metrics;
@@ -46,17 +47,17 @@ public class NodeStateRequest extends Request<Response.NodeResponse> {
}
private static void fillInMetricValue(String name, Metrics.Value value, Response.NodeResponse result) {
- if (name.equals("vds.datastored.alldisks.docs")) {
+ if (name.equals(StorageMetrics.VDS_DATASTORED_ALLDISKS_DOCS.baseName())) {
if (value.getLast() == null) {
return;
}
result.addMetric("unique-document-count", value.getLast());
- } else if (name.equals("vds.datastored.alldisks.bytes")) {
+ } else if (name.equals(StorageMetrics.VDS_DATASTORED_ALLDISKS_BYTES.baseName())) {
if (value.getLast() == null) {
return;
}
result.addMetric("unique-document-total-size", value.getLast());
- } else if (name.equals("vds.datastored.alldisks.buckets")) {
+ } else if (name.equals(StorageMetrics.VDS_DATASTORED_ALLDISKS_BUCKETS.baseName())) {
if (value.getLast() == null) {
return;
}
diff --git a/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java b/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java
index a1aa5287d2f..83be05d970e 100644
--- a/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java
+++ b/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.reindexing;
+import ai.vespa.metrics.ClusterControllerMetrics;
import com.yahoo.documentapi.ProgressToken;
import com.yahoo.jdisc.Metric;
@@ -27,7 +28,7 @@ class ReindexingMetrics {
void dump(Reindexing reindexing) {
reindexing.status().forEach((type, status) -> {
- metric.set("reindexing.progress",
+ metric.set(ClusterControllerMetrics.REINDEXING_PROGRESS.baseName(),
status.progress().map(ProgressToken::percentFinished).map(percentage -> percentage * 1e-2)
.orElse(status.state() == SUCCESSFUL ? 1.0 : 0.0),
metric.createContext(Map.of("clusterid", cluster, "documenttype", type.getName())));
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java
index 82ad8e5d6e8..47024c1171c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java
@@ -35,7 +35,7 @@ public class RedundancyIncreaseValidator implements ChangeValidator {
}
private int redundancyOf(ContentCluster cluster) {
- return cluster.redundancy().finalRedundancy();
+ return cluster.getRedundancy().finalRedundancy();
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java
index 5228610537f..2be0f0b8422 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java
@@ -2,7 +2,6 @@
package com.yahoo.vespa.model.application.validation.first;
import com.yahoo.config.application.api.ValidationId;
-import com.yahoo.config.application.api.ValidationOverrides;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.vespa.model.VespaModel;
@@ -10,7 +9,6 @@ import com.yahoo.vespa.model.application.validation.Validator;
import com.yahoo.vespa.model.application.validation.change.ChangeValidator;
import com.yahoo.vespa.model.content.cluster.ContentCluster;
-import java.time.Instant;
import java.util.List;
import java.util.stream.Stream;
@@ -48,7 +46,7 @@ public class RedundancyValidator extends Validator implements ChangeValidator {
}
private boolean hasRedundancyOne(ContentCluster cluster) {
- return cluster != null && cluster.redundancy().finalRedundancy() == 1 && cluster.redundancy().groups() == 1;
+ return cluster != null && cluster.getRedundancy().finalRedundancy() == 1 && cluster.getRedundancy().groups() == 1;
}
private void invalidRedundancy(ContentCluster cluster, DeployState deployState) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java
index 64592e75c41..a0a4151daf5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java
@@ -298,6 +298,8 @@ public class DomSearchTuningBuilder extends VespaDomBuilder.DomConfigProducerBui
for (Element e : XML.getChildren(spec)) {
if (equals("concurrency", e)) {
sn.feeding.concurrency = asDouble(e);
+ } else if (equals("niceness", e)) {
+ sn.feeding.niceness = asDouble(e);
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java
index 2ca6d5d7155..f7d4fe28c6e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java
@@ -109,6 +109,7 @@ public abstract class Container extends AbstractService implements
addChild(new SimpleComponent("com.yahoo.container.jdisc.ConfiguredApplication$ApplicationContext"));
appendJvmOptions(jvmOmitStackTraceInFastThrowOption(deployState.featureFlags()));
+ addEnvironmentVariable("VESPA_MALLOC_MMAP_THRESHOLD","0x200000");
}
protected String jvmOmitStackTraceInFastThrowOption(ModelContext.FeatureFlags featureFlags) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java
index ec7acaf819f..34ea41384bc 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java
@@ -270,6 +270,25 @@ public class ContentSearchCluster extends TreeConfigProducer<AnyConfigProducer>
clusters.put(sc.getClusterName(), sc);
}
+ /**
+ * Returns whether the schemas in this cluster use streaming mode.
+ *
+ * @return True if this cluster only has schemas with streaming mode, False if it only has schemas
+ * with indexing, null if it has both or none.
+ */
+ public Boolean isStreaming() {
+ boolean hasStreaming = false;
+ boolean hasIndexed = false;
+ for (var cluster : clusters.values()) {
+ if (cluster.isStreaming())
+ hasStreaming = true;
+ else
+ hasIndexed = true;
+ }
+ if (hasIndexed == hasStreaming) return null;
+ return hasStreaming;
+ }
+
public List<SearchNode> getSearchNodes() {
return hasIndexedCluster() ? getIndexed().getSearchNodes() : nonIndexed;
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java
index 4aac8bfb647..6f0a03bab60 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.content;
+import ai.vespa.metrics.DistributorMetrics;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.vespa.config.content.core.StorDistributormanagerConfig;
import com.yahoo.vespa.config.content.core.StorServerConfig;
@@ -135,13 +136,13 @@ public class DistributorCluster extends TreeConfigProducer<Distributor> implemen
@Override
public void getConfig(MetricsmanagerConfig.Builder builder) {
ContentCluster.getMetricBuilder("log", builder).
- addedmetrics("vds.distributor.docsstored").
- addedmetrics("vds.distributor.bytesstored").
- addedmetrics("vds.idealstate.delete_bucket.done_ok").
- addedmetrics("vds.idealstate.merge_bucket.done_ok").
- addedmetrics("vds.idealstate.split_bucket.done_ok").
- addedmetrics("vds.idealstate.join_bucket.done_ok").
- addedmetrics("vds.idealstate.buckets_rechecking");
+ addedmetrics(DistributorMetrics.VDS_DISTRIBUTOR_DOCSSTORED.baseName()).
+ addedmetrics(DistributorMetrics.VDS_DISTRIBUTOR_BYTESSTORED.baseName()).
+ addedmetrics(DistributorMetrics.VDS_IDEALSTATE_DELETE_BUCKET_DONE_OK.baseName()).
+ addedmetrics(DistributorMetrics.VDS_IDEALSTATE_MERGE_BUCKET_DONE_OK.baseName()).
+ addedmetrics(DistributorMetrics.VDS_IDEALSTATE_SPLIT_BUCKET_DONE_OK.baseName()).
+ addedmetrics(DistributorMetrics.VDS_IDEALSTATE_JOIN_BUCKET_DONE_OK.baseName()).
+ addedmetrics(DistributorMetrics.VDS_IDEALSTATE_BUCKETS_RECHECKING.baseName());
}
@Override
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java b/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java
index 52b2ce06dfe..6078215f9b6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java
@@ -171,12 +171,11 @@ public class StorageGroup {
}
@Override
- public boolean equals(Object obj) {
- if (obj instanceof StorageGroup) {
- StorageGroup rhs = (StorageGroup)obj;
- return this.index.equals(rhs.index) &&
- this.name.equals(rhs.name) &&
- this.partitions.equals(rhs.partitions);
+ public boolean equals(Object o) {
+ if (o instanceof StorageGroup other) {
+ return this.index.equals(other.index) &&
+ this.name.equals(other.name) &&
+ this.partitions.equals(other.partitions);
}
return false;
}
@@ -208,9 +207,7 @@ public class StorageGroup {
this.context = context;
}
- public StorageGroup buildRootGroup(DeployState deployState,
- RedundancyBuilder redundancyBuilder,
- ContentCluster owner) {
+ public StorageGroup buildRootGroup(DeployState deployState, ContentCluster owner, Boolean isStreaming) {
try {
if (owner.isHosted())
validateRedundancyAndGroups(deployState.zone().environment());
@@ -229,7 +226,8 @@ public class StorageGroup {
? groupBuilder.buildHosted(deployState, owner, Optional.empty(), context)
: groupBuilder.buildNonHosted(deployState, owner, Optional.empty());
- Redundancy redundancy = redundancyBuilder.build(owner.isHosted(), storageGroup.subgroups.size(),
+ RedundancyBuilder redundancyBuilder = new RedundancyBuilder(clusterElement);
+ Redundancy redundancy = redundancyBuilder.build(owner.isHosted(), isStreaming, storageGroup.subgroups.size(),
storageGroup.getNumberOfLeafGroups(), storageGroup.countNodes(false));
owner.setRedundancy(redundancy);
if (storageGroup.partitions.isEmpty() && (redundancy.groups() > 1)) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java
index 2592beca6c6..f792ac3a591 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java
@@ -114,7 +114,6 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem
new SearchDefinitionBuilder().build(deployState.getDocumentModel().getDocumentManager(), documentsElement);
String routingSelection = new DocumentSelectionBuilder().build(documentsElement);
- RedundancyBuilder redundancyBuilder = new RedundancyBuilder(contentElement);
Set<NewDocumentType> globallyDistributedDocuments = new GlobalDistributionBuilder(documentDefinitions).build(documentsElement);
String clusterId = getClusterId(contentElement);
@@ -133,7 +132,7 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem
c.persistenceFactory = new EngineFactoryBuilder().build(contentElement, c);
c.storageNodes = new StorageCluster.Builder().build(deployState, c, w3cContentElement);
c.distributorNodes = new DistributorCluster.Builder(c).build(deployState, c, w3cContentElement);
- c.rootGroup = new StorageGroup.Builder(contentElement, context).buildRootGroup(deployState, redundancyBuilder, c);
+ c.rootGroup = new StorageGroup.Builder(contentElement, context).buildRootGroup(deployState, c, c.search.isStreaming());
c.clusterControllerConfig = createClusterControllerConfig(contentElement, deployState, c, resourceLimits);
validateThatGroupSiblingsAreUnique(c.clusterId, c.rootGroup);
c.search.handleRedundancy(c.redundancy);
@@ -447,7 +446,7 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem
public final ContentSearchCluster getSearch() { return search; }
- public Redundancy redundancy() { return redundancy; }
+ public Redundancy getRedundancy() { return redundancy; }
public ContentCluster setRedundancy(Redundancy redundancy) {
this.redundancy = redundancy;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
index e7bafdf52e4..d310db067a6 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java
@@ -20,7 +20,7 @@ public class RedundancyBuilder {
// Always global (across groups)
private Integer globalMinRedundancy = null;
- RedundancyBuilder(ModelElement clusterXml) {
+ public RedundancyBuilder(ModelElement clusterXml) {
ModelElement redundancyElement = clusterXml.child("redundancy");
if (redundancyElement != null) {
initialRedundancy = redundancyElement.integerAttribute("reply-after");
@@ -47,22 +47,40 @@ public class RedundancyBuilder {
throw new IllegalArgumentException("Either <redundancy> or <min-redundancy> must be set");
}
- public Redundancy build(boolean isHosted, int subGroups, int leafGroups, int totalNodes) {
+ /**
+ * @param isHosted
+ * @param isStreaming true if this cluster only has schemas with streaming mode, false if it only has schemas
+ * without streaming, null if it has both
+ * @param subGroups
+ * @param leafGroups
+ * @param totalNodes
+ * @return
+ */
+ public Redundancy build(boolean isHosted, Boolean isStreaming, int subGroups, int leafGroups, int totalNodes) {
if (isHosted) {
if (globalMinRedundancy != null && ( finalRedundancy == null || finalRedundancy * leafGroups < globalMinRedundancy ))
initialRedundancy = finalRedundancy = (int)Math.ceil((double)globalMinRedundancy / leafGroups);
if (readyCopies == null) {
- if (leafGroups > 1)
- readyCopies = 1;
- else
- readyCopies = finalRedundancy > 1 ? 2 : 1;
+ if (isStreaming == Boolean.TRUE) {
+ readyCopies = finalRedundancy;
+ }
+ else { // If isStreaming is null (mixed mode cluster) there are no good options ...
+ if (leafGroups > 1)
+ readyCopies = 1;
+ else
+ readyCopies = finalRedundancy > 1 ? 2 : 1;
+ }
}
return new Redundancy(initialRedundancy, finalRedundancy, readyCopies, leafGroups, totalNodes);
} else {
if (globalMinRedundancy != null && ( finalRedundancy == null || finalRedundancy < globalMinRedundancy))
initialRedundancy = finalRedundancy = globalMinRedundancy;
- if (readyCopies == null)
- readyCopies = finalRedundancy > 1 ? Math.max(subGroups, 2) : 1;
+ if (readyCopies == null) {
+ if (isStreaming == Boolean.TRUE)
+ readyCopies = finalRedundancy;
+ else // If isStreaming is null (mixed mode cluster) there are no good options ...
+ readyCopies = finalRedundancy > 1 ? Math.max(subGroups, 2) : 1;
+ }
subGroups = Math.max(1, subGroups);
IndexedHierarchicDistributionValidator.validateThatLeafGroupsCountIsAFactorOfRedundancy(finalRedundancy, subGroups);
IndexedHierarchicDistributionValidator.validateThatReadyCopiesIsCompatibleWithRedundancy(finalRedundancy, readyCopies, subGroups);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java
index 2d67a344a17..a1e809098f2 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.content.storagecluster;
+import ai.vespa.metrics.StorageMetrics;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.vespa.config.content.core.StorIntegritycheckerConfig;
import com.yahoo.vespa.config.content.core.StorBucketmoverConfig;
@@ -75,24 +76,24 @@ public class StorageCluster extends TreeConfigProducer<StorageNode>
@Override
public void getConfig(MetricsmanagerConfig.Builder builder) {
ContentCluster.getMetricBuilder("fleetcontroller", builder).
- addedmetrics("vds.datastored.alldisks.docs").
- addedmetrics("vds.datastored.alldisks.bytes").
- addedmetrics("vds.datastored.alldisks.buckets").
- addedmetrics("vds.datastored.bucket_space.buckets_total");
+ addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_DOCS.baseName()).
+ addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_BYTES.baseName()).
+ addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_BUCKETS.baseName()).
+ addedmetrics(StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_BUCKETS_TOTAL.baseName());
ContentCluster.getMetricBuilder("log", builder).
addedmetrics("vds.filestor.allthreads.put").
addedmetrics("vds.filestor.allthreads.get").
addedmetrics("vds.filestor.allthreads.remove").
addedmetrics("vds.filestor.allthreads.update").
- addedmetrics("vds.datastored.alldisks.docs").
- addedmetrics("vds.datastored.alldisks.bytes").
- addedmetrics("vds.filestor.queuesize").
- addedmetrics("vds.filestor.averagequeuewait").
- addedmetrics("vds.visitor.cv_queuewaittime").
- addedmetrics("vds.visitor.allthreads.averagequeuewait").
- addedmetrics("vds.visitor.allthreads.averagevisitorlifetime").
- addedmetrics("vds.visitor.allthreads.created");
+ addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_DOCS.baseName()).
+ addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_BYTES.baseName()).
+ addedmetrics(StorageMetrics.VDS_FILESTOR_QUEUESIZE.baseName()).
+ addedmetrics(StorageMetrics.VDS_FILESTOR_AVERAGEQUEUEWAIT.baseName()).
+ addedmetrics(StorageMetrics.VDS_VISITOR_CV_QUEUEWAITTIME.baseName()).
+ addedmetrics(StorageMetrics.VDS_VISITOR_ALLTHREADS_AVERAGEQUEUEWAIT.baseName()).
+ addedmetrics(StorageMetrics.VDS_VISITOR_ALLTHREADS_AVERAGEVISITORLIFETIME.baseName()).
+ addedmetrics(StorageMetrics.VDS_VISITOR_ALLTHREADS_CREATED.baseName());
}
public String getClusterName() {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 1984ceadac6..8edd446b209 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -274,7 +274,8 @@ public class OnnxModelInfo {
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
g.writeStartObject();
g.writeStringField("name", valueInfo.getName());
- g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType()));
+ var elemType = Onnx.TensorProto.DataType.forNumber(valueInfo.getType().getTensorType().getElemType());
+ g.writeStringField("type", onnxValueTypeToString(elemType));
g.writeArrayFieldStart("dim");
for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) {
g.writeStartObject();
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java b/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java
index 93e3a6e7a19..83eccc8697c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java
@@ -371,12 +371,16 @@ public class Tuning extends AnyConfigProducer implements ProtonConfig.Producer {
public static class Feeding implements ProtonConfig.Producer {
public Double concurrency = null;
+ public Double niceness = null;
@Override
public void getConfig(ProtonConfig.Builder builder) {
if (concurrency != null) {
builder.feeding.concurrency(concurrency);
}
+ if (niceness != null) {
+ builder.feeding.niceness(niceness);
+ }
}
}
diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto
index dc6542867e0..1d265ae9f28 100644
--- a/config-model/src/main/protobuf/onnx.proto
+++ b/config-model/src/main/protobuf/onnx.proto
@@ -3,8 +3,8 @@
//
-// Copyright (c) Facebook Inc. and Microsoft Corporation.
-// Licensed under the MIT license.
+// SPDX-License-Identifier: Apache-2.0
+
syntax = "proto2";
@@ -20,23 +20,16 @@ package onnx;
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
+// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
@@ -60,22 +53,60 @@ enum Version {
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
- // control. We should use version as
- // xx(major) - xx(minor) - xxxx(bugfix)
- // and we are starting with 0x00000001 (0.0.1), which was the
- // version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x00000001;
+ // control.
+ // For the IR, we are using simple numbers starting with 0x00000001,
+ // which was the version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x0000000000000001;
- // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x00000002;
+ IR_VERSION_2017_10_30 = 0x0000000000000002;
- // IR VERSION 0.0.3 published on Nov 3, 2017
+ // IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
- IR_VERSION = 0x00000003;
+ IR_VERSION_2017_11_3 = 0x0000000000000003;
+
+ // IR VERSION 4 published on Jan 22, 2019
+ // - Relax constraint that initializers should be a subset of graph inputs
+ // - Add type BFLOAT16
+ IR_VERSION_2019_1_22 = 0x0000000000000004;
+
+ // IR VERSION 5 published on March 18, 2019
+ // - Add message TensorAnnotation.
+ // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
+ IR_VERSION_2019_3_18 = 0x0000000000000005;
+
+ // IR VERSION 6 published on Sep 19, 2019
+ // - Add support for sparse tensor constants stored in model.
+ // - Add message SparseTensorProto
+ // - Add sparse initializers
+ IR_VERSION_2019_9_19 = 0x0000000000000006;
+
+ // IR VERSION 7 published on May 8, 2020
+ // - Add support to allow function body graph to rely on multiple external opreator sets.
+ // - Add a list to promote inference graph's initializers to global and
+ // mutable variables. Global variables are visible in all graphs of the
+ // stored models.
+ // - Add message TrainingInfoProto to store initialization
+ // method and training algorithm. The execution of TrainingInfoProto
+ // can modify the values of mutable variables.
+ // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
+ IR_VERSION_2020_5_8 = 0x0000000000000007;
+
+ // IR VERSION 8 published on July 30, 2021
+ // Introduce TypeProto.SparseTensor
+ // Introduce TypeProto.Optional
+ // Added a list of FunctionProtos local to the model
+ // Deprecated since_version and operator status from FunctionProto
+ IR_VERSION_2021_7_30 = 0x0000000000000008;
+
+ // IR VERSION 9 published on May 5, 2023
+ // Added AttributeProto to FunctionProto so that default attribute values can be set.
+ // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
+ IR_VERSION = 0x0000000000000009;
}
// Attributes
@@ -95,17 +126,21 @@ message AttributeProto {
STRING = 3;
TENSOR = 4;
GRAPH = 5;
+ SPARSE_TENSOR = 11;
+ TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
+ SPARSE_TENSORS = 12;
+ TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
-
+
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
@@ -117,10 +152,10 @@ message AttributeProto {
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
+ // implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
+ // change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
@@ -129,14 +164,18 @@ message AttributeProto {
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
+ optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
+ optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
+ repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
+ repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
@@ -144,7 +183,8 @@ message AttributeProto {
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
+ // This field MUST be present in this version of the IR for
+ // inputs and outputs of the top-level graph.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
@@ -155,7 +195,7 @@ message ValueInfoProto {
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
+// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
@@ -177,12 +217,130 @@ message NodeProto {
optional string doc_string = 6;
}
+// Training information
+// TrainingInfoProto stores information for training a model.
+// In particular, this defines two functionalities: an initialization-step
+// and a training-algorithm-step. Initialization resets the model
+// back to its original state as if no training has been performed.
+// Training algorithm improves the model based on input data.
+//
+// The semantics of the initialization-step is that the initializers
+// in ModelProto.graph and in TrainingInfoProto.algorithm are first
+// initialized as specified by the initializers in the graph, and then
+// updated by the "initialization_binding" in every instance in
+// ModelProto.training_info.
+//
+// The field "algorithm" defines a computation graph which represents a
+// training algorithm's step. After the execution of a
+// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
+// may be immediately updated. If the targeted training algorithm contains
+// consecutive update steps (such as block coordinate descent methods),
+// the user needs to create a TrainingInfoProto for each step.
+message TrainingInfoProto {
+ // This field describes a graph to compute the initial tensors
+ // upon starting the training process. Initialization graph has no input
+ // and can have multiple outputs. Usually, trainable tensors in neural
+ // networks are randomly initialized. To achieve that, for each tensor,
+ // the user can put a random number operator such as RandomNormal or
+ // RandomUniform in TrainingInfoProto.initialization.node and assign its
+ // random output to the specific tensor using "initialization_binding".
+ // This graph can also set the initializers in "algorithm" in the same
+ // TrainingInfoProto; a use case is resetting the number of training
+ // iteration to zero.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Thus, no initializer would be changed by default.
+ optional GraphProto initialization = 1;
+
+ // This field represents a training algorithm step. Given required inputs,
+ // it computes outputs to update initializers in its own or inference graph's
+ // initializer lists. In general, this field contains loss node, gradient node,
+ // optimizer node, increment of iteration count.
+ //
+ // An execution of the training algorithm step is performed by executing the
+ // graph obtained by combining the inference graph (namely "ModelProto.graph")
+ // and the "algorithm" graph. That is, the actual
+ // input/initializer/output/node/value_info/sparse_initializer list of
+ // the training graph is the concatenation of
+ // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
+ // and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
+ // in that order. This combined graph must satisfy the normal ONNX conditions.
+ // Now, let's provide a visualization of graph combination for clarity.
+ // Let the inference graph (i.e., "ModelProto.graph") be
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
+ // and the "algorithm" graph be
+ // tensor_d -> Add -> tensor_e
+ // The combination process results
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
+ //
+ // Notice that an input of a node in the "algorithm" graph may reference the
+ // output of a node in the inference graph (but not the other way round). Also, inference
+ // node cannot reference inputs of "algorithm". With these restrictions, inference graph
+ // can always be run independently without training information.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Evaluating the default training step never
+ // update any initializers.
+ optional GraphProto algorithm = 2;
+
+ // This field specifies the bindings from the outputs of "initialization" to
+ // some initializers in "ModelProto.graph.initializer" and
+ // the "algorithm.initializer" in the same TrainingInfoProto.
+ // See "update_binding" below for details.
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "initialization".
+ repeated StringStringEntryProto initialization_binding = 3;
+
+ // Gradient-based training is usually an iterative procedure. In one gradient
+ // descent iteration, we apply
+ //
+ // x = x - r * g
+ //
+ // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
+ // gradient of "x" with respect to a chosen loss. To avoid adding assignments
+ // into the training graph, we split the update equation into
+ //
+ // y = x - r * g
+ // x = y
+ //
+ // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
+ // tell that "y" should be assigned to "x", the field "update_binding" may
+ // contain a key-value pair of strings, "x" (key of StringStringEntryProto)
+ // and "y" (value of StringStringEntryProto).
+ // For a neural network with multiple trainable (mutable) tensors, there can
+ // be multiple key-value pairs in "update_binding".
+ //
+ // The initializers appears as keys in "update_binding" are considered
+ // mutable variables. This implies some behaviors
+ // as described below.
+ //
+ // 1. We have only unique keys in all "update_binding"s so that two
+ // variables may not have the same name. This ensures that one
+ // variable is assigned up to once.
+ // 2. The keys must appear in names of "ModelProto.graph.initializer" or
+ // "TrainingInfoProto.algorithm.initializer".
+ // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
+ // 4. Mutable variables are initialized to the value specified by the
+ // corresponding initializer, and then potentially updated by
+ // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
+ //
+ // This field usually contains names of trainable tensors
+ // (in ModelProto.graph), optimizer states such as momentums in advanced
+ // stochastic gradient methods (in TrainingInfoProto.graph),
+ // and number of training iterations (in TrainingInfoProto.graph).
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "algorithm".
+ repeated StringStringEntryProto update_binding = 4;
+}
+
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
-// The semantics of the model are described by the associated GraphProto.
+// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
@@ -227,18 +385,58 @@ message ModelProto {
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
+
+ // Training-specific information. Sequentially executing all stored
+ // `TrainingInfoProto.algorithm`s and assigning their outputs following
+ // the corresponding `TrainingInfoProto.update_binding`s is one training
+ // iteration. Similarly, to initialize the model
+ // (as if training hasn't happened), the user should sequentially execute
+ // all stored `TrainingInfoProto.initialization`s and assigns their outputs
+ // using `TrainingInfoProto.initialization_binding`s.
+ //
+ // If this field is empty, the training behavior of the model is undefined.
+ repeated TrainingInfoProto training_info = 20;
+
+ // A list of function protos local to the model.
+ //
+ // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
+ // In case of any conflicts the behavior (whether the model local functions are given higher priority,
+ // or standard operator sets are given higher priotity or this is treated as error) is defined by
+ // the runtimes.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto and other model local FunctionProtos.
+ // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
+ // or by 2 FunctionProtos then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same for every node in the function body.
+ //
+ // One FunctionProto can reference other FunctionProto in the model, however, recursive reference
+ // is not allowed.
+ repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
optional string key = 1;
- optional string value= 2;
+ optional string value = 2;
};
+message TensorAnnotation {
+ optional string tensor_name = 1;
+ // <key, value> pairs to annotate tensor specified by <tensor_name> above.
+ // The keys used in the mapping below must be pre-defined in ONNX spec.
+ // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
+ // quantization parameter keys.
+ repeated StringStringEntryProto quant_parameter_tensor_names = 2;
+}
+
+
+
// Graphs
//
-// A graph defines the computational logic of a model and is comprised of a parameterized
+// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
@@ -250,10 +448,14 @@ message GraphProto {
optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
+ // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
+ // The name MUST be unique across both initializer and sparse_initializer,
+ // but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
+ // Initializers (see above) stored in sparse format.
+ repeated SparseTensorProto sparse_initializer = 15;
+
// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;
@@ -265,13 +467,14 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
+ // This field carries information to indicate the mapping among a tensor and its
+ // quantization parameter tensors. For example:
+ // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
+ // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
+ repeated TensorAnnotation quantization_annotation = 14;
+
+ reserved 3, 4, 6 to 9;
+ reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
@@ -291,13 +494,32 @@ message TensorProto {
STRING = 8; // string
BOOL = 9; // bool
- // Advanced types
+ // IEEE754 half-precision floating-point format (16 bits wide).
+ // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
+
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
+
+ // Non-IEEE floating-point format based on IEEE754 single-precision
+ // floating-point number truncated to 16 bits.
+ // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+ BFLOAT16 = 16;
+
+ // Non-IEEE floating-point format based on papers
+ // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
+ // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
+ // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
+ // The computation usually happens inside a block quantize / dequantize
+ // fused by the runtime.
+ FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
+ FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
+ FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
+ FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
+
// Future extensions go here.
}
@@ -305,7 +527,8 @@ message TensorProto {
repeated int64 dims = 1;
// The data type of the tensor.
- optional DataType data_type = 2;
+ // This field MUST have a valid TensorProto.DataType value
+ optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
@@ -324,17 +547,17 @@ message TensorProto {
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
- // For int32, uint8, int8, uint16, int16, bool, and float16 values
- // float16 values must be bit-wise converted to an uint16_t prior
+ // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
+ // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
@@ -371,10 +594,32 @@ message TensorProto {
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;
+ // Data can be stored inside the protobuf file using type-specific fields or raw_data.
+ // Alternatively, raw bytes data can be stored in an external file, using the external_data field.
+ // external_data stores key-value pairs describing data location. Recognized keys are:
+ // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
+ // protobuf model was stored
+ // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
+ // Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
+ // - "length" (optional) - number of bytes containing data. Integer stored as string.
+ // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
+ repeated StringStringEntryProto external_data = 13;
+
+ // Location of the data for this tensor. MUST be one of:
+ // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
+ // - EXTERNAL - data stored in an external location as described by external_data field.
+ enum DataLocation {
+ DEFAULT = 0;
+ EXTERNAL = 1;
+ }
+
+ // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
+ optional DataLocation data_location = 14;
+
// For double
- // Complex64 tensors are encoded as a single array of doubles,
+ // Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
@@ -386,6 +631,30 @@ message TensorProto {
repeated uint64 uint64_data = 11 [packed = true];
}
+// A serialized sparse-tensor value
+message SparseTensorProto {
+ // The sequence of non-default values are encoded as a tensor of shape [NNZ].
+ // The default-value is zero for numeric tensors, and empty-string for string tensors.
+ // values must have a non-empty name present which serves as a name for SparseTensorProto
+ // when used in sparse_initializer list.
+ optional TensorProto values = 1;
+
+ // The indices of the non-default values, which may be stored in one of two formats.
+ // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
+ // corresponding to the j-th index of the i-th value (in the values tensor).
+ // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
+ // must be the linearized-index of the i-th value (in the values tensor).
+ // The linearized-index can be converted into an index tuple (k_1,...,k_rank)
+ // using the shape provided below.
+ // The indices must appear in ascending order without duplication.
+ // In the first format, the ordering is lexicographic-ordering:
+ // e.g., index-value [1,4] must appear before [2,1]
+ optional TensorProto indices = 2;
+
+ // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
+ repeated int64 dims = 3;
+}
+
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
@@ -398,36 +667,13 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
+ // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
+ // for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}
-// A set of pre-defined constants to be used as values for
-// the standard denotation field in TensorShapeProto.Dimension
-// for semantic description of the tensor dimension.
-message DenotationConstProto {
- // Describe a batch number dimension.
- optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
- // Describe a channel dimension.
- optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
- // Describe a time dimension.
- optional string DATA_TIME = 3 [default = "DATA_TIME"];
- // Describe a feature dimension. This is typically a feature
- // dimension in RNN and/or spatial dimension in CNN.
- optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
- // Describe a filter in-channel dimension. This is the dimension
- // that is identical (in size) to the channel dimension of the input
- // image feature maps.
- optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
- // Describe a filter out channel dimension. This is the dimension
- // that is identical (int size) to the channel dimension of the output
- // image feature maps.
- optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
- // Describe a filter spatial dimension.
- optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
-}
-
// Types
//
// The standard ONNX data types.
@@ -435,8 +681,43 @@ message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ optional int32 elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+ // repeated T
+ message Sequence {
+ // The type and optional shape of each element of the sequence.
+ // This field MUST be present for this version of the IR.
+ optional TypeProto elem_type = 1;
+ };
+
+ // map<K,V>
+ message Map {
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
+ optional int32 key_type = 1;
+ // This field MUST be present for this version of the IR.
+ optional TypeProto value_type = 2;
+ };
+
+ // wrapper for Tensor, Sequence, or Map
+ message Optional {
+ // The type and optional shape of the element wrapped.
+ // This field MUST be present for this version of the IR.
+ // Possible values correspond to OptionalProto.DataType enum
+ optional TypeProto elem_type = 1;
+ };
+
+
+ message SparseTensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
+ optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
@@ -445,7 +726,31 @@ message TypeProto {
// The type of a tensor.
Tensor tensor_type = 1;
+ // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
+ // as input and output to graphs and nodes. These types are needed to naturally
+ // support classical ML operators. DNN operators SHOULD restrict their input
+ // and output types to tensors.
+
+ // The type of a sequence.
+ Sequence sequence_type = 4;
+
+ // The type of a map.
+ Map map_type = 5;
+
+ // The type of an optional.
+ Optional optional_type = 9;
+
+
+ // Type of the sparse tensor
+ SparseTensor sparse_tensor_type = 8;
+
}
+
+ // An optional denotation can be used to denote the whole
+ // type with a standard semantic description as to what is
+ // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
+ // for pre-defined type denotations.
+ optional string denotation = 6;
}
// Operator Sets
@@ -461,4 +766,70 @@ message OperatorSetIdProto {
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
-} \ No newline at end of file
+}
+
+// Operator/function status.
+enum OperatorStatus {
+ EXPERIMENTAL = 0;
+ STABLE = 1;
+}
+
+message FunctionProto {
+ // The name of the function, similar usage of op_type in OperatorProto.
+ // Combined with FunctionProto.domain, this forms the unique identity of
+ // the FunctionProto.
+ optional string name = 1;
+
+ // Deprecated since IR Version 8
+ // optional int64 since_version = 2;
+ reserved 2;
+ reserved "since_version";
+
+ // Deprecated since IR Version 8
+ // optional OperatorStatus status = 3;
+ reserved 3;
+ reserved "status";
+
+ // The inputs and outputs of the function.
+ repeated string input = 4;
+ repeated string output = 5;
+
+ // The attribute parameters of the function.
+ // It is for function parameters without default values.
+ repeated string attribute = 6;
+
+ // The attribute protos of the function.
+ // It is for function attributes with default values.
+ // A function attribute shall be represented either as
+ // a string attribute or an AttributeProto, not both.
+ repeated AttributeProto attribute_proto = 11;
+
+ // The nodes in the function.
+ repeated NodeProto node = 7;
+ // A human-readable documentation for this function. Markdown is allowed.
+ optional string doc_string = 8;
+
+ // The OperatorSets this function body (graph) relies on.
+ //
+ // All nodes in the function body (graph) will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets. This means at most one version can be relied
+ // for one domain.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
+ // and ModelProto then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same.
+
+ repeated OperatorSetIdProto opset_import = 9;
+
+ // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
+ // the FunctionProto.
+ optional string domain = 10;
+}
+
+
+// For using protobuf-lite
+option optimize_for = LITE_RUNTIME;
+
diff --git a/config-model/src/main/resources/schema/content.rnc b/config-model/src/main/resources/schema/content.rnc
index 5833b575a74..a73236454c6 100644
--- a/config-model/src/main/resources/schema/content.rnc
+++ b/config-model/src/main/resources/schema/content.rnc
@@ -370,7 +370,8 @@ Tuning = element tuning {
element threads { xsd:nonNegativeInteger }?
}? &
element feeding {
- element concurrency { xsd:double { minInclusive = "0.0" maxInclusive = "1.0" } }?
+ element concurrency { xsd:double { minInclusive = "0.0" maxInclusive = "1.0" } }? &
+ element niceness { xsd:double { minInclusive = "0.0" maxInclusive = "1.0" } }?
}? &
element removed-db {
element prune {
diff --git a/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java b/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java
index 84804bc48fa..19fe9e0038d 100644
--- a/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java
+++ b/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java
@@ -1183,9 +1183,9 @@ public class ModelProvisioningTest {
ContentCluster cluster = model.getContentClusters().get("bar");
List<StorageGroup> subGroups = cluster.getRootGroup().getSubgroups();
- assertEquals(2*3, cluster.redundancy().effectiveInitialRedundancy()); // Reduced from 3*3
- assertEquals(2*3, cluster.redundancy().effectiveFinalRedundancy()); // Reduced from 3*4
- assertEquals(2*3, cluster.redundancy().effectiveReadyCopies()); // Reduced from 3*3
+ assertEquals(2*3, cluster.getRedundancy().effectiveInitialRedundancy()); // Reduced from 3*3
+ assertEquals(2*3, cluster.getRedundancy().effectiveFinalRedundancy()); // Reduced from 3*4
+ assertEquals(2*3, cluster.getRedundancy().effectiveReadyCopies()); // Reduced from 3*3
assertEquals("2|2|*", cluster.getRootGroup().getPartitions().get()); // Reduced from 4|4|*
assertEquals(0, cluster.getRootGroup().getNodes().size());
assertEquals(3, subGroups.size());
@@ -1257,9 +1257,9 @@ public class ModelProvisioningTest {
assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size());
ContentCluster cluster = model.getContentClusters().get("bar");
- assertEquals(2, cluster.redundancy().effectiveInitialRedundancy());
- assertEquals(2, cluster.redundancy().effectiveFinalRedundancy());
- assertEquals(2, cluster.redundancy().effectiveReadyCopies());
+ assertEquals(2, cluster.getRedundancy().effectiveInitialRedundancy());
+ assertEquals(2, cluster.getRedundancy().effectiveFinalRedundancy());
+ assertEquals(2, cluster.getRedundancy().effectiveReadyCopies());
assertEquals("1|*", cluster.getRootGroup().getPartitions().get());
assertEquals(0, cluster.getRootGroup().getNodes().size());
assertEquals(2, cluster.getRootGroup().getSubgroups().size());
@@ -1287,9 +1287,9 @@ public class ModelProvisioningTest {
ContentCluster cluster = model.getContentClusters().get("bar");
assertEquals(2, cluster.getStorageCluster().getChildren().size());
- assertEquals(1, cluster.redundancy().effectiveInitialRedundancy());
- assertEquals(1, cluster.redundancy().effectiveFinalRedundancy());
- assertEquals(1, cluster.redundancy().effectiveReadyCopies());
+ assertEquals(1, cluster.getRedundancy().effectiveInitialRedundancy());
+ assertEquals(1, cluster.getRedundancy().effectiveFinalRedundancy());
+ assertEquals(1, cluster.getRedundancy().effectiveReadyCopies());
assertEquals(2, cluster.getRootGroup().getNodes().size());
assertEquals(0, cluster.getRootGroup().getSubgroups().size());
}
@@ -1324,9 +1324,9 @@ public class ModelProvisioningTest {
assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size());
ContentCluster cluster = model.getContentClusters().get("bar");
- assertEquals(4, cluster.redundancy().effectiveInitialRedundancy());
- assertEquals(4, cluster.redundancy().effectiveFinalRedundancy());
- assertEquals(4, cluster.redundancy().effectiveReadyCopies());
+ assertEquals(4, cluster.getRedundancy().effectiveInitialRedundancy());
+ assertEquals(4, cluster.getRedundancy().effectiveFinalRedundancy());
+ assertEquals(4, cluster.getRedundancy().effectiveReadyCopies());
assertEquals(4, cluster.getSearch().getIndexed().getDispatchSpec().getGroups().size());
assertEquals(4, cluster.getSearch().getIndexed().getSearchableCopies());
assertFalse(cluster.getRootGroup().getPartitions().isPresent());
@@ -1368,9 +1368,9 @@ public class ModelProvisioningTest {
assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size());
ContentCluster cluster = model.getContentClusters().get("bar");
- assertEquals(1, cluster.redundancy().effectiveInitialRedundancy()); // Reduced from 3*3
- assertEquals(1, cluster.redundancy().effectiveFinalRedundancy()); // Reduced from 3*4
- assertEquals(1, cluster.redundancy().effectiveReadyCopies()); // Reduced from 3*3
+ assertEquals(1, cluster.getRedundancy().effectiveInitialRedundancy()); // Reduced from 3*3
+ assertEquals(1, cluster.getRedundancy().effectiveFinalRedundancy()); // Reduced from 3*4
+ assertEquals(1, cluster.getRedundancy().effectiveReadyCopies()); // Reduced from 3*3
assertFalse(cluster.getRootGroup().getPartitions().isPresent()); // 1 group - > flattened -> no distribution
assertEquals(1, cluster.getRootGroup().getNodes().size());
assertEquals(0, cluster.getRootGroup().getSubgroups().size());
@@ -1473,9 +1473,9 @@ public class ModelProvisioningTest {
assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size());
ContentCluster cluster = model.getContentClusters().get("bar");
- assertEquals(1, cluster.redundancy().effectiveInitialRedundancy());
- assertEquals(1, cluster.redundancy().effectiveFinalRedundancy());
- assertEquals(1, cluster.redundancy().effectiveReadyCopies());
+ assertEquals(1, cluster.getRedundancy().effectiveInitialRedundancy());
+ assertEquals(1, cluster.getRedundancy().effectiveFinalRedundancy());
+ assertEquals(1, cluster.getRedundancy().effectiveReadyCopies());
assertEquals(1, cluster.getSearch().getIndexed().getDispatchSpec().getGroups().size());
assertFalse(cluster.getRootGroup().getPartitions().isPresent());
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java
index e3e9fc1a232..db15d7e0a78 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java
@@ -285,9 +285,13 @@ public class DomSchemaTuningBuilderTest extends DomBuilderTest {
void requireThatWeCanParseFeedingTag() {
Tuning t = createTuning(parseXml("<feeding>",
"<concurrency>0.7</concurrency>",
+ "<niceness>0.3</niceness>",
"</feeding>"));
assertEquals(0.7, t.searchNode.feeding.concurrency, DELTA);
- assertEquals(getProtonCfg(t).feeding().concurrency(), 0.7, DELTA);
+ assertEquals(0.3, t.searchNode.feeding.niceness, DELTA);
+ var cfg = getProtonCfg(t);
+ assertEquals(cfg.feeding().concurrency(), 0.7, DELTA);
+ assertEquals(cfg.feeding().niceness(), 0.3, DELTA);
}
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java
index 4ce7119f5f7..73bbd6ee464 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java
@@ -38,6 +38,7 @@ import com.yahoo.vespa.model.routing.DocumentProtocol;
import com.yahoo.vespa.model.routing.Routing;
import com.yahoo.vespa.model.test.utils.ApplicationPackageUtils;
import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg;
+import com.yahoo.yolean.Exceptions;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
@@ -470,7 +471,8 @@ public class ContentClusterTest extends ContentBaseTest {
new VespaModelCreatorWithMockPkg(getHosts(), xml, sds).create();
fail("Deploying without redundancy should fail");
} catch (IllegalArgumentException e) {
- assertTrue(e.getMessage().contains("Either <redundancy> or <min-redundancy> must be set"), e.getMessage());
+ assertEquals("In content cluster 'bar': Either <redundancy> or <min-redundancy> must be set",
+ Exceptions.toMessageString(e));
}
}
@@ -478,12 +480,13 @@ public class ContentClusterTest extends ContentBaseTest {
void testRedundancyFinalLessThanInitial() {
try {
parse(
- "<content version=\"1.0\" id=\"storage\">\n" +
- " <redundancy reply-after=\"4\">2</redundancy>\n" +
- " <group>" +
- " <node hostalias='node0' distribution-key='0' />" +
- " </group>" +
- "</content>"
+ """
+ <content version="1.0" id="storage">
+ <redundancy reply-after="4">2</redundancy>
+ <group>
+ <node hostalias='node0' distribution-key='0' />
+ </group>
+ </content>"""
);
fail("no exception thrown");
} catch (Exception e) { /* ignore */
@@ -494,17 +497,18 @@ public class ContentClusterTest extends ContentBaseTest {
void testReadyTooHigh() {
try {
parse(
- "<content version=\"1.0\" id=\"storage\">\n" +
- " <engine>" +
- " <proton>" +
- " <searchable-copies>3</searchable-copies>" +
- " </proton>" +
- " </engine>" +
- " <redundancy>2</redundancy>\n" +
- " <group>" +
- " <node hostalias='node0' distribution-key='0' />" +
- " </group>" +
- "</content>"
+ """
+ <content version="1.0" id="storage">
+ <engine>
+ <proton>
+ <searchable-copies>3</searchable-copies>
+ </proton>
+ </engine>
+ <redundancy>2</redundancy>
+ <group>
+ <node hostalias='node0' distribution-key='0' />
+ </group>
+ </content>"""
);
fail("no exception thrown");
} catch (Exception e) { /* ignore */
@@ -972,15 +976,17 @@ public class ContentClusterTest extends ContentBaseTest {
@Test
void reserved_document_name_throws_exception() {
- String xml = "<content version=\"1.0\" id=\"storage\">" +
- " <redundancy>1</redundancy>" +
- " <documents>" +
- " <document type=\"true\" mode=\"index\"/>" +
- " </documents>" +
- " <group>" +
- " <node distribution-key=\"0\" hostalias=\"mockhost\"/>" +
- " </group>" +
- "</content>";
+ String xml = """
+ <content version="1.0" id="storage">
+ <redundancy>1</redundancy>
+ <documents>
+ <document type="true" mode="index"/>
+ </documents>
+ <group>
+ <node distribution-key="0" hostalias="mockhost"/>
+ </group>
+ </content>
+ """;
List<String> sds = ApplicationPackageUtils.generateSchemas("true");
try {
@@ -991,6 +997,65 @@ public class ContentClusterTest extends ContentBaseTest {
}
}
+ @Test
+ void default_searchable_copies_indexing() {
+ String services = """
+ <content version="1.0" id="storage">
+ <redundancy>3</redundancy>
+ <documents>
+ <document type="music" mode="index"/>
+ </documents>
+ <group>
+ <node distribution-key="0" hostalias="mockhost"/>
+ <node distribution-key="1" hostalias="mockhost"/>
+ <node distribution-key="2" hostalias="mockhost"/>
+ </group>
+ </content>
+ """;
+ var model = new VespaModelCreatorWithMockPkg(null, services, ApplicationPackageUtils.generateSchemas("music")).create();
+ assertEquals(2, model.getContentClusters().get("storage").getRedundancy().readyCopies());
+ }
+
+ @Test
+ void default_searchable_copies_streaming() {
+ String services = """
+ <content version="1.0" id="storage">
+ <redundancy>3</redundancy>
+ <documents>
+ <document type="mail" mode="streaming"/>
+ </documents>
+ <group>
+ <node distribution-key="0" hostalias="mockhost"/>
+ <node distribution-key="1" hostalias="mockhost"/>
+ <node distribution-key="2" hostalias="mockhost"/>
+ </group>
+ </content>
+ """;
+ var model = new VespaModelCreatorWithMockPkg(null, services, ApplicationPackageUtils.generateSchemas("mail")).create();
+ assertEquals(3, model.getContentClusters().get("storage").getRedundancy().readyCopies());
+ }
+
+ /** Here there is no good choice. */
+ @Test
+ void default_searchable_copies_mixed() {
+ String services = """
+ <content version="1.0" id="storage">
+ <redundancy>3</redundancy>
+ <documents>
+ <document type="music" mode="index"/>
+ <document type="mail" mode="streaming"/>
+ </documents>
+ <group>
+ <node distribution-key="0" hostalias="mockhost"/>
+ <node distribution-key="1" hostalias="mockhost"/>
+ <node distribution-key="2" hostalias="mockhost"/>
+ </group>
+ </content>
+ """;
+ var model = new VespaModelCreatorWithMockPkg(null, services, ApplicationPackageUtils.generateSchemas("music", "mail")).create();
+ assertEquals(2, model.getContentClusters().get("storage").getRedundancy().readyCopies());
+ }
+
private void assertClusterHasBucketSpaceMappings(AllClustersBucketSpacesConfig config, String clusterId,
List<String> defaultSpaceTypes, List<String> globalSpaceTypes) {
AllClustersBucketSpacesConfig.Cluster cluster = config.cluster(clusterId);
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
index 36f09f989a7..b662179c418 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
@@ -24,7 +24,6 @@ import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.DataplaneToken;
import com.yahoo.config.provision.DockerImage;
import com.yahoo.config.provision.HostName;
-import com.yahoo.config.provision.TenantName;
import com.yahoo.config.provision.Zone;
import com.yahoo.container.jdisc.secretstore.SecretStore;
import com.yahoo.vespa.config.server.tenant.SecretStoreExternalIdRetriever;
@@ -34,6 +33,7 @@ import com.yahoo.vespa.flags.Flags;
import com.yahoo.vespa.flags.PermanentFlags;
import com.yahoo.vespa.flags.StringFlag;
import com.yahoo.vespa.flags.UnboundFlag;
+
import java.io.File;
import java.net.URI;
import java.security.cert.X509Certificate;
@@ -319,13 +319,7 @@ public class ModelContextImpl implements ModelContext {
return flag.bindTo(source)
.with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm())
.with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString())
- .boxedValue();
- }
-
- private static <V> V flagValue(FlagSource source, TenantName tenant, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) {
- return flag.bindTo(source)
- .with(FetchVector.Dimension.TENANT_ID, tenant.value())
- .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString())
+ .with(FetchVector.Dimension.TENANT_ID, appId.tenant().value())
.boxedValue();
}
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java
index 0532a81617f..b2762b2a3d4 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java
@@ -98,7 +98,7 @@ public class ApplicationApiHandler extends SessionHandler {
"Unable to parse multipart in deploy from tenant '" + tenantName.value() + "': " + Exceptions.toMessageString(e));
var message = "Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage();
- log.log(INFO, message + ", parts: " + parts, e);
+ log.log(FINE, message + ", parts: " + parts, e);
throw new BadRequestException("Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage());
}
} else {
diff --git a/container-core/abi-spec.json b/container-core/abi-spec.json
index 236586a4132..572d18b02f3 100644
--- a/container-core/abi-spec.json
+++ b/container-core/abi-spec.json
@@ -1044,6 +1044,7 @@
"public com.yahoo.jdisc.http.ConnectorConfig$Builder requestHeaderSize(int)",
"public com.yahoo.jdisc.http.ConnectorConfig$Builder responseHeaderSize(int)",
"public com.yahoo.jdisc.http.ConnectorConfig$Builder acceptQueueSize(int)",
+ "public com.yahoo.jdisc.http.ConnectorConfig$Builder maxContentSize(long)",
"public com.yahoo.jdisc.http.ConnectorConfig$Builder reuseAddress(boolean)",
"public com.yahoo.jdisc.http.ConnectorConfig$Builder idleTimeout(double)",
"public com.yahoo.jdisc.http.ConnectorConfig$Builder tcpKeepAliveEnabled(boolean)",
@@ -1413,6 +1414,7 @@
"public int requestHeaderSize()",
"public int responseHeaderSize()",
"public int acceptQueueSize()",
+ "public long maxContentSize()",
"public boolean reuseAddress()",
"public double idleTimeout()",
"public boolean tcpKeepAliveEnabled()",
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java
index e1ec22bd622..af98e380f2a 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.jdisc.state;
+import ai.vespa.metrics.ContainerMetrics;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
@@ -289,7 +290,7 @@ public class StateHandler extends AbstractRequestHandler implements CapabilityRe
Tuple latencySeconds = new Tuple(NULL_DIMENSIONS, "latencySeconds", null);
for (Map.Entry<MetricDimensions, MetricSet> entry : snapshot) {
MetricSet metricSet = entry.getValue();
- MetricValue val = metricSet.get("serverTotalSuccessfulResponseLatency");
+ MetricValue val = metricSet.get(ContainerMetrics.SERVER_TOTAL_SUCCESFUL_RESPONSE_LATENCY.baseName());
if (val instanceof GaugeMetric gauge) {
latencySeconds.add(GaugeMetric.newInstance(gauge.getLast() / 1000,
gauge.getMax() / 1000,
@@ -297,7 +298,7 @@ public class StateHandler extends AbstractRequestHandler implements CapabilityRe
gauge.getSum() / 1000,
gauge.getCount()));
}
- requestsPerSecond.add(metricSet.get("serverNumSuccessfulResponses"));
+ requestsPerSecond.add(metricSet.get(ContainerMetrics.SERVER_NUM_SUCCESSFUL_RESPONSES.baseName()));
}
List<Tuple> lst = new ArrayList<>();
if (requestsPerSecond.val != null) {
diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java
index 2f2c48e0b48..75ef655c60c 100644
--- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java
+++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java
@@ -1,16 +1,19 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.http.server.jetty;
+import com.yahoo.jdisc.Response;
import com.yahoo.jdisc.handler.CompletionHandler;
import com.yahoo.jdisc.handler.ContentChannel;
import jakarta.servlet.ReadListener;
import jakarta.servlet.ServletInputStream;
import jakarta.servlet.http.HttpServletRequest;
+import org.eclipse.jetty.server.Request;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -33,6 +36,7 @@ import java.util.logging.Logger;
*/
class ServletRequestReader {
+
private enum State {
NOT_STARTED, READING, ALL_DATA_READ, REQUEST_CONTENT_CLOSED
}
@@ -96,12 +100,15 @@ class ServletRequestReader {
private final CompletableFuture<Void> finishedFuture = new CompletableFuture<>();
ServletRequestReader(
- HttpServletRequest req,
+ Request req,
ContentChannel requestContentChannel,
Janitor janitor,
RequestMetricReporter metricReporter) {
this.req = Objects.requireNonNull(req);
- this.requestContentChannel = Objects.requireNonNull(requestContentChannel);
+ long maxContentSize = RequestUtils.getConnector(req).connectorConfig().maxContentSize();
+ this.requestContentChannel = maxContentSize >= 0
+ ? new ByteLimitedContentChannel(Objects.requireNonNull(requestContentChannel), maxContentSize)
+ : Objects.requireNonNull(requestContentChannel);
this.janitor = Objects.requireNonNull(janitor);
this.metricReporter = Objects.requireNonNull(metricReporter);
}
@@ -259,4 +266,30 @@ class ServletRequestReader {
}
}
+ private static class ByteLimitedContentChannel implements ContentChannel {
+ private final long maxContentSize;
+ private final AtomicLong bytesWritten = new AtomicLong();
+ private final ContentChannel delegate;
+
+ ByteLimitedContentChannel(ContentChannel delegate, long maxContentSize) {
+ this.delegate = delegate;
+ this.maxContentSize = maxContentSize;
+ }
+
+ @Override
+ public void write(ByteBuffer buf, CompletionHandler handler) {
+ long written = bytesWritten.addAndGet(buf.remaining());
+ if (written > maxContentSize) {
+ handler.failed(new RequestException(
+ Response.Status.REQUEST_TOO_LONG,
+ "Request content length %d exceeds limit of %d bytes".formatted(written, maxContentSize)));
+ return;
+ }
+ delegate.write(buf, handler);
+ }
+
+ @Override public void close(CompletionHandler h) { delegate.close(h); }
+ @Override public void onError(Throwable t) { delegate.onError(t); }
+ }
+
}
diff --git a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def
index bdcc3f9e40a..3c01012fd9e 100644
--- a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def
+++ b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def
@@ -22,6 +22,9 @@ responseHeaderSize int default=65536
# The accept queue size (also known as accept backlog).
acceptQueueSize int default=0
+# Max content size allowed for requests. Set to -1 to disable.
+maxContentSize long default=-1
+
# Whether the server socket reuses addresses.
reuseAddress bool default=true
diff --git a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java
index 0a697bd8fb3..6f9c854be64 100644
--- a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java
+++ b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java
@@ -70,6 +70,7 @@ import static com.yahoo.jdisc.Response.Status.GATEWAY_TIMEOUT;
import static com.yahoo.jdisc.Response.Status.INTERNAL_SERVER_ERROR;
import static com.yahoo.jdisc.Response.Status.NOT_FOUND;
import static com.yahoo.jdisc.Response.Status.OK;
+import static com.yahoo.jdisc.Response.Status.REQUEST_TOO_LONG;
import static com.yahoo.jdisc.Response.Status.REQUEST_URI_TOO_LONG;
import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED;
import static com.yahoo.jdisc.Response.Status.UNSUPPORTED_MEDIA_TYPE;
@@ -172,11 +173,11 @@ public class HttpServerTest {
driver.client()
.newGet("/status.html").addHeader("Host", "localhost").addHeader("Host", "vespa.ai").execute()
.expectStatusCode(is(BAD_REQUEST)).expectContent(containsString("reason: Duplicate Host Header"));
- assertTrue(driver.close());
var aggregator = ResponseMetricAggregator.getBean(driver.server());
var metric = waitForStatistics(aggregator);
assertEquals(400, metric.dimensions.statusCode);
assertEquals("GET", metric.dimensions.method);
+ assertTrue(driver.close());
}
@Test
@@ -795,6 +796,17 @@ public class HttpServerTest {
assertTrue(driver.close());
}
+ @Test
+ void exceedingMaxContentSizeReturns413() throws IOException {
+ JettyTestDriver driver = JettyTestDriver.newConfiguredInstance(
+ new EchoRequestHandler(),
+ new ServerConfig.Builder(),
+ new ConnectorConfig.Builder().maxContentSize(4));
+ driver.client().newPost("/").setBinaryContent(new byte[4]).execute().expectStatusCode(is(OK));
+ driver.client().newPost("/").setBinaryContent(new byte[5]).execute().expectStatusCode(is(REQUEST_TOO_LONG));
+ assertTrue(driver.close());
+ }
+
private static JettyTestDriver createSslWithTlsClientAuthenticationEnforcer(Path certificateFile, Path privateKeyFile) {
ConnectorConfig.Builder connectorConfig = new ConnectorConfig.Builder()
.tlsClientAuthEnforcer(
diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml
index d63a1867f9a..a016c44f829 100644
--- a/container-dependencies-enforcer/pom.xml
+++ b/container-dependencies-enforcer/pom.xml
@@ -60,8 +60,6 @@
<rules>
<enforceDependencies implementation="com.yahoo.vespa.maven.plugin.enforcer.EnforceDependencies">
<allowed>
- <include>*:*:*:test</include>
- <include>com.yahoo.vespa:*:*:*</include>
<include>aopalliance:aopalliance:${aopalliance.version}:provided</include>
<include>com.fasterxml.jackson.core:jackson-annotations:${jackson2.version}:provided</include>
<include>com.fasterxml.jackson.core:jackson-core:${jackson2.version}:provided</include>
@@ -89,6 +87,121 @@
<include>org.slf4j:slf4j-api:${slf4j.version}:provided</include>
<include>org.slf4j:slf4j-jdk14:${slf4j.version}:provided</include>
<include>xml-apis:xml-apis:${xml-apis.version}:provided</include>
+
+ <!-- Vespa provided dependencies -->
+ <include>com.yahoo.vespa:annotations:*:provided</include>
+ <include>com.yahoo.vespa:component:*:provided</include>
+ <include>com.yahoo.vespa:config-bundle:*:provided</include>
+ <include>com.yahoo.vespa:config-lib:*:provided</include>
+ <include>com.yahoo.vespa:config:*:provided</include>
+ <include>com.yahoo.vespa:configdefinitions:*:provided</include>
+ <include>com.yahoo.vespa:configgen:*:provided</include>
+ <include>com.yahoo.vespa:container-core:*:provided</include>
+ <include>com.yahoo.vespa:container-dev:*:provided</include>
+ <include>com.yahoo.vespa:container-disc:*:provided</include>
+ <include>com.yahoo.vespa:container-documentapi:*:provided</include>
+ <include>com.yahoo.vespa:container-messagebus:*:provided</include>
+ <include>com.yahoo.vespa:container-onnxruntime:*:provided</include>
+ <include>com.yahoo.vespa:container-search-and-docproc:*:provided</include>
+ <include>com.yahoo.vespa:container-search:*:provided</include>
+ <include>com.yahoo.vespa:container:*:provided</include>
+ <include>com.yahoo.vespa:defaults:*:provided</include>
+ <include>com.yahoo.vespa:docproc:*:provided</include>
+ <include>com.yahoo.vespa:document:*:provided</include>
+ <include>com.yahoo.vespa:documentapi:*:provided</include>
+ <include>com.yahoo.vespa:fileacquirer:*:provided</include>
+ <include>com.yahoo.vespa:fsa:*:provided</include>
+ <include>com.yahoo.vespa:hosted-zone-api:*:provided</include>
+ <include>com.yahoo.vespa:http-utils:*:provided</include>
+ <include>com.yahoo.vespa:jdisc_core:*:provided</include>
+ <include>com.yahoo.vespa:jrt:*:provided</include>
+ <include>com.yahoo.vespa:linguistics:*:provided</include>
+ <include>com.yahoo.vespa:messagebus:*:provided</include>
+ <include>com.yahoo.vespa:metrics:*:provided</include>
+ <include>com.yahoo.vespa:model-evaluation:*:provided</include>
+ <include>com.yahoo.vespa:opennlp-linguistics:*:provided</include>
+ <include>com.yahoo.vespa:predicate-search-core:*:provided</include>
+ <include>com.yahoo.vespa:provided-dependencies:*:provided</include>
+ <include>com.yahoo.vespa:searchcore:*:provided</include>
+ <include>com.yahoo.vespa:searchlib:*:provided</include>
+ <include>com.yahoo.vespa:security-utils:*:provided</include>
+ <include>com.yahoo.vespa:vdslib:*:provided</include>
+ <include>com.yahoo.vespa:vespa-3party-bundles:pom:*:provided</include>
+ <include>com.yahoo.vespa:vespaclient-container-plugin:*:provided</include>
+ <include>com.yahoo.vespa:vespajlib:*:provided</include>
+ <include>com.yahoo.vespa:vespalog:*:provided</include>
+
+ <!-- Vespa test dependencies -->
+ <include>com.yahoo.vespa:airlift-zstd:*:test</include>
+ <include>com.yahoo.vespa:application:*:test</include>
+ <include>com.yahoo.vespa:config-application-package:*:test</include>
+ <include>com.yahoo.vespa:config-model-api:*:test</include>
+ <include>com.yahoo.vespa:config-model:*:test</include>
+ <include>com.yahoo.vespa:config-provisioning:*:test</include>
+ <include>com.yahoo.vespa:container-apache-http-client-bundle:*:test</include>
+ <include>com.yahoo.vespa:container-test:*:test</include>
+ <include>com.yahoo.vespa:indexinglanguage:*:test</include>
+ <include>com.yahoo.vespa:logd:*:test</include>
+ <include>com.yahoo.vespa:metrics-proxy:*:test</include>
+ <include>com.yahoo.vespa:model-integration:*:test</include>
+ <include>com.yahoo.vespa:searchsummary:*:test</include>
+ <include>com.yahoo.vespa:standalone-container:*:test</include>
+ <include>com.yahoo.vespa:storage:*:test</include>
+ <include>com.yahoo.vespa:vespaclient-core:*:test</include>
+ <include>com.yahoo.vespa:vsm:*:test</include>
+
+ <!-- 3rd party test dependencies -->
+ <include>com.google.code.findbugs:jsr305:3.0.2:test</include>
+ <include>com.google.protobuf:protobuf-java:${protobuf.version}:test</include>
+ <include>com.ibm.icu:icu4j:70.1:test</include>
+ <include>com.microsoft.onnxruntime:onnxruntime:${onnxruntime.version}:test</include>
+ <include>com.thaiopensource:jing:20091111:test</include>
+ <include>commons-codec:commons-codec:${commons-codec.version}:test</include>
+ <include>io.airlift:airline:0.9:test</include>
+ <include>io.prometheus:simpleclient:0.6.0:test</include>
+ <include>io.prometheus:simpleclient_common:0.6.0:test</include>
+ <include>junit:junit:4.13.2:test</include>
+ <include>net.java.dev.jna:jna:5.11.0:test</include>
+ <include>net.openhft:zero-allocation-hashing:jar:0.16:test</include>
+ <include>org.antlr:antlr-runtime:3.5.3:test</include>
+ <include>org.antlr:antlr4-runtime:4.11.1:test</include>
+ <include>org.apache.commons:commons-exec:1.3:test</include>
+ <include>org.apache.commons:commons-math3:3.6.1:test</include>
+ <include>org.apache.felix:org.apache.felix.framework:${felix.version}:test</include>
+ <include>org.apache.felix:org.apache.felix.framework:${felix.version}:test</include>
+ <include>org.apache.felix:org.apache.felix.log:1.0.1:test</include>
+ <include>org.apache.httpcomponents.client5:httpclient5:${apache.httpclient5.version}:test</include>
+ <include>org.apache.httpcomponents.core5:httpcore5:${apache.httpcore5.version}:test</include>
+ <include>org.apache.httpcomponents.core5:httpcore5-h2:${apache.httpcore5.version}:test</include>
+ <include>org.apache.httpcomponents:httpclient:${apache.httpclient.version}:test</include>
+ <include>org.apache.httpcomponents:httpcore:${apache.httpcore.version}:test</include>
+ <include>org.apache.httpcomponents:httpmime:${apache.httpclient.version}:test</include>
+ <include>org.apache.opennlp:opennlp-tools:1.9.3:test</include>
+ <include>org.bouncycastle:bcpkix-jdk18on:${bouncycastle.version}:test</include>
+ <include>org.bouncycastle:bcprov-jdk18on:${bouncycastle.version}:test</include>
+ <include>org.bouncycastle:bcutil-jdk18on:${bouncycastle.version}:test</include>
+ <include>org.eclipse.jetty.http2:http2-common:${jetty.version}:test</include>
+ <include>org.eclipse.jetty.http2:http2-hpack:${jetty.version}:test</include>
+ <include>org.eclipse.jetty.http2:http2-server:${jetty.version}:test</include>
+ <include>org.eclipse.jetty.toolchain:jetty-jakarta-servlet-api:5.0.2:test</include>
+ <include>org.eclipse.jetty:jetty-alpn-client:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-alpn-java-server:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-alpn-server:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-client:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-http:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-io:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-jmx:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-security:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-server:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-servlet:${jetty.version}:test</include>
+ <include>org.eclipse.jetty:jetty-util:${jetty.version}:test</include>
+ <include>org.hamcrest:hamcrest-core:1.3:test</include>
+ <include>org.hdrhistogram:HdrHistogram:2.1.12:test</include>
+ <include>org.json:json:${org.json.version}:test</include> <!-- TODO: Remove on Vespa 9 -->
+ <include>org.lz4:lz4-java:${org.lz4.version}:test</include>
+ <include>org.osgi:org.osgi.compendium:4.1.0:test</include>
+ <include>org.osgi:org.osgi.core:4.1.0:test</include>
+ <include>xerces:xercesImpl:2.12.2:test</include>
</allowed>
</enforceDependencies>
</rules>
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java
index e6af65c0bc8..47050168b80 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java
@@ -36,8 +36,8 @@ public class DataplaneProxyService extends AbstractComponent {
private final Path root;
enum NginxState {INITIALIZING, RUNNING, RELOAD_REQUIRED, STOPPED};
- private NginxState state;
- private NginxState wantedState;
+ private volatile NginxState state;
+ private volatile NginxState wantedState;
private DataplaneProxyConfig cfg;
private Path proxyCredentialsCert;
@@ -113,35 +113,46 @@ public class DataplaneProxyService extends AbstractComponent {
throw new RuntimeException("Error reconfiguring data plane proxy", e);
}
}
- if (wantedState == NginxState.RUNNING) {
+ NginxState convergeTo = wantedState;
+ if (convergeTo == NginxState.RUNNING) {
boolean nginxRunning = proxyCommands.isRunning();
if (!nginxRunning) {
try {
proxyCommands.start(nginxConf);
- changeState(wantedState);
+ changeState(convergeTo);
} catch (Exception e) {
logger.log(Level.INFO, "Failed to start nginx, will retry");
+ logger.log(Level.FINE, "Exception from nginx start", e);
}
- } else if (nginxRunning && state == NginxState.RELOAD_REQUIRED) {
- try {
- proxyCommands.reload();
- changeState(wantedState);
- } catch (Exception e) {
- logger.log(Level.INFO, "Failed to reconfigure nginx, will retry.");
+ } else {
+ if (state == NginxState.RELOAD_REQUIRED) {
+ try {
+ proxyCommands.reload();
+ changeState(convergeTo);
+ } catch (Exception e) {
+ logger.log(Level.INFO, "Failed to reconfigure nginx, will retry.");
+ logger.log(Level.FINE, "Exception from nginx reload", e);
+ }
+ } else if (state != convergeTo) {
+ // Already running, but state not updated
+ changeState(convergeTo);
}
}
- } else if (wantedState == NginxState.STOPPED) {
+ } else if (convergeTo == NginxState.STOPPED) {
if (proxyCommands.isRunning()) {
try {
proxyCommands.stop();
- changeState(wantedState);
- executorService.shutdownNow();
} catch (Exception e) {
logger.log(Level.INFO, "Failed to stop nginx, will retry");
+ logger.log(Level.FINE, "Exception from nginx stop", e);
}
}
+ if (! proxyCommands.isRunning()) {
+ changeState(convergeTo);
+ executorService.shutdownNow();
+ }
} else {
- logger.warning("Unknown state " + wantedState);
+ logger.warning("Unknown state " + convergeTo);
}
}
@@ -150,9 +161,9 @@ public class DataplaneProxyService extends AbstractComponent {
super.deconstruct();
wantedState = NginxState.STOPPED;
try {
- executorService.awaitTermination(5, TimeUnit.MINUTES);
+ executorService.awaitTermination(30, TimeUnit.SECONDS);
} catch (InterruptedException e) {
- logger.log(Level.WARNING, "Error shutting down proxy reload thread");
+ logger.log(Level.WARNING, "Error shutting down proxy reload thread", e);
}
}
@@ -203,10 +214,12 @@ public class DataplaneProxyService extends AbstractComponent {
return template.replaceAll("\\$\\{%s\\}".formatted(key), value);
}
+ // Used for testing
NginxState state() {
return state;
}
+ // Used for testing
NginxState wantedState() {
return wantedState;
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java
index 3e6ee3a35a2..501438c8c2e 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java
@@ -2,6 +2,7 @@
package com.yahoo.container.jdisc.metric;
import ai.vespa.metrics.ContainerMetrics;
+import ai.vespa.metrics.ContainerMetrics;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.Metric;
@@ -141,7 +142,7 @@ public class MetricUpdater extends AbstractComponent {
"home", System.getProperty("java.home"),
"vendor", System.getProperty("java.vm.vendor"),
"arch", System.getProperty("os.arch")));
- metric.set("jdisc.jvm", Runtime.version().feature(), ctx);
+ metric.set(ContainerMetrics.JDISC_JVM.baseName(), Runtime.version().feature(), ctx);
}
private void tlsMetrics() {
diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java
index 947c99adf51..351890e2a3a 100644
--- a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java
+++ b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java
@@ -22,13 +22,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
public class DataplaneProxyServiceTest {
private FileSystem fileSystem = Jimfs.newFileSystem();
- DataplaneProxyService.ProxyCommands proxyCommandsMock = Mockito.mock(DataplaneProxyService.ProxyCommands.class);
+ DataplaneProxyService.ProxyCommands proxyCommandsMock = mock(DataplaneProxyService.ProxyCommands.class);
@Test
public void starts_and_reloads_if_no_errors() throws IOException {
@@ -122,6 +125,35 @@ public class DataplaneProxyServiceTest {
assertFalse(proxyCommands.isRunning());
}
+ @Test
+ public void stops_executor_when_nginx_stop_throws() throws IOException, InterruptedException {
+ DataplaneProxyService.ProxyCommands mockProxyCommands = mock(DataplaneProxyService.ProxyCommands.class);
+ DataplaneProxyService service = dataplaneProxyService(mockProxyCommands);
+ service.converge();
+ when (mockProxyCommands.isRunning()).thenReturn(true);
+ assertEquals(DataplaneProxyService.NginxState.RUNNING, service.state());
+
+ reset(proxyCommandsMock);
+
+ when(mockProxyCommands.isRunning()).thenReturn(true).thenReturn(false);
+ doThrow(new RuntimeException("Failed to stop proxy")).when(proxyCommandsMock).stop();
+ Thread thread = new Thread(service::deconstruct);// deconstruct will block until nginx is stopped
+ thread.start();
+
+ // Wait for above thread to set the wanted state to STOPPED
+ while (service.wantedState() != DataplaneProxyService.NginxState.STOPPED) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ }
+ }
+ service.converge();
+ assertEquals(service.state(), DataplaneProxyService.NginxState.STOPPED);
+ thread.join();
+
+ verify(mockProxyCommands, times(1)).stop();
+ }
+
private DataplaneProxyService dataplaneProxyService(DataplaneProxyService.ProxyCommands proxyCommands) throws IOException {
Path root = fileSystem.getPath("/opt/vespa");
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index e439f7905cc..c41c1c79149 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -5414,6 +5414,8 @@
"public java.lang.Integer getRerankCount()",
"public void setKeepRankCount(int)",
"public java.lang.Integer getKeepRankCount()",
+ "public void setRankScoreDropLimit(double)",
+ "public java.lang.Double getRankScoreDropLimit()",
"public com.yahoo.prelude.Location getLocation()",
"public void setLocation(com.yahoo.prelude.Location)",
"public void setLocation(java.lang.String)",
@@ -5449,6 +5451,7 @@
"public static final java.lang.String QUERYCACHE",
"public static final java.lang.String RERANKCOUNT",
"public static final java.lang.String KEEPRANKCOUNT",
+ "public static final java.lang.String RANKSCOREDROPLIMIT",
"public static final java.lang.String MATCH_PHASE",
"public static final java.lang.String DIVERSITY",
"public static final java.lang.String SOFTTIMEOUT",
diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
index c9e125ab55f..ecce5ddd740 100644
--- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
+++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.handler;
+import ai.vespa.metrics.ContainerMetrics;
import ai.vespa.cloud.ZoneInfo;
import ai.vespa.metrics.ContainerMetrics;
import com.yahoo.collections.Tuple2;
@@ -81,7 +82,7 @@ public class SearchHandler extends LoggingRequestHandler {
private static final CompoundName FORCE_TIMESTAMPS = CompoundName.from("trace.timestamps");
/** Event name for number of connections to the search subsystem */
- private static final String SEARCH_CONNECTIONS = "search_connections";
+ private static final String SEARCH_CONNECTIONS = ContainerMetrics.SEARCH_CONNECTIONS.baseName();
static final String RENDER_LATENCY_METRIC = ContainerMetrics.JDISC_RENDER_LATENCY.baseName();
static final String MIME_DIMENSION = "mime";
static final String RENDERER_DIMENSION = "renderer";
diff --git a/container-search/src/main/java/com/yahoo/search/query/Ranking.java b/container-search/src/main/java/com/yahoo/search/query/Ranking.java
index 5426268d173..ac32bec80ef 100644
--- a/container-search/src/main/java/com/yahoo/search/query/Ranking.java
+++ b/container-search/src/main/java/com/yahoo/search/query/Ranking.java
@@ -42,6 +42,7 @@ public class Ranking implements Cloneable {
public static final String QUERYCACHE = "queryCache";
public static final String RERANKCOUNT = "rerankCount";
public static final String KEEPRANKCOUNT = "keepRankCount";
+ public static final String RANKSCOREDROPLIMIT = "rankScoreDropLimit";
public static final String MATCH_PHASE = "matchPhase";
public static final String DIVERSITY = "diversity";
public static final String SOFTTIMEOUT = "softtimeout";
@@ -63,6 +64,7 @@ public class Ranking implements Cloneable {
argumentType.addField(new FieldDescription(QUERYCACHE, "boolean"));
argumentType.addField(new FieldDescription(RERANKCOUNT, "integer"));
argumentType.addField(new FieldDescription(KEEPRANKCOUNT, "integer"));
+ argumentType.addField(new FieldDescription(RANKSCOREDROPLIMIT, "double"));
argumentType.addField(new FieldDescription(MATCH_PHASE, new QueryProfileFieldType(MatchPhase.getArgumentType()), "matchPhase"));
argumentType.addField(new FieldDescription(DIVERSITY, new QueryProfileFieldType(Diversity.getArgumentType())));
argumentType.addField(new FieldDescription(SOFTTIMEOUT, new QueryProfileFieldType(SoftTimeout.getArgumentType())));
@@ -94,6 +96,7 @@ public class Ranking implements Cloneable {
private Integer rerankCount = null;
private Integer keepRankCount = null;
+ private Double rankScoreDropLimit = null;
private RankProperties rankProperties = new RankProperties();
@@ -165,6 +168,11 @@ public class Ranking implements Cloneable {
/** Returns the keep-rank-count that will be used, or null if not set */
public Integer getKeepRankCount() { return keepRankCount; }
+ /** Sets the rank-score-drop-limit that will be used, or null if not set */
+ public void setRankScoreDropLimit(double rankScoreDropLimit) { this.rankScoreDropLimit = rankScoreDropLimit; }
+ /** Returns the rank-score-drop-limit that will be used, or null if not set */
+ public Double getRankScoreDropLimit() { return rankScoreDropLimit; }
+
/** Returns the location of this query, or null if none */
public Location getLocation() { return location; }
@@ -241,6 +249,8 @@ public class Ranking implements Cloneable {
rankProperties.put("vespa.hitcollector.heapsize", rerankCount);
if (keepRankCount != null)
rankProperties.put("vespa.hitcollector.arraysize", keepRankCount);
+ if (rankScoreDropLimit != null)
+ rankProperties.put("vespa.hitcollector.rankscoredroplimit", rankScoreDropLimit);
}
private void prepareNow(Freshness freshness) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index da0051c527c..240da3f123f 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -77,6 +77,7 @@ public class QueryProperties extends Properties {
if (key.last().equals(Ranking.QUERYCACHE)) return ranking.getQueryCache();
if (key.last().equals(Ranking.RERANKCOUNT)) return ranking.getRerankCount();
if (key.last().equals(Ranking.KEEPRANKCOUNT)) return ranking.getKeepRankCount();
+ if (key.last().equals(Ranking.RANKSCOREDROPLIMIT)) return ranking.getRankScoreDropLimit();
if (key.last().equals(Ranking.LIST_FEATURES)) return ranking.getListFeatures();
}
else if (key.size() >= 3 && key.get(1).equals(Ranking.MATCH_PHASE)) {
@@ -203,6 +204,8 @@ public class QueryProperties extends Properties {
ranking.setRerankCount(asInteger(value, null));
else if (key.last().equals(Ranking.KEEPRANKCOUNT))
ranking.setKeepRankCount(asInteger(value, null));
+ else if (key.last().equals(Ranking.RANKSCOREDROPLIMIT))
+ ranking.setRankScoreDropLimit(asDouble(value, null));
else if (key.last().equals(Ranking.LIST_FEATURES))
ranking.setListFeatures(asBoolean(value,false));
else
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java
index 742f4b0f889..f510d68d32e 100644
--- a/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.searchers;
+import ai.vespa.metrics.ContainerMetrics;
import com.yahoo.component.chain.dependencies.After;
import com.yahoo.metrics.simple.Gauge;
import com.yahoo.metrics.simple.Point;
@@ -21,7 +22,7 @@ public class ContainerLatencySearcher extends Searcher {
private final Gauge latencyGauge;
public ContainerLatencySearcher(MetricReceiver metrics) {
- latencyGauge = metrics.declareGauge("query_container_latency");
+ latencyGauge = metrics.declareGauge(ContainerMetrics.QUERY_CONTAINER_LATENCY.baseName());
}
@Override
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java
index 35a3c86f763..846b8881cef 100755
--- a/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.searchers;
+import ai.vespa.metrics.ContainerMetrics;
import com.yahoo.component.annotation.Inject;
import com.yahoo.cloud.config.ClusterInfoConfig;
@@ -60,7 +61,7 @@ public class RateLimitingSearcher extends Searcher {
public static final CompoundName idDimensionKey = CompoundName.from("rate.idDimension");
public static final CompoundName dryRunKey = CompoundName.from("rate.dryRun");
- private static final String requestsOverQuotaMetricName = "requestsOverQuota";
+ private static final String requestsOverQuotaMetricName = ContainerMetrics.REQUESTS_OVER_QUOTA.baseName();
/** Used to divide quota by nodes. Assumption: All nodes get the same share of traffic. */
private final int nodeCount;
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java b/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java
index 5537528d36b..76da6407166 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.searchers.test;
+import ai.vespa.metrics.ContainerMetrics;
import com.yahoo.cloud.config.ClusterInfoConfig;
import com.yahoo.component.chain.Chain;
import com.yahoo.metrics.simple.Bucket;
@@ -114,7 +115,7 @@ public class RateLimitingBenchmark {
private int rejectedRequests(int id) {
Point context = metric.pointBuilder().set("id", toClientId(id)).build();
- UntypedMetric rejectedRequestsMetric = metricSnapshot.getMapForMetric("requestsOverQuota").get(context);
+ UntypedMetric rejectedRequestsMetric = metricSnapshot.getMapForMetric(ContainerMetrics.REQUESTS_OVER_QUOTA.baseName()).get(context);
if (rejectedRequestsMetric == null) return 0;
return (int)rejectedRequestsMetric.getCount();
}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java
index ffaee34e727..d73a7410cc6 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java
@@ -26,18 +26,21 @@ public class InstanceInformation {
public URI url;
public String scope;
public RoutingMethod routingMethod;
+ public String auth;
@JsonCreator
public Endpoint(@JsonProperty("cluster") String cluster ,
@JsonProperty("tls") boolean tls,
@JsonProperty("url") URI url,
@JsonProperty("scope") String scope,
- @JsonProperty("routingMethod") RoutingMethod routingMethod) {
+ @JsonProperty("routingMethod") RoutingMethod routingMethod,
+ @JsonProperty("authMethod") String auth) {
this.cluster = cluster;
this.tls = tls;
this.url = url;
this.scope = scope;
this.routingMethod = routingMethod;
+ this.auth = auth;
}
@Override
@@ -47,6 +50,7 @@ public class InstanceInformation {
", tls=" + tls +
", url=" + url +
", scope='" + scope + '\'' +
+ ", authType='" + auth + '\'' +
", routingMethod=" + routingMethod +
'}';
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java
index 0f3f9479176..68852f90055 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java
@@ -129,7 +129,7 @@ public class EndpointCertificates {
}
private Optional<EndpointCertificateMetadata> getOrProvision(Instance instance, ZoneId zone, DeploymentSpec deploymentSpec) {
- if (useRandomizedCert.with(FetchVector.Dimension.APPLICATION_ID, instance.id().toFullString()).value()) {
+ if (useRandomizedCert.with(FetchVector.Dimension.APPLICATION_ID, instance.id().serializedForm()).value()) {
return Optional.of(assignFromPool(instance, zone));
}
Optional<AssignedCertificate> assignedCertificate = curator.readAssignedCertificate(TenantAndApplicationId.from(instance.id()), Optional.of(instance.id().instance()));
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
index aa3f78f1395..693275987c5 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java
@@ -139,7 +139,6 @@ import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
-import java.io.UncheckedIOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
@@ -911,14 +910,17 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler {
}
private HttpResponse listTokens(String tenant, HttpRequest request) {
- List<DataplaneTokenVersions> dataplaneTokenVersions = controller.dataplaneTokenService().listTokens(TenantName.from(tenant));
+ var tokens = controller.dataplaneTokenService().listTokens(TenantName.from(tenant))
+ .stream().sorted(Comparator.comparing(DataplaneTokenVersions::tokenId)).toList();
Slime slime = new Slime();
Cursor tokensArray = slime.setObject().setArray("tokens");
- for (DataplaneTokenVersions token : dataplaneTokenVersions) {
+ for (DataplaneTokenVersions token : tokens) {
Cursor tokenObject = tokensArray.addObject();
tokenObject.setString("id", token.tokenId().value());
Cursor fingerprintsArray = tokenObject.setArray("versions");
- for (DataplaneTokenVersions.Version tokenVersion : token.tokenVersions()) {
+ var versions = token.tokenVersions().stream()
+ .sorted(Comparator.comparing(DataplaneTokenVersions.Version::creationTime)).toList();
+ for (var tokenVersion : versions) {
Cursor fingerprintObject = fingerprintsArray.addObject();
fingerprintObject.setString("fingerprint", tokenVersion.fingerPrint().value());
fingerprintObject.setString("created", tokenVersion.creationTime().toString());
diff --git a/dist/vespa.spec b/dist/vespa.spec
index 091a17e822b..c1a409e057f 100644
--- a/dist/vespa.spec
+++ b/dist/vespa.spec
@@ -90,7 +90,7 @@ BuildRequires: llvm-devel
BuildRequires: vespa-boost-devel >= 1.76.0-1
BuildRequires: vespa-openssl-devel >= 1.1.1o-1
%define _use_vespa_openssl 1
-BuildRequires: vespa-gtest = 1.11.0
+BuildRequires: vespa-gtest = 1.13.0
%define _use_vespa_gtest 1
BuildRequires: vespa-lz4-devel >= 1.9.4-1
BuildRequires: vespa-onnxruntime-devel = 1.13.1
@@ -192,7 +192,7 @@ Requires: unzip
Requires: zlib
Requires: zstd
%if 0%{?el8}
-Requires: vespa-gtest = 1.11.0
+Requires: vespa-gtest = 1.13.0
%endif
%if 0%{?el9}
Requires: gtest
diff --git a/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java b/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java
index 927672a43f7..97185e9c703 100644
--- a/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java
+++ b/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java
@@ -30,9 +30,9 @@ public enum HostedNodeAdminMetrics implements VespaMetrics {
NET_IN_BYTES("net.in.bytes", Unit.BYTE, "Network bytes received (rxBytes) (COUNT metric)"),
NET_IN_ERROR("net.in.errors", Unit.FAILURE, "Network receive errors (rxErrors)"),
NET_IN_DROPPED("net.in.dropped", Unit.PACKET, "Inbound network packets dropped (rxDropped)"),
- NET_OUT_BYTES("net.in.bytes", Unit.BYTE, "Network bytes sent (txBytes) (COUNT metric)"),
- NET_OUT_ERROR("net.in.errors", Unit.FAILURE, "Network send errors (txErrors)"),
- NET_OUT_DROPPED("net.in.dropped", Unit.PACKET, "Outbound network packets dropped (txDropped)"),
+ NET_OUT_BYTES("net.out.bytes", Unit.BYTE, "Network bytes sent (txBytes) (COUNT metric)"),
+ NET_OUT_ERROR("net.out.errors", Unit.FAILURE, "Network send errors (txErrors)"),
+ NET_OUT_DROPPED("net.out.dropped", Unit.PACKET, "Outbound network packets dropped (txDropped)"),
BANDWIDTH_LIMIT("bandwidth.limit", Unit.BYTE_PER_SECOND, "Available network bandwidth");
private final String name;
@@ -58,4 +58,3 @@ public enum HostedNodeAdminMetrics implements VespaMetrics {
}
}
-
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index f12f60dcc8e..f690b8e8c8a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -32,8 +32,9 @@ class TensorConverter {
}
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType());
if (tensorProto.hasRawData()) {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case BOOL: return new RawBoolValues(tensorProto);
case FLOAT: return new RawFloatValues(tensorProto);
case DOUBLE: return new RawDoubleValues(tensorProto);
@@ -41,7 +42,7 @@ class TensorConverter {
case INT64: return new RawLongValues(tensorProto);
}
} else {
- switch (tensorProto.getDataType()) {
+ switch (elemType) {
case FLOAT: return new FloatValues(tensorProto);
case DOUBLE: return new DoubleValues(tensorProto);
case INT32: return new IntValues(tensorProto);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 35ec1d8c54a..deac950d324 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -37,7 +37,8 @@ class TypeConverter {
static OrderedTensorType typeFrom(Onnx.TypeProto type) {
String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType()));
+ var elemType = Onnx.TensorProto.DataType.forNumber(type.getTensorType().getElemType());
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(elemType));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
@@ -52,8 +53,8 @@ class TypeConverter {
}
static OrderedTensorType typeFrom(Onnx.TensorProto tensor) {
- return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()),
- tensor.getDimsList());
+ var elemType = Onnx.TensorProto.DataType.forNumber(tensor.getDataType());
+ return OrderedTensorType.fromDimensionList(toValueType(elemType), tensor.getDimsList());
}
private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
diff --git a/model-integration/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto
index dc6542867e0..1d265ae9f28 100644
--- a/model-integration/src/main/protobuf/onnx.proto
+++ b/model-integration/src/main/protobuf/onnx.proto
@@ -3,8 +3,8 @@
//
-// Copyright (c) Facebook Inc. and Microsoft Corporation.
-// Licensed under the MIT license.
+// SPDX-License-Identifier: Apache-2.0
+
syntax = "proto2";
@@ -20,23 +20,16 @@ package onnx;
//
// This document describes the syntax of models and their computation graphs,
// as well as the standard data types. Together, they are referred to as the ONNX
-// Intermediate Representation, or 'IR' for short.
+// Intermediate Representation, or 'IR' for short.
//
// The normative semantic specification of the ONNX IR is found in docs/IR.md.
// Definitions of the built-in neural network operators may be found in docs/Operators.md.
// Notes
//
-// Release
-//
-// We are still in the very early stage of defining ONNX. The current
-// version of ONNX is a starting point. While we are actively working
-// towards a complete spec, we would like to get the community involved
-// by sharing our working version of ONNX.
-//
// Protobuf compatibility
-//
-// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
// that is compatible with both protobuf v2 and v3. This means that we do not use any
// protobuf features that are only available in one of the two versions.
//
@@ -60,22 +53,60 @@ enum Version {
_START_VERSION = 0;
// The version field is always serialized and we will use it to store the
// version that the graph is generated from. This helps us set up version
- // control. We should use version as
- // xx(major) - xx(minor) - xxxx(bugfix)
- // and we are starting with 0x00000001 (0.0.1), which was the
- // version we published on Oct 10, 2017.
- IR_VERSION_2017_10_10 = 0x00000001;
+ // control.
+ // For the IR, we are using simple numbers starting with 0x00000001,
+ // which was the version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x0000000000000001;
- // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // IR_VERSION 2 published on Oct 30, 2017
// - Added type discriminator to AttributeProto to support proto3 users
- IR_VERSION_2017_10_30 = 0x00000002;
+ IR_VERSION_2017_10_30 = 0x0000000000000002;
- // IR VERSION 0.0.3 published on Nov 3, 2017
+ // IR VERSION 3 published on Nov 3, 2017
// - For operator versioning:
// - Added new message OperatorSetIdProto
// - Added opset_import in ModelProto
// - For vendor extensions, added domain in NodeProto
- IR_VERSION = 0x00000003;
+ IR_VERSION_2017_11_3 = 0x0000000000000003;
+
+ // IR VERSION 4 published on Jan 22, 2019
+ // - Relax constraint that initializers should be a subset of graph inputs
+ // - Add type BFLOAT16
+ IR_VERSION_2019_1_22 = 0x0000000000000004;
+
+ // IR VERSION 5 published on March 18, 2019
+ // - Add message TensorAnnotation.
+ // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters.
+ IR_VERSION_2019_3_18 = 0x0000000000000005;
+
+ // IR VERSION 6 published on Sep 19, 2019
+ // - Add support for sparse tensor constants stored in model.
+ // - Add message SparseTensorProto
+ // - Add sparse initializers
+ IR_VERSION_2019_9_19 = 0x0000000000000006;
+
+ // IR VERSION 7 published on May 8, 2020
+ // - Add support to allow function body graph to rely on multiple external opreator sets.
+ // - Add a list to promote inference graph's initializers to global and
+ // mutable variables. Global variables are visible in all graphs of the
+ // stored models.
+ // - Add message TrainingInfoProto to store initialization
+ // method and training algorithm. The execution of TrainingInfoProto
+ // can modify the values of mutable variables.
+ // - Implicitly add inference graph into each TrainingInfoProto's algorithm.
+ IR_VERSION_2020_5_8 = 0x0000000000000007;
+
+ // IR VERSION 8 published on July 30, 2021
+ // Introduce TypeProto.SparseTensor
+ // Introduce TypeProto.Optional
+ // Added a list of FunctionProtos local to the model
+ // Deprecated since_version and operator status from FunctionProto
+ IR_VERSION_2021_7_30 = 0x0000000000000008;
+
+ // IR VERSION 9 published on May 5, 2023
+ // Added AttributeProto to FunctionProto so that default attribute values can be set.
+ // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ.
+ IR_VERSION = 0x0000000000000009;
}
// Attributes
@@ -95,17 +126,21 @@ message AttributeProto {
STRING = 3;
TENSOR = 4;
GRAPH = 5;
+ SPARSE_TENSOR = 11;
+ TYPE_PROTO = 13;
FLOATS = 6;
INTS = 7;
STRINGS = 8;
TENSORS = 9;
GRAPHS = 10;
+ SPARSE_TENSORS = 12;
+ TYPE_PROTOS = 14;
}
// The name field MUST be present for this version of the IR.
optional string name = 1; // namespace Attribute
-
+
// if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
// In this case, this AttributeProto does not contain data, and it's a reference of attribute
// in parent scope.
@@ -117,10 +152,10 @@ message AttributeProto {
// The type field MUST be present for this version of the IR.
// For 0.0.1 versions of the IR, this field was not defined, and
- // implementations needed to use has_field hueristics to determine
+ // implementations needed to use has_field heuristics to determine
// which value field was in use. For IR_VERSION 0.0.2 or later, this
// field MUST be set and match the f|i|s|t|... field in use. This
- // change was made to accomodate proto3 implementations.
+ // change was made to accommodate proto3 implementations.
optional AttributeType type = 20; // discriminator that indicates which field below is in use
// Exactly ONE of the following fields must be present for this version of the IR
@@ -129,14 +164,18 @@ message AttributeProto {
optional bytes s = 4; // UTF-8 string
optional TensorProto t = 5; // tensor value
optional GraphProto g = 6; // graph
+ optional SparseTensorProto sparse_tensor = 22; // sparse tensor value
// Do not use field below, it's deprecated.
// optional ValueProto v = 12; // value - subsumes everything but graph
+ optional TypeProto tp = 14; // type proto
repeated float floats = 7; // list of floats
repeated int64 ints = 8; // list of ints
repeated bytes strings = 9; // list of UTF-8 strings
repeated TensorProto tensors = 10; // list of tensors
repeated GraphProto graphs = 11; // list of graph
+ repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors
+ repeated TypeProto type_protos = 15;// list of type protos
}
// Defines information on value, including the name, the type, and
@@ -144,7 +183,8 @@ message AttributeProto {
message ValueInfoProto {
// This field MUST be present in this version of the IR.
optional string name = 1; // namespace Value
- // This field MUST be present in this version of the IR.
+ // This field MUST be present in this version of the IR for
+ // inputs and outputs of the top-level graph.
optional TypeProto type = 2;
// A human-readable documentation for this value. Markdown is allowed.
optional string doc_string = 3;
@@ -155,7 +195,7 @@ message ValueInfoProto {
// Computation graphs are made up of a DAG of nodes, which represent what is
// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
//
-// For example, it can be a node of type "Conv" that takes in an image, a filter
+// For example, it can be a node of type "Conv" that takes in an image, a filter
// tensor and a bias tensor, and produces the convolved output.
message NodeProto {
repeated string input = 1; // namespace Value
@@ -177,12 +217,130 @@ message NodeProto {
optional string doc_string = 6;
}
+// Training information
+// TrainingInfoProto stores information for training a model.
+// In particular, this defines two functionalities: an initialization-step
+// and a training-algorithm-step. Initialization resets the model
+// back to its original state as if no training has been performed.
+// Training algorithm improves the model based on input data.
+//
+// The semantics of the initialization-step is that the initializers
+// in ModelProto.graph and in TrainingInfoProto.algorithm are first
+// initialized as specified by the initializers in the graph, and then
+// updated by the "initialization_binding" in every instance in
+// ModelProto.training_info.
+//
+// The field "algorithm" defines a computation graph which represents a
+// training algorithm's step. After the execution of a
+// TrainingInfoProto.algorithm, the initializers specified by "update_binding"
+// may be immediately updated. If the targeted training algorithm contains
+// consecutive update steps (such as block coordinate descent methods),
+// the user needs to create a TrainingInfoProto for each step.
+message TrainingInfoProto {
+ // This field describes a graph to compute the initial tensors
+ // upon starting the training process. Initialization graph has no input
+ // and can have multiple outputs. Usually, trainable tensors in neural
+ // networks are randomly initialized. To achieve that, for each tensor,
+ // the user can put a random number operator such as RandomNormal or
+ // RandomUniform in TrainingInfoProto.initialization.node and assign its
+ // random output to the specific tensor using "initialization_binding".
+ // This graph can also set the initializers in "algorithm" in the same
+ // TrainingInfoProto; a use case is resetting the number of training
+ // iteration to zero.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Thus, no initializer would be changed by default.
+ optional GraphProto initialization = 1;
+
+ // This field represents a training algorithm step. Given required inputs,
+ // it computes outputs to update initializers in its own or inference graph's
+ // initializer lists. In general, this field contains loss node, gradient node,
+ // optimizer node, increment of iteration count.
+ //
+ // An execution of the training algorithm step is performed by executing the
+ // graph obtained by combining the inference graph (namely "ModelProto.graph")
+ // and the "algorithm" graph. That is, the actual
+ // input/initializer/output/node/value_info/sparse_initializer list of
+ // the training graph is the concatenation of
+ // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer"
+ // and "algorithm.input/initializer/output/node/value_info/sparse_initializer"
+ // in that order. This combined graph must satisfy the normal ONNX conditions.
+ // Now, let's provide a visualization of graph combination for clarity.
+ // Let the inference graph (i.e., "ModelProto.graph") be
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d
+ // and the "algorithm" graph be
+ // tensor_d -> Add -> tensor_e
+ // The combination process results
+ // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e
+ //
+ // Notice that an input of a node in the "algorithm" graph may reference the
+ // output of a node in the inference graph (but not the other way round). Also, inference
+ // node cannot reference inputs of "algorithm". With these restrictions, inference graph
+ // can always be run independently without training information.
+ //
+ // By default, this field is an empty graph and its evaluation does not
+ // produce any output. Evaluating the default training step never
+ // update any initializers.
+ optional GraphProto algorithm = 2;
+
+ // This field specifies the bindings from the outputs of "initialization" to
+ // some initializers in "ModelProto.graph.initializer" and
+ // the "algorithm.initializer" in the same TrainingInfoProto.
+ // See "update_binding" below for details.
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "initialization".
+ repeated StringStringEntryProto initialization_binding = 3;
+
+ // Gradient-based training is usually an iterative procedure. In one gradient
+ // descent iteration, we apply
+ //
+ // x = x - r * g
+ //
+ // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is
+ // gradient of "x" with respect to a chosen loss. To avoid adding assignments
+ // into the training graph, we split the update equation into
+ //
+ // y = x - r * g
+ // x = y
+ //
+ // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To
+ // tell that "y" should be assigned to "x", the field "update_binding" may
+ // contain a key-value pair of strings, "x" (key of StringStringEntryProto)
+ // and "y" (value of StringStringEntryProto).
+ // For a neural network with multiple trainable (mutable) tensors, there can
+ // be multiple key-value pairs in "update_binding".
+ //
+ // The initializers appears as keys in "update_binding" are considered
+ // mutable variables. This implies some behaviors
+ // as described below.
+ //
+ // 1. We have only unique keys in all "update_binding"s so that two
+ // variables may not have the same name. This ensures that one
+ // variable is assigned up to once.
+ // 2. The keys must appear in names of "ModelProto.graph.initializer" or
+ // "TrainingInfoProto.algorithm.initializer".
+ // 3. The values must be output names of "algorithm" or "ModelProto.graph.output".
+ // 4. Mutable variables are initialized to the value specified by the
+ // corresponding initializer, and then potentially updated by
+ // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s.
+ //
+ // This field usually contains names of trainable tensors
+ // (in ModelProto.graph), optimizer states such as momentums in advanced
+ // stochastic gradient methods (in TrainingInfoProto.graph),
+ // and number of training iterations (in TrainingInfoProto.graph).
+ //
+ // By default, this field is empty and no initializer would be changed
+ // by the execution of "algorithm".
+ repeated StringStringEntryProto update_binding = 4;
+}
+
// Models
//
// ModelProto is a top-level file/container format for bundling a ML model and
// associating its computation graph with metadata.
//
-// The semantics of the model are described by the associated GraphProto.
+// The semantics of the model are described by the associated GraphProto's.
message ModelProto {
// The version of the IR this model targets. See Version enum above.
// This field MUST be present.
@@ -227,18 +385,58 @@ message ModelProto {
// Named metadata values; keys should be distinct.
repeated StringStringEntryProto metadata_props = 14;
+
+ // Training-specific information. Sequentially executing all stored
+ // `TrainingInfoProto.algorithm`s and assigning their outputs following
+ // the corresponding `TrainingInfoProto.update_binding`s is one training
+ // iteration. Similarly, to initialize the model
+ // (as if training hasn't happened), the user should sequentially execute
+ // all stored `TrainingInfoProto.initialization`s and assigns their outputs
+ // using `TrainingInfoProto.initialization_binding`s.
+ //
+ // If this field is empty, the training behavior of the model is undefined.
+ repeated TrainingInfoProto training_info = 20;
+
+ // A list of function protos local to the model.
+ //
+ // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain".
+ // In case of any conflicts the behavior (whether the model local functions are given higher priority,
+ // or standard operator sets are given higher priotity or this is treated as error) is defined by
+ // the runtimes.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto and other model local FunctionProtos.
+ // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto
+ // or by 2 FunctionProtos then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same for every node in the function body.
+ //
+ // One FunctionProto can reference other FunctionProto in the model, however, recursive reference
+ // is not allowed.
+ repeated FunctionProto functions = 25;
};
// StringStringEntryProto follows the pattern for cross-proto-version maps.
// See https://developers.google.com/protocol-buffers/docs/proto3#maps
message StringStringEntryProto {
optional string key = 1;
- optional string value= 2;
+ optional string value = 2;
};
+message TensorAnnotation {
+ optional string tensor_name = 1;
+ // <key, value> pairs to annotate tensor specified by <tensor_name> above.
+ // The keys used in the mapping below must be pre-defined in ONNX spec.
+ // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as
+ // quantization parameter keys.
+ repeated StringStringEntryProto quant_parameter_tensor_names = 2;
+}
+
+
+
// Graphs
//
-// A graph defines the computational logic of a model and is comprised of a parameterized
+// A graph defines the computational logic of a model and is comprised of a parameterized
// list of nodes that form a directed acyclic graph based on their inputs and outputs.
// This is the equivalent of the "network" or "graph" in many deep learning
// frameworks.
@@ -250,10 +448,14 @@ message GraphProto {
optional string name = 2; // namespace Graph
// A list of named tensor values, used to specify constant inputs of the graph.
- // Each TensorProto entry must have a distinct name (within the list) that
- // also appears in the input list.
+ // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name.
+ // The name MUST be unique across both initializer and sparse_initializer,
+ // but the name MAY also appear in the input list.
repeated TensorProto initializer = 5;
+ // Initializers (see above) stored in sparse format.
+ repeated SparseTensorProto sparse_initializer = 15;
+
// A human-readable documentation for this graph. Markdown is allowed.
optional string doc_string = 10;
@@ -265,13 +467,14 @@ message GraphProto {
// must be distinct. It is optional for a value to appear in value_info list.
repeated ValueInfoProto value_info = 13;
- // DO NOT USE the following fields, they were deprecated from earlier versions.
- // repeated string input = 3;
- // repeated string output = 4;
- // optional int64 ir_version = 6;
- // optional int64 producer_version = 7;
- // optional string producer_tag = 8;
- // optional string domain = 9;
+ // This field carries information to indicate the mapping among a tensor and its
+ // quantization parameter tensors. For example:
+ // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated,
+ // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model.
+ repeated TensorAnnotation quantization_annotation = 14;
+
+ reserved 3, 4, 6 to 9;
+ reserved "ir_version", "producer_version", "producer_tag", "domain";
}
// Tensors
@@ -291,13 +494,32 @@ message TensorProto {
STRING = 8; // string
BOOL = 9; // bool
- // Advanced types
+ // IEEE754 half-precision floating-point format (16 bits wide).
+ // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
FLOAT16 = 10;
+
DOUBLE = 11;
UINT32 = 12;
UINT64 = 13;
COMPLEX64 = 14; // complex with float32 real and imaginary components
COMPLEX128 = 15; // complex with float64 real and imaginary components
+
+ // Non-IEEE floating-point format based on IEEE754 single-precision
+ // floating-point number truncated to 16 bits.
+ // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
+ BFLOAT16 = 16;
+
+ // Non-IEEE floating-point format based on papers
+ // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
+ // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
+ // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
+ // The computation usually happens inside a block quantize / dequantize
+ // fused by the runtime.
+ FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf
+ FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
+ FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients
+ FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
+
// Future extensions go here.
}
@@ -305,7 +527,8 @@ message TensorProto {
repeated int64 dims = 1;
// The data type of the tensor.
- optional DataType data_type = 2;
+ // This field MUST have a valid TensorProto.DataType value
+ optional int32 data_type = 2;
// For very large tensors, we may want to store them in chunks, in which
// case the following fields will specify the segment that is stored in
@@ -324,17 +547,17 @@ message TensorProto {
// For float and complex64 values
// Complex64 tensors are encoded as a single array of floats,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
repeated float float_data = 4 [packed = true];
- // For int32, uint8, int8, uint16, int16, bool, and float16 values
- // float16 values must be bit-wise converted to an uint16_t prior
+ // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values
+ // float16 and float8 values must be bit-wise converted to an uint16_t prior
// to writing to the buffer.
// When this field is present, the data_type field MUST be
- // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ
repeated int32 int32_data = 5 [packed = true];
// For strings.
@@ -371,10 +594,32 @@ message TensorProto {
// When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
optional bytes raw_data = 9;
+ // Data can be stored inside the protobuf file using type-specific fields or raw_data.
+ // Alternatively, raw bytes data can be stored in an external file, using the external_data field.
+ // external_data stores key-value pairs describing data location. Recognized keys are:
+ // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX
+ // protobuf model was stored
+ // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string.
+ // Offset values SHOULD be multiples 4096 (page size) to enable mmap support.
+ // - "length" (optional) - number of bytes containing data. Integer stored as string.
+ // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key.
+ repeated StringStringEntryProto external_data = 13;
+
+ // Location of the data for this tensor. MUST be one of:
+ // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field.
+ // - EXTERNAL - data stored in an external location as described by external_data field.
+ enum DataLocation {
+ DEFAULT = 0;
+ EXTERNAL = 1;
+ }
+
+ // If value not set, data is stored in raw_data (if set) otherwise in type-specified field.
+ optional DataLocation data_location = 14;
+
// For double
- // Complex64 tensors are encoded as a single array of doubles,
+ // Complex128 tensors are encoded as a single array of doubles,
// with the real components appearing in odd numbered positions,
- // and the corresponding imaginary component apparing in the
+ // and the corresponding imaginary component appearing in the
// subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
// is encoded as [1.0, 2.0 ,3.0 ,4.0]
// When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
@@ -386,6 +631,30 @@ message TensorProto {
repeated uint64 uint64_data = 11 [packed = true];
}
+// A serialized sparse-tensor value
+message SparseTensorProto {
+ // The sequence of non-default values are encoded as a tensor of shape [NNZ].
+ // The default-value is zero for numeric tensors, and empty-string for string tensors.
+ // values must have a non-empty name present which serves as a name for SparseTensorProto
+ // when used in sparse_initializer list.
+ optional TensorProto values = 1;
+
+ // The indices of the non-default values, which may be stored in one of two formats.
+ // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value
+ // corresponding to the j-th index of the i-th value (in the values tensor).
+ // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value
+ // must be the linearized-index of the i-th value (in the values tensor).
+ // The linearized-index can be converted into an index tuple (k_1,...,k_rank)
+ // using the shape provided below.
+ // The indices must appear in ascending order without duplication.
+ // In the first format, the ordering is lexicographic-ordering:
+ // e.g., index-value [1,4] must appear before [2,1]
+ optional TensorProto indices = 2;
+
+ // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank]
+ repeated int64 dims = 3;
+}
+
// Defines a tensor shape. A dimension can be either an integer value
// or a symbolic variable. A symbolic variable represents an unknown
// dimension.
@@ -398,36 +667,13 @@ message TensorShapeProto {
// Standard denotation can optionally be used to denote tensor
// dimensions with standard semantic descriptions to ensure
// that operations are applied to the correct axis of a tensor.
+ // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
+ // for pre-defined dimension denotations.
optional string denotation = 3;
};
repeated Dimension dim = 1;
}
-// A set of pre-defined constants to be used as values for
-// the standard denotation field in TensorShapeProto.Dimension
-// for semantic description of the tensor dimension.
-message DenotationConstProto {
- // Describe a batch number dimension.
- optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
- // Describe a channel dimension.
- optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
- // Describe a time dimension.
- optional string DATA_TIME = 3 [default = "DATA_TIME"];
- // Describe a feature dimension. This is typically a feature
- // dimension in RNN and/or spatial dimension in CNN.
- optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
- // Describe a filter in-channel dimension. This is the dimension
- // that is identical (in size) to the channel dimension of the input
- // image feature maps.
- optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
- // Describe a filter out channel dimension. This is the dimension
- // that is identical (int size) to the channel dimension of the output
- // image feature maps.
- optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
- // Describe a filter spatial dimension.
- optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
-}
-
// Types
//
// The standard ONNX data types.
@@ -435,8 +681,43 @@ message TypeProto {
message Tensor {
// This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ optional int32 elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+ // repeated T
+ message Sequence {
+ // The type and optional shape of each element of the sequence.
+ // This field MUST be present for this version of the IR.
+ optional TypeProto elem_type = 1;
+ };
+
+ // map<K,V>
+ message Map {
+ // This field MUST have a valid TensorProto.DataType value
+ // This field MUST be present for this version of the IR.
+ // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING
+ optional int32 key_type = 1;
+ // This field MUST be present for this version of the IR.
+ optional TypeProto value_type = 2;
+ };
+
+ // wrapper for Tensor, Sequence, or Map
+ message Optional {
+ // The type and optional shape of the element wrapped.
+ // This field MUST be present for this version of the IR.
+ // Possible values correspond to OptionalProto.DataType enum
+ optional TypeProto elem_type = 1;
+ };
+
+
+ message SparseTensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST have a valid TensorProto.DataType value
// This field MUST be present for this version of the IR.
- optional TensorProto.DataType elem_type = 1;
+ optional int32 elem_type = 1;
optional TensorShapeProto shape = 2;
}
@@ -445,7 +726,31 @@ message TypeProto {
// The type of a tensor.
Tensor tensor_type = 1;
+ // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values
+ // as input and output to graphs and nodes. These types are needed to naturally
+ // support classical ML operators. DNN operators SHOULD restrict their input
+ // and output types to tensors.
+
+ // The type of a sequence.
+ Sequence sequence_type = 4;
+
+ // The type of a map.
+ Map map_type = 5;
+
+ // The type of an optional.
+ Optional optional_type = 9;
+
+
+ // Type of the sparse tensor
+ SparseTensor sparse_tensor_type = 8;
+
}
+
+ // An optional denotation can be used to denote the whole
+ // type with a standard semantic description as to what is
+ // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition
+ // for pre-defined type denotations.
+ optional string denotation = 6;
}
// Operator Sets
@@ -461,4 +766,70 @@ message OperatorSetIdProto {
// The version of the operator set being identified.
// This field MUST be present in this version of the IR.
optional int64 version = 2;
-} \ No newline at end of file
+}
+
+// Operator/function status.
+enum OperatorStatus {
+ EXPERIMENTAL = 0;
+ STABLE = 1;
+}
+
+message FunctionProto {
+ // The name of the function, similar usage of op_type in OperatorProto.
+ // Combined with FunctionProto.domain, this forms the unique identity of
+ // the FunctionProto.
+ optional string name = 1;
+
+ // Deprecated since IR Version 8
+ // optional int64 since_version = 2;
+ reserved 2;
+ reserved "since_version";
+
+ // Deprecated since IR Version 8
+ // optional OperatorStatus status = 3;
+ reserved 3;
+ reserved "status";
+
+ // The inputs and outputs of the function.
+ repeated string input = 4;
+ repeated string output = 5;
+
+ // The attribute parameters of the function.
+ // It is for function parameters without default values.
+ repeated string attribute = 6;
+
+ // The attribute protos of the function.
+ // It is for function attributes with default values.
+ // A function attribute shall be represented either as
+ // a string attribute or an AttributeProto, not both.
+ repeated AttributeProto attribute_proto = 11;
+
+ // The nodes in the function.
+ repeated NodeProto node = 7;
+ // A human-readable documentation for this function. Markdown is allowed.
+ optional string doc_string = 8;
+
+ // The OperatorSets this function body (graph) relies on.
+ //
+ // All nodes in the function body (graph) will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets. This means at most one version can be relied
+ // for one domain.
+ //
+ // The operator sets imported by FunctionProto should be compatible with the ones
+ // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto
+ // and ModelProto then versions for the operator set may be different but,
+ // the operator schema returned for op_type, domain, version combination
+ // for both the versions should be same.
+
+ repeated OperatorSetIdProto opset_import = 9;
+
+ // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of
+ // the FunctionProto.
+ optional string domain = 10;
+}
+
+
+// For using protobuf-lite
+option optimize_for = LITE_RUNTIME;
+
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 3ef96cdf166..2b707c3beb3 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -775,10 +775,10 @@ public class OnnxOperationsTestCase {
Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder();
tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get()));
if (tensor.type().valueType() == TensorType.Value.FLOAT) {
- builder.setDataType(Onnx.TensorProto.DataType.FLOAT);
+ builder.setDataType(Onnx.TensorProto.DataType.FLOAT_VALUE);
tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue()));
} else {
- builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);
+ builder.setDataType(Onnx.TensorProto.DataType.DOUBLE_VALUE);
tensor.valueIterator().forEachRemaining(builder::addDoubleData);
}
Onnx.TensorProto val = builder.build();
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java
index cd7eee8ba50..ad067110f59 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.node.admin.nodeadmin;
+import ai.vespa.metrics.ContainerMetrics;
import com.yahoo.jdisc.Timer;
import com.yahoo.vespa.hosted.node.admin.container.ContainerStats;
import com.yahoo.vespa.hosted.node.admin.container.metrics.Counter;
@@ -80,9 +81,9 @@ public class NodeAdminImpl implements NodeAdmin {
new Dimensions(Map.of("src", "node-agents")));
this.procMeminfoReader = procMeminfoReader;
- this.jvmHeapUsed = metrics.declareGauge("mem.heap.used");
- this.jvmHeapFree = metrics.declareGauge("mem.heap.free");
- this.jvmHeapTotal = metrics.declareGauge("mem.heap.total");
+ this.jvmHeapUsed = metrics.declareGauge(ContainerMetrics.MEM_HEAP_USED.baseName());
+ this.jvmHeapFree = metrics.declareGauge(ContainerMetrics.MEM_HEAP_FREE.baseName());
+ this.jvmHeapTotal = metrics.declareGauge(ContainerMetrics.MEM_HEAP_TOTAL.baseName());
this.containerCount = metrics.declareGauge("container.count");
this.metrics = metrics;
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java
index dc86daf2c67..9bc18533ddf 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java
@@ -17,9 +17,16 @@ import java.util.Objects;
*/
public final class LockedNodeList extends NodeList {
+ private final Mutex lock;
+
public LockedNodeList(List<Node> nodes, Mutex lock) {
super(nodes, false);
- Objects.requireNonNull(lock, "lock must be non-null");
+ this.lock = Objects.requireNonNull(lock, "lock must be non-null");
+ }
+
+ /** Returns a new LockedNodeList with the for the same lock. */
+ public LockedNodeList childList(List<Node> nodes) {
+ return new LockedNodeList(nodes, lock);
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java
index 60fd07951c6..20c246b3ebd 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java
@@ -28,4 +28,5 @@ public class NodeMutex implements Mutex {
return new NodeMutex(updatedNode, mutex);
}
+
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java
index d6671d41cbd..9da66413b9c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java
@@ -229,7 +229,7 @@ public class NodeRepository extends AbstractComponent {
applicationNodes.asList(),
Agent.system,
Optional.of("Application is removed"),
- transaction.nested());
+ transaction);
applications.remove(transaction);
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java
index 8766dea3d61..e300591fbb2 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.jdisc.Metric;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
+import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.History;
@@ -33,8 +35,12 @@ public class DirtyExpirer extends Expirer {
@Override
protected void expire(List<Node> expired) {
- for (Node expiredNode : expired)
- nodeRepository().nodes().fail(expiredNode.hostname(), wantToDeprovisionOnExpiry, Agent.DirtyExpirer, "Node is stuck in dirty");
+ nodeRepository().nodes().performOn(NodeList.copyOf(expired),
+ node -> node.state() == State.dirty && isExpired(node),
+ (node, lock) -> nodeRepository().nodes().fail(node.hostname(),
+ wantToDeprovisionOnExpiry,
+ Agent.DirtyExpirer,
+ "Node is stuck in dirty"));
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java
index fa3f9435c70..cb0a8005e87 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java
@@ -6,13 +6,14 @@ import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.Zone;
import com.yahoo.jdisc.Metric;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
import com.yahoo.vespa.hosted.provision.NodeList;
+import com.yahoo.vespa.hosted.provision.NodeMutex;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.node.Agent;
-import com.yahoo.vespa.hosted.provision.node.History;
+import com.yahoo.vespa.hosted.provision.node.History.Event.Type;
import java.time.Duration;
-import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Predicate;
@@ -67,55 +68,47 @@ public class FailedExpirer extends NodeRepositoryMaintainer {
@Override
protected double maintain() {
- NodeList allNodes = nodeRepository.nodes().list();
- List<Node> remainingNodes = new ArrayList<>(allNodes.state(Node.State.failed)
- .nodeType(NodeType.tenant, NodeType.host)
- .asList());
+ Predicate<Node> isExpired = node -> node.state() == State.failed
+ && node.history().hasEventBefore(Type.failed, clock().instant().minus(expiryFor(node)));
+ NodeList allNodes = nodeRepository.nodes().list(); // Stale snapshot, not critical.
- recycleIf(node -> node.allocation().isEmpty(), remainingNodes, allNodes);
- recycleIf(node -> !node.allocation().get().membership().cluster().isStateful() &&
- node.history().hasEventBefore(History.Event.Type.failed, clock().instant().minus(statelessExpiry)),
- remainingNodes,
- allNodes);
- recycleIf(node -> node.allocation().get().membership().cluster().isStateful() &&
- node.history().hasEventBefore(History.Event.Type.failed, clock().instant().minus(statefulExpiry)),
- remainingNodes,
- allNodes);
+ nodeRepository.nodes().performOn(allNodes.nodeType(NodeType.tenant),
+ isExpired,
+ (node, lock) -> recycle(node, List.of(), allNodes).get());
+
+ nodeRepository.nodes().performOnRecursively(allNodes.nodeType(NodeType.host),
+ nodes -> isExpired.test(nodes.parent().node()),
+ nodes -> recycle(nodes.parent().node(),
+ nodes.children().stream().map(NodeMutex::node).toList(),
+ allNodes)
+ .map(List::of).orElse(List.of()));
return 1.0;
}
- /** Recycle the nodes matching condition, and remove those nodes from the nodes list. */
- private void recycleIf(Predicate<Node> condition, List<Node> failedNodes, NodeList allNodes) {
- List<Node> nodesToRecycle = failedNodes.stream().filter(condition).toList();
- failedNodes.removeAll(nodesToRecycle);
- recycle(nodesToRecycle, allNodes);
+ private Duration expiryFor(Node node) {
+ return node.allocation().isEmpty() ? Duration.ZERO
+ : node.allocation().get().membership().cluster().isStateful() ? statefulExpiry
+ : statelessExpiry;
}
- /** Move eligible nodes to dirty or parked. This may be a subset of the given nodes */
- private void recycle(List<Node> nodes, NodeList allNodes) {
- List<Node> nodesToRecycle = new ArrayList<>();
- for (Node candidate : nodes) {
- Optional<String> reason = shouldPark(candidate, allNodes);
- if (reason.isPresent()) {
- List<String> unparkedChildren = candidate.type().isHost() ?
- allNodes.childrenOf(candidate)
- .not()
- .state(Node.State.parked)
- .mapToList(Node::hostname) :
- List.of();
-
- if (unparkedChildren.isEmpty()) {
- nodeRepository.nodes().park(candidate.hostname(), true, Agent.FailedExpirer,
- "Parked by FailedExpirer due to " + reason.get());
- } else {
- log.info(String.format("Expired failed node %s was not parked because of unparked children: %s",
- candidate.hostname(), String.join(", ", unparkedChildren)));
- }
+ private Optional<Node> recycle(Node node, List<Node> children, NodeList allNodes) {
+ Optional<String> reason = shouldPark(node, allNodes);
+ if (reason.isPresent()) {
+ List<String> unparkedChildren = children.stream()
+ .filter(child -> child.state() != Node.State.parked)
+ .map(Node::hostname)
+ .toList();
+ if (unparkedChildren.isEmpty()) {
+ return Optional.of(nodeRepository.nodes().park(node.hostname(), true, Agent.FailedExpirer,
+ "Parked by FailedExpirer due to " + reason.get()));
} else {
- nodesToRecycle.add(candidate);
+ log.info(String.format("Expired failed node %s was not parked because of unparked children: %s",
+ node.hostname(), String.join(", ", unparkedChildren)));
+ return Optional.empty();
}
+ } else {
+ return Optional.of(nodeRepository.nodes().deallocate(node, Agent.FailedExpirer, "Expired by FailedExpirer"));
}
- nodeRepository.nodes().deallocate(nodesToRecycle, Agent.FailedExpirer, "Expired by FailedExpirer");
}
/** Returns whether the node should be parked instead of recycled */
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java
index e9e3fd5179a..d70ee825860 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java
@@ -72,7 +72,14 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer {
protected double maintain() {
List<Node> provisionedSnapshot;
try {
- provisionedSnapshot = provision(nodeRepository().nodes().list());
+ NodeList nodes;
+ // Host and child nodes are written in separate transactions, but both are written while holding the
+ // unallocated lock. Hold the unallocated lock while reading nodes to ensure we get all the children
+ // of newly provisioned hosts.
+ try (Mutex ignored = nodeRepository().nodes().lockUnallocated()) {
+ nodes = nodeRepository().nodes().list();
+ }
+ provisionedSnapshot = provision(nodes);
} catch (NodeAllocationException | IllegalStateException e) {
log.log(Level.WARNING, "Failed to allocate preprovisioned capacity and/or find excess hosts: " + e.getMessage());
return 0; // avoid removing excess hosts
@@ -85,16 +92,7 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer {
}
private double markForRemoval(List<Node> provisionedSnapshot) {
- // Group nodes by parent; no parent means it's a host.
- Map<Optional<String>, List<Node>> nodesByParent = provisionedSnapshot.stream().collect(groupingBy(Node::parentHostname));
-
- // Find all hosts that we once thought were empty (first clause), or whose children are now all removable (second clause).
- List<Node> emptyHosts = nodesByParent.get(Optional.<String>empty()).stream()
- .filter(host -> host.hostEmptyAt().isPresent()
- || nodesByParent.getOrDefault(Optional.of(host.hostname()), List.of())
- .stream().allMatch(HostCapacityMaintainer::canDeprovision))
- .toList();
-
+ List<Node> emptyHosts = findEmptyOrRemovableHosts(provisionedSnapshot);
if (emptyHosts.isEmpty()) return 1;
int attempts = 0, success = 0;
@@ -108,18 +106,16 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer {
// Re-read all nodes under lock and compute the candidates for removal. The actual nodes we want
// to mark for removal is the intersection with typeEmptyHosts, which excludes the preprovisioned hosts.
Map<Optional<String>, List<Node>> currentNodesByParent = nodeRepository().nodes().list().stream().collect(groupingBy(Node::parentHostname));
- List<Node> candidateHosts = new ArrayList<>(currentNodesByParent.get(Optional.<String>empty()));
+ List<Node> candidateHosts = new ArrayList<>(getHosts(currentNodesByParent));
candidateHosts.retainAll(typeEmptyHosts);
for (Node host : candidateHosts) {
attempts++;
// Any hosts that are no longer empty should be marked as such, and excluded from removal.
- if (currentNodesByParent.getOrDefault(Optional.of(host.hostname()), List.of())
- .stream().anyMatch(n -> ! canDeprovision(n))) {
- if (host.hostEmptyAt().isPresent()) {
- nodeRepository().nodes().write(host.withHostEmptyAt(null), lock);
- }
+ if (currentNodesByParent.getOrDefault(Optional.of(host.hostname()), List.of()).stream().anyMatch(n -> ! canDeprovision(n))
+ && host.hostEmptyAt().isPresent()) {
+ nodeRepository().nodes().write(host.withHostEmptyAt(null), lock);
}
// If the host is still empty, we can mark it as empty now, or mark it for removal if it has already expired.
else {
@@ -282,11 +278,38 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer {
nodeResources,
nodeRepository().clock().instant()))
.toList();
-
}
private static NodeResources toNodeResources(ClusterCapacity clusterCapacity) {
- return new NodeResources(clusterCapacity.vcpu(), clusterCapacity.memoryGb(), clusterCapacity.diskGb(),
- clusterCapacity.bandwidthGbps());
+ return new NodeResources(clusterCapacity.vcpu(),
+ clusterCapacity.memoryGb(),
+ clusterCapacity.diskGb(),
+ clusterCapacity.bandwidthGbps(),
+ NodeResources.DiskSpeed.valueOf(clusterCapacity.diskSpeed()),
+ NodeResources.StorageType.valueOf(clusterCapacity.storageType()),
+ NodeResources.Architecture.valueOf(clusterCapacity.architecture()));
+ }
+
+ private static List<Node> findEmptyOrRemovableHosts(List<Node> provisionedSnapshot) {
+ // Group nodes by parent; no parent means it's a host.
+ var nodesByParent = provisionedSnapshot.stream().collect(groupingBy(Node::parentHostname));
+
+ // Find all hosts that we once thought were empty (first clause), or whose children are now all removable (second clause).
+ return getHosts(nodesByParent).stream()
+ .filter(host -> host.hostEmptyAt().isPresent() || allChildrenCanBeDeprovisioned(nodesByParent, host))
+ .toList();
+ }
+
+ private static List<Node> getHosts(Map<Optional<String>, List<Node>> nodesByParent) {
+ return nodesByParent.get(Optional.<String>empty());
}
+
+ private static List<Node> getChildren(Map<Optional<String>, List<Node>> nodesByParent, Node host) {
+ return nodesByParent.getOrDefault(Optional.of(host.hostname()), List.of());
+ }
+
+ private static boolean allChildrenCanBeDeprovisioned(Map<Optional<String>, List<Node>> nodesByParent, Node host) {
+ return getChildren(nodesByParent, host).stream().allMatch(HostCapacityMaintainer::canDeprovision);
+ }
+
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java
index e1624183607..fe89ba17469 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.provision.maintenance;
+import com.yahoo.config.provision.CloudAccount;
import com.yahoo.config.provision.NodeType;
import com.yahoo.jdisc.Metric;
import com.yahoo.transaction.Mutex;
@@ -49,18 +50,15 @@ public class HostResumeProvisioner extends NodeRepositoryMaintainer {
NodeList hosts = allNodes.state(Node.State.provisioned).nodeType(NodeType.host, NodeType.confighost, NodeType.controllerhost);
int failures = 0;
for (Node host : hosts) {
- NodeList children = allNodes.childrenOf(host);
try {
- log.log(Level.INFO, "Provisioning " + host.hostname() + " with " + children.size() + " children");
- HostIpConfig hostIpConfig = hostProvisioner.provision(host, children.asSet());
- setIpConfig(host, children, hostIpConfig);
+ HostIpConfig hostIpConfig = hostProvisioner.provision(host);
+ setIpConfig(host, hostIpConfig);
} catch (IllegalArgumentException | IllegalStateException e) {
- log.log(Level.INFO, "Could not provision " + host.hostname() + " with " + children.size() + " children, will retry in " +
+ log.log(Level.INFO, "Could not provision " + host.hostname() + ", will retry in " +
interval() + ": " + Exceptions.toMessageString(e));
} catch (FatalProvisioningException e) {
failures++;
- log.log(Level.SEVERE, "Failed to provision " + host.hostname() + " with " + children.size() +
- " children, failing out the host recursively", e);
+ log.log(Level.SEVERE, "Failed to provision " + host.hostname() + ", failing out the host recursively", e);
nodeRepository().nodes().failOrMarkRecursively(
host.hostname(), Agent.HostResumeProvisioner, "Failed by HostResumeProvisioner due to provisioning failure");
} catch (RuntimeException e) {
@@ -75,19 +73,17 @@ public class HostResumeProvisioner extends NodeRepositoryMaintainer {
return asSuccessFactorDeviation(hosts.size(), failures);
}
- private void setIpConfig(Node host, NodeList children, HostIpConfig hostIpConfig) {
+ private void setIpConfig(Node host, HostIpConfig hostIpConfig) {
if (hostIpConfig.isEmpty()) return;
- NodeList nodes = NodeList.of(host).and(children);
- for (var node : nodes) {
- verifyDns(node, hostIpConfig.require(node.hostname()));
- }
+ hostIpConfig.asMap().forEach((hostname, ipConfig) ->
+ verifyDns(hostname, host.type(), host.cloudAccount(), ipConfig));
nodeRepository().nodes().setIpConfig(hostIpConfig);
}
/** Verify DNS configuration of given node */
- private void verifyDns(Node node, IP.Config ipConfig) {
+ private void verifyDns(String hostname, NodeType hostType, CloudAccount cloudAccount, IP.Config ipConfig) {
for (String ipAddress : ipConfig.primary()) {
- IP.verifyDns(node.hostname(), ipAddress, node.type(), nodeRepository().nameResolver(), node.cloudAccount(), nodeRepository().zone());
+ IP.verifyDns(hostname, ipAddress, hostType, nodeRepository().nameResolver(), cloudAccount, nodeRepository().zone());
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java
index aa7aac34389..503ac4be86c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java
@@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.jdisc.Metric;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
+import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.History;
@@ -40,9 +42,9 @@ public class InactiveExpirer extends Expirer {
@Override
protected void expire(List<Node> expired) {
- expired.forEach(node -> {
- nodeRepository.nodes().deallocate(node, Agent.InactiveExpirer, "Expired by InactiveExpirer");
- });
+ nodeRepository.nodes().performOn(NodeList.copyOf(expired),
+ node -> node.state() == State.inactive && isExpired(node),
+ (node, lock) -> nodeRepository.nodes().deallocate(node, Agent.InactiveExpirer, "Expired by InactiveExpirer"));
}
@Override
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java
index 6f06a2ac22e..2484f496ece 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java
@@ -25,6 +25,8 @@ public class ReservationExpirer extends Expirer {
}
@Override
- protected void expire(List<Node> expired) { nodeRepository().nodes().deallocate(expired, Agent.ReservationExpirer, "Expired by ReservationExpirer"); }
+ protected void expire(List<Node> expired) {
+ nodeRepository().nodes().deallocate(expired, Agent.ReservationExpirer, "Expired by ReservationExpirer");
+ }
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java
index cc7db3c138a..1ff6d2b300d 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java
@@ -113,7 +113,7 @@ public record IP() {
*
* @throws IllegalArgumentException if there are IP conflicts with existing nodes
*/
- public static List<Node> verify(List<Node> nodes, LockedNodeList allNodes) {
+ public static LockedNodeList verify(List<Node> nodes, LockedNodeList allNodes) {
NodeList sortedNodes = allNodes.sortedBy(Comparator.comparing(Node::hostname));
for (var node : nodes) {
for (var other : sortedNodes) {
@@ -135,7 +135,7 @@ public record IP() {
other.hostname());
}
}
- return nodes;
+ return allNodes.childList(nodes);
}
/** Returns whether IP address of existing node can be assigned to node */
@@ -152,7 +152,7 @@ public record IP() {
}
public static Node verify(Node node, LockedNodeList allNodes) {
- return verify(List.of(node), allNodes).get(0);
+ return verify(List.of(node), allNodes).asList().get(0);
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java
index b10a371e8bd..490e7b9ac33 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java
@@ -1,7 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.provision.node;
-import com.yahoo.collections.ListMap;
import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ApplicationTransaction;
@@ -10,6 +9,7 @@ import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.Zone;
+import com.yahoo.time.TimeBudget;
import com.yahoo.transaction.Mutex;
import com.yahoo.transaction.NestedTransaction;
import com.yahoo.vespa.applicationmodel.HostName;
@@ -17,6 +17,7 @@ import com.yahoo.vespa.applicationmodel.InfrastructureApplication;
import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.NoSuchNodeException;
import com.yahoo.vespa.hosted.provision.Node;
+import com.yahoo.vespa.hosted.provision.Node.State;
import com.yahoo.vespa.hosted.provision.NodeList;
import com.yahoo.vespa.hosted.provision.NodeMutex;
import com.yahoo.vespa.hosted.provision.applications.Applications;
@@ -31,20 +32,26 @@ import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Comparator;
import java.util.EnumSet;
+import java.util.HashSet;
+import java.util.Iterator;
import java.util.List;
-import java.util.Map;
-import java.util.Objects;
+import java.util.NavigableSet;
import java.util.Optional;
import java.util.Set;
+import java.util.TreeSet;
import java.util.function.BiFunction;
+import java.util.function.Function;
import java.util.function.Predicate;
import java.util.logging.Level;
import java.util.logging.Logger;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
import static com.yahoo.vespa.hosted.provision.restapi.NodePatcher.DROP_DOCUMENTS_REPORT;
+import static java.util.Comparator.comparing;
+import static java.util.stream.Collectors.groupingBy;
+import static java.util.stream.Collectors.joining;
/**
* The nodes in the node repo and their state transitions
@@ -148,7 +155,7 @@ public class Nodes {
if (existing.isPresent())
throw new IllegalStateException("Cannot add " + node + ": A node with this name already exists");
}
- return db.addNodesInState(nodes.asList(), Node.State.reserved, Agent.system);
+ return db.addNodesInState(nodes, Node.State.reserved, Agent.system);
}
/**
@@ -157,7 +164,8 @@ public class Nodes {
* with the history of that node.
*/
public List<Node> addNodes(List<Node> nodes, Agent agent) {
- try (Mutex lock = lockUnallocated()) {
+ try (NodeMutexes existingNodesLocks = lockAndGetAll(nodes, Optional.empty()); // Locks for any existing nodes we may remove.
+ Mutex allocationLock = lockUnallocated()) {
List<Node> nodesToAdd = new ArrayList<>();
List<Node> nodesToRemove = new ArrayList<>();
for (int i = 0; i < nodes.size(); i++) {
@@ -194,7 +202,7 @@ public class Nodes {
}
NestedTransaction transaction = new NestedTransaction();
db.removeNodes(nodesToRemove, transaction);
- List<Node> resultingNodes = db.addNodesInState(IP.Config.verify(nodesToAdd, list(lock)), Node.State.provisioned, agent, transaction);
+ List<Node> resultingNodes = db.addNodesInState(IP.Config.verify(nodesToAdd, list(allocationLock)), Node.State.provisioned, agent, transaction);
transaction.commit();
return resultingNodes;
}
@@ -218,7 +226,7 @@ public class Nodes {
}
/** Activate nodes. This method does <b>not</b> lock the node repository. */
- public List<Node> activate(List<Node> nodes, NestedTransaction transaction) {
+ public List<Node> activate(List<Node> nodes, ApplicationTransaction transaction) {
return db.writeTo(Node.State.active, nodes, Agent.application, Optional.empty(), transaction);
}
@@ -229,8 +237,7 @@ public class Nodes {
* @param reusable move the node directly to {@link Node.State#dirty} after removal
*/
public void setRemovable(NodeList nodes, boolean reusable) {
- performOn(nodes, (node, mutex) -> write(node.with(node.allocation().get().removable(true, reusable)),
- mutex));
+ performOn(nodes, (node, mutex) -> write(node.with(node.allocation().get().removable(true, reusable)), mutex));
}
/**
@@ -239,7 +246,7 @@ public class Nodes {
*/
public List<Node> deactivate(List<Node> nodes, ApplicationTransaction transaction) {
if ( ! zone.environment().isProduction() || zone.system().isCd())
- return deallocate(nodes, Agent.application, "Deactivated by application", transaction.nested());
+ return deallocate(nodes, Agent.application, "Deactivated by application", transaction);
NodeList nodeList = NodeList.copyOf(nodes);
NodeList stateless = nodeList.stateless();
@@ -247,9 +254,9 @@ public class Nodes {
NodeList statefulToInactive = stateful.not().reusable();
NodeList statefulToDirty = stateful.reusable();
List<Node> written = new ArrayList<>();
- written.addAll(deallocate(stateless.asList(), Agent.application, "Deactivated by application", transaction.nested()));
- written.addAll(deallocate(statefulToDirty.asList(), Agent.application, "Deactivated by application (recycled)", transaction.nested()));
- written.addAll(db.writeTo(Node.State.inactive, statefulToInactive.asList(), Agent.application, Optional.empty(), transaction.nested()));
+ written.addAll(deallocate(stateless.asList(), Agent.application, "Deactivated by application", transaction));
+ written.addAll(deallocate(statefulToDirty.asList(), Agent.application, "Deactivated by application (recycled)", transaction));
+ written.addAll(db.writeTo(Node.State.inactive, statefulToInactive.asList(), Agent.application, Optional.empty(), transaction));
return written;
}
@@ -258,21 +265,9 @@ public class Nodes {
* transaction commits.
*/
public List<Node> fail(List<Node> nodes, ApplicationTransaction transaction) {
- return fail(nodes, Agent.application, "Failed by application", transaction.nested());
- }
-
- public List<Node> fail(List<Node> nodes, Agent agent, String reason) {
- NestedTransaction transaction = new NestedTransaction();
- nodes = fail(nodes, agent, reason, transaction);
- transaction.commit();
- return nodes;
- }
-
- private List<Node> fail(List<Node> nodes, Agent agent, String reason, NestedTransaction transaction) {
- nodes = nodes.stream()
- .map(n -> n.withWantToFail(false, agent, clock.instant()))
- .toList();
- return db.writeTo(Node.State.failed, nodes, agent, Optional.of(reason), transaction);
+ return db.writeTo(Node.State.failed,
+ nodes.stream().map(n -> n.withWantToFail(false, Agent.application, clock.instant())).toList(),
+ Agent.application, Optional.of("Failed by application"), transaction);
}
/** Move nodes to the dirty state */
@@ -282,40 +277,48 @@ public class Nodes {
public List<Node> deallocateRecursively(String hostname, Agent agent, String reason) {
Node nodeToDirty = node(hostname).orElseThrow(() -> new NoSuchNodeException("Could not deallocate " + hostname + ": Node not found"));
-
- List<Node> nodesToDirty =
- (nodeToDirty.type().isHost() ?
- Stream.concat(list().childrenOf(hostname).asList().stream(), Stream.of(nodeToDirty)) :
- Stream.of(nodeToDirty)).filter(node -> node.state() != Node.State.dirty).toList();
- List<String> hostnamesNotAllowedToDirty = nodesToDirty.stream()
- .filter(node -> node.state() != Node.State.provisioned)
- .filter(node -> node.state() != Node.State.failed)
- .filter(node -> node.state() != Node.State.parked)
- .filter(node -> node.state() != Node.State.breakfixed)
- .map(Node::hostname).toList();
- if ( ! hostnamesNotAllowedToDirty.isEmpty())
- illegal("Could not deallocate " + nodeToDirty + ": " +
- hostnamesNotAllowedToDirty + " are not in states [provisioned, failed, parked, breakfixed]");
-
- return nodesToDirty.stream().map(node -> deallocate(node, agent, reason)).toList();
+ List<Node> nodesToDirty = new ArrayList<>();
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) {
+ for (NodeMutex child : locked.children())
+ if (child.node().state() != Node.State.dirty)
+ nodesToDirty.add(child.node());
+
+ if (locked.parent().node().state() != State.dirty)
+ nodesToDirty.add(locked.parent().node());
+
+ List<String> hostnamesNotAllowedToDirty = nodesToDirty.stream()
+ .filter(node -> node.state() != Node.State.provisioned)
+ .filter(node -> node.state() != Node.State.failed)
+ .filter(node -> node.state() != Node.State.parked)
+ .filter(node -> node.state() != Node.State.breakfixed)
+ .map(Node::hostname).toList();
+ if ( ! hostnamesNotAllowedToDirty.isEmpty())
+ illegal("Could not deallocate " + nodeToDirty + ": " +
+ hostnamesNotAllowedToDirty + " are not in states [provisioned, failed, parked, breakfixed]");
+
+ return nodesToDirty.stream().map(node -> deallocate(node, agent, reason)).toList();
+ }
}
/**
- * Set a node dirty or parked, allowed if it is in the provisioned, inactive, failed or parked state.
+ * Set a node dirty or parked, allowed if it is in the provisioned, inactive, failed or parked state.
* Use this to clean newly provisioned nodes or to recycle failed nodes which have been repaired or put on hold.
*/
public Node deallocate(Node node, Agent agent, String reason) {
- NestedTransaction transaction = new NestedTransaction();
- Node deallocated = deallocate(node, agent, reason, transaction);
- transaction.commit();
- return deallocated;
+ try (NodeMutex locked = lockAndGetRequired(node)) {
+ NestedTransaction transaction = new NestedTransaction();
+ Node deallocated = deallocate(locked.node(), agent, reason, transaction);
+ transaction.commit();
+ return deallocated;
+ }
}
- public List<Node> deallocate(List<Node> nodes, Agent agent, String reason, NestedTransaction transaction) {
- return nodes.stream().map(node -> deallocate(node, agent, reason, transaction)).toList();
+ public List<Node> deallocate(List<Node> nodes, Agent agent, String reason, ApplicationTransaction transaction) {
+ return nodes.stream().map(node -> deallocate(node, agent, reason, transaction.nested())).toList();
}
- public Node deallocate(Node node, Agent agent, String reason, NestedTransaction transaction) {
+ // Be sure to hold the right lock!
+ private Node deallocate(Node node, Agent agent, String reason, NestedTransaction transaction) {
if (parkOnDeallocationOf(node, agent)) {
return park(node.hostname(), false, agent, reason, transaction);
} else {
@@ -339,7 +342,9 @@ public class Nodes {
}
public Node fail(String hostname, boolean forceDeprovision, Agent agent, String reason) {
- return move(hostname, Node.State.failed, agent, forceDeprovision, Optional.of(reason));
+ try (NodeMutex lock = lockAndGetRequired(hostname)) {
+ return move(hostname, Node.State.failed, agent, forceDeprovision, Optional.of(reason), lock);
+ }
}
/**
@@ -350,14 +355,16 @@ public class Nodes {
* @return all the nodes that were changed by this request
*/
public List<Node> failOrMarkRecursively(String hostname, Agent agent, String reason) {
- NodeList children = list().childrenOf(hostname);
- List<Node> changed = performOn(children, (node, lock) -> failOrMark(node, agent, reason, lock));
-
- if (children.state(Node.State.active).isEmpty())
- changed.add(move(hostname, Node.State.failed, agent, false, Optional.of(reason)));
- else
- changed.addAll(performOn(NodeList.of(node(hostname).orElseThrow()), (node, lock) -> failOrMark(node, agent, reason, lock)));
+ List<Node> changed = new ArrayList<>();
+ try (RecursiveNodeMutexes nodes = lockAndGetRecursively(hostname, Optional.empty())) {
+ for (NodeMutex child : nodes.children())
+ changed.add(failOrMark(child.node(), agent, reason, child));
+ if (changed.stream().noneMatch(child -> child.state() == Node.State.active))
+ changed.add(move(hostname, Node.State.failed, agent, false, Optional.of(reason), nodes.parent()));
+ else
+ changed.add(failOrMark(nodes.parent().node(), agent, reason, nodes.parent()));
+ }
return changed;
}
@@ -367,12 +374,14 @@ public class Nodes {
write(node, lock);
return node;
} else {
- return move(node.hostname(), Node.State.failed, agent, false, Optional.of(reason));
+ return move(node.hostname(), Node.State.failed, agent, false, Optional.of(reason), lock);
}
}
/** Update IP config for nodes in given config */
public void setIpConfig(HostIpConfig hostIpConfig) {
+ // Ideally this should hold the unallocated lock over the entire method, but unallocated lock must be taken
+ // after the application lock, making this impossible
Predicate<Node> nodeInConfig = (node) -> hostIpConfig.contains(node.hostname());
performOn(nodeInConfig, (node, lock) -> {
IP.Config ipConfig = hostIpConfig.require(node.hostname());
@@ -387,10 +396,12 @@ public class Nodes {
* @throws NoSuchNodeException if the node is not found
*/
public Node park(String hostname, boolean forceDeprovision, Agent agent, String reason) {
- NestedTransaction transaction = new NestedTransaction();
- Node parked = park(hostname, forceDeprovision, agent, reason, transaction);
- transaction.commit();
- return parked;
+ try (NodeMutex locked = lockAndGetRequired(hostname)) {
+ NestedTransaction transaction = new NestedTransaction();
+ Node parked = park(hostname, forceDeprovision, agent, reason, transaction);
+ transaction.commit();
+ return parked;
+ }
}
private Node park(String hostname, boolean forceDeprovision, Agent agent, String reason, NestedTransaction transaction) {
@@ -413,36 +424,38 @@ public class Nodes {
* @throws NoSuchNodeException if the node is not found
*/
public Node reactivate(String hostname, Agent agent, String reason) {
- return move(hostname, Node.State.active, agent, false, Optional.of(reason));
+ try (NodeMutex lock = lockAndGetRequired(hostname)) {
+ return move(hostname, Node.State.active, agent, false, Optional.of(reason), lock);
+ }
}
/**
* Moves a host to breakfixed state, removing any children.
*/
public List<Node> breakfixRecursively(String hostname, Agent agent, String reason) {
- Node node = requireNode(hostname);
- try (Mutex lock = lockUnallocated()) {
- requireBreakfixable(node);
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) {
+ requireBreakfixable(locked.parent().node());
NestedTransaction transaction = new NestedTransaction();
- List<Node> removed = removeChildren(node, false, transaction);
- removed.add(move(node.hostname(), Node.State.breakfixed, agent, false, Optional.of(reason), transaction));
+ removeChildren(locked, false, transaction);
+ move(hostname, Node.State.breakfixed, agent, false, Optional.of(reason), transaction);
transaction.commit();
- return removed;
+ return locked.nodes().nodes().stream().map(NodeMutex::node).toList();
}
}
private List<Node> moveRecursively(String hostname, Node.State toState, Agent agent, Optional<String> reason) {
- NestedTransaction transaction = new NestedTransaction();
- List<Node> moved = list().childrenOf(hostname).asList().stream()
- .map(child -> move(child.hostname(), toState, agent, false, reason, transaction))
- .collect(Collectors.toCollection(ArrayList::new));
- moved.add(move(hostname, toState, agent, false, reason, transaction));
- transaction.commit();
- return moved;
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) {
+ List<Node> moved = new ArrayList<>();
+ NestedTransaction transaction = new NestedTransaction();
+ for (NodeMutex node : locked.nodes().nodes())
+ moved.add(move(node.node().hostname(), toState, agent, false, reason, transaction));
+ transaction.commit();
+ return moved;
+ }
}
/** Move a node to given state */
- private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason) {
+ private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason, Mutex lock) {
NestedTransaction transaction = new NestedTransaction();
Node moved = move(hostname, toState, agent, forceDeprovision, reason, transaction);
transaction.commit();
@@ -451,8 +464,7 @@ public class Nodes {
/** Move a node to given state as part of a transaction */
private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason, NestedTransaction transaction) {
- // TODO: Work out a safe lock acquisition strategy for moves. Lock is only held while adding operations to
- // transaction, but lock must also be held while committing
+ // TODO: This lock is already held here, but we still need to read the node. Perhaps change to requireNode(hostname) later.
try (NodeMutex lock = lockAndGetRequired(hostname)) {
Node node = lock.node();
if (toState == Node.State.active) {
@@ -521,17 +533,18 @@ public class Nodes {
}
public List<Node> removeRecursively(Node node, boolean force) {
- try (Mutex lock = lockUnallocated()) {
- requireRemovable(node, false, force);
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(node.hostname(), Optional.empty())) {
+ requireRemovable(locked.parent().node(), false, force);
NestedTransaction transaction = new NestedTransaction();
List<Node> removed;
- if (!node.type().isHost()) {
+ if ( ! node.type().isHost()) {
removed = List.of(node);
db.removeNodes(removed, transaction);
- } else {
- removed = removeChildren(node, force, transaction);
+ }
+ else {
+ removeChildren(locked, force, transaction);
move(node.hostname(), Node.State.deprovisioned, Agent.system, false, Optional.empty(), transaction);
- removed.add(node);
+ removed = locked.nodes().nodes().stream().map(NodeMutex::node).toList();
}
transaction.commit();
return removed;
@@ -540,20 +553,22 @@ public class Nodes {
/** Forgets a deprovisioned node. This removes all traces of the node in the node repository. */
public void forget(Node node) {
- if (node.state() != Node.State.deprovisioned)
- throw new IllegalArgumentException(node + " must be deprovisioned before it can be forgotten");
- if (node.status().wantToRebuild())
- throw new IllegalArgumentException(node + " is rebuilding and cannot be forgotten");
- NestedTransaction transaction = new NestedTransaction();
- db.removeNodes(List.of(node), transaction);
- transaction.commit();
+ try (NodeMutex locked = lockAndGetRequired(node.hostname())) {
+ if (node.state() != Node.State.deprovisioned)
+ throw new IllegalArgumentException(node + " must be deprovisioned before it can be forgotten");
+ if (node.status().wantToRebuild())
+ throw new IllegalArgumentException(node + " is rebuilding and cannot be forgotten");
+ NestedTransaction transaction = new NestedTransaction();
+ db.removeNodes(List.of(node), transaction);
+ transaction.commit();
+ }
}
- private List<Node> removeChildren(Node node, boolean force, NestedTransaction transaction) {
- List<Node> children = list().childrenOf(node).asList();
+ private void removeChildren(RecursiveNodeMutexes nodes, boolean force, NestedTransaction transaction) {
+ if (nodes.children().isEmpty()) return;
+ List<Node> children = nodes.children().stream().map(NodeMutex::node).toList();
children.forEach(child -> requireRemovable(child, true, force));
db.removeNodes(children, transaction);
- return new ArrayList<>(children);
}
/**
@@ -715,8 +730,8 @@ public class Nodes {
return db.writeTo(nodes, Agent.system, Optional.empty());
}
- private List<Node> performOn(Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) {
- return performOn(list().matching(filter), action);
+ public List<Node> performOn(Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) {
+ return performOn(list(), filter, action);
}
/**
@@ -725,35 +740,33 @@ public class Nodes {
* @param action the action to perform
* @return the set of nodes on which the action was performed, as they became as a result of the operation
*/
- private List<Node> performOn(NodeList nodes, BiFunction<Node, Mutex, Node> action) {
- List<Node> unallocatedNodes = new ArrayList<>();
- ListMap<ApplicationId, Node> allocatedNodes = new ListMap<>();
+ public List<Node> performOn(NodeList nodes, BiFunction<Node, Mutex, Node> action) {
+ return performOn(nodes, __ -> true, action);
+ }
- // Group matching nodes by the lock needed
- for (Node node : nodes) {
- Optional<ApplicationId> applicationId = applicationIdForLock(node);
- if (applicationId.isPresent())
- allocatedNodes.put(applicationId.get(), node);
- else
- unallocatedNodes.add(node);
- }
+ public List<Node> performOn(NodeList nodes, Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) {
+ List<Node> resultingNodes = new ArrayList<>();
+ nodes.stream().collect(groupingBy(Nodes::applicationIdForLock))
+ .forEach((applicationId, nodeList) -> { // Grouped only to reduce number of lock acquire/release cycles.
+ try (NodeMutexes locked = lockAndGetAll(nodeList, Optional.empty())) {
+ for (NodeMutex node : locked.nodes())
+ if (filter.test(node.node()))
+ resultingNodes.add(action.apply(node.node(), node));
+ }
+ });
+ return resultingNodes;
+ }
+
+ public List<Node> performOnRecursively(NodeList parents, Predicate<RecursiveNodeMutexes> filter, Function<RecursiveNodeMutexes, List<Node>> action) {
+ for (Node node : parents)
+ if (node.parentHostname().isPresent())
+ throw new IllegalArgumentException(node + " is not a parent host");
- // Perform operation while holding appropriate lock
List<Node> resultingNodes = new ArrayList<>();
- try (Mutex lock = lockUnallocated()) {
- for (Node node : unallocatedNodes) {
- Optional<Node> currentNode = db.readNode(node.hostname()); // Re-read while holding lock
- if (currentNode.isEmpty()) continue;
- resultingNodes.add(action.apply(currentNode.get(), lock));
- }
- }
- for (Map.Entry<ApplicationId, List<Node>> applicationNodes : allocatedNodes.entrySet()) {
- try (Mutex lock = applications.lock(applicationNodes.getKey())) {
- for (Node node : applicationNodes.getValue()) {
- Optional<Node> currentNode = db.readNode(node.hostname()); // Re-read while holding lock
- if (currentNode.isEmpty()) continue;
- resultingNodes.add(action.apply(currentNode.get(), lock));
- }
+ for (Node parent : parents) {
+ try (RecursiveNodeMutexes locked = lockAndGetRecursively(parent.hostname(), Optional.empty())) {
+ if (filter.test(locked))
+ resultingNodes.addAll(action.apply(locked));
}
}
return resultingNodes;
@@ -816,9 +829,7 @@ public class Nodes {
return Optional.empty();
}
- if (node.type() != NodeType.tenant ||
- Objects.equals(freshNode.get().allocation().map(Allocation::owner),
- staleNode.allocation().map(Allocation::owner))) {
+ if (applicationIdForLock(freshNode.get()).equals(applicationIdForLock(staleNode))) {
NodeMutex nodeMutex = new NodeMutex(freshNode.get(), lockToClose);
lockToClose = null;
return Optional.of(nodeMutex);
@@ -879,6 +890,168 @@ public class Nodes {
return node(hostname).orElseThrow(() -> new NoSuchNodeException("No node with hostname '" + hostname + "'"));
}
+ /**
+ * Locks the children of the given node, the node itself, and finally takes the unallocated lock.
+ * <br>
+ * When taking multiple locks, it's crucial that we always take them in the same order, to avoid deadlocks.
+ * We want to take the most contended locks last, so that we don't block other operations for longer than necessary.
+ * This method does that, by first taking the locks for any children the given node may have, and then the node itself.
+ * (This is enforced by taking host locks after tenant node locks, in {@link #lockAndGetAll(Collection, Optional)}.)
+ * Finally, the allocation lock is taken, to ensure no new children are added while we hold this snapshot.
+ * Unfortunately, since that lock is taken last, we may detect new nodes after taking it, and then we have to retry.
+ * Closing the returned {@link RecursiveNodeMutexes} will release all the locks, and the locks should not be closed elsewhere.
+ */
+ public RecursiveNodeMutexes lockAndGetRecursively(String hostname, Optional<Duration> timeout) {
+ TimeBudget budget = TimeBudget.fromNow(clock, timeout.orElse(Duration.ofMinutes(2)));
+ Set<Node> children = new HashSet<>(list().childrenOf(hostname).asList());
+ Optional<Node> node = node(hostname);
+
+ int attempts = 5; // We'll retry locking the whole list of children this many times, in case new children appear.
+ for (int attempt = 0; attempt < attempts; attempt++) {
+ NodeMutexes mutexes = null;
+ Mutex unallocatedLock = null;
+ try {
+ // First, we lock all the children, and the host; then we take the allocation lock to ensure our snapshot is valid.
+ List<Node> nodes = new ArrayList<>(children.size() + 1);
+ nodes.addAll(children);
+ node.ifPresent(nodes::add);
+ mutexes = lockAndGetAll(nodes, budget.timeLeftOrThrow());
+ unallocatedLock = db.lockInactive(budget.timeLeftOrThrow().get());
+ RecursiveNodeMutexes recursive = new RecursiveNodeMutexes(hostname, mutexes, unallocatedLock);
+ Set<Node> freshChildren = list().childrenOf(hostname).asSet();
+ Optional<Node> freshNode = recursive.parent.map(NodeMutex::node);
+ if (children.equals(freshChildren) && node.equals(freshNode)) {
+ // No new nodes have appeared, and none will now, so we have a consistent snapshot.
+ if (node.isEmpty() && ! children.isEmpty())
+ throw new IllegalStateException("node '" + hostname + "' was not found, but it has children: " + children);
+
+ mutexes = null;
+ unallocatedLock = null;
+ return recursive;
+ }
+ else {
+ // New nodes have appeared, so we need to let go of the locks and try again with the new set of nodes.
+ children = freshChildren;
+ node = freshNode;
+ }
+ }
+ finally {
+ if (unallocatedLock != null) unallocatedLock.close();
+ if (mutexes != null) mutexes.close();
+ }
+ }
+ throw new IllegalStateException("giving up (after " + attempts + " attempts) fetching an up to " +
+ "date recursive node set under lock for node " + hostname);
+ }
+
+ /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes required. */
+ public NodeMutexes lockAndRequireAll(Collection<Node> nodes, Optional<Duration> timeout) {
+ return lockAndGetAll(nodes, timeout, true);
+ }
+
+ /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes acquired. */
+ public NodeMutexes lockAndGetAll(Collection<Node> nodes, Optional<Duration> timeout) {
+ return lockAndGetAll(nodes, timeout, false);
+ }
+
+ /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes. */
+ private NodeMutexes lockAndGetAll(Collection<Node> nodes, Optional<Duration> timeout, boolean required) {
+ TimeBudget budget = TimeBudget.fromNow(clock, timeout.orElse(Duration.ofMinutes(2)));
+ Comparator<Node> universalOrder = (a, b) -> {
+ Optional<ApplicationId> idA = applicationIdForLock(a);
+ Optional<ApplicationId> idB = applicationIdForLock(b);
+ if (idA.isPresent() != idB.isPresent()) return idA.isPresent() ? -1 : 1; // Allocated nodes first.
+ if (a.type() != b.type()) return a.type().compareTo(b.type()); // Tenant nodes first among those.
+ if ( ! idA.equals(idB)) return idA.get().compareTo(idB.get()); // Sort primarily by tenant owner id.
+ return a.hostname().compareTo(b.hostname()); // Sort secondarily by hostname.
+ };
+ NavigableSet<NodeMutex> locked = new TreeSet<>(comparing(NodeMutex::node, universalOrder));
+ NavigableSet<Node> unlocked = new TreeSet<>(universalOrder);
+ unlocked.addAll(nodes);
+ try {
+ int attempts = 10; // We'll accept getting the wrong lock at most this many times before giving up.
+ for (int attempt = 0; attempt < attempts; ) {
+ if (unlocked.isEmpty()) {
+ NodeMutexes mutexes = new NodeMutexes(List.copyOf(locked));
+ locked.clear();
+ return mutexes;
+ }
+
+ // If the first node is now earlier in lock order than some other locks we have, we need to close those and re-acquire them.
+ Node next = unlocked.pollFirst();
+ Set<NodeMutex> outOfOrder = locked.tailSet(new NodeMutex(next, () -> { }), false);
+ NodeMutexes.close(outOfOrder.iterator());
+ for (NodeMutex node : outOfOrder) unlocked.add(node.node());
+ outOfOrder.clear();
+
+ Mutex lock = lock(next, budget.timeLeftOrThrow());
+ try {
+ Optional<Node> fresh = node(next.hostname());
+ if (fresh.isEmpty()) {
+ if (required) throw new NoSuchNodeException("No node with hostname '" + next.hostname() + "'");
+ continue; // Node is gone; skip to close lock.
+ }
+
+ if (applicationIdForLock(fresh.get()).equals(applicationIdForLock(next))) {
+ // We held the right lock, so this node is ours now.
+ locked.add(new NodeMutex(fresh.get(), lock));
+ lock = null;
+ }
+ else {
+ // We held the wrong lock, and need to try again.
+ ++attempt;
+ unlocked.add(fresh.get());
+ }
+ }
+ finally {
+ // If we didn't hold the right lock, we must close the wrong one before we continue.
+ if (lock != null) lock.close();
+ }
+ }
+ throw new IllegalStateException("giving up (after " + attempts + " extra attempts) to lock nodes: " +
+ nodes.stream().map(Node::hostname).collect(joining(", ")));
+ }
+ finally {
+ // If we didn't manage to lock all nodes, we must close the ones we did lock before we throw.
+ NodeMutexes.close(locked.iterator());
+ }
+ }
+
+ /** A node with their locks, acquired in a universal order. */
+ public record NodeMutexes(List<NodeMutex> nodes) implements AutoCloseable {
+ @Override public void close() { close(nodes.iterator()); }
+ private static void close(Iterator<NodeMutex> nodes) {
+ if (nodes.hasNext()) try (NodeMutex node = nodes.next()) { close(nodes); }
+ }
+ }
+
+ /** A parent node, all its children, their locks acquired in a universal order, and then the unallocated lock. */
+ public static class RecursiveNodeMutexes implements AutoCloseable {
+
+ private final String hostname;
+ private final NodeMutexes nodes;
+ private final Mutex unallocatedLock;
+ private final List<NodeMutex> children;
+ private final Optional<NodeMutex> parent;
+
+ public RecursiveNodeMutexes(String hostname, NodeMutexes nodes, Mutex unallocatedLock) {
+ this.hostname = hostname;
+ this.nodes = nodes;
+ this.unallocatedLock = unallocatedLock;
+ this.children = nodes.nodes().stream().filter(node -> ! node.node().hostname().equals(hostname)).toList();
+ this.parent = nodes.nodes().stream().filter(node -> node.node().hostname().equals(hostname)).findFirst();
+ }
+
+ /** Any children of the node. */
+ public List<NodeMutex> children() { return children; }
+ /** The node itself, or throws if the node was not found. */
+ public NodeMutex parent() { return parent.orElseThrow(() -> new NoSuchNodeException("No node with hostname '" + hostname + "'")); }
+ /** Empty if the node was not found, or the node, and any children. */
+ public NodeMutexes nodes() { return nodes; }
+ /** Closes the allocation lock, and all the node locks. */
+ @Override public void close() { try (nodes; unallocatedLock) { } }
+ }
+
/** Returns the application ID that should be used for locking when modifying this node */
private static Optional<ApplicationId> applicationIdForLock(Node node) {
return switch (node.type()) {
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java
index fc008b7b9dc..037338cb2ed 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java
@@ -18,6 +18,7 @@ import com.yahoo.vespa.curator.Lock;
import com.yahoo.vespa.curator.recipes.CuratorCounter;
import com.yahoo.vespa.curator.transaction.CuratorOperations;
import com.yahoo.vespa.curator.transaction.CuratorTransaction;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.applications.Application;
import com.yahoo.vespa.hosted.provision.archive.ArchiveUris;
@@ -105,7 +106,7 @@ public class CuratorDb {
}
/** Adds a set of nodes. Rollbacks/fails transaction if any node is not in the expected state. */
- public List<Node> addNodesInState(List<Node> nodes, Node.State expectedState, Agent agent, NestedTransaction transaction) {
+ public List<Node> addNodesInState(LockedNodeList nodes, Node.State expectedState, Agent agent, NestedTransaction transaction) {
CuratorTransaction curatorTransaction = db.newCuratorTransactionIn(transaction);
for (Node node : nodes) {
if (node.state() != expectedState)
@@ -116,10 +117,10 @@ public class CuratorDb {
curatorTransaction.add(CuratorOperations.create(nodePath(node).getAbsolute(), serialized));
}
transaction.onCommitted(() -> nodes.forEach(node -> log.log(Level.INFO, "Added " + node)));
- return nodes;
+ return nodes.asList();
}
- public List<Node> addNodesInState(List<Node> nodes, Node.State expectedState, Agent agent) {
+ public List<Node> addNodesInState(LockedNodeList nodes, Node.State expectedState, Agent agent) {
NestedTransaction transaction = new NestedTransaction();
List<Node> writtenNodes = addNodesInState(nodes, expectedState, agent, transaction);
transaction.commit();
@@ -175,6 +176,7 @@ public class CuratorDb {
return writtenNodes;
}
}
+
public Node writeTo(Node.State toState, Node node, Agent agent, Optional<String> reason) {
return writeTo(toState, Collections.singletonList(node), agent, reason).get(0);
}
@@ -192,6 +194,12 @@ public class CuratorDb {
*/
public List<Node> writeTo(Node.State toState, List<Node> nodes,
Agent agent, Optional<String> reason,
+ ApplicationTransaction transaction) {
+ return writeTo(toState, nodes, agent, reason, transaction.nested());
+ }
+
+ public List<Node> writeTo(Node.State toState, List<Node> nodes,
+ Agent agent, Optional<String> reason,
NestedTransaction transaction) {
if (nodes.isEmpty()) return nodes;
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
index caf936e8aeb..c25f33bc8c2 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java
@@ -88,7 +88,7 @@ class Activator {
NodeList activeToRemove = oldActive.matching(node -> ! hostnames.contains(node.hostname()));
remove(activeToRemove, transaction); // TODO: Pass activation time in this call and next line
- nodeRepository.nodes().activate(newActive.asList(), transaction.nested()); // activate also continued active to update node state
+ nodeRepository.nodes().activate(newActive.asList(), transaction); // activate also continued active to update node state
rememberResourceChange(transaction, generation, activationTime,
oldActive.not().retired(),
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java
index 2bc5a0719d9..2e9cca21052 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java
@@ -5,12 +5,18 @@ import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provisioning.FlavorsConfig;
+import static com.yahoo.config.provision.Flavor.Type.BARE_METAL;
+import static com.yahoo.config.provision.Flavor.Type.DOCKER_CONTAINER;
import static com.yahoo.config.provision.NodeResources.Architecture;
+import static com.yahoo.config.provision.NodeResources.Architecture.arm64;
+import static com.yahoo.config.provision.NodeResources.Architecture.x86_64;
/**
* Simplifies creation of a node-repository config containing flavors.
* This is needed because the config builder API is inconvenient.
*
+ * Note: Flavors added will have fast disk and remote storage unless explicitly specified.
+ *
* @author bratseth
*/
public class FlavorConfigBuilder {
@@ -27,7 +33,7 @@ public class FlavorConfigBuilder {
double disk,
double bandwidth,
Flavor.Type type) {
- return addFlavor(flavorName, cpu, mem, disk, bandwidth, true, true, type, Architecture.x86_64, 0, 0);
+ return addFlavor(flavorName, cpu, mem, disk, bandwidth, true, true, type, x86_64, 0, 0);
}
public FlavorsConfig.Flavor.Builder addFlavor(String flavorName,
@@ -69,31 +75,20 @@ public class FlavorConfigBuilder {
/** Convenience method which creates a node flavors instance from a list of flavor names */
public static NodeFlavors createDummies(String... flavors) {
-
- FlavorConfigBuilder flavorConfigBuilder = new FlavorConfigBuilder();
+ FlavorConfigBuilder builder = new FlavorConfigBuilder();
for (String flavorName : flavors) {
- if (flavorName.equals("docker"))
- flavorConfigBuilder.addFlavor(flavorName, 1., 30., 20., 1.5, Flavor.Type.DOCKER_CONTAINER);
- else if (flavorName.equals("docker2"))
- flavorConfigBuilder.addFlavor(flavorName, 2., 40., 40., 0.5, Flavor.Type.DOCKER_CONTAINER);
- else if (flavorName.equals("host"))
- flavorConfigBuilder.addFlavor(flavorName, 7., 100., 120., 5, Flavor.Type.BARE_METAL);
- else if (flavorName.equals("host2"))
- flavorConfigBuilder.addFlavor(flavorName, 16, 24, 100, 1, Flavor.Type.BARE_METAL);
- else if (flavorName.equals("host3"))
- flavorConfigBuilder.addFlavor(flavorName, 24, 64, 100, 10, Flavor.Type.BARE_METAL);
- else if (flavorName.equals("host4"))
- flavorConfigBuilder.addFlavor(flavorName, 48, 128, 1000, 10, Flavor.Type.BARE_METAL);
- else if (flavorName.equals("devhost"))
- flavorConfigBuilder.addFlavor(flavorName, 4., 80., 100, 10, Flavor.Type.BARE_METAL);
- else if (flavorName.equals("arm64"))
- flavorConfigBuilder.addFlavor(flavorName,2., 30., 20., 3, Flavor.Type.BARE_METAL, Architecture.arm64);
- else if (flavorName.equals("gpu"))
- flavorConfigBuilder.addFlavor(flavorName,4, 16, 125, 10, true, false, Flavor.Type.BARE_METAL, Architecture.x86_64, 1, 16);
- else
- flavorConfigBuilder.addFlavor(flavorName, 1., 30., 20., 3, Flavor.Type.BARE_METAL);
+ switch (flavorName) {
+ case "docker" -> builder.addFlavor(flavorName, 1., 30., 20., 1.5, DOCKER_CONTAINER);
+ case "host" -> builder.addFlavor(flavorName, 7., 100., 120., 5, BARE_METAL);
+ case "host2" -> builder.addFlavor(flavorName, 16, 24, 100, 1, BARE_METAL);
+ case "host3" -> builder.addFlavor(flavorName, 24, 64, 100, 10, BARE_METAL);
+ case "host4" -> builder.addFlavor(flavorName, 48, 128, 1000, 10, BARE_METAL);
+ case "arm64" -> builder.addFlavor(flavorName, 2., 30., 20., 3, BARE_METAL, arm64);
+ case "gpu" -> builder.addFlavor(flavorName, 4, 16, 125, 10, true, false, BARE_METAL, x86_64, 1, 16);
+ default -> builder.addFlavor(flavorName, 1., 30., 20., 3, BARE_METAL);
+ }
}
- return new NodeFlavors(flavorConfigBuilder.build());
+ return new NodeFlavors(builder.build());
}
}
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java
index 397eb4d7af9..dd838375a59 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java
@@ -7,7 +7,6 @@ import com.yahoo.config.provision.NodeAllocationException;
import com.yahoo.vespa.hosted.provision.Node;
import java.util.List;
-import java.util.Set;
import java.util.function.Consumer;
/**
@@ -46,12 +45,11 @@ public interface HostProvisioner {
* Continue provisioning of given list of Nodes.
*
* @param host the host to provision
- * @param children list of all the nodes that run on the given host
* @return IP config for the provisioned host and its children
* @throws FatalProvisioningException if the provisioning has irrecoverably failed and the input nodes
* should be deleted from node-repo.
*/
- HostIpConfig provision(Node host, Set<Node> children) throws FatalProvisioningException;
+ HostIpConfig provision(Node host) throws FatalProvisioningException;
/**
* Deprovisions a given host and resources associated with it and its children (such as DNS entries).
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java
index 3d5987cd04d..bc10a97068e 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java
@@ -97,15 +97,13 @@ public class MockHostProvisioner implements HostProvisioner {
}
@Override
- public HostIpConfig provision(Node host, Set<Node> children) throws FatalProvisioningException {
+ public HostIpConfig provision(Node host) throws FatalProvisioningException {
if (behaviour(Behaviour.failProvisioning)) throw new FatalProvisioningException("Failed to provision node(s)");
if (host.state() != Node.State.provisioned) throw new IllegalStateException("Host to provision must be in " + Node.State.provisioned);
Map<String, IP.Config> result = new HashMap<>();
result.put(host.hostname(), createIpConfig(host));
- for (var child : children) {
- if (child.state() != Node.State.reserved) throw new IllegalStateException("Child to provisioned must be in " + Node.State.reserved);
- result.put(child.hostname(), createIpConfig(child));
- }
+ host.ipConfig().pool().hostnames().forEach(hostname ->
+ result.put(hostname.value(), IP.Config.ofEmptyPool(nameResolver.resolveAll(hostname.value()))));
return new HostIpConfig(result);
}
@@ -199,8 +197,6 @@ public class MockHostProvisioner implements HostProvisioner {
return this;
}
- public Optional<Flavor> getHostFlavor(ClusterSpec.Type type) { return Optional.ofNullable(hostFlavors.get(type)); }
-
public MockHostProvisioner addEvent(HostEvent event) {
hostEvents.add(event);
return this;
@@ -230,18 +226,17 @@ public class MockHostProvisioner implements HostProvisioner {
}
public IP.Config createIpConfig(Node node) {
- if (!node.type().isHost()) {
- return node.ipConfig().withPrimary(nameResolver.resolveAll(node.hostname()));
- }
+ if (!node.type().isHost()) throw new IllegalArgumentException("Node " + node + " is not a host");
int hostIndex = Integer.parseInt(node.hostname().replaceAll("^[a-z]+|-\\d+$", ""));
Set<String> addresses = Set.of("::" + hostIndex + ":0");
Set<String> ipAddressPool = new HashSet<>();
if (!behaviour(Behaviour.failDnsUpdate)) {
nameResolver.addRecord(node.hostname(), addresses.iterator().next());
- for (int i = 1; i <= 2; i++) {
- String ip = "::" + hostIndex + ":" + i;
+ int i = 1;
+ for (HostName hostName : node.ipConfig().pool().hostnames()) {
+ String ip = "::" + hostIndex + ":" + i++;
ipAddressPool.add(ip);
- nameResolver.addRecord(node.hostname() + "-" + i, ip);
+ nameResolver.addRecord(hostName.value(), ip);
}
}
IP.Pool pool = node.ipConfig().pool().withIpAddresses(ipAddressPool);
@@ -250,7 +245,7 @@ public class MockHostProvisioner implements HostProvisioner {
public enum Behaviour {
- /** Fail call to {@link MockHostProvisioner#provision(com.yahoo.vespa.hosted.provision.Node, java.util.Set)} */
+ /** Fail call to {@link MockHostProvisioner#provision(com.yahoo.vespa.hosted.provision.Node)} */
failProvisioning,
/** Fail call to {@link MockHostProvisioner#provisionHosts(HostProvisionRequest, Consumer)} */
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java
index 94cb05d20cc..722dc5ef96c 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java
@@ -6,7 +6,6 @@ import com.yahoo.vespa.hosted.provision.persistence.NameResolver;
import java.net.UnknownHostException;
import java.util.Arrays;
-import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.HashSet;
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java
index b7d6e0a9dd9..714374ccb8a 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java
@@ -176,7 +176,7 @@ public class MockNodeRepository extends NodeRepository {
.build());
// Ready all nodes, except 7 and 55
- nodes = nodes().addNodes(nodes, Agent.system);
+ nodes = new ArrayList<>(nodes().addNodes(nodes, Agent.system));
nodes.remove(node7);
nodes.remove(node55);
nodes = nodes().deallocate(nodes, Agent.system, getClass().getSimpleName());
diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java
index 5a9da1e1c3f..d8ad892e210 100644
--- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java
+++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java
@@ -18,6 +18,7 @@ import java.util.List;
/**
* @author freva
*/
+@SuppressWarnings("unused") // Injected in container from test code (services.xml)
public class MockProvisioner implements Provisioner {
@Override
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java
index 29ebf1789c0..9c843b3eb01 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java
@@ -141,7 +141,7 @@ public class RealDataScenarioTest {
if (nodeNext.get()) {
String json = input.substring(input.indexOf("{\""), input.lastIndexOf('}') + 1);
Node node = nodeSerializer.fromJson(json.getBytes(UTF_8));
- nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system);
nodeNext.set(false);
} else {
if (!zkNodePathPattern.matcher(input).matches()) return;
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java
index f8ec271ce5f..523feeeb303 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java
@@ -23,6 +23,7 @@ import com.yahoo.test.ManualClock;
import com.yahoo.vespa.curator.Curator;
import com.yahoo.vespa.curator.mock.MockCurator;
import com.yahoo.vespa.flags.InMemoryFlagSource;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.NodeRepository;
import com.yahoo.vespa.hosted.provision.autoscale.MemoryMetricsDb;
@@ -201,7 +202,7 @@ public class CapacityCheckerTester {
nodeRepository.nodes().addNodes(hostsWithChildren.getOrDefault(tenantHostApp, List.of()), Agent.system);
hostsWithChildren.forEach((applicationId, nodes) -> {
if (applicationId.equals(tenantHostApp)) return;
- nodeRepository.database().addNodesInState(nodes, Node.State.active, Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(nodes, () -> { }), Node.State.active, Agent.system);
});
nodeRepository.nodes().addNodes(createEmptyHosts(numHosts, numEmptyHosts, emptyHostExcessCapacity, emptyHostExcessIps), Agent.system);
@@ -322,9 +323,9 @@ public class CapacityCheckerTester {
}
}
- nodeRepository.database().addNodesInState(hosts, Node.State.active, Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(hosts, () -> { }), Node.State.active, Agent.system);
nodes.forEach((application, applicationNodes) -> {
- nodeRepository.database().addNodesInState(applicationNodes, Node.State.active, Agent.system);
+ nodeRepository.database().addNodesInState(new LockedNodeList(applicationNodes, () -> { }), Node.State.active, Agent.system);
});
updateCapacityChecker();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java
index ddd7413567a..262616d5eac 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java
@@ -6,6 +6,7 @@ import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.node.Allocation;
@@ -45,7 +46,7 @@ public class DirtyExpirerTest {
false))
.build();
- tester.nodeRepository().database().addNodesInState(List.of(node), node.state(), Agent.system);
+ tester.nodeRepository().database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system);
Duration expiryTimeout = Duration.ofMinutes(30);
DirtyExpirer expirer = new DirtyExpirer(tester.nodeRepository(), expiryTimeout, new TestMetric());
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java
index 66d4b67c7c2..c16ed47a216 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java
@@ -27,6 +27,7 @@ import com.yahoo.test.ManualClock;
import com.yahoo.vespa.flags.InMemoryFlagSource;
import com.yahoo.vespa.flags.PermanentFlags;
import com.yahoo.vespa.flags.custom.ClusterCapacity;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.Node.State;
import com.yahoo.vespa.hosted.provision.NodeList;
@@ -64,6 +65,9 @@ import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Stream;
+import static com.yahoo.config.provision.NodeResources.Architecture.arm64;
+import static com.yahoo.config.provision.NodeResources.DiskSpeed.fast;
+import static com.yahoo.config.provision.NodeResources.StorageType.remote;
import static com.yahoo.vespa.hosted.provision.testutils.MockHostProvisioner.Behaviour;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
@@ -77,11 +81,13 @@ import static org.junit.Assert.fail;
*/
public class HostCapacityMaintainerTest {
+ private DynamicProvisioningTester tester;
+
@Test
public void finds_nodes_that_need_deprovisioning_without_pre_provisioning() {
- var tester = new DynamicProvisioningTester().addInitialNodes();
- assertTrue(tester.nodeRepository.nodes().node("host2").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host3").isPresent());
+ tester = new DynamicProvisioningTester().addInitialNodes();
+ assertNodeExists("host2");
+ assertNodeExists("host3");
tester.maintain();
assertSame(State.deprovisioned, tester.nodeRepository.nodes().node("host2").get().state());
@@ -89,31 +95,30 @@ public class HostCapacityMaintainerTest {
@Test
public void does_not_deprovision_when_preprovisioning_enabled() {
- var tester = new DynamicProvisioningTester().addInitialNodes();
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), List.of(new ClusterCapacity(1, 1.0, 3.0, 2.0, 1.0, "fast", "local", "x86_64")), ClusterCapacity.class);
- Optional<Node> failedHost = tester.nodeRepository.nodes().node("host2");
+ tester = new DynamicProvisioningTester().addInitialNodes();
+ setPreprovisionCapacityFlag(tester, new ClusterCapacity(1, 1.0, 3.0, 2.0, 1.0, "fast", "remote", "x86_64"));
+ Optional<Node> failedHost = node("host2");
assertTrue(failedHost.isPresent());
tester.maintain();
- assertSame("Failed host is deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node(failedHost.get().hostname()).get().state());
+ assertSame("Failed host is deprovisioned", State.deprovisioned, node(failedHost.get().hostname()).get().state());
assertEquals(1, tester.hostProvisioner.deprovisionedHosts());
}
@Test
public void provision_deficit_and_deprovision_excess() {
- var tester = new DynamicProvisioningTester().addInitialNodes();
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(2, 48.0, 128.0, 1000.0, 10.0, "fast", "local", "x86_64"),
- new ClusterCapacity(1, 16.0, 24.0, 100.0, 1.0, "fast", "local", "x86_64")),
- ClusterCapacity.class);
+ tester = new DynamicProvisioningTester().addInitialNodes();
+ setPreprovisionCapacityFlag(tester,
+ new ClusterCapacity(2, 48.0, 128.0, 1000.0, 10.0, "fast", "remote", "x86_64"),
+ new ClusterCapacity(1, 16.0, 24.0, 100.0, 1.0, "fast", "remote", "x86_64"));
assertEquals(0, tester.hostProvisioner.provisionedHosts().size());
assertEquals(9, tester.nodeRepository.nodes().list().size());
- assertTrue(tester.nodeRepository.nodes().node("host2").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host2-1").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host3").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host100").isEmpty());
- assertTrue(tester.nodeRepository.nodes().node("host101").isEmpty());
+ assertNodeExists("host2");
+ assertNodeExists("host2-1");
+ assertNodeExists("host3");
+ assertNodeDoesNotExist("host100");
+ assertNodeDoesNotExist("host101");
tester.maintain();
@@ -121,37 +126,35 @@ public class HostCapacityMaintainerTest {
assertEquals(2, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10)));
NodeList nodesAfter = tester.nodeRepository.nodes().list().not().state(State.deprovisioned);
assertEquals(9, nodesAfter.size()); // 2 removed, 2 added
- assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node("host2").get().state());
- assertTrue("Node on deprovisioned host removed", tester.nodeRepository.nodes().node("host2-1").isEmpty());
- assertTrue("Host satisfying 16-24-100-1 is kept", tester.nodeRepository.nodes().node("host3").isPresent());
- assertTrue("New 48-128-1000-10 host added", tester.nodeRepository.nodes().node("host100").isPresent());
- assertTrue("New 48-128-1000-10 host added", tester.nodeRepository.nodes().node("host101").isPresent());
+ assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, node("host2").get().state());
+ assertNodeDoesNotExist("Node on deprovisioned host removed", "host2-1");
+ assertNodeExists("Host satisfying 16-24-100-1 is kept", "host3");
+ assertNodeExists("New 48-128-1000-10 host added", "host100");
+ assertNodeExists("New 48-128-1000-10 host added", "host100");
- Instant deprovisionedAt = tester.nodeRepository.nodes().node("host2").get().history().event(History.Event.Type.deprovisioned).get().at();
+ Instant deprovisionedAt = node("host2").get().history().event(History.Event.Type.deprovisioned).get().at();
tester.provisioningTester.clock().advance(Duration.ofSeconds(1));
tester.maintain();
assertEquals("Host moves to deprovisioned once", deprovisionedAt,
- tester.nodeRepository.nodes().node("host2").get().history()
+ node("host2").get().history()
.event(History.Event.Type.deprovisioned).get().at());
}
@Test
public void preprovision_with_shared_host() {
- var tester = new DynamicProvisioningTester().addInitialNodes();
+ tester = new DynamicProvisioningTester().addInitialNodes();
// Makes provisioned hosts 48-128-1000-10
tester.hostProvisioner.setHostFlavor("host4");
- var clusterCapacity = new ClusterCapacity(2, 1.0, 30.0, 20.0, 3.0, "fast", "local", "x86_64");
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(clusterCapacity),
- ClusterCapacity.class);
+ var clusterCapacity = new ClusterCapacity(2, 1.0, 30.0, 20.0, 3.0, "fast", "remote", "x86_64");
+ setPreprovisionCapacityFlag(tester, clusterCapacity);
assertEquals(0, tester.hostProvisioner.provisionedHosts().size());
assertEquals(9, tester.nodeRepository.nodes().list().size());
- assertTrue(tester.nodeRepository.nodes().node("host2").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host2-1").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host3").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host100").isEmpty());
+ assertTrue(node("host2").isPresent());
+ assertTrue(node("host2-1").isPresent());
+ assertTrue(node("host3").isPresent());
+ assertTrue(node("host100").isEmpty());
// The first cluster will be allocated to host3 and a new host host100.
// host100 will be a large shared host specified above.
@@ -163,10 +166,7 @@ public class HostCapacityMaintainerTest {
verifyFirstMaintain(tester);
// Add a second cluster equal to the first. It should fit on existing host3 and host100.
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(clusterCapacity,
- clusterCapacity),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester, clusterCapacity, clusterCapacity);
tester.maintain();
verifyFirstMaintain(tester);
@@ -177,25 +177,23 @@ public class HostCapacityMaintainerTest {
// in this test, due to skew), so host3 will be deprovisioned when host101 is provisioned.
// host3 is a 24-64-100-10 while host100 is 48-128-1000-10.
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(clusterCapacity,
- new ClusterCapacity(2, 24.0, 64.0, 100.0, 1.0, "fast", "local", "x86_64")),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester,
+ clusterCapacity,
+ new ClusterCapacity(2, 24.0, 64.0, 100.0, 1.0, "fast", "remote", "x86_64"));
tester.maintain();
assertEquals(2, tester.hostProvisioner.provisionedHosts().size());
assertEquals(2, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10)));
assertEquals(8, tester.nodeRepository.nodes().list().not().state(State.deprovisioned).size()); // 3 removed, 2 added
- assertSame("preprovision capacity is prefered on shared hosts", State.deprovisioned, tester.nodeRepository.nodes().node("host3").get().state());
- assertTrue(tester.nodeRepository.nodes().node("host100").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host101").isPresent());
+ assertSame("preprovision capacity is prefered on shared hosts", State.deprovisioned, node("host3").get().state());
+ assertTrue(node("host100").isPresent());
+ assertTrue(node("host101").isPresent());
// If the preprovision capacity is reduced, we should see shared hosts deprovisioned.
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(1, 1.0, 30.0, 20.0, 3.0, "fast", "local", "x86_64")),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester,
+ new ClusterCapacity(1, 1.0, 30.0, 20.0, 3.0, "fast", "remote", "x86_64"));
tester.maintain();
@@ -203,45 +201,54 @@ public class HostCapacityMaintainerTest {
1, tester.hostProvisioner.provisionedHosts().size());
assertEquals(1, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10)));
assertEquals(7, tester.nodeRepository.nodes().list().not().state(State.deprovisioned).size()); // 4 removed, 2 added
- if (tester.nodeRepository.nodes().node("host100").isPresent()) {
+ if (node("host100").isPresent()) {
assertSame("host101 is superfluous and should have been deprovisioned", State.deprovisioned,
- tester.nodeRepository.nodes().node("host101").get().state());
+ node("host101").get().state());
} else {
assertTrue("host101 is required for preprovision capacity",
- tester.nodeRepository.nodes().node("host101").isPresent());
+ node("host101").isPresent());
}
+ // If a host with another architecture is added to preprovision capacity, a shared host should be added.
+ setPreprovisionCapacityFlag(tester,
+ new ClusterCapacity(1, 2.0, 30.0, 20.0, 3.0, "fast", "remote", "x86_64"),
+ new ClusterCapacity(1, 2.0, 30.0, 20.0, 3.0, "fast", "remote", "arm64"));
+ tester.hostProvisioner.setHostFlavor("arm64");
+ tester.maintain();
+
+ assertEquals(2, tester.hostProvisioner.provisionedHosts().size());
+ assertEquals(1, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10)));
+ assertEquals(1, tester.provisionedHostsMatching(new NodeResources(2, 30, 20, 3, fast, remote, arm64)));
}
private void verifyFirstMaintain(DynamicProvisioningTester tester) {
- assertEquals(1, tester.hostProvisioner.provisionedHosts().size());
+ assertEquals(tester.hostProvisioner.provisionedHosts().toString(), 1, tester.hostProvisioner.provisionedHosts().size());
assertEquals(1, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10)));
assertEquals(8, tester.nodeRepository.nodes().list().not().state(State.deprovisioned).size()); // 2 removed, 1 added
- assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node("host2").get().state());
- assertTrue("Node on deprovisioned host removed", tester.nodeRepository.nodes().node("host2-1").isEmpty());
- assertTrue("One 1-30-20-3 node fits on host3", tester.nodeRepository.nodes().node("host3").isPresent());
- assertTrue("New 48-128-1000-10 host added", tester.nodeRepository.nodes().node("host100").isPresent());
+ assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, node("host2").get().state());
+ assertTrue("Node on deprovisioned host removed", node("host2-1").isEmpty());
+ assertTrue("One 1-30-20-3 node fits on host3", node("host3").isPresent());
+ assertTrue("New 48-128-1000-10 host added", node("host100").isPresent());
}
@Test
public void does_not_remove_if_host_provisioner_failed() {
- var tester = new DynamicProvisioningTester();
+ tester = new DynamicProvisioningTester();
Node host2 = tester.addNode("host2", Optional.empty(), NodeType.host, Node.State.failed, DynamicProvisioningTester.tenantApp);
tester.hostProvisioner.with(Behaviour.failDeprovisioning);
tester.maintain();
- assertTrue(tester.nodeRepository.nodes().node(host2.hostname()).isPresent());
+ assertTrue(node(host2.hostname()).isPresent());
}
@Test
public void test_minimum_capacity() {
- var tester = new DynamicProvisioningTester();
+ tester = new DynamicProvisioningTester();
NodeResources resources1 = new NodeResources(24, 64, 100, 10);
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(2, resources1.vcpu(), resources1.memoryGb(), resources1.diskGb(),
- resources1.bandwidthGbps(), resources1.diskSpeed().name(),
- resources1.storageType().name(), resources1.architecture().name())),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester,
+ new ClusterCapacity(2, resources1.vcpu(), resources1.memoryGb(), resources1.diskGb(),
+ resources1.bandwidthGbps(), resources1.diskSpeed().name(),
+ resources1.storageType().name(), resources1.architecture().name()));
tester.maintain();
// Hosts are provisioned
@@ -252,16 +259,14 @@ public class HostCapacityMaintainerTest {
tester.assertNodesUnchanged();
// Pretend shared-host flag has been set to host4's flavor
- var sharedHostNodeResources = new NodeResources(48, 128, 1000, 10, NodeResources.DiskSpeed.fast, NodeResources.StorageType.remote);
+ var sharedHostNodeResources = new NodeResources(48, 128, 1000, 10, fast, remote);
tester.hostProvisioner.setHostFlavor("host4");
// Next maintenance run does nothing
tester.assertNodesUnchanged();
// Must be able to allocate 2 nodes with "no resource requirement"
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(2, 0.0, 0.0, 0.0, 0.0, null, null, null)),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester, new ClusterCapacity(2, 0.0, 0.0, 0.0, 0.0, null, null, null));
// Next maintenance run does nothing
tester.assertNodesUnchanged();
@@ -280,62 +285,58 @@ public class HostCapacityMaintainerTest {
tester.assertNodesUnchanged();
// Clearing flag does nothing
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), List.of(), ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester);
tester.assertNodesUnchanged();
// Increasing the capacity provisions additional hosts
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(3, 0.0, 0.0, 0.0, 0.0, null, null, null)),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester, new ClusterCapacity(3, 0.0, 0.0, 0.0, 0.0, null, null, null));
assertEquals(0, tester.provisionedHostsMatching(sharedHostNodeResources));
- assertTrue(tester.nodeRepository.nodes().node("host102").isEmpty());
+ assertTrue(node("host102").isEmpty());
tester.maintain();
assertEquals(1, tester.provisionedHostsMatching(sharedHostNodeResources));
- assertTrue(tester.nodeRepository.nodes().node("host102").isPresent());
+ assertTrue(node("host102").isPresent());
// Next maintenance run does nothing
tester.assertNodesUnchanged();
// Requiring >0 capacity does nothing as long as it fits on the 3 hosts
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(3,
- resources1.vcpu() - applicationNodeResources.vcpu(),
- resources1.memoryGb() - applicationNodeResources.memoryGb(),
- resources1.diskGb() - applicationNodeResources.diskGb(),
- resources1.bandwidthGbps() - applicationNodeResources.bandwidthGbps(),
- resources1.diskSpeed().name(),
- resources1.storageType().name(),
- resources1.architecture().name())),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester,
+ new ClusterCapacity(3,
+ resources1.vcpu() - applicationNodeResources.vcpu(),
+ resources1.memoryGb() - applicationNodeResources.memoryGb(),
+ resources1.diskGb() - applicationNodeResources.diskGb(),
+ resources1.bandwidthGbps() - applicationNodeResources.bandwidthGbps(),
+ resources1.diskSpeed().name(),
+ resources1.storageType().name(),
+ resources1.architecture().name()));
tester.assertNodesUnchanged();
// But requiring a bit more in the cluster => provisioning of 2 shared hosts.
- tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(),
- List.of(new ClusterCapacity(3,
- resources1.vcpu() - applicationNodeResources.vcpu() + 1,
- resources1.memoryGb() - applicationNodeResources.memoryGb() + 1,
- resources1.diskGb() - applicationNodeResources.diskGb() + 1,
- resources1.bandwidthGbps(),
- resources1.diskSpeed().name(),
- resources1.storageType().name(),
- resources1.architecture().name())),
- ClusterCapacity.class);
+ setPreprovisionCapacityFlag(tester,
+ new ClusterCapacity(3,
+ resources1.vcpu() - applicationNodeResources.vcpu() + 1,
+ resources1.memoryGb() - applicationNodeResources.memoryGb() + 1,
+ resources1.diskGb() - applicationNodeResources.diskGb() + 1,
+ resources1.bandwidthGbps(),
+ resources1.diskSpeed().name(),
+ resources1.storageType().name(),
+ resources1.architecture().name()));
assertEquals(1, tester.provisionedHostsMatching(sharedHostNodeResources));
- assertTrue(tester.nodeRepository.nodes().node("host102").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host103").isEmpty());
- assertTrue(tester.nodeRepository.nodes().node("host104").isEmpty());
+ assertTrue(node("host102").isPresent());
+ assertTrue(node("host103").isEmpty());
+ assertTrue(node("host104").isEmpty());
tester.maintain();
assertEquals(3, tester.provisionedHostsMatching(sharedHostNodeResources));
- assertTrue(tester.nodeRepository.nodes().node("host102").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host103").isPresent());
- assertTrue(tester.nodeRepository.nodes().node("host104").isPresent());
+ assertTrue(node("host102").isPresent());
+ assertTrue(node("host103").isPresent());
+ assertTrue(node("host104").isPresent());
}
@Test
public void deprovision_empty_confighost() {
// cfghost1, cfg1, cfghost2, cfg2, cfghost3, and NOT cfg3.
- var tester = new DynamicProvisioningTester();
+ tester = new DynamicProvisioningTester();
tester.addCfghost(1, true);
tester.addCfghost(2, true);
Node cfghost3 = tester.addCfghost(3, false);
@@ -475,8 +476,8 @@ public class HostCapacityMaintainerTest {
@Test
public void custom_cloud_account() {
- DynamicProvisioningTester tester = new DynamicProvisioningTester(Cloud.builder().name(CloudName.AWS).dynamicProvisioning(true).allowEnclave(true).account(CloudAccount.from("001122334455")).build(),
- new MockNameResolver().mockAnyLookup());
+ tester = new DynamicProvisioningTester(Cloud.builder().name(CloudName.AWS).dynamicProvisioning(true).allowEnclave(true).account(CloudAccount.from("001122334455")).build(),
+ new MockNameResolver().mockAnyLookup());
ProvisioningTester provisioningTester = tester.provisioningTester;
ApplicationId applicationId = ApplicationId.from("t1", "a1", "i1");
@@ -513,7 +514,7 @@ public class HostCapacityMaintainerTest {
@Test
public void deprovision_node_when_no_allocation_and_past_ttl() {
- var tester = new DynamicProvisioningTester();
+ tester = new DynamicProvisioningTester();
ManualClock clock = (ManualClock) tester.nodeRepository.clock();
tester.hostProvisioner.with(Behaviour.failProvisioning);
tester.provisioningTester.makeReadyHosts(2, new NodeResources(1, 1, 1, 1)).activateTenantHosts();
@@ -526,61 +527,61 @@ public class HostCapacityMaintainerTest {
// Host is not marked for deprovisioning by maintainer, because child is present
tester.maintain();
- assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
- assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
+ assertFalse(node(host1.hostname()).get().status().wantToDeprovision());
+ assertEquals(Optional.empty(), node(host1.hostname()).get().hostEmptyAt());
// Child is set to deprovision, but turns active
tester.nodeRepository.nodes().park(host11.hostname(), true, Agent.system, "not good");
tester.nodeRepository.nodes().reactivate(host11.hostname(), Agent.operator, "all good");
- assertTrue(tester.nodeRepository.nodes().node(host11.hostname()).get().status().wantToDeprovision());
- assertEquals(State.active, tester.nodeRepository.nodes().node(host11.hostname()).get().state());
+ assertTrue(node(host11.hostname()).get().status().wantToDeprovision());
+ assertEquals(State.active, node(host11.hostname()).get().state());
tester.maintain();
- assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
- assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
+ assertFalse(node(host1.hostname()).get().status().wantToDeprovision());
+ assertEquals(Optional.empty(), node(host1.hostname()).get().hostEmptyAt());
// Child is parked, to make the host effectively empty
tester.nodeRepository.nodes().park(host11.hostname(), true, Agent.system, "not good");
tester.maintain();
- assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
+ assertFalse(node(host1.hostname()).get().status().wantToDeprovision());
assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS)),
- tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
+ node(host1.hostname()).get().hostEmptyAt());
// Some time passes, but not enough for host1 to be deprovisioned
clock.advance(Duration.ofDays(1).minusSeconds(1));
tester.maintain();
- assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
+ assertFalse(node(host1.hostname()).get().status().wantToDeprovision());
assertEquals(Optional.of(clock.instant().minus(Duration.ofDays(1).minusSeconds(1)).truncatedTo(ChronoUnit.MILLIS)),
- tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
- assertTrue(tester.nodeRepository.nodes().node(host2.hostname()).get().status().wantToDeprovision());
- assertTrue(tester.nodeRepository.nodes().node(host2.hostname()).get().status().wantToRetire());
- assertEquals(State.active, tester.nodeRepository.nodes().node(host2.hostname()).get().state());
+ node(host1.hostname()).get().hostEmptyAt());
+ assertTrue(node(host2.hostname()).get().status().wantToDeprovision());
+ assertTrue(node(host2.hostname()).get().status().wantToRetire());
+ assertEquals(State.active, node(host2.hostname()).get().state());
assertEquals(Optional.of(clock.instant().minus(Duration.ofDays(1).minusSeconds(1)).truncatedTo(ChronoUnit.MILLIS)),
- tester.nodeRepository.nodes().node(host2.hostname()).get().hostEmptyAt());
+ node(host2.hostname()).get().hostEmptyAt());
// Some more time passes, but child is reactivated on host1, rendering the host non-empty again
clock.advance(Duration.ofDays(1));
tester.nodeRepository.nodes().reactivate(host11.hostname(), Agent.operator, "all good");
tester.maintain();
- assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
- assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
+ assertFalse(node(host1.hostname()).get().status().wantToDeprovision());
+ assertEquals(Optional.empty(), node(host1.hostname()).get().hostEmptyAt());
// Child is removed, and host is marked as empty
tester.nodeRepository.database().writeTo(State.deprovisioned, host11, Agent.operator, Optional.empty());
- tester.nodeRepository.nodes().forget(tester.nodeRepository.nodes().node(host11.hostname()).get());
- assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host11.hostname()));
+ tester.nodeRepository.nodes().forget(node(host11.hostname()).get());
+ assertEquals(Optional.empty(), node(host11.hostname()));
tester.maintain();
- assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
+ assertFalse(node(host1.hostname()).get().status().wantToDeprovision());
assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS)),
- tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
+ node(host1.hostname()).get().hostEmptyAt());
// Enough time passes for the host to be deprovisioned
clock.advance(Duration.ofDays(1));
tester.maintain();
- assertTrue(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision());
- assertTrue(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToRetire());
- assertEquals(State.active, tester.nodeRepository.nodes().node(host1.hostname()).get().state());
+ assertTrue(node(host1.hostname()).get().status().wantToDeprovision());
+ assertTrue(node(host1.hostname()).get().status().wantToRetire());
+ assertEquals(State.active, node(host1.hostname()).get().state());
assertEquals(Optional.of(clock.instant().minus(Duration.ofDays(1)).truncatedTo(ChronoUnit.MILLIS)),
- tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt());
+ node(host1.hostname()).get().hostEmptyAt());
// Let tenant host app redeploy, retiring the obsolete host.
tester.provisioningTester.activateTenantHosts();
@@ -599,7 +600,7 @@ public class HostCapacityMaintainerTest {
@Test
public void deprovision_parked_node_with_allocation() {
- var tester = new DynamicProvisioningTester();
+ tester = new DynamicProvisioningTester();
tester.hostProvisioner.with(Behaviour.failProvisioning);
Node host4 = tester.addNode("host4", Optional.empty(), NodeType.host, Node.State.parked, null, Duration.ofDays(1));
Node host41 = tester.addNode("host4-1", Optional.of("host4"), NodeType.tenant, Node.State.parked, DynamicProvisioningTester.tenantApp);
@@ -609,16 +610,16 @@ public class HostCapacityMaintainerTest {
// Host and children are marked for deprovisioning, bypassing host TTL.
tester.nodeRepository.nodes().deprovision("host4", Agent.operator, Instant.now());
for (var node : List.of(host4, host41, host42, host43)) {
- assertTrue(tester.nodeRepository.nodes().node(node.hostname()).map(n -> n.status().wantToDeprovision()).get());
+ assertTrue(node(node.hostname()).map(n -> n.status().wantToDeprovision()).get());
}
// Host and children remain parked because one child is still active
tester.maintain();
for (var node : List.of(host4, host41)) {
- assertEquals(Node.State.parked, tester.nodeRepository.nodes().node(node.hostname()).get().state());
+ assertEquals(Node.State.parked, node(node.hostname()).get().state());
}
- assertEquals(Node.State.active, tester.nodeRepository.nodes().node(host42.hostname()).get().state());
- assertEquals(Node.State.failed, tester.nodeRepository.nodes().node(host43.hostname()).get().state());
+ assertEquals(Node.State.active, node(host42.hostname()).get().state());
+ assertEquals(Node.State.failed, node(host43.hostname()).get().state());
// Last child is parked
tester.nodeRepository.nodes().park(host42.hostname(), false, Agent.system, getClass().getSimpleName());
@@ -627,9 +628,9 @@ public class HostCapacityMaintainerTest {
tester.maintain();
for (var node : List.of(host4, host41, host42, host43)) {
if (node.type().isHost()) {
- assertSame(node.hostname() + " moved to deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node(node.hostname()).get().state());
+ assertSame(node.hostname() + " moved to deprovisioned", State.deprovisioned, node(node.hostname()).get().state());
} else {
- assertTrue(node.hostname() + " removed", tester.nodeRepository.nodes().node(node.hostname()).isEmpty());
+ assertTrue(node.hostname() + " removed", node(node.hostname()).isEmpty());
}
}
}
@@ -656,7 +657,7 @@ public class HostCapacityMaintainerTest {
private void assertCfghost3IsActive(DynamicProvisioningTester tester) {
assertEquals(5, tester.nodeRepository.nodes().list(Node.State.active).size());
assertEquals(3, tester.nodeRepository.nodes().list(Node.State.active).nodeType(NodeType.confighost).size());
- Optional<Node> cfghost3 = tester.nodeRepository.nodes().node("cfghost3");
+ Optional<Node> cfghost3 = node("cfghost3");
assertTrue(cfghost3.isPresent());
assertEquals(Node.State.active, cfghost3.get().state());
}
@@ -664,7 +665,35 @@ public class HostCapacityMaintainerTest {
private void assertCfghost3IsDeprovisioned(DynamicProvisioningTester tester) {
assertEquals(4, tester.nodeRepository.nodes().list(Node.State.active).size());
assertEquals(2, tester.nodeRepository.nodes().list(Node.State.active).nodeType(NodeType.confighost).size());
- assertSame(State.deprovisioned, tester.nodeRepository.nodes().node("cfghost3").get().state());
+ assertSame(State.deprovisioned, node("cfghost3").get().state());
+ }
+
+ private static void setPreprovisionCapacityFlag(DynamicProvisioningTester tester, ClusterCapacity... clusterCapacities) {
+ tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), List.of(clusterCapacities), ClusterCapacity.class);
+ }
+
+ private void assertNodeExists(String nodeName) {
+ assertTrue(nodeExists(nodeName));
+ }
+
+ private void assertNodeExists(String message, String nodeName) {
+ assertTrue(message, nodeExists(nodeName));
+ }
+
+ private void assertNodeDoesNotExist(String nodeName) {
+ assertFalse(nodeExists(nodeName));
+ }
+
+ private void assertNodeDoesNotExist(String message, String nodeName) {
+ assertFalse(message, nodeExists(nodeName));
+ }
+
+ private boolean nodeExists(String nodeName) {
+ return node(nodeName).isPresent();
+ }
+
+ private Optional<Node> node(String nodeName) {
+ return tester.nodeRepository.nodes().node(nodeName);
}
private static class DynamicProvisioningTester {
@@ -673,7 +702,7 @@ public class HostCapacityMaintainerTest {
private static final InfraApplication configServerHostApp = new ConfigServerHostApplication();
private static final InfraApplication configServerApp = new ConfigServerApplication();
private static final ApplicationId tenantApp = ApplicationId.from("mytenant", "myapp", "default");
- private static final NodeFlavors flavors = FlavorConfigBuilder.createDummies("default", "docker", "host2", "host3", "host4");
+ private static final NodeFlavors flavors = FlavorConfigBuilder.createDummies("default", "docker", "host2", "host3", "host4", "arm64");
private final InMemoryFlagSource flagSource = new InMemoryFlagSource();
@@ -722,7 +751,7 @@ public class HostCapacityMaintainerTest {
createNode("host4", Optional.empty(), NodeType.host, Node.State.provisioned, null),
createNode("host4-1", Optional.of("host4"), NodeType.tenant, Node.State.reserved, tenantApp),
createNode("host4-2", Optional.of("host4"), NodeType.tenant, Node.State.reserved, tenantApp))
- .forEach(node -> nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system));
+ .forEach(node -> nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system));
return this;
}
@@ -744,7 +773,7 @@ public class HostCapacityMaintainerTest {
private Node addNode(String hostname, Optional<String> parentHostname, NodeType nodeType, Node.State state, ApplicationId application, Duration hostTTL) {
Node node = createNode(hostname, parentHostname, nodeType, state, application, hostTTL);
- return nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system).get(0);
+ return nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system).get(0);
}
private Node createNode(String hostname, Optional<String> parentHostname, NodeType nodeType,
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java
index de2c060a0eb..83aea78ce58 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.component.Version;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationTransaction;
import com.yahoo.config.provision.Capacity;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterResources;
@@ -10,11 +11,13 @@ import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.config.provision.ProvisionLock;
import com.yahoo.jdisc.Metric;
import com.yahoo.transaction.Mutex;
import com.yahoo.transaction.NestedTransaction;
import com.yahoo.vespa.applicationmodel.ApplicationInstance;
import com.yahoo.vespa.applicationmodel.ApplicationInstanceReference;
+import com.yahoo.vespa.applicationmodel.InfrastructureApplication;
import com.yahoo.vespa.curator.stats.LockStats;
import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
@@ -182,7 +185,7 @@ public class MetricsReporterTest {
@Test
public void container_metrics() {
- NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker", "docker2");
+ NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker");
ProvisioningTester tester = new ProvisioningTester.Builder().flavors(nodeFlavors.getFlavors()).build();
NodeRepository nodeRepository = tester.nodeRepository();
@@ -210,7 +213,8 @@ public class MetricsReporterTest {
}
NestedTransaction transaction = new NestedTransaction();
- nodeRepository.nodes().activate(nodeRepository.nodes().list().nodeType(NodeType.host).asList(), transaction);
+ nodeRepository.nodes().activate(nodeRepository.nodes().list().nodeType(NodeType.host).asList(),
+ new ApplicationTransaction(new ProvisionLock(InfrastructureApplication.TENANT_HOST.id(), () -> { }), transaction));
transaction.commit();
Orchestrator orchestrator = mock(Orchestrator.class);
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java
index 359f75c27ab..ac1e452d7a5 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java
@@ -9,6 +9,7 @@ import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.Zone;
+import com.yahoo.vespa.hosted.provision.LockedNodeList;
import com.yahoo.vespa.hosted.provision.Node;
import com.yahoo.vespa.hosted.provision.node.Agent;
import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester;
@@ -45,7 +46,7 @@ public class ProvisionedExpirerTest {
var nodes = IntStream.range(0, 15)
.mapToObj(i -> Node.create("id-" + i, "host-" + i, new Flavor(NodeResources.unspecified()), Node.State.provisioned, NodeType.host).build())
.toList();
- tester.nodeRepository().database().addNodesInState(nodes, Node.State.provisioned, Agent.system);
+ tester.nodeRepository().database().addNodesInState(new LockedNodeList(nodes, () -> { }), Node.State.provisioned, Agent.system);
}
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java
index b54975cbf41..a5ac2be72ee 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.provision.maintenance;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationTransaction;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterSpec;
import com.yahoo.config.provision.DockerImage;
@@ -10,6 +11,7 @@ import com.yahoo.config.provision.Flavor;
import com.yahoo.config.provision.NodeFlavors;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
+import com.yahoo.config.provision.ProvisionLock;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.Zone;
import com.yahoo.test.ManualClock;
@@ -313,7 +315,7 @@ public class SpareCapacityMaintainerTest {
}
private void allocate(ApplicationId application, ClusterSpec clusterSpec, List<Node> nodes) {
- nodes = nodeRepository.nodes().addNodes(nodes, Agent.system);
+ nodes = new ArrayList<>(nodeRepository.nodes().addNodes(nodes, Agent.system));
for (int i = 0; i < nodes.size(); i++) {
Node node = nodes.get(i);
ClusterMembership membership = ClusterMembership.from(clusterSpec, i);
@@ -322,7 +324,7 @@ public class SpareCapacityMaintainerTest {
}
nodes = nodeRepository.nodes().reserve(nodes);
var transaction = new NestedTransaction();
- nodes = nodeRepository.nodes().activate(nodes, transaction);
+ nodes = nodeRepository.nodes().activate(nodes, new ApplicationTransaction(new ProvisionLock(application, () -> { }), transaction));
transaction.commit();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java
index 47d34a76dd6..478b201d71b 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java
@@ -2,6 +2,7 @@
package com.yahoo.vespa.hosted.provision.provisioning;
import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.ApplicationTransaction;
import com.yahoo.config.provision.Capacity;
import com.yahoo.config.provision.ClusterMembership;
import com.yahoo.config.provision.ClusterResources;
@@ -12,6 +13,7 @@ import com.yahoo.config.provision.HostSpec;
import com.yahoo.config.provision.NodeResources;
import com.yahoo.config.provision.NodeType;
import com.yahoo.config.provision.NodeAllocationException;
+import com.yahoo.config.provision.ProvisionLock;
import com.yahoo.config.provision.RegionName;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.Zone;
@@ -540,9 +542,9 @@ public class DynamicAllocationTest {
clusterSpec.with(Optional.of(ClusterSpec.Group.from(0))), index); // Need to add group here so that group is serialized in node allocation
Node node1aAllocation = node1a.allocate(id, clusterMembership1, node1a.resources(), Instant.now());
- tester.nodeRepository().nodes().addNodes(Collections.singletonList(node1aAllocation), Agent.system);
+ tester.nodeRepository().nodes().addNodes(List.of(node1aAllocation), Agent.system);
NestedTransaction transaction = new NestedTransaction().add(new CuratorTransaction(tester.getCurator()));
- tester.nodeRepository().nodes().activate(Collections.singletonList(node1aAllocation), transaction);
+ tester.nodeRepository().nodes().activate(List.of(node1aAllocation), new ApplicationTransaction(new ProvisionLock(id, () -> { }), transaction));
transaction.commit();
}
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java
index ec8a8f637c8..35d6cdbf5d0 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java
@@ -45,7 +45,7 @@ public class HostCapacityTest {
doAnswer(invocation -> ((Flavor)invocation.getArguments()[0]).resources()).when(hostResourcesCalculator).advertisedResourcesOf(any());
// Create flavors
- NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker", "docker2");
+ NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker");
// Create hosts
host1 = Node.create("host1", IP.Config.of(Set.of("::1"), createIps(2, 4), List.of()), "host1", nodeFlavors.getFlavorOrThrow("host"), NodeType.host).build();
diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java
index 2acbeb00f5f..dd8f97d82de 100644
--- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java
+++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java
@@ -214,7 +214,7 @@ public class ProvisioningTester {
NestedTransaction t = new NestedTransaction();
if (parent.ipConfig().primary().isEmpty())
parent = parent.with(IP.Config.of(Set.of("::" + 0 + ":0"), Set.of("::" + 0 + ":2")));
- nodeRepository.nodes().activate(List.of(parent), t);
+ nodeRepository.nodes().activate(List.of(parent), new ApplicationTransaction(new ProvisionLock(application, () -> { }), t));
t.commit();
}
}
diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp
index d0c1f99af11..28da6013edb 100644
--- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp
+++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp
@@ -35,6 +35,7 @@ using search::fef::MatchData;
using search::fef::RankSetup;
using search::fef::indexproperties::hitcollector::HeapSize;
using search::fef::indexproperties::hitcollector::ArraySize;
+using search::fef::indexproperties::hitcollector::RankScoreDropLimit;
using search::queryeval::Blueprint;
using search::queryeval::SearchIterator;
using vespalib::Doom;
@@ -239,10 +240,11 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl
const Properties & rankProperties = request.propertiesMap.rankProperties();
uint32_t heapSize = HeapSize::lookup(rankProperties, _rankSetup->getHeapSize());
uint32_t arraySize = ArraySize::lookup(rankProperties, _rankSetup->getArraySize());
+ search::feature_t rank_score_drop_limit = RankScoreDropLimit::lookup(rankProperties, _rankSetup->getRankScoreDropLimit());
- MatchParams params(searchContext.getDocIdLimit(), heapSize, arraySize,
- _rankSetup->getRankScoreDropLimit(), request.offset, request.maxhits,
- !_rankSetup->getSecondPhaseRank().empty(), !willNotNeedRanking(request, groupingContext));
+ MatchParams params(searchContext.getDocIdLimit(), heapSize, arraySize, rank_score_drop_limit,
+ request.offset, request.maxhits, !_rankSetup->getSecondPhaseRank().empty(),
+ !willNotNeedRanking(request, groupingContext));
ResultProcessor rp(attrContext, metaStore, sessionMgr, groupingContext, sessionId,
request.sortSpec, params.offset, params.hits);
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index f3fe86e261f..7d6f2f8790c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -946,6 +946,8 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRandom()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorEuclideanDistance()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCosineSimilarity()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
@@ -1098,6 +1100,8 @@
"public static final int RANDOM",
"public static final int L1_NORMALIZE",
"public static final int L2_NORMALIZE",
+ "public static final int EUCLIDEAN_DISTANCE",
+ "public static final int COSINE_SIMILARITY",
"public static final int MATMUL",
"public static final int SOFTMAX",
"public static final int XW_PLUS_B",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 42b5f2c191a..41647a5ef5b 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -138,6 +138,8 @@ TOKEN :
<RANDOM: "random"> |
<L1_NORMALIZE: "l1_normalize"> |
<L2_NORMALIZE: "l2_normalize"> |
+ <EUCLIDEAN_DISTANCE: "euclidean_distance"> |
+ <COSINE_SIMILARITY: "cosine_similarity"> |
<MATMUL: "matmul"> |
<SOFTMAX: "softmax"> |
<XW_PLUS_B: "xw_plus_b"> |
@@ -379,6 +381,8 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorRandom() |
tensorExpression = tensorL1Normalize() |
tensorExpression = tensorL2Normalize() |
+ tensorExpression = tensorEuclideanDistance() |
+ tensorExpression = tensorCosineSimilarity() |
tensorExpression = tensorMatmul() |
tensorExpression = tensorSoftmax() |
tensorExpression = tensorXwPlusB() |
@@ -544,6 +548,30 @@ TensorFunctionNode tensorL2Normalize() :
{ return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
+TensorFunctionNode tensorEuclideanDistance() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <EUCLIDEAN_DISTANCE> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new EuclideanDistance(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
+}
+
+TensorFunctionNode tensorCosineSimilarity() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <COSINE_SIMILARITY> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new CosineSimilarity(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
+}
+
TensorFunctionNode tensorMatmul() :
{
ExpressionNode tensor1, tensor2;
@@ -701,6 +729,8 @@ String tensorFunctionName() :
( <RANDOM> { return token.image; } ) |
( <L1_NORMALIZE> { return token.image; } ) |
( <L2_NORMALIZE> { return token.image; } ) |
+ ( <EUCLIDEAN_DISTANCE> { return token.image; } ) |
+ ( <COSINE_SIMILARITY> { return token.image; } ) |
( <MATMUL> { return token.image; } ) |
( <SOFTMAX> { return token.image; } ) |
( <XW_PLUS_B> { return token.image; } ) |
@@ -1041,4 +1071,4 @@ String label() :
String string() : {}
{
<STRING> { return token.image.substring(1, token.image.length() - 1); }
-} \ No newline at end of file
+}
diff --git a/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp b/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp
index 4b01808e855..e3e4f391cc4 100644
--- a/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp
+++ b/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp
@@ -95,7 +95,7 @@ public:
ArrayStoreConfig config(max_array_store_type_id,
ArrayStoreConfig::AllocSpec(0, RefType::offsetSize(), 8_Ki, ALLOC_GROW_FACTOR));
config.enable_free_lists(enable_free_lists);
- _mvMapping = std::make_unique<MvMapping>(config, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats));
+ _mvMapping = std::make_unique<MvMapping>(config, ArrayStoreConfig::default_max_buffer_size, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats));
_attr = std::make_unique<AttributeType>(*_mvMapping);
_maxSmallArraySize = _mvMapping->get_mapper().get_array_size(max_array_store_type_id);
}
@@ -103,7 +103,7 @@ public:
ArrayStoreConfig config(max_array_store_type_id,
ArrayStoreConfig::AllocSpec(min_entries, max_entries, num_entries_for_new_buffer, ALLOC_GROW_FACTOR));
config.enable_free_lists(enable_free_lists);
- _mvMapping = std::make_unique<MvMapping>(config, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats));
+ _mvMapping = std::make_unique<MvMapping>(config, ArrayStoreConfig::default_max_buffer_size, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats));
_attr = std::make_unique<AttributeType>(*_mvMapping);
_maxSmallArraySize = _mvMapping->get_mapper().get_array_size(max_array_store_type_id);
}
diff --git a/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp b/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp
index 2bf2a38b7e6..0d64683b4a6 100644
--- a/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp
+++ b/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp
@@ -565,7 +565,7 @@ TEST("Control static memory usage") {
IDocumentStore &ds = vcs.getStore();
vespalib::MemoryUsage usage = ds.getMemoryUsage();
constexpr size_t mutex_size = sizeof(std::mutex) * 2 * (113 + 1); // sizeof(std::mutex) is platform dependent
- EXPECT_EQUAL(74572 + mutex_size, usage.allocatedBytes());
+ EXPECT_EQUAL(74668 + mutex_size, usage.allocatedBytes());
EXPECT_EQUAL(944u + mutex_size, usage.usedBytes());
}
@@ -575,29 +575,29 @@ TEST("test the update cache strategy") {
for (size_t i(1); i <= 10; i++) {
vcs.write(i);
}
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 28));
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241));
vcs.write(8);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241));
vcs.write(7, 17);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 282));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 302));
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 282));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 302));
vcs.remove(8);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 282));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 302));
vcs.remove(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 28));
vcs.write(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 28));
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 241));
vcs.write(7, 17);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 282));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 302));
vcs.recreate();
IDocumentStore & ds2 = vcs.getStore();
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds2.getCacheStats(), 0, 1, 1, 282));
+ TEST_DO(verifyCacheStats(ds2.getCacheStats(), 0, 1, 1, 302));
}
TEST("test the invalidate cache strategy") {
@@ -606,23 +606,23 @@ TEST("test the invalidate cache strategy") {
for (size_t i(1); i <= 10; i++) {
vcs.write(i);
}
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 28));
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241));
vcs.write(8);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241));
vcs.write(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 0, 28));
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 241));
vcs.remove(8);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 241));
vcs.remove(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 28));
vcs.write(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 28));
vcs.verifyRead(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 3, 1, 221));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 3, 1, 241));
}
TEST("test that the integrated visit cache works.") {
@@ -631,12 +631,12 @@ TEST("test that the integrated visit cache works.") {
for (size_t i(1); i <= 100; i++) {
vcs.write(i);
}
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 0));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 28));
for (size_t i(1); i <= 100; i++) {
vcs.verifyRead(i);
}
- constexpr size_t BASE_SZ = 20594;
+ constexpr size_t BASE_SZ = 20602;
TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 100, 100, BASE_SZ));
for (size_t i(1); i <= 100; i++) {
vcs.verifyRead(i);
@@ -646,32 +646,32 @@ TEST("test that the integrated visit cache works.") {
vcs.verifyVisit({7,9,17,19,67,88}, false);
TEST_DO(verifyCacheStats(ds.getCacheStats(), 100, 100, 100, BASE_SZ));
vcs.verifyVisit({7,9,17,19,67,88}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 100, 101, 101, BASE_SZ+557));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 100, 101, 101, BASE_SZ+16));
vcs.verifyVisit({7,9,17,19,67,88}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 101, BASE_SZ+557));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 101, BASE_SZ+16));
vcs.rewrite(8);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 100, BASE_SZ+328)); // From the individual cache.
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 100, BASE_SZ-197)); // From the individual cache.
vcs.rewrite(7);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 98, BASE_SZ-442)); // From the both caches.
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 98, BASE_SZ-166)); // From the both caches.
vcs.verifyVisit({7,9,17,19,67,88}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 102, 99, BASE_SZ+130));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 102, 99, BASE_SZ-410));
vcs.verifyVisit({7,9,17,19,67,88,89}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 99, BASE_SZ+180));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 99, BASE_SZ-406));
vcs.rewrite(17);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 97, BASE_SZ-671));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 97, BASE_SZ-391));
vcs.verifyVisit({7,9,17,19,67,88,89}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 98, BASE_SZ-20));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 98, BASE_SZ-611));
vcs.remove(17);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 97, BASE_SZ-671));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 97, BASE_SZ-391));
vcs.verifyVisit({7,9,17,19,67,88,89}, {7,9,19,67,88,89}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 105, 98, BASE_SZ-70));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 105, 98, BASE_SZ-611));
vcs.verifyVisit({41, 42}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 106, 99, BASE_SZ+230));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 106, 99, BASE_SZ-611));
vcs.verifyVisit({43, 44}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 107, 100, BASE_SZ+540));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 107, 100, BASE_SZ-611));
vcs.verifyVisit({41, 42, 43, 44}, true);
- TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 108, 99, BASE_SZ+360));
+ TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 108, 99, BASE_SZ-611));
}
TEST("testWriteRead") {
diff --git a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp
index 26ef57aab65..87420e8939f 100644
--- a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp
+++ b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp
@@ -9,6 +9,7 @@ LOG_SETUP("dense_tensor_store_test");
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/value_type.h>
#include <vespa/eval/eval/test/value_compare.h>
+#include <vespa/vespalib/util/size_literals.h>
using search::tensor::DenseTensorStore;
using vespalib::eval::SimpleValue;
@@ -90,5 +91,26 @@ TEST("require that array size is calculated correctly")
TEST_DO(assertArraySize("tensor<int8>(x[65])", 96));
}
+void
+assert_max_buffer_entries(const vespalib::string& tensor_type, uint32_t exp_entries)
+{
+ Fixture f(tensor_type);
+ EXPECT_EQUAL(exp_entries, f.store.get_max_buffer_entries());
+}
+
+TEST("require that max entries is calculated correctly")
+{
+ TEST_DO(assert_max_buffer_entries("tensor(x[1])", 1_Mi));
+ TEST_DO(assert_max_buffer_entries("tensor(x[32])", 1_Mi));
+ TEST_DO(assert_max_buffer_entries("tensor(x[64])", 512_Ki));
+ TEST_DO(assert_max_buffer_entries("tensor(x[1024])", 32_Ki));
+ TEST_DO(assert_max_buffer_entries("tensor(x[1024])", 32_Ki));
+ TEST_DO(assert_max_buffer_entries("tensor(x[16777216])", 2));
+ TEST_DO(assert_max_buffer_entries("tensor(x[33554428])", 2));
+ TEST_DO(assert_max_buffer_entries("tensor(x[33554429])", 1));
+ TEST_DO(assert_max_buffer_entries("tensor(x[33554432])", 1));
+ TEST_DO(assert_max_buffer_entries("tensor(x[303554432])", 1));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp b/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp
index 0a69d8149ed..b42520aa9e0 100644
--- a/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp
+++ b/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp
@@ -198,7 +198,7 @@ TEST_F(TensorBufferStoreTest, buffer_handles_range_of_subspaces)
auto buffer_id = ref.buffer_id(offset_bits);
buffers.insert(buffer_id);
}
- EXPECT_EQ(156u, buffers.size());
+ EXPECT_EQ(119u, buffers.size());
uint32_t x = 0;
for (auto ref : refs) {
auto tensor = store.get_tensor(ref);
diff --git a/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp b/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp
index 08c0901de01..fc574ba9b2c 100644
--- a/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp
+++ b/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp
@@ -15,7 +15,7 @@ const vespalib::string tensor_type_2d_mixed_spec("tensor(x{},y[2])");
const vespalib::string float_tensor_type_spec("tensor<float>(y{})");
const vespalib::string tensor_type_dense_spec("tensor(x[2])");
-constexpr double grow_factor = 1.02;
+constexpr double grow_factor = 1.03;
struct TestParam
{
@@ -128,10 +128,10 @@ TensorBufferTypeMapperTest::select_type_ids()
INSTANTIATE_TEST_SUITE_P(TensorBufferTypeMapperMultiTest,
TensorBufferTypeMapperTest,
- testing::Values(TestParam("1d", {8, 16, 32, 40, 64}, {1760, 10880, 76896, 555248, 4020512}, tensor_type_sparse_spec),
- TestParam("1dfloat", {4, 12, 20, 28, 36}, {1728, 11104, 79168, 572128, 4143664}, float_tensor_type_spec),
- TestParam("2d", {8, 24, 40, 56, 80}, {1600, 9184, 63872, 460416, 3332976}, tensor_type_2d_spec),
- TestParam("2dmixed", {8, 24, 48, 64, 96}, {1984, 11472, 79824, 575504, 4166208}, tensor_type_2d_mixed_spec),
+ testing::Values(TestParam("1d", {8, 16, 32, 40, 64}, {2768, 49712, 950768, 18268976, 351101184}, tensor_type_sparse_spec),
+ TestParam("1dfloat", {4, 12, 20, 28, 36}, {2688, 48896, 937248, 18009808, 346121248}, float_tensor_type_spec),
+ TestParam("2d", {8, 24, 40, 56, 80}, {2416, 41392, 790112, 15179616, 291726288}, tensor_type_2d_spec),
+ TestParam("2dmixed", {8, 24, 48, 64, 96}, {3008, 51728, 987632, 18974512, 364657856}, tensor_type_2d_mixed_spec),
TestParam("dense", {8, 24}, {}, tensor_type_dense_spec)),
testing::PrintToStringParamName());
@@ -152,7 +152,7 @@ TEST_P(TensorBufferTypeMapperTest, large_arrays_grows_exponentially)
TEST_P(TensorBufferTypeMapperTest, avoid_array_size_overflow)
{
- TensorBufferTypeMapper mapper(400, 2.0, &_ops);
+ TensorBufferTypeMapper mapper(300, 2.0, &_ops);
EXPECT_GE(30, mapper.get_max_type_id(1000));
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h
index 307d1a0d112..ced076dc632 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h
+++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h
@@ -36,6 +36,7 @@ public:
MultiValueMapping(const MultiValueMapping &) = delete;
MultiValueMapping & operator = (const MultiValueMapping &) = delete;
MultiValueMapping(const vespalib::datastore::ArrayStoreConfig &storeCfg,
+ size_t max_buffer_size,
const vespalib::GrowStrategy &gs,
std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator);
~MultiValueMapping() override;
@@ -79,11 +80,12 @@ public:
static vespalib::datastore::ArrayStoreConfig optimizedConfigForHugePage(size_t max_type_id,
- size_t hugePageSize,
- size_t smallPageSize,
- size_t min_num_entries_for_new_buffer,
- float allocGrowFactor,
- bool enable_free_lists);
+ size_t hugePageSize,
+ size_t smallPageSize,
+ size_t max_buffer_size,
+ size_t min_num_entries_for_new_buffer,
+ float allocGrowFactor,
+ bool enable_free_lists);
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp
index 99808b11e92..64c4777ffda 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp
@@ -9,10 +9,11 @@ namespace search::attribute {
template <typename ElemT, typename RefT>
MultiValueMapping<ElemT,RefT>::MultiValueMapping(const vespalib::datastore::ArrayStoreConfig &storeCfg,
+ size_t max_buffer_size,
const vespalib::GrowStrategy &gs,
std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator)
: MultiValueMappingBase(gs, ArrayStore::getGenerationHolderLocation(_store), memory_allocator),
- _store(storeCfg, std::move(memory_allocator), ArrayStoreTypeMapper(storeCfg.max_type_id(), array_store_grow_factor))
+ _store(storeCfg, std::move(memory_allocator), ArrayStoreTypeMapper(storeCfg.max_type_id(), array_store_grow_factor, max_buffer_size))
{
}
@@ -66,14 +67,15 @@ MultiValueMapping<ElemT, RefT>::getAddressSpaceUsage() const {
template <typename ElemT, typename RefT>
vespalib::datastore::ArrayStoreConfig
MultiValueMapping<ElemT, RefT>::optimizedConfigForHugePage(size_t max_type_id,
- size_t hugePageSize,
- size_t smallPageSize,
- size_t min_num_entries_for_new_buffer,
- float allocGrowFactor,
- bool enable_free_lists)
+ size_t hugePageSize,
+ size_t smallPageSize,
+ size_t max_buffer_size,
+ size_t min_num_entries_for_new_buffer,
+ float allocGrowFactor,
+ bool enable_free_lists)
{
- ArrayStoreTypeMapper mapper(max_type_id, array_store_grow_factor);
- auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, min_num_entries_for_new_buffer, allocGrowFactor);
+ ArrayStoreTypeMapper mapper(max_type_id, array_store_grow_factor, max_buffer_size);
+ auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, max_buffer_size, min_num_entries_for_new_buffer, allocGrowFactor);
result.enable_free_lists(enable_free_lists);
return result;
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp
index d8ada97fa2c..56c6d010582 100644
--- a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp
@@ -12,6 +12,8 @@
#include <vespa/vespalib/util/memory_allocator.h>
#include <vespa/vespalib/util/stash.h>
+using vespalib::datastore::ArrayStoreConfig;
+
namespace search {
namespace multivalueattribute {
@@ -28,9 +30,11 @@ MultiValueAttribute(const vespalib::string &baseFileName,
_mvMapping(MultiValueMapping::optimizedConfigForHugePage(MultiValueMapping::array_store_max_type_id,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
8 * 1024,
cfg.getGrowStrategy().getMultiValueAllocGrowFactor(),
multivalueattribute::enable_free_lists),
+ ArrayStoreConfig::default_max_buffer_size,
cfg.getGrowStrategy(), this->get_memory_allocator())
{
}
diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp
index cd9e0508344..0c6dd9c75a8 100644
--- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp
@@ -5,6 +5,7 @@
#include <cassert>
using vespalib::alloc::MemoryAllocator;
+using vespalib::datastore::ArrayStoreConfig;
using vespalib::datastore::EntryRef;
namespace {
@@ -17,11 +18,12 @@ namespace search::attribute {
RawBufferStore::RawBufferStore(std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator, uint32_t max_small_buffer_type_id, double grow_factor)
: _array_store(ArrayStoreType::optimizedConfigForHugePage(max_small_buffer_type_id,
- TypeMapper(max_small_buffer_type_id, grow_factor),
+ TypeMapper(max_small_buffer_type_id, grow_factor, ArrayStoreConfig::default_max_buffer_size),
MemoryAllocator::HUGEPAGE_SIZE,
MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
8_Ki, ALLOC_GROW_FACTOR),
- std::move(allocator), TypeMapper(max_small_buffer_type_id, grow_factor))
+ std::move(allocator), TypeMapper(max_small_buffer_type_id, grow_factor, ArrayStoreConfig::default_max_buffer_size))
{
}
diff --git a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp
index b61cf49c438..322d0eb341b 100644
--- a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp
+++ b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp
@@ -165,6 +165,7 @@ public:
CompressedBlobSet readSet(const KeySet & keys);
void removeKey(uint32_t key);
vespalib::MemoryUsage getStaticMemoryUsage() const override;
+ CacheStats get_stats() const override;
private:
void locateAndInvalidateOtherSubsets(const UniqueLock & cacheGuard, const KeySet & keys);
using IdSet = vespalib::hash_set<uint64_t>;
@@ -267,7 +268,8 @@ VisitCache::remove(uint32_t key) {
CacheStats
VisitCache::getCacheStats() const {
- return _cache->get_stats();
+ CacheStats stats = _cache->get_stats();
+ return stats;
}
VisitCache::Cache::Cache(BackingStore & b, size_t maxBytes) :
@@ -306,19 +308,22 @@ VisitCache::Cache::onRemove(const K & key) {
vespalib::MemoryUsage
VisitCache::Cache::getStaticMemoryUsage() const {
vespalib::MemoryUsage usage = Parent::getStaticMemoryUsage();
- auto cacheGuard = getGuard();
size_t baseSelf = sizeof(_lid2Id) + sizeof(_id2KeySet);
usage.incAllocatedBytes(baseSelf);
- usage.incAllocatedBytes(_lid2Id.capacity() * sizeof(LidUniqueKeySetId::value_type));
- usage.incAllocatedBytes(_id2KeySet.capacity() * sizeof(IdKeySetMap::value_type));
usage.incUsedBytes(baseSelf);
- usage.incUsedBytes(_lid2Id.size() * sizeof(LidUniqueKeySetId::value_type));
- usage.incUsedBytes(_id2KeySet.size() * sizeof(IdKeySetMap::value_type));
+ return usage;
+}
+
+CacheStats
+VisitCache::Cache::get_stats() const {
+ CacheStats stats = Parent::get_stats();
+ auto cacheGuard = getGuard();
+ stats.memory_used += _lid2Id.capacity() * sizeof(LidUniqueKeySetId::value_type);
+ stats.memory_used += _id2KeySet.capacity() * sizeof(IdKeySetMap::value_type);
for (const auto & entry: _id2KeySet) {
- usage.incAllocatedBytes(entry.second.getKeys().capacity() * sizeof(uint32_t));
- usage.incUsedBytes(entry.second.getKeys().size() * sizeof(uint32_t));
+ stats.memory_used = entry.second.getKeys().capacity() * sizeof(uint32_t);
}
- return usage;
+ return stats;
}
}
diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp
index 26a7be005b7..c993cdeb790 100644
--- a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp
+++ b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp
@@ -620,7 +620,13 @@ const feature_t RankScoreDropLimit::DEFAULT_VALUE(-std::numeric_limits<feature_t
feature_t
RankScoreDropLimit::lookup(const Properties &props)
{
- return lookupDouble(props, NAME, DEFAULT_VALUE);
+ return lookup(props, DEFAULT_VALUE);
+}
+
+feature_t
+RankScoreDropLimit::lookup(const Properties &props, feature_t defaultValue)
+{
+ return lookupDouble(props, NAME, defaultValue);
}
} // namspace hitcollector
diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.h b/searchlib/src/vespa/searchlib/fef/indexproperties.h
index 5ff4ea26bd8..53f1789b2e0 100644
--- a/searchlib/src/vespa/searchlib/fef/indexproperties.h
+++ b/searchlib/src/vespa/searchlib/fef/indexproperties.h
@@ -507,6 +507,7 @@ namespace hitcollector {
static const vespalib::string NAME;
static const feature_t DEFAULT_VALUE;
static feature_t lookup(const Properties &props);
+ static feature_t lookup(const Properties &props, feature_t defaultValue);
};
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
index c51d0ec7fd3..638c254aac1 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
@@ -34,6 +34,13 @@ size_t my_align(size_t size, size_t alignment) {
return (size - (size % alignment));
}
+size_t
+cap_max_entries(size_t max_entries, size_t max_buffer_size, size_t entry_size)
+{
+ size_t dynamic_max_entries = (max_buffer_size + (entry_size - 1)) / entry_size;
+ return std::min(max_entries, dynamic_max_entries);
+}
+
}
DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type)
@@ -55,7 +62,7 @@ DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type)
}
DenseTensorStore::BufferType::BufferType(const TensorSizeCalc &tensorSizeCalc, std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator)
- : vespalib::datastore::BufferType<char>(tensorSizeCalc.alignedSize(), MIN_BUFFER_ARRAYS, RefType::offsetSize()),
+ : vespalib::datastore::BufferType<char>(tensorSizeCalc.alignedSize(), MIN_BUFFER_ARRAYS, cap_max_entries(RefType::offsetSize(), max_dense_tensor_buffer_size, tensorSizeCalc.alignedSize())),
_allocator(std::move(allocator))
{}
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
index 0dd483e7f08..b7375c4d2c3 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
@@ -20,9 +20,11 @@ namespace search::tensor {
class DenseTensorStore : public TensorStore
{
public:
- using RefType = vespalib::datastore::EntryRefT<22>;
+ // 4 Ki buffers of 256 MiB each is 1 TiB.
+ using RefType = vespalib::datastore::EntryRefT<20>;
using DataStoreType = vespalib::datastore::DataStoreT<RefType>;
using ValueType = vespalib::eval::ValueType;
+ static constexpr size_t max_dense_tensor_buffer_size = 256_Mi;
struct TensorSizeCalc
{
@@ -90,8 +92,9 @@ public:
return VectorBundle(getRawBuffer(ref), 1, _subspace_type);
}
const SubspaceType& get_subspace_type() const noexcept { return _subspace_type; }
- // The following method is meant to be used only for unit tests.
+ // The following methods are meant to be used only for unit tests.
uint32_t getArraySize() const { return _bufferType.getArraySize(); }
+ uint32_t get_max_buffer_entries() const noexcept { return _bufferType.get_max_entries(); }
};
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index 748a747d515..22a33270a27 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -31,6 +31,7 @@ namespace search::tensor {
using search::AddressSpaceComponents;
using search::StateExplorerUtils;
using search::queryeval::GlobalFilter;
+using vespalib::datastore::ArrayStoreConfig;
using vespalib::datastore::CompactionStrategy;
using vespalib::datastore::EntryRef;
using vespalib::GenericHeader;
@@ -145,25 +146,27 @@ PreparedAddDoc::PreparedAddDoc(PreparedAddDoc&& other) noexcept = default;
}
template <HnswIndexType type>
-vespalib::datastore::ArrayStoreConfig
+ArrayStoreConfig
HnswIndex<type>::make_default_level_array_store_config()
{
return LevelArrayStore::optimizedConfigForHugePage(max_level_array_size,
- vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
- vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
+ vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
min_num_arrays_for_new_buffer,
alloc_grow_factor).enable_free_lists(true);
}
template <HnswIndexType type>
-vespalib::datastore::ArrayStoreConfig
+ArrayStoreConfig
HnswIndex<type>::make_default_link_array_store_config()
{
return LinkArrayStore::optimizedConfigForHugePage(max_link_array_size,
- vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
- vespalib::alloc::MemoryAllocator::PAGE_SIZE,
- min_num_arrays_for_new_buffer,
- alloc_grow_factor).enable_free_lists(true);
+ vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
+ vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
+ min_num_arrays_for_new_buffer,
+ alloc_grow_factor).enable_free_lists(true);
}
template <HnswIndexType type>
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp
index a78d9cefc64..cf30d62a0b8 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp
@@ -49,6 +49,7 @@ HnswNodeidMapping::HnswNodeidMapping()
_nodeids(NodeidStore::optimizedConfigForHugePage(max_type_id,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
min_num_arrays_for_new_buffer,
alloc_grow_factor).enable_free_lists(true), {}),
_hold_list(),
diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp
index 51ebc22c269..1e79ca53e68 100644
--- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp
@@ -16,7 +16,8 @@ namespace search::tensor {
SerializedFastValueAttribute::SerializedFastValueAttribute(stringref name, const Config &cfg, const NearestNeighborIndexFactory& index_factory)
: TensorAttribute(name, cfg, _tensorBufferStore, index_factory),
- _tensorBufferStore(cfg.tensorType(), get_memory_allocator(), 400u)
+ _tensorBufferStore(cfg.tensorType(), get_memory_allocator(),
+ TensorBufferStore::array_store_max_type_id)
{
}
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp
index ff39c33fc5d..8a7d84010cb 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp
@@ -26,8 +26,6 @@ namespace {
constexpr float ALLOC_GROW_FACTOR = 0.2;
-constexpr double mapper_grow_factor = 1.02;
-
}
TensorBufferStore::TensorBufferStore(const ValueType& tensor_type, std::shared_ptr<MemoryAllocator> allocator, uint32_t max_small_subspaces_type_id)
@@ -35,11 +33,12 @@ TensorBufferStore::TensorBufferStore(const ValueType& tensor_type, std::shared_p
_tensor_type(tensor_type),
_ops(_tensor_type),
_array_store(ArrayStoreType::optimizedConfigForHugePage(max_small_subspaces_type_id,
- TensorBufferTypeMapper(max_small_subspaces_type_id, mapper_grow_factor, &_ops),
+ TensorBufferTypeMapper(max_small_subspaces_type_id, array_store_grow_factor, &_ops),
MemoryAllocator::HUGEPAGE_SIZE,
MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
8_Ki, ALLOC_GROW_FACTOR),
- std::move(allocator), TensorBufferTypeMapper(max_small_subspaces_type_id, mapper_grow_factor, &_ops))
+ std::move(allocator), TensorBufferTypeMapper(max_small_subspaces_type_id, array_store_grow_factor, &_ops))
{
}
diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h
index 2e86ff5fb67..3342e9f3d27 100644
--- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h
+++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h
@@ -24,6 +24,10 @@ class TensorBufferStore : public TensorStore
TensorBufferOperations _ops;
ArrayStoreType _array_store;
public:
+
+ static constexpr double array_store_grow_factor = 1.03;
+ static constexpr uint32_t array_store_max_type_id = 300;
+
TensorBufferStore(const vespalib::eval::ValueType& tensor_type, std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator, uint32_t max_small_subspaces_type_id);
~TensorBufferStore();
void holdTensor(EntryRef ref) override;
diff --git a/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp b/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp
index baec5494b36..efa7e18aa33 100644
--- a/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp
+++ b/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp
@@ -40,6 +40,7 @@ vespalib::datastore::ArrayStoreConfig make_default_array_store_config() {
return ReplicaStore::optimizedConfigForHugePage(1023,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ vespalib::datastore::ArrayStoreConfig::default_max_buffer_size,
8_Ki, 0.2).enable_free_lists(true);
}
diff --git a/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java b/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java
index c6d056675a8..aff505f934f 100644
--- a/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java
+++ b/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java
@@ -70,6 +70,7 @@ public class ApplicationMojo extends AbstractMojo {
vespaversion = project.getPlugin("com.yahoo.vespa:vespa-application-maven-plugin").getVersion();
Version compileVersion = Version.from(vespaversion);
+ if (compileVersion.isSnapshot()) return;
MavenProject current = project;
while (current.getParent() != null && current.getParent().getParentArtifact() != null)
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 676e212f5c6..76d007dd633 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1705,6 +1705,24 @@
],
"fields" : [ ]
},
+ "com.yahoo.tensor.functions.CosineSimilarity" : {
+ "superClass" : "com.yahoo.tensor.functions.TensorFunction",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.tensor.functions.Diag" : {
"superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction",
"interfaces" : [ ],
@@ -1740,6 +1758,24 @@
],
"fields" : [ ]
},
+ "com.yahoo.tensor.functions.EuclideanDistance" : {
+ "superClass" : "com.yahoo.tensor.functions.TensorFunction",
+ "interfaces" : [ ],
+ "attributes" : [
+ "public"
+ ],
+ "methods" : [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
+ ],
+ "fields" : [ ]
+ },
"com.yahoo.tensor.functions.Expand" : {
"superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction",
"interfaces" : [ ],
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
new file mode 100644
index 00000000000..ebb8a11fd8a
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
@@ -0,0 +1,93 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorType.Dimension;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Convenience for cosine similarity between vectors.
+ * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim))
+ * @author arnej
+ */
+public class CosineSimilarity<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> arg1;
+ private final TensorFunction<NAMETYPE> arg2;
+ private final String dimension;
+
+ public CosineSimilarity(TensorFunction<NAMETYPE> argument1,
+ TensorFunction<NAMETYPE> argument2,
+ String dimension)
+ {
+ this.arg1 = argument1;
+ this.arg2 = argument2;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(arg1, arg2); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("CosineSimilarity must have 2 arguments, got " + arguments.size());
+ return new CosineSimilarity<>(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ TensorType t1 = arg1.toPrimitive().type(context);
+ TensorType t2 = arg2.toPrimitive().type(context);
+ var d1 = t1.dimension(dimension);
+ var d2 = t2.dimension(dimension);
+ if (d1.isEmpty() || d2.isEmpty()
+ || d1.get().type() != Dimension.Type.indexedBound
+ || d2.get().type() != Dimension.Type.indexedBound
+ || d1.get().size().get() != d2.get().size().get())
+ {
+ throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '"
+ + dimension + "' dimension with same size, but input types were "
+ + t1 + " and " + t2);
+ }
+ // Finds the type this produces by first converting it to a primitive function
+ return toPrimitive().type(context);
+ }
+
+ /** Evaluates this by first converting it to a primitive function */
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> a = arg1.toPrimitive();
+ TensorFunction<NAMETYPE> b = arg2.toPrimitive();
+ var aa = new Join<>(a, a, ScalarFunctions.multiply());
+ var ab = new Join<>(a, b, ScalarFunctions.multiply());
+ var bb = new Join<>(b, b, ScalarFunctions.multiply());
+ var dot_aa = new Reduce<>(aa, Reduce.Aggregator.sum, dimension);
+ var dot_ab = new Reduce<>(ab, Reduce.Aggregator.sum, dimension);
+ var dot_bb = new Reduce<>(bb, Reduce.Aggregator.sum, dimension);
+ var aabb = new Join<>(dot_aa, dot_bb, ScalarFunctions.multiply());
+ var sqrt_aabb = new Map<>(aabb, ScalarFunctions.sqrt());
+ return new Join<>(dot_ab, sqrt_aabb, ScalarFunctions.divide());
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "cosine_similarity(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("cosine_similarity", arg1, arg2, dimension); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
new file mode 100644
index 00000000000..f9fc8e195d3
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
@@ -0,0 +1,89 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorType.Dimension;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * Convenience for euclidean distance between vectors.
+ * euclidean_distance(a, b, mydim) == sqrt(sum(pow(a-b, 2), mydim))
+ * @author arnej
+ */
+public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> arg1;
+ private final TensorFunction<NAMETYPE> arg2;
+ private final String dimension;
+
+ public EuclideanDistance(TensorFunction<NAMETYPE> argument1,
+ TensorFunction<NAMETYPE> argument2,
+ String dimension)
+ {
+ this.arg1 = argument1;
+ this.arg2 = argument2;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(arg1, arg2); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("EuclideanDistance must have 2 arguments, got " + arguments.size());
+ return new EuclideanDistance<>(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ TensorType t1 = arg1.toPrimitive().type(context);
+ TensorType t2 = arg2.toPrimitive().type(context);
+ var d1 = t1.dimension(dimension);
+ var d2 = t2.dimension(dimension);
+ if (d1.isEmpty() || d2.isEmpty()
+ || d1.get().type() != Dimension.Type.indexedBound
+ || d2.get().type() != Dimension.Type.indexedBound
+ || d1.get().size().get() != d2.get().size().get())
+ {
+ throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '"
+ + dimension + "' dimension with same size, but input types were "
+ + t1 + " and " + t2);
+ }
+ // Finds the type this produces by first converting it to a primitive function
+ return toPrimitive().type(context);
+ }
+
+ /** Evaluates this by first converting it to a primitive function */
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive();
+ TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive();
+ // this should match the C++ optimized "l2_distance"
+ var diffs = new Join<>(primitive1, primitive2, ScalarFunctions.subtract());
+ var squaredDiffs = new Map<>(diffs, ScalarFunctions.square());
+ var sumOfSquares = new Reduce<>(squaredDiffs, Reduce.Aggregator.sum, dimension);
+ return new Map<>(sumOfSquares, ScalarFunctions.sqrt());
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("euclidean_distance", arg1, arg2, dimension); }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
new file mode 100644
index 00000000000..b303e2c1739
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
@@ -0,0 +1,66 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author arnej
+ */
+public class CosineSimilarityTestCase {
+
+ @Test
+ public void testVectorSimilarity() {
+ var a = Tensor.from("tensor(x[3]):[ 2.0, 3.0, 6.0]");
+ var b = Tensor.from("tensor(x[3]):[-2.0, 0.0, 0.0]");
+ var c = Tensor.from("tensor(x[3]):[ 0.0, 4.0, 3.0]");
+ var op = new CosineSimilarity<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x");
+ Tensor result = op.evaluate();
+ assertEquals((-2.0 / 7.0), result.asDouble(), 0.000001);
+ op = new CosineSimilarity<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x");
+ result = op.evaluate();
+ assertEquals((-2.0 / 7.0), result.asDouble(), 0.000001);
+ op = new CosineSimilarity<>(new ConstantTensor<>(a), new ConstantTensor<>(c), "x");
+ result = op.evaluate();
+ assertEquals((30.0 / 35.0), result.asDouble(), 0.000001);
+ op = new CosineSimilarity<>(new ConstantTensor<>(b), new ConstantTensor<>(c), "x");
+ result = op.evaluate();
+ assertEquals(0.0, result.asDouble(), 0.000001);
+ }
+
+ @Test
+ public void testSimilarityInMixed() {
+ var a = Tensor.from("tensor(c{},yy[3]):{foo:[3.0, 4.0, 0.0],bar:[0.0, -4.0, 3.0]}");
+ var b = Tensor.from("tensor(c{},yy[3]):{foo:[0.0, 4.0, -3.0],bar:[4.0, 0.0, -3.0]}");
+ var op = new CosineSimilarity<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "yy");
+ Tensor result = op.evaluate();
+ var expect = Tensor.from("tensor(c{}):{foo:0.64,bar:-0.36}");
+ assertEquals(expect, result);
+ }
+
+ @Test
+ public void testExpansion() {
+ var tType = TensorType.fromSpec("tensor(vecdim[128])");
+ var a = new VariableTensor<>("left", tType);
+ var b = new VariableTensor<>("right", tType);
+ var op = new CosineSimilarity<>(a, b, "vecdim");
+ assertEquals("join(" +
+ ( "reduce(join(left, right, f(a,b)(a * b)), sum, vecdim), " +
+ "map(" +
+ ( "join(" +
+ ( "reduce(join(left, left, f(a,b)(a * b)), sum, vecdim), " +
+ "reduce(join(right, right, f(a,b)(a * b)), sum, vecdim), " +
+ "f(a,b)(a * b)), " ) +
+ "f(a)(sqrt(a))), " ) +
+ "f(a,b)(a / b)" ) +
+ ")",
+ op.toPrimitive().toString());
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
new file mode 100644
index 00000000000..4fae432b3ca
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
@@ -0,0 +1,54 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author arnej
+ */
+public class EuclideanDistanceTestCase {
+
+ @Test
+ public void testVectorDistances() {
+ var a = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]");
+ var b = Tensor.from("tensor(x[3]):[4.0, 2.0, 7.0]");
+ var c = Tensor.from("tensor(x[3]):[1.0, 6.0, 6.0]");
+ var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x");
+ Tensor result = op.evaluate();
+ assertEquals(5.0, result.asDouble(), 0.000001);
+ op = new EuclideanDistance<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x");
+ result = op.evaluate();
+ assertEquals(5.0, result.asDouble(), 0.000001);
+ op = new EuclideanDistance<>(new ConstantTensor<>(c), new ConstantTensor<>(a), "x");
+ result = op.evaluate();
+ assertEquals(5.0, result.asDouble(), 0.000001);
+ }
+
+ @Test
+ public void testDistancesInMixed() {
+ var a = Tensor.from("tensor(c{},x[3]):{foo:[1.0, 2.0, 3.0],bar:[0.0, 0.0, 0.0]}");
+ var b = Tensor.from("tensor(c{},x[3]):{foo:[4.0, 2.0, 7.0],bar:[12.0, 0.0, 5.0]}");
+ var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x");
+ Tensor result = op.evaluate();
+ var expect = Tensor.from("tensor(c{}):{foo:5.0,bar:13.0}");
+ assertEquals(expect, result);
+ }
+
+ @Test
+ public void testExpansion() {
+ var tType = TensorType.fromSpec("tensor(vecdim[128])");
+ var a = new VariableTensor<>("left", tType);
+ var b = new VariableTensor<>("right", tType);
+ var op = new EuclideanDistance<>(a, b, "vecdim");
+ assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))",
+ op.toPrimitive().toString());
+ }
+
+}
diff --git a/vespalib/src/tests/datastore/array_store/array_store_test.cpp b/vespalib/src/tests/datastore/array_store/array_store_test.cpp
index a259fcaa4dc..6e433e48d88 100644
--- a/vespalib/src/tests/datastore/array_store/array_store_test.cpp
+++ b/vespalib/src/tests/datastore/array_store/array_store_test.cpp
@@ -32,7 +32,7 @@ constexpr float ALLOC_GROW_FACTOR = 0.2;
template <typename ElemT>
class MyArrayStoreSimpleTypeMapper : public ArrayStoreSimpleTypeMapper<ElemT> {
public:
- MyArrayStoreSimpleTypeMapper(uint32_t, double)
+ MyArrayStoreSimpleTypeMapper(uint32_t, double, size_t)
: ArrayStoreSimpleTypeMapper<ElemT>()
{
}
@@ -62,7 +62,7 @@ struct ArrayStoreTest : public TestT
bool add_using_allocate;
double type_mapper_grow_factor;
ArrayStoreTest(uint32_t max_type_id = 3, bool enable_free_lists = true, bool add_using_allocate_in = false, double type_mapper_grow_factor_in = 2.0)
- : type_mapper(max_type_id, type_mapper_grow_factor_in),
+ : type_mapper(max_type_id, type_mapper_grow_factor_in, ArrayStoreConfig::default_max_buffer_size),
store(ArrayStoreConfig(max_type_id,
ArrayStoreConfig::AllocSpec(16, RefT::offsetSize(), 8_Ki,
ALLOC_GROW_FACTOR)).enable_free_lists(enable_free_lists),
@@ -74,7 +74,7 @@ struct ArrayStoreTest : public TestT
type_mapper_grow_factor(type_mapper_grow_factor_in)
{}
explicit ArrayStoreTest(const ArrayStoreConfig &storeCfg)
- : type_mapper(storeCfg.max_type_id(), 2.0),
+ : type_mapper(storeCfg.max_type_id(), 2.0, ArrayStoreConfig::default_max_buffer_size),
store(storeCfg, std::make_unique<MemoryAllocatorObserver>(stats), TypeMapperType(type_mapper)),
refStore(),
generation(1),
@@ -280,7 +280,7 @@ TYPED_TEST(NumberStoreTest, control_static_sizes) {
EXPECT_EQ(202140u, usage.allocatedBytes());
EXPECT_EQ(197680u, usage.usedBytes());
} else {
- EXPECT_EQ(202388u, usage.allocatedBytes());
+ EXPECT_EQ(202328u, usage.allocatedBytes());
EXPECT_EQ(197568u, usage.usedBytes());
}
}
@@ -564,10 +564,10 @@ TYPED_TEST(NumberStoreTest, address_space_usage_is_ratio_between_used_arrays_and
* allocated elements = 256 / sizeof(int) = 64.
* limit = 64 / 3 = 21.
*
- * For dynamic buffer 3, we have 16 * 5 * sizeof(int) => 320 -> 512
- * limit = 512 / (5 * 4) = 25
+ * For dynamic buffer 3, we have 16 * 5 * sizeof(int) => 320 -> 512 - 64
+ * limit = (512 -64) / (5 * 4) = 22
*/
- size_t type_id_3_entries = this->simple_buffers() ? 21 : 25;
+ size_t type_id_3_entries = this->simple_buffers() ? 21 : 22;
size_t expLimit = fourgig - 4 * TestFixture::EntryRefType::offsetSize() + 3 * 16 + type_id_3_entries;
EXPECT_EQ(static_cast<double>(2)/ expLimit, this->store.addressSpaceUsage().usage());
EXPECT_EQ(expLimit, this->store.addressSpaceUsage().limit());
@@ -578,6 +578,7 @@ struct ByteStoreTest : public ArrayStoreTest<testing::Test, uint8_t, EntryRefT<1
optimizedConfigForHugePage(1023,
vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE,
vespalib::alloc::MemoryAllocator::PAGE_SIZE,
+ ArrayStoreConfig::default_max_buffer_size,
8_Ki, ALLOC_GROW_FACTOR)) {}
};
diff --git a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp
index 71c1341ae74..3bcc130052d 100644
--- a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp
+++ b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp
@@ -22,15 +22,20 @@ struct Fixture
Fixture(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer)
: cfg(ArrayStoreConfig::optimizeForHugePage(max_type_id,
[](size_t type_id) noexcept { return type_id * sizeof(int); },
hugePageSize, smallPageSize,
EntryRefType::offsetSize(),
+ max_buffer_size,
min_num_entries_for_new_buffer,
ALLOC_GROW_FACTOR)) { }
void assertSpec(uint32_t type_id, uint32_t num_entries_for_new_buffer) {
- assertSpec(type_id, AllocSpec(0, EntryRefType::offsetSize(),
+ assertSpec(type_id, EntryRefType::offsetSize(), num_entries_for_new_buffer);
+ }
+ void assertSpec(uint32_t type_id, uint32_t max_entries, uint32_t num_entries_for_new_buffer) {
+ assertSpec(type_id, AllocSpec(0, max_entries,
num_entries_for_new_buffer, ALLOC_GROW_FACTOR));
}
void assertSpec(uint32_t type_id, const AllocSpec &expSpec) {
@@ -50,9 +55,6 @@ makeSpec(size_t min_entries_in_buffer,
return AllocSpec(min_entries_in_buffer, max_entries_in_buffer, num_entries_for_new_buffer, ALLOC_GROW_FACTOR);
}
-constexpr size_t KB = 1024;
-constexpr size_t MB = KB * KB;
-
TEST_F("require that default allocation spec is given for all array sizes", Fixture(3, makeSpec(4, 32, 8)))
{
EXPECT_EQUAL(3u, f.cfg.max_type_id());
@@ -62,26 +64,54 @@ TEST_F("require that default allocation spec is given for all array sizes", Fixt
TEST_DO(f.assertSpec(3, makeSpec(4, 32, 8)));
}
-TEST_F("require that we can generate config optimized for a given huge page", Fixture(1024,
- 2 * MB,
- 4 * KB,
- 8 * KB))
+struct BigBuffersFixture : public Fixture {
+ BigBuffersFixture() : Fixture(1023, 2_Mi, 4_Ki, 1024_Gi, 8_Ki) { }
+};
+
+TEST_F("require that we can generate config optimized for a given huge page without capped buffer sizes", BigBuffersFixture())
+{
+ EXPECT_EQUAL(1023u, f.cfg.max_type_id());
+ TEST_DO(f.assertSpec(0, 8_Ki)); // large arrays
+ TEST_DO(f.assertSpec(1, 256_Ki));
+ TEST_DO(f.assertSpec(2, 256_Ki));
+ TEST_DO(f.assertSpec(3, 168_Ki));
+ TEST_DO(f.assertSpec(4, 128_Ki));
+ TEST_DO(f.assertSpec(5, 100_Ki));
+ TEST_DO(f.assertSpec(6, 84_Ki));
+
+ TEST_DO(f.assertSpec(32, 16_Ki));
+ TEST_DO(f.assertSpec(33, 12_Ki));
+ TEST_DO(f.assertSpec(42, 12_Ki));
+ TEST_DO(f.assertSpec(43, 8_Ki));
+ TEST_DO(f.assertSpec(1022, 8_Ki));
+ TEST_DO(f.assertSpec(1023, 8_Ki));
+}
+
+struct CappedBuffersFixture : public Fixture {
+ CappedBuffersFixture() : Fixture(1023, 2_Mi, 4_Ki, 256_Mi, 8_Ki) { }
+ size_t max_entries(size_t array_size) {
+ auto entry_size = array_size * sizeof(int);
+ return (256_Mi + entry_size - 1) / entry_size;
+ }
+};
+
+TEST_F("require that we can generate config optimized for a given huge page with capped buffer sizes", CappedBuffersFixture())
{
- EXPECT_EQUAL(1_Ki, f.cfg.max_type_id());
- TEST_DO(f.assertSpec(0, 8 * KB)); // large arrays
- TEST_DO(f.assertSpec(1, 256 * KB));
- TEST_DO(f.assertSpec(2, 256 * KB));
- TEST_DO(f.assertSpec(3, 168 * KB));
- TEST_DO(f.assertSpec(4, 128 * KB));
- TEST_DO(f.assertSpec(5, 100 * KB));
- TEST_DO(f.assertSpec(6, 84 * KB));
+ EXPECT_EQUAL(1023u, f.cfg.max_type_id());
+ TEST_DO(f.assertSpec(0, f.max_entries(1023), 8_Ki)); // large arrays
+ TEST_DO(f.assertSpec(1, 256_Ki));
+ TEST_DO(f.assertSpec(2, 256_Ki));
+ TEST_DO(f.assertSpec(3, 168_Ki));
+ TEST_DO(f.assertSpec(4, 128_Ki));
+ TEST_DO(f.assertSpec(5, 100_Ki));
+ TEST_DO(f.assertSpec(6, 84_Ki));
- TEST_DO(f.assertSpec(32, 16 * KB));
- TEST_DO(f.assertSpec(33, 12 * KB));
- TEST_DO(f.assertSpec(42, 12 * KB));
- TEST_DO(f.assertSpec(43, 8 * KB));
- TEST_DO(f.assertSpec(1022, 8 * KB));
- TEST_DO(f.assertSpec(1023, 8 * KB));
+ TEST_DO(f.assertSpec(32, 16_Ki));
+ TEST_DO(f.assertSpec(33, 12_Ki));
+ TEST_DO(f.assertSpec(42, 12_Ki));
+ TEST_DO(f.assertSpec(43, 8_Ki));
+ TEST_DO(f.assertSpec(1022, f.max_entries(1022), 8_Ki));
+ TEST_DO(f.assertSpec(1023, f.max_entries(1023), 8_Ki));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp b/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp
index 7ead0b97269..f53f5a8ff22 100644
--- a/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp
+++ b/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp
@@ -2,10 +2,18 @@
#include <vespa/vespalib/datastore/array_store_dynamic_type_mapper.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/size_literals.h>
+#include <limits>
using vespalib::datastore::ArrayStoreDynamicTypeMapper;
+namespace {
+
constexpr double default_grow_factor = 1.03;
+constexpr size_t default_max_buffer_size = 256_Mi;
+constexpr size_t small_max_buffer_size = 256_Ki;
+constexpr size_t max_max_buffer_size = std::numeric_limits<uint32_t>::max();
+}
template <typename ElemT>
class TestBase : public testing::Test
@@ -19,13 +27,13 @@ protected:
std::vector<size_t> get_large_array_sizes(uint32_t num_large_arrays);
void select_type_ids(std::vector<size_t> array_sizes);
void setup_mapper(uint32_t max_buffer_type_id, double grow_factor);
- static uint32_t calc_max_buffer_type_id(double grow_factor);
+ static uint32_t calc_max_buffer_type_id(double grow_factor, size_t max_buffer_size = default_max_buffer_size);
};
template <typename ElemT>
TestBase<ElemT>::TestBase()
: testing::Test(),
- _mapper(5, default_grow_factor)
+ _mapper(5, default_grow_factor, default_max_buffer_size)
{
}
@@ -36,7 +44,7 @@ template <typename ElemT>
void
TestBase<ElemT>::setup_mapper(uint32_t max_buffer_type_id, double grow_factor)
{
- _mapper = ArrayStoreDynamicTypeMapper<ElemT>(max_buffer_type_id, grow_factor);
+ _mapper = ArrayStoreDynamicTypeMapper<ElemT>(max_buffer_type_id, grow_factor, default_max_buffer_size);
}
template <typename ElemT>
@@ -108,9 +116,9 @@ TestBase<ElemT>::select_type_ids(std::vector<size_t> array_sizes)
template <typename ElemT>
uint32_t
-TestBase<ElemT>::calc_max_buffer_type_id(double grow_factor)
+TestBase<ElemT>::calc_max_buffer_type_id(double grow_factor, size_t max_buffer_size)
{
- ArrayStoreDynamicTypeMapper<ElemT> mapper(1000, grow_factor);
+ ArrayStoreDynamicTypeMapper<ElemT> mapper(1000, grow_factor, max_buffer_size);
return mapper.get_max_type_id(1000);
}
@@ -139,11 +147,13 @@ TEST_F(ArrayStoreDynamicTypeMapperCharTest, large_arrays_grows_exponentially)
TEST_F(ArrayStoreDynamicTypeMapperCharTest, avoid_entry_size_overflow)
{
- EXPECT_EQ(32, calc_max_buffer_type_id(2.0));
- EXPECT_EQ(410, calc_max_buffer_type_id(1.05));
- EXPECT_EQ(507, calc_max_buffer_type_id(1.04));
- EXPECT_EQ(661, calc_max_buffer_type_id(1.03));
- EXPECT_EQ(968, calc_max_buffer_type_id(1.02));
+ EXPECT_EQ(29, calc_max_buffer_type_id(2.0));
+ EXPECT_EQ(367, calc_max_buffer_type_id(1.05));
+ EXPECT_EQ(454, calc_max_buffer_type_id(1.04));
+ EXPECT_EQ(591, calc_max_buffer_type_id(1.03));
+ EXPECT_EQ(357, calc_max_buffer_type_id(1.03, small_max_buffer_size));
+ EXPECT_EQ(661, calc_max_buffer_type_id(1.03, max_max_buffer_size));
+ EXPECT_EQ(863, calc_max_buffer_type_id(1.02));
}
using ArrayStoreDynamicTypeMapperInt32Test = TestBase<int32_t>;
@@ -159,11 +169,13 @@ TEST_F(ArrayStoreDynamicTypeMapperInt32Test, array_sizes_are_calculated)
TEST_F(ArrayStoreDynamicTypeMapperInt32Test, avoid_entry_size_overflow)
{
- EXPECT_EQ(30, calc_max_buffer_type_id(2.0));
- EXPECT_EQ(395, calc_max_buffer_type_id(1.05));
- EXPECT_EQ(487, calc_max_buffer_type_id(1.04));
- EXPECT_EQ(636, calc_max_buffer_type_id(1.03));
- EXPECT_EQ(930, calc_max_buffer_type_id(1.02));
+ EXPECT_EQ(27, calc_max_buffer_type_id(2.0));
+ EXPECT_EQ(337, calc_max_buffer_type_id(1.05));
+ EXPECT_EQ(409, calc_max_buffer_type_id(1.04));
+ EXPECT_EQ(525, calc_max_buffer_type_id(1.03));
+ EXPECT_EQ(291, calc_max_buffer_type_id(1.03, small_max_buffer_size));
+ EXPECT_EQ(596, calc_max_buffer_type_id(1.03, max_max_buffer_size));
+ EXPECT_EQ(744, calc_max_buffer_type_id(1.02));
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp b/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp
index a703d9b18eb..9279aff46b9 100644
--- a/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp
+++ b/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp
@@ -5,6 +5,7 @@
#include <ostream>
using vespalib::datastore::ArrayStoreConfig;
+using vespalib::datastore::AtomicEntryRef;
using vespalib::datastore::BufferTypeBase;
using vespalib::datastore::DynamicArrayBufferType;
using vespalib::datastore::EntryCount;
@@ -120,20 +121,23 @@ protected:
BufferType _buffer_type;
size_t _entry_size;
+ size_t _buffer_underflow_size;
size_t _buf_size;
- std::unique_ptr<char[]> _buf;
+ std::unique_ptr<char[]> _buf_alloc;
+ char* _buf;
};
DynamicArrayBufferTypeTest::DynamicArrayBufferTypeTest()
: testing::Test(),
_buffer_type(3, ArrayStoreConfig::AllocSpec(0, 10, 0, 0.2), {}),
_entry_size(_buffer_type.entry_size()),
+ _buffer_underflow_size(_buffer_type.buffer_underflow_size()),
_buf_size(2 * _entry_size),
- _buf(std::make_unique<char[]>(_buf_size))
+ _buf_alloc(std::make_unique<char[]>(_buf_size + _buffer_underflow_size)),
+ _buf(_buf_alloc.get() + _buffer_underflow_size)
{
// Call initialize_reserved_entries to force construction of empty element
- _buffer_type.initialize_reserved_entries(_buf.get(), 1);
- memset(_buf.get(), 55, _buf_size);
+ _buffer_type.initialize_reserved_entries(_buf, 1);
// Reset counts after empty element has been constructed
counts = Counts();
}
@@ -178,7 +182,7 @@ DynamicArrayBufferTypeTest::get_max_vector(const void* buffer, uint32_t offset)
void
DynamicArrayBufferTypeTest::write_entry1()
{
- auto e1 = BufferType::get_entry(_buf.get(), 1, _entry_size);
+ auto e1 = BufferType::get_entry(_buf, 1, _entry_size);
BufferType::set_dynamic_array_size(e1, 2);
new (static_cast<void *>(e1)) WrapInt32(42);
new (static_cast<void *>(e1 + 1)) WrapInt32(47);
@@ -200,50 +204,58 @@ TEST_F(DynamicArrayBufferTypeTest, entry_size_is_calculated)
EXPECT_EQ(16, get_entry_size<int64_t>(1));
EXPECT_EQ(24, get_entry_size<int64_t>(2));
EXPECT_EQ(20, get_entry_size<WrapInt32>(4));
+
+ EXPECT_EQ(1028, get_entry_size<WrapInt32>(256));
+ EXPECT_EQ(1028, get_entry_size<AtomicEntryRef>(256));
+ EXPECT_EQ(1088, get_entry_size<int32_t>(256));
+ EXPECT_EQ(1088, get_entry_size<int64_t>(128));
+ EXPECT_EQ(1088, get_entry_size<float>(256));
+ EXPECT_EQ(1088, get_entry_size<double>(128));
}
TEST_F(DynamicArrayBufferTypeTest, initialize_reserved_entries)
{
- _buffer_type.initialize_reserved_entries(_buf.get(), 2);
- EXPECT_EQ((std::vector<int>{}), get_vector(_buf.get(), 0));
- EXPECT_EQ((std::vector<int>{}), get_vector(_buf.get(), 1));
- EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf.get(), 0));
- EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf.get(), 1));
+ _buffer_type.initialize_reserved_entries(_buf, 2);
+ EXPECT_EQ((std::vector<int>{}), get_vector(_buf, 0));
+ EXPECT_EQ((std::vector<int>{}), get_vector(_buf, 1));
+ EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf, 0));
+ EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf, 1));
EXPECT_EQ(Counts(0, 0, 6, 0, 0), counts);
}
TEST_F(DynamicArrayBufferTypeTest, fallback_copy)
{
- _buffer_type.initialize_reserved_entries(_buf.get(), 1);
+ _buffer_type.initialize_reserved_entries(_buf, 1);
write_entry1();
EXPECT_EQ(Counts(0, 3, 3, 0, 0), counts);
- auto buf2 = std::make_unique<char[]>(_buf_size);
- _buffer_type.fallback_copy(buf2.get(), _buf.get(), 2);
- EXPECT_EQ((std::vector<int>{}), get_vector(buf2.get(), 0));
- EXPECT_EQ((std::vector<int>{42, 47}), get_vector(buf2.get(), 1));
- EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(buf2.get(), 0));
- EXPECT_EQ((std::vector<int>{42, 47, 49}), get_max_vector(buf2.get(), 1));
+ auto buf2_alloc = std::make_unique<char[]>(_buf_size + _buffer_underflow_size);
+ char* buf2 = buf2_alloc.get() + _buffer_underflow_size;
+ _buffer_type.fallback_copy(buf2, _buf, 2);
+ EXPECT_EQ((std::vector<int>{}), get_vector(buf2, 0));
+ EXPECT_EQ((std::vector<int>{42, 47}), get_vector(buf2, 1));
+ EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(buf2, 0));
+ EXPECT_EQ((std::vector<int>{42, 47, 49}), get_max_vector(buf2, 1));
EXPECT_EQ(Counts(0, 3, 9, 0, 0), counts);
}
TEST_F(DynamicArrayBufferTypeTest, destroy_entries)
{
- _buffer_type.initialize_reserved_entries(_buf.get(), 2);
+ _buffer_type.initialize_reserved_entries(_buf, 2);
write_entry1();
- _buffer_type.destroy_entries(_buf.get(), 2);
+ _buffer_type.destroy_entries(_buf, 2);
EXPECT_EQ(Counts(0, 3, 6, 6, 0), counts);
}
TEST_F(DynamicArrayBufferTypeTest, clean_hold)
{
- _buffer_type.initialize_reserved_entries(_buf.get(), 1);
+ _buffer_type.initialize_reserved_entries(_buf, 1);
write_entry1();
MyCleanContext clean_context;
- _buffer_type.clean_hold(_buf.get(), 1, 1, clean_context);
- EXPECT_EQ((std::vector<int>{0, 0}), get_vector(_buf.get(), 1));
- EXPECT_EQ((std::vector<int>{0, 0, 49}), get_max_vector(_buf.get(), 1));
+ _buffer_type.clean_hold(_buf, 1, 1, clean_context);
+ EXPECT_EQ((std::vector<int>{0, 0}), get_vector(_buf, 1));
+ EXPECT_EQ((std::vector<int>{0, 0, 49}), get_max_vector(_buf, 1));
EXPECT_EQ(Counts(0, 3, 3, 0, 2), counts);
- _buffer_type.clean_hold(_buf.get(), 0, 2, clean_context);
+ _buffer_type.clean_hold(_buf, 0, 2, clean_context);
EXPECT_EQ(Counts(0, 3, 3, 0, 4), counts);
}
diff --git a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp
index c6f0581bceb..afd18a13d2e 100644
--- a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp
+++ b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp
@@ -11,6 +11,7 @@
#include <ranges>
#include <random>
#include <array>
+#include <algorithm>
using namespace vespalib;
using namespace vespalib::test;
@@ -23,6 +24,71 @@ size_t state_loop = 1;
//-----------------------------------------------------------------------------
+/**
+ * Estimates the 80th percentile by throwing away the 2 best samples
+ * in each set of 10 samples, using the best remaining sample as a
+ * representative for the set. Representatives are hierarchically
+ * matched against representatives from other sample sets. Result
+ * extraction is simplified in that it does not try to estimate the
+ * actual 80th percentile, but rather tries to drop the best samples
+ * if possible.
+ *
+ * The goal is to have a more robust way of combining repeated
+ * micro-benchmark samples than simply using minimum time. With simple
+ * single-threaded CPU-bound tasks, minimum time is a good measure of
+ * how expensive something is, but when we start benchmarking
+ * operations that may conflict with themselves, we do not want to
+ * account for being super lucky. However, we still want to account
+ * for the benchmark conditions being as good as possible.
+ **/
+struct Est80P {
+ struct Level {
+ int cnt;
+ std::array<double,3> data;
+ Level(double value) noexcept
+ : cnt(1), data{value, 0.0, 0.0} {}
+ bool empty() const noexcept { return (cnt == 0); }
+ bool full() const noexcept { return (cnt == 10); }
+ void add(double value) noexcept {
+ assert(!full());
+ if (cnt < 3 || data[2] > value) {
+ size_t i = std::min(cnt, 2);
+ while (i > 0 && data[i - 1] > value) {
+ data[i] = data[i - 1];
+ --i;
+ }
+ data[i] = value;
+ }
+ ++cnt;
+ }
+ double get() const noexcept {
+ assert(!empty());
+ return data[std::min(2, cnt - 1)];
+ }
+ void clear() noexcept {
+ cnt = 0;
+ }
+ };
+ std::vector<Level> levels;
+ void add_sample(double value) {
+ for (auto &level: levels) {
+ level.add(value);
+ if (!level.full()) [[likely]] {
+ return;
+ }
+ value = level.get();
+ level.clear();
+ }
+ levels.emplace_back(value);
+ }
+ double get_result() {
+ assert(!levels.empty());
+ return levels.back().get();
+ }
+};
+
+//-----------------------------------------------------------------------------
+
struct DummyLock {
constexpr DummyLock() noexcept {}
// BasicLockable
@@ -158,47 +224,31 @@ double measure_ns(auto &work) {
struct BenchmarkResult {
double cost_ns;
- double range_ns;
- size_t threads;
- BenchmarkResult(size_t num_threads)
- : cost_ns(std::numeric_limits<double>::max()), range_ns(0.0), threads(num_threads) {}
+ BenchmarkResult(double cost_ns_in) : cost_ns(cost_ns_in) {}
void report(vespalib::string desc) {
- if (threads == 1) {
- fprintf(stderr, "%s: cost_ns: %g\n",
- desc.c_str(), cost_ns);
- } else {
- fprintf(stderr, "%s: cost_ns: %g, range_ns: %g (%zu threads)\n",
- desc.c_str(), cost_ns, range_ns, threads);
- }
+ fprintf(stderr, "%s: cost_ns: %g\n", desc.c_str(), cost_ns);
}
void report(vespalib::string name, vespalib::string desc) {
report(name + "(" + desc + ")");
}
};
-struct Meets {
- vespalib::test::ThreadMeets::Avg avg;
- vespalib::test::ThreadMeets::Range<double> range;
- Meets(size_t num_threads) : avg(num_threads), range(num_threads) {}
-};
-
BenchmarkResult benchmark_ns(auto &&work, size_t num_threads = 1) {
- Meets meets(num_threads);
+ Est80P collector;
+ vespalib::test::ThreadMeets::Avg avg(num_threads);
auto entry = [&](Nexus &ctx) {
Timer timer;
BenchmarkResult result(ctx.num_threads());
for (bool once_more = true; ctx.vote(once_more); once_more = (timer.elapsed() < budget)) {
- auto my_ns = measure_ns(work);
- auto cost_ns = meets.avg(my_ns);
- auto range_ns = meets.range(my_ns);
- if (cost_ns < result.cost_ns) {
- result.cost_ns = cost_ns;
- result.range_ns = range_ns;
+ auto cost_ns = avg(measure_ns(work));
+ if (ctx.is_main()) {
+ collector.add_sample(cost_ns);
}
}
- return result;
};
- return Nexus::run(num_threads, entry);
+ Nexus::run(num_threads, entry);
+ auto result = collector.get_result();
+ return {result};
}
//-----------------------------------------------------------------------------
@@ -224,7 +274,7 @@ void estimate_cost() {
//-----------------------------------------------------------------------------
template <typename T>
-void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, Meets &meets, int read_bp) {
+void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, auto &max, int read_bp) {
Rnd rnd(ctx.thread_id());
size_t write_cnt = 0;
size_t bad_reads = 0;
@@ -247,16 +297,11 @@ void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, Meets &meets, int r
}
}
}
- auto t1 = steady_clock::now();
- ctx.barrier();
- auto t2 = steady_clock::now();
- auto my_ms = count_ns(t1 - t0) / 1'000'000.0;
- auto total_ms = count_ns(t2 - t0) / 1'000'000.0;
- auto cost_ms = meets.avg(my_ms);
- auto range_ms = meets.range(my_ms);
- if (ctx.thread_id() == 0) {
- fprintf(stderr, "---> %s with %2zu threads (%5d bp r): avg: %10.2f ms, range: %10.2f ms, max: %10.2f ms\n",
- getClassName(lock).c_str(), ctx.num_threads(), read_bp, cost_ms, range_ms, total_ms);
+ auto my_ms = count_ns(steady_clock::now() - t0) / 1'000'000.0;
+ auto actual_ms = max(my_ms);
+ if (ctx.is_main()) {
+ fprintf(stderr, "---> %s with %2zu threads (%5d bp r): time: %10.2f ms\n",
+ getClassName(lock).c_str(), ctx.num_threads(), read_bp, actual_ms);
}
state.commit_inconsistent_reads(bad_reads);
state.commit_expected_writes(write_cnt);
@@ -290,9 +335,9 @@ void benchmark_lock() {
for (size_t bp: {10000, 9999, 5000, 0}) {
for (size_t num_threads: {8, 4, 2, 1}) {
if (bench || (bp == 9999 && num_threads == 8)) {
- Meets meets(num_threads);
+ vespalib::test::ThreadMeets::Max<double> max(num_threads);
Nexus::run(num_threads, [&](Nexus &ctx) {
- thread_safety_loop(ctx, *lock, *state, meets, bp);
+ thread_safety_loop(ctx, *lock, *state, max, bp);
});
}
}
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.h b/vespalib/src/vespa/vespalib/datastore/array_store.h
index f5c30c90c5b..7ee63be3848 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store.h
+++ b/vespalib/src/vespa/vespalib/datastore/array_store.h
@@ -126,7 +126,7 @@ public:
return get_dynamic_array<typename TypeMapper::DynamicBufferType>(bufferAndMeta.get_buffer_acquire(), internalRef.offset(), bufferAndMeta.get_entry_size());
}
}
- return getSmallArray(internalRef, bufferAndMeta.getArraySize());
+ return getSmallArray(internalRef, bufferAndMeta.get_array_size());
} else {
return getLargeArray(internalRef);
}
@@ -196,6 +196,7 @@ public:
static ArrayStoreConfig optimizedConfigForHugePage(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor);
@@ -203,6 +204,7 @@ public:
const TypeMapper& mapper,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor);
};
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.hpp b/vespalib/src/vespa/vespalib/datastore/array_store.hpp
index 211176b8ad0..bfd4ff0430a 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store.hpp
+++ b/vespalib/src/vespa/vespalib/datastore/array_store.hpp
@@ -252,6 +252,7 @@ ArrayStoreConfig
ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor)
{
@@ -260,6 +261,7 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_ty
mapper,
hugePageSize,
smallPageSize,
+ max_buffer_size,
min_num_entries_for_new_buffer,
allocGrowFactor);
}
@@ -267,17 +269,19 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_ty
template <typename ElemT, typename RefT, typename TypeMapperT>
ArrayStoreConfig
ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id,
- const TypeMapper& mapper,
- size_t hugePageSize,
- size_t smallPageSize,
- size_t min_num_entries_for_new_buffer,
- float allocGrowFactor)
+ const TypeMapper& mapper,
+ size_t hugePageSize,
+ size_t smallPageSize,
+ size_t max_buffer_size,
+ size_t min_num_entries_for_new_buffer,
+ float allocGrowFactor)
{
return ArrayStoreConfig::optimizeForHugePage(mapper.get_max_type_id(max_type_id),
[&](uint32_t type_id) noexcept { return mapper.get_entry_size(type_id); },
hugePageSize,
smallPageSize,
RefT::offsetSize(),
+ max_buffer_size,
min_num_entries_for_new_buffer,
allocGrowFactor);
}
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp
index c7f0b69a85e..37f6fab96dc 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp
+++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "array_store_config.h"
+#include <algorithm>
#include <cassert>
namespace vespalib::datastore {
@@ -42,6 +43,13 @@ alignToSmallPageSize(size_t value, size_t minLimit, size_t smallPageSize)
return ((value - minLimit) / smallPageSize) * smallPageSize + minLimit;
}
+size_t
+cap_max_entries(size_t max_entries, size_t max_buffer_size, size_t entry_size)
+{
+ size_t dynamic_max_entries = (max_buffer_size + (entry_size - 1)) / entry_size;
+ return std::min(max_entries, dynamic_max_entries);
+}
+
}
ArrayStoreConfig
@@ -50,17 +58,21 @@ ArrayStoreConfig::optimizeForHugePage(uint32_t max_type_id,
size_t hugePageSize,
size_t smallPageSize,
size_t maxEntryRefOffset,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor)
{
AllocSpecVector allocSpecs;
- allocSpecs.emplace_back(0, maxEntryRefOffset, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec;
+ auto entry_size = type_id_to_entry_size(max_type_id);
+ auto capped_max_entries = cap_max_entries(maxEntryRefOffset, max_buffer_size, entry_size);
+ allocSpecs.emplace_back(0, capped_max_entries, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec;
for (uint32_t type_id = 1; type_id <= max_type_id; ++type_id) {
- size_t entry_size = type_id_to_entry_size(type_id);
+ entry_size = type_id_to_entry_size(type_id);
+ capped_max_entries = cap_max_entries(maxEntryRefOffset, max_buffer_size, entry_size);
size_t num_entries_for_new_buffer = hugePageSize / entry_size;
- num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, maxEntryRefOffset);
+ num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, capped_max_entries);
num_entries_for_new_buffer = alignToSmallPageSize(num_entries_for_new_buffer, min_num_entries_for_new_buffer, smallPageSize);
- allocSpecs.emplace_back(0, maxEntryRefOffset, num_entries_for_new_buffer, allocGrowFactor);
+ allocSpecs.emplace_back(0, capped_max_entries, num_entries_for_new_buffer, allocGrowFactor);
}
return ArrayStoreConfig(allocSpecs);
}
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.h b/vespalib/src/vespa/vespalib/datastore/array_store_config.h
index 3b62609d0f1..3967996c64d 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store_config.h
+++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.h
@@ -2,6 +2,7 @@
#pragma once
+#include <vespa/vespalib/util/size_literals.h>
#include <cstddef>
#include <cstdint>
#include <functional>
@@ -39,6 +40,8 @@ public:
using AllocSpecVector = std::vector<AllocSpec>;
+ static constexpr size_t default_max_buffer_size = 256_Mi;
+
private:
AllocSpecVector _allocSpecs;
bool _enable_free_lists;
@@ -77,6 +80,7 @@ public:
size_t hugePageSize,
size_t smallPageSize,
size_t maxEntryRefOffset,
+ size_t max_buffer_size,
size_t min_num_entries_for_new_buffer,
float allocGrowFactor);
};
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h
index 73c998e82a5..6797b2a79b4 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h
+++ b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h
@@ -36,9 +36,9 @@ public:
using LargeBufferType = vespalib::datastore::LargeArrayBufferType<ElemT>;
ArrayStoreDynamicTypeMapper();
- ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor);
+ ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size);
~ArrayStoreDynamicTypeMapper();
- void setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor);
+ void setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size);
size_t get_entry_size(uint32_t type_id) const;
bool is_dynamic_buffer(uint32_t type_id) const noexcept { return type_id > _max_static_array_buffer_type_id; }
uint32_t count_dynamic_buffer_types(uint32_t max_type_id) const noexcept { return (max_type_id > _max_static_array_buffer_type_id) ? (max_type_id - _max_static_array_buffer_type_id) : 0u; }
diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp
index e74cd92e6aa..48de5cf5332 100644
--- a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp
+++ b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp
@@ -18,16 +18,16 @@ ArrayStoreDynamicTypeMapper<ElemT>::ArrayStoreDynamicTypeMapper()
}
template <typename ElemT>
-ArrayStoreDynamicTypeMapper<ElemT>::ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor)
+ArrayStoreDynamicTypeMapper<ElemT>::ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size)
: ArrayStoreTypeMapper(),
_max_static_array_buffer_type_id(0)
{
- setup_array_sizes(max_buffer_type_id, grow_factor);
+ setup_array_sizes(max_buffer_type_id, grow_factor, max_buffer_size);
}
template <typename ElemT>
void
-ArrayStoreDynamicTypeMapper<ElemT>::setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor)
+ArrayStoreDynamicTypeMapper<ElemT>::setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size)
{
_array_sizes.clear();
_array_sizes.reserve(max_buffer_type_id + 1);
@@ -49,7 +49,8 @@ ArrayStoreDynamicTypeMapper<ElemT>::setup_array_sizes(uint32_t max_buffer_type_i
entry_size = array_size * sizeof(ElemT);
}
}
- if (entry_size > std::numeric_limits<uint32_t>::max()) {
+ if (entry_size > std::numeric_limits<uint32_t>::max() ||
+ entry_size >= 2 * max_buffer_size) {
break;
}
_array_sizes.emplace_back(array_size);
diff --git a/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp b/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp
index bb34c5d0f9d..1b087a01c58 100644
--- a/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp
+++ b/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp
@@ -27,6 +27,7 @@ BufferTypeBase::CleanContext::extraBytesCleaned(size_t value)
}
BufferTypeBase::BufferTypeBase(uint32_t entry_size_in,
+ uint32_t buffer_underflow_size_in,
uint32_t arraySize,
uint32_t min_entries,
uint32_t max_entries,
@@ -34,6 +35,7 @@ BufferTypeBase::BufferTypeBase(uint32_t entry_size_in,
float allocGrowFactor) noexcept
: _entry_size(entry_size_in),
_arraySize(arraySize),
+ _buffer_underflow_size(buffer_underflow_size_in),
_min_entries(std::min(min_entries, max_entries)),
_max_entries(max_entries),
_num_entries_for_new_buffer(std::min(num_entries_for_new_buffer, max_entries)),
@@ -46,10 +48,11 @@ BufferTypeBase::BufferTypeBase(uint32_t entry_size_in,
}
BufferTypeBase::BufferTypeBase(uint32_t entry_size_in,
+ uint32_t buffer_underflow_size_in,
uint32_t arraySize,
uint32_t min_entries,
uint32_t max_entries) noexcept
- : BufferTypeBase(entry_size_in, arraySize, min_entries, max_entries, 0u, DEFAULT_ALLOC_GROW_FACTOR)
+ : BufferTypeBase(entry_size_in, buffer_underflow_size_in, arraySize, min_entries, max_entries, 0u, DEFAULT_ALLOC_GROW_FACTOR)
{
}
@@ -117,6 +120,12 @@ BufferTypeBase::get_memory_allocator() const
return nullptr;
}
+bool
+BufferTypeBase::is_dynamic_array_buffer_type() const noexcept
+{
+ return false;
+}
+
void
BufferTypeBase::clamp_max_entries(uint32_t max_entries)
{
diff --git a/vespalib/src/vespa/vespalib/datastore/buffer_type.h b/vespalib/src/vespa/vespalib/datastore/buffer_type.h
index 3370bd47fad..7b23a238ba2 100644
--- a/vespalib/src/vespa/vespalib/datastore/buffer_type.h
+++ b/vespalib/src/vespa/vespalib/datastore/buffer_type.h
@@ -39,8 +39,8 @@ public:
BufferTypeBase & operator=(const BufferTypeBase &rhs) = delete;
BufferTypeBase(BufferTypeBase &&rhs) noexcept = default;
BufferTypeBase & operator=(BufferTypeBase &&rhs) noexcept = default;
- BufferTypeBase(uint32_t entry_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept;
- BufferTypeBase(uint32_t entry_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries,
+ BufferTypeBase(uint32_t entry_size_in, uint32_t buffer_underflow_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept;
+ BufferTypeBase(uint32_t entry_size_in, uint32_t buffer_underflow_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries,
uint32_t num_entries_for_new_buffer, float allocGrowFactor) noexcept;
virtual ~BufferTypeBase();
virtual void destroy_entries(void *buffer, EntryCount num_entries) = 0;
@@ -57,6 +57,7 @@ public:
*/
virtual void initialize_reserved_entries(void *buffer, EntryCount reserved_entries) = 0;
size_t entry_size() const noexcept { return _entry_size; }
+ uint32_t buffer_underflow_size() const noexcept { return _buffer_underflow_size; }
virtual void clean_hold(void *buffer, size_t offset, EntryCount num_entries, CleanContext cleanCtx) = 0;
size_t getArraySize() const noexcept { return _arraySize; }
virtual void on_active(uint32_t bufferId, std::atomic<EntryCount>* used_entries, std::atomic<EntryCount>* dead_entries, void* buffer);
@@ -64,6 +65,7 @@ public:
virtual void on_free(EntryCount used_entries);
void resume_primary_buffer(uint32_t buffer_id, std::atomic<EntryCount>* used_entries, std::atomic<EntryCount>* dead_entries);
virtual const alloc::MemoryAllocator* get_memory_allocator() const;
+ virtual bool is_dynamic_array_buffer_type() const noexcept;
/**
* Calculate number of entries to allocate for new buffer given how many free entries are needed.
@@ -114,6 +116,16 @@ protected:
uint32_t _entry_size; // Number of bytes in an allocation unit
uint32_t _arraySize; // Number of elements in an allocation unit
+
+ /*
+ * Buffer underflow size is the size of an area before the start
+ * of the logical buffer that is safe to access (part of the same
+ * memory alloation as the buffer itself). This allows for data
+ * belonging to an entry to be placed at the end of what is normally
+ * the last part of the previos entry (e.g. dynamic array size
+ * for the dynamic array buffer type).
+ */
+ uint32_t _buffer_underflow_size;
uint32_t _min_entries; // Minimum number of entries to allocate in a buffer
uint32_t _max_entries; // Maximum number of entries to allocate in a buffer
// Number of entries needed before allocating a new buffer instead of just resizing the first one
diff --git a/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp b/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp
index 375c832d9fb..00d642be9bc 100644
--- a/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp
+++ b/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp
@@ -8,13 +8,13 @@ namespace vespalib::datastore {
template <typename ElemT, typename EmptyT>
BufferType<ElemT, EmptyT>::BufferType(uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept
- : BufferTypeBase(arraySize * sizeof(ElemT), arraySize, min_entries, max_entries)
+ : BufferTypeBase(arraySize * sizeof(ElemT), 0u, arraySize, min_entries, max_entries)
{ }
template <typename ElemT, typename EmptyT>
BufferType<ElemT, EmptyT>::BufferType(uint32_t arraySize, uint32_t min_entries, uint32_t max_entries,
uint32_t num_entries_for_new_buffer, float allocGrowFactor) noexcept
- : BufferTypeBase(arraySize * sizeof(ElemT), arraySize, min_entries, max_entries, num_entries_for_new_buffer, allocGrowFactor)
+ : BufferTypeBase(arraySize * sizeof(ElemT), 0u, arraySize, min_entries, max_entries, num_entries_for_new_buffer, allocGrowFactor)
{ }
template <typename ElemT, typename EmptyT>
diff --git a/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp b/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp
index f312596d6f7..e7832a1c4e2 100644
--- a/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp
+++ b/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp
@@ -64,13 +64,14 @@ calc_allocation(uint32_t bufferId,
{
size_t alloc_entries = typeHandler.calc_entries_to_alloc(bufferId, free_entries_needed, resizing);
size_t entry_size = typeHandler.entry_size();
- size_t allocBytes = roundUpToMatchAllocator(alloc_entries * entry_size);
- size_t maxAllocBytes = typeHandler.get_max_entries() * entry_size;
+ auto buffer_underflow_size = typeHandler.buffer_underflow_size();
+ size_t allocBytes = roundUpToMatchAllocator(alloc_entries * entry_size + buffer_underflow_size);
+ size_t maxAllocBytes = typeHandler.get_max_entries() * entry_size + buffer_underflow_size;
if (allocBytes > maxAllocBytes) {
// Ensure that allocated bytes does not exceed the maximum handled by this type.
allocBytes = maxAllocBytes;
}
- size_t adjusted_alloc_entries = allocBytes / entry_size;
+ size_t adjusted_alloc_entries = (allocBytes - buffer_underflow_size) / entry_size;
return AllocResult(adjusted_alloc_entries, allocBytes);
}
@@ -102,7 +103,8 @@ BufferState::on_active(uint32_t bufferId, uint32_t typeId,
_buffer = (allocator != nullptr) ? Alloc::alloc_with_allocator(allocator) : Alloc::alloc(0, MemoryAllocator::HUGEPAGE_SIZE);
_buffer.create(alloc.bytes).swap(_buffer);
assert(_buffer.get() != nullptr || alloc.entries == 0u);
- buffer.store(_buffer.get(), std::memory_order_release);
+ auto buffer_underflow_size = typeHandler->buffer_underflow_size();
+ buffer.store(get_buffer(buffer_underflow_size), std::memory_order_release);
_stats.set_alloc_entries(alloc.entries);
_typeHandler.store(typeHandler, std::memory_order_release);
assert(typeId <= std::numeric_limits<uint16_t>::max());
@@ -117,28 +119,30 @@ void
BufferState::onHold(uint32_t buffer_id)
{
assert(getState() == State::ACTIVE);
- assert(getTypeHandler() != nullptr);
+ auto type_handler = getTypeHandler();
+ assert(type_handler != nullptr);
_state.store(State::HOLD, std::memory_order_release);
_compacting = false;
assert(_stats.dead_entries() <= size());
assert(_stats.hold_entries() <= (size() - _stats.dead_entries()));
_stats.set_dead_entries(0);
_stats.set_hold_entries(size());
- getTypeHandler()->on_hold(buffer_id, &_stats.used_entries_ref(), &_stats.dead_entries_ref());
+ type_handler->on_hold(buffer_id, &_stats.used_entries_ref(), &_stats.dead_entries_ref());
_free_list.disable();
}
void
BufferState::onFree(std::atomic<void*>& buffer)
{
- assert(buffer.load(std::memory_order_relaxed) == _buffer.get());
assert(getState() == State::HOLD);
- assert(_typeHandler != nullptr);
+ auto type_handler = getTypeHandler();
+ assert(type_handler != nullptr);
+ assert(buffer.load(std::memory_order_relaxed) == get_buffer(type_handler->buffer_underflow_size()));
assert(_stats.dead_entries() <= size());
assert(_stats.hold_entries() == (size() - _stats.dead_entries()));
- getTypeHandler()->destroy_entries(buffer, size());
+ type_handler->destroy_entries(buffer, size());
Alloc::alloc().swap(_buffer);
- getTypeHandler()->on_free(size());
+ type_handler->on_free(size());
buffer.store(nullptr, std::memory_order_release);
_stats.clear();
_state.store(State::FREE, std::memory_order_release);
@@ -200,9 +204,11 @@ BufferState::free_entries(EntryRef ref, size_t num_entries, size_t ref_offset)
}
_stats.inc_dead_entries(num_entries);
_stats.dec_hold_entries(num_entries);
- getTypeHandler()->clean_hold(_buffer.get(), ref_offset, num_entries,
- BufferTypeBase::CleanContext(_stats.extra_used_bytes_ref(),
- _stats.extra_hold_bytes_ref()));
+ auto type_handler = getTypeHandler();
+ auto buffer_underflow_size = type_handler->buffer_underflow_size();
+ type_handler->clean_hold(get_buffer(buffer_underflow_size), ref_offset, num_entries,
+ BufferTypeBase::CleanContext(_stats.extra_used_bytes_ref(),
+ _stats.extra_hold_bytes_ref()));
}
void
@@ -212,17 +218,19 @@ BufferState::fallback_resize(uint32_t bufferId,
Alloc &holdBuffer)
{
assert(getState() == State::ACTIVE);
- assert(_typeHandler != nullptr);
+ auto type_handler = getTypeHandler();
+ assert(type_handler != nullptr);
assert(holdBuffer.get() == nullptr);
- AllocResult alloc = calc_allocation(bufferId, *_typeHandler, free_entries_needed, true);
+ auto buffer_underflow_size = type_handler->buffer_underflow_size();
+ AllocResult alloc = calc_allocation(bufferId, *type_handler, free_entries_needed, true);
assert(alloc.entries >= size() + free_entries_needed);
assert(alloc.entries > capacity());
Alloc newBuffer = _buffer.create(alloc.bytes);
- getTypeHandler()->fallback_copy(newBuffer.get(), buffer.load(std::memory_order_relaxed), size());
+ type_handler->fallback_copy(get_buffer(newBuffer, buffer_underflow_size), buffer.load(std::memory_order_relaxed), size());
holdBuffer.swap(_buffer);
std::atomic_thread_fence(std::memory_order_release);
_buffer = std::move(newBuffer);
- buffer.store(_buffer.get(), std::memory_order_release);
+ buffer.store(get_buffer(buffer_underflow_size), std::memory_order_release);
_stats.set_alloc_entries(alloc.entries);
}
diff --git a/vespalib/src/vespa/vespalib/datastore/bufferstate.h b/vespalib/src/vespa/vespalib/datastore/bufferstate.h
index 289be32e19b..01439586f5b 100644
--- a/vespalib/src/vespa/vespalib/datastore/bufferstate.h
+++ b/vespalib/src/vespa/vespalib/datastore/bufferstate.h
@@ -48,6 +48,8 @@ private:
bool _disable_entry_hold_list : 1;
bool _compacting : 1;
+ static void *get_buffer(Alloc& buffer, uint32_t buffer_underflow_size) noexcept { return static_cast<char *>(buffer.get()) + buffer_underflow_size; }
+ void *get_buffer(uint32_t buffer_underflow_size) noexcept { return get_buffer(_buffer, buffer_underflow_size); }
public:
/**
* TODO: Check if per-buffer free lists are useful, or if
@@ -137,24 +139,28 @@ public:
void* get_buffer_relaxed() noexcept { return _buffer.load(std::memory_order_relaxed); }
const void* get_buffer_acquire() const noexcept { return _buffer.load(std::memory_order_acquire); }
uint32_t getTypeId() const { return _typeId; }
- uint32_t getArraySize() const { return _arraySize; }
+ uint32_t get_array_size() const { return _array_size; }
BufferState * get_state_relaxed() { return _state.load(std::memory_order_relaxed); }
const BufferState * get_state_acquire() const { return _state.load(std::memory_order_acquire); }
- uint32_t get_entry_size() const { return get_state_acquire()->getTypeHandler()->entry_size(); }
+ uint32_t get_entry_size() const noexcept { return _entry_size; }
void setTypeId(uint32_t typeId) { _typeId = typeId; }
- void setArraySize(uint32_t arraySize) { _arraySize = arraySize; }
+ void set_array_size(uint32_t arraySize) { _array_size = arraySize; }
+ void set_entry_size(uint32_t entry_size) noexcept { _entry_size = entry_size; }
void set_state(BufferState * state) { _state.store(state, std::memory_order_release); }
private:
BufferAndMeta(void* buffer, BufferState * state, uint32_t typeId, uint32_t arraySize)
: _buffer(buffer),
_state(state),
_typeId(typeId),
- _arraySize(arraySize)
+ _array_size(arraySize)
{ }
std::atomic<void*> _buffer;
std::atomic<BufferState*> _state;
uint32_t _typeId;
- uint32_t _arraySize;
+ union {
+ uint32_t _array_size; // Valid unless buffer type is dynamic array buffer type
+ uint32_t _entry_size; // Valid if buffer type is dynamic array buffer type
+ };
};
}
diff --git a/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp b/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp
index 75ffe855a32..5c88900ae92 100644
--- a/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp
+++ b/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp
@@ -59,7 +59,8 @@ DataStoreBase::FallbackHold::FallbackHold(size_t bytesSize, BufferState::Alloc &
DataStoreBase::FallbackHold::~FallbackHold()
{
- _typeHandler->destroy_entries(_buffer.get(), _used_entries);
+ auto buffer_underflow_size = _typeHandler->buffer_underflow_size();
+ _typeHandler->destroy_entries(static_cast<char *>(_buffer.get()) + buffer_underflow_size, _used_entries);
}
class DataStoreBase::BufferHold : public GenerationHeldBase {
@@ -415,7 +416,11 @@ DataStoreBase::on_active(uint32_t bufferId, uint32_t typeId, size_t entries_need
assert(state->isFree());
state->on_active(bufferId, typeId, _typeHandlers[typeId], entries_needed, bufferMeta.get_atomic_buffer());
bufferMeta.setTypeId(typeId);
- bufferMeta.setArraySize(state->getArraySize());
+ if (_typeHandlers[typeId]->is_dynamic_array_buffer_type()) {
+ bufferMeta.set_entry_size(_typeHandlers[typeId]->entry_size());
+ } else {
+ bufferMeta.set_array_size(state->getArraySize());
+ }
if (_freeListsEnabled && state->isActive() && !state->getCompacting()) {
state->enable_free_list(_free_lists[state->getTypeId()]);
}
diff --git a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h
index fbd3c2361d1..62b311f6aee 100644
--- a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h
+++ b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h
@@ -7,6 +7,7 @@
#include "array_store_config.h"
#include <algorithm>
#include <memory>
+#include <type_traits>
namespace vespalib::datastore {
@@ -26,12 +27,14 @@ class DynamicArrayBufferType : public BufferTypeBase
{
using AllocSpec = ArrayStoreConfig::AllocSpec;
std::shared_ptr<alloc::MemoryAllocator> _memory_allocator;
+
public:
using ElemType = ElemT;
static constexpr size_t entry_min_align = std::max(alignof(uint32_t), alignof(ElemT));
using EntryMinAligner = Aligner<entry_min_align>;
- static constexpr size_t entry_bias = EntryMinAligner::align(sizeof(uint32_t));
+ static constexpr uint32_t dynamic_array_buffer_underflow_size = 64u;
+ static constexpr bool align_for_simd = std::disjunction_v<std::is_same<ElemT, double>,std::is_same<ElemT, float>, std::is_same<ElemT, int64_t>, std::is_same<ElemT, int32_t>>;
protected:
static const ElemType& empty_entry() noexcept;
ElemType* get_entry(void *buffer, size_t offset) noexcept { return get_entry(buffer, offset, entry_size()); }
@@ -55,10 +58,11 @@ public:
const vespalib::alloc::MemoryAllocator* get_memory_allocator() const override;
static size_t calc_entry_size(size_t array_size) noexcept;
static size_t calc_array_size(size_t entry_size) noexcept;
- static ElemType* get_entry(void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<ElemType*>(static_cast<char*>(buffer) + offset * entry_size + entry_bias); }
- static const ElemType* get_entry(const void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<const ElemType*>(static_cast<const char*>(buffer) + offset * entry_size + entry_bias); }
+ static ElemType* get_entry(void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<ElemType*>(static_cast<char*>(buffer) + offset * entry_size); }
+ static const ElemType* get_entry(const void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<const ElemType*>(static_cast<const char*>(buffer) + offset * entry_size); }
static uint32_t get_dynamic_array_size(const ElemType* buffer) noexcept { return *(reinterpret_cast<const uint32_t*>(buffer) - 1); }
static void set_dynamic_array_size(ElemType* buffer, uint32_t array_size) noexcept { *(reinterpret_cast<uint32_t*>(buffer) - 1) = array_size; }
+ bool is_dynamic_array_buffer_type() const noexcept override;
};
extern template class DynamicArrayBufferType<char>;
diff --git a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp
index bf3235a6b97..b208fa627e4 100644
--- a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp
+++ b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp
@@ -9,7 +9,7 @@ namespace vespalib::datastore {
template <typename ElemT>
DynamicArrayBufferType<ElemT>::DynamicArrayBufferType(uint32_t array_size, const AllocSpec& spec, std::shared_ptr<alloc::MemoryAllocator> memory_allocator) noexcept
- : BufferTypeBase(calc_entry_size(array_size), array_size, spec.min_entries_in_buffer, spec.max_entries_in_buffer, spec.num_entries_for_new_buffer, spec.allocGrowFactor),
+ : BufferTypeBase(calc_entry_size(array_size), dynamic_array_buffer_underflow_size, array_size, spec.min_entries_in_buffer, spec.max_entries_in_buffer, spec.num_entries_for_new_buffer, spec.allocGrowFactor),
_memory_allocator(std::move(memory_allocator))
{
}
@@ -24,14 +24,18 @@ template <typename ElemT>
size_t
DynamicArrayBufferType<ElemT>::calc_entry_size(size_t array_size) noexcept
{
- return EntryMinAligner::align(sizeof(ElemType) * array_size + entry_bias);
+ auto entry_size = EntryMinAligner::align(sizeof(ElemType) * array_size + sizeof(uint32_t));
+ if (align_for_simd && entry_size >= 512) {
+ entry_size = Aligner<64>::align(entry_size);
+ }
+ return entry_size;
}
template <typename ElemT>
size_t
DynamicArrayBufferType<ElemT>::calc_array_size(size_t entry_size) noexcept
{
- return (entry_size - entry_bias) / sizeof(ElemType);
+ return (entry_size - sizeof(uint32_t)) / sizeof(ElemType);
}
template <typename ElemT>
@@ -116,4 +120,11 @@ DynamicArrayBufferType<ElemT>::get_memory_allocator() const
return _memory_allocator.get();
}
+template <typename ElemT>
+bool
+DynamicArrayBufferType<ElemT>::is_dynamic_array_buffer_type() const noexcept
+{
+ return true;
+}
+
}
diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h
index d3348950891..102808f7629 100644
--- a/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h
+++ b/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h
@@ -117,7 +117,7 @@ public:
auto &state = _store.getBufferMeta(iRef.bufferId());
auto type_id = state.getTypeId();
if (type_id != 0) {
- return *reinterpret_cast<const UniqueStoreEntryBase *>(_store.template getEntryArray<char>(iRef, state.getArraySize()));
+ return *reinterpret_cast<const UniqueStoreEntryBase *>(_store.template getEntryArray<char>(iRef, state.get_array_size()));
} else {
return *_store.template getEntry<WrappedExternalEntryType>(iRef);
}
@@ -127,7 +127,7 @@ public:
auto &state = _store.getBufferMeta(iRef.bufferId());
auto type_id = state.getTypeId();
if (type_id != 0) {
- return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, state.getArraySize()))->value();
+ return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, state.get_array_size()))->value();
} else {
return _store.template getEntry<WrappedExternalEntryType>(iRef)->value().c_str();
}
diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h
index e507132a085..e71dcd3aafb 100644
--- a/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h
+++ b/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h
@@ -28,7 +28,7 @@ protected:
const auto &meta = _store.getBufferMeta(iRef.bufferId());
auto type_id = meta.getTypeId();
if (type_id != 0) {
- return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, meta.getArraySize()))->value();
+ return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, meta.get_array_size()))->value();
} else {
return _store.template getEntry<WrappedExternalEntryType>(iRef)->value().c_str();
}
diff --git a/vespalib/src/vespa/vespalib/stllike/cache.h b/vespalib/src/vespa/vespalib/stllike/cache.h
index f7456cda197..97ba23aba65 100644
--- a/vespalib/src/vespa/vespalib/stllike/cache.h
+++ b/vespalib/src/vespa/vespalib/stllike/cache.h
@@ -110,7 +110,7 @@ public:
*/
bool hasKey(const K & key) const;
- CacheStats get_stats() const;
+ virtual CacheStats get_stats() const;
size_t getHit() const { return _hit.load(std::memory_order_relaxed); }
size_t getMiss() const { return _miss.load(std::memory_order_relaxed); }
diff --git a/vespalib/src/vespa/vespalib/stllike/cache.hpp b/vespalib/src/vespa/vespalib/stllike/cache.hpp
index 4e7736c9e5f..f767cf9812c 100644
--- a/vespalib/src/vespa/vespalib/stllike/cache.hpp
+++ b/vespalib/src/vespa/vespalib/stllike/cache.hpp
@@ -61,9 +61,7 @@ cache<P>::getStaticMemoryUsage() const {
MemoryUsage usage;
auto cacheGuard = getGuard();
usage.incAllocatedBytes(sizeof(*this));
- usage.incAllocatedBytes(Lru::capacity()*sizeof(typename Lru::value_type));
usage.incUsedBytes(sizeof(*this));
- usage.incUsedBytes(Lru::size()*sizeof(typename Lru::value_type));
return usage;
}
diff --git a/vespalib/src/vespa/vespalib/test/nexus.h b/vespalib/src/vespa/vespalib/test/nexus.h
index 659b563e43c..7be35d42463 100644
--- a/vespalib/src/vespa/vespalib/test/nexus.h
+++ b/vespalib/src/vespa/vespalib/test/nexus.h
@@ -36,6 +36,7 @@ public:
Nexus &operator=(const Nexus &) = delete;
size_t num_threads() const noexcept { return _vote.size(); }
size_t thread_id() const noexcept { return _thread_id; }
+ bool is_main() const noexcept { return _thread_id == 0; }
bool vote(bool my_vote) { return _vote(my_vote); }
void barrier() { REQUIRE_EQ(_vote(true), true); }
struct select_thread_0 {};
diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.h b/vespalib/src/vespa/vespalib/test/thread_meets.h
index 7ef4dcb9921..26ca560641d 100644
--- a/vespalib/src/vespa/vespalib/test/thread_meets.h
+++ b/vespalib/src/vespa/vespalib/test/thread_meets.h
@@ -38,8 +38,8 @@ struct ThreadMeets {
explicit Sum(size_t N) : vespalib::Rendezvous<T,T>(N) {}
T operator()(T value) { return rendezvous(value); }
void mingle() override {
- T acc{};
- for (size_t i = 0; i < size(); ++i) {
+ T acc = in(0);
+ for (size_t i = 1; i < size(); ++i) {
acc += in(i);
}
for (size_t i = 0; i < size(); ++i) {
@@ -47,6 +47,48 @@ struct ThreadMeets {
}
}
};
+ // maximum of values across all threads
+ template <typename T>
+ struct Max : vespalib::Rendezvous<T,T> {
+ using vespalib::Rendezvous<T,T>::in;
+ using vespalib::Rendezvous<T,T>::out;
+ using vespalib::Rendezvous<T,T>::size;
+ using vespalib::Rendezvous<T,T>::rendezvous;
+ explicit Max(size_t N) : vespalib::Rendezvous<T,T>(N) {}
+ T operator()(T value) { return rendezvous(value); }
+ void mingle() override {
+ T max = in(0);
+ for (size_t i = 1; i < size(); ++i) {
+ if (in(i) > max) {
+ max = in(i);
+ }
+ }
+ for (size_t i = 0; i < size(); ++i) {
+ out(i) = max;
+ }
+ }
+ };
+ // minimum of values across all threads
+ template <typename T>
+ struct Min : vespalib::Rendezvous<T,T> {
+ using vespalib::Rendezvous<T,T>::in;
+ using vespalib::Rendezvous<T,T>::out;
+ using vespalib::Rendezvous<T,T>::size;
+ using vespalib::Rendezvous<T,T>::rendezvous;
+ explicit Min(size_t N) : vespalib::Rendezvous<T,T>(N) {}
+ T operator()(T value) { return rendezvous(value); }
+ void mingle() override {
+ T min = in(0);
+ for (size_t i = 1; i < size(); ++i) {
+ if (in(i) < min) {
+ min = in(i);
+ }
+ }
+ for (size_t i = 0; i < size(); ++i) {
+ out(i) = min;
+ }
+ }
+ };
// range of values across all threads
template <typename T>
struct Range : vespalib::Rendezvous<T,T> {
diff --git a/vespamalloc/src/vespamalloc/malloc/threadlist.hpp b/vespamalloc/src/vespamalloc/malloc/threadlist.hpp
index 743090a4e12..e22b93c48fe 100644
--- a/vespamalloc/src/vespamalloc/malloc/threadlist.hpp
+++ b/vespamalloc/src/vespamalloc/malloc/threadlist.hpp
@@ -2,9 +2,14 @@
#pragma once
#include "threadlist.h"
+#include <malloc.h>
namespace vespamalloc {
+namespace {
+ const char * VESPA_MALLOC_MMAP_THRESHOLD = "VESPA_MALLOC_MMAP_THRESHOLD";
+}
+
template <typename MemBlockPtrT, typename ThreadStatT>
ThreadListT<MemBlockPtrT, ThreadStatT>::ThreadListT(AllocPool & allocPool, MMapPool & mmapPool) :
_isThreaded(false),
@@ -13,8 +18,14 @@ ThreadListT<MemBlockPtrT, ThreadStatT>::ThreadListT(AllocPool & allocPool, MMapP
_allocPool(allocPool),
_mmapPool(mmapPool)
{
+ const char * mmapThresholdS = getenv(VESPA_MALLOC_MMAP_THRESHOLD);
+ int mmapThreshold = (mmapThresholdS != nullptr)
+ ? strtol(mmapThresholdS, nullptr, 0)
+ : MMAP_LIMIT_DEFAULT;
for (size_t i = 0; i < getMaxNumThreads(); i++) {
- _threadVector[i].setPool(_allocPool, _mmapPool);
+ auto & thread = _threadVector[i];
+ thread.setPool(_allocPool, _mmapPool);
+ thread.mallopt(M_MMAP_THRESHOLD, mmapThreshold);
}
}
diff --git a/vespamalloc/src/vespamalloc/malloc/threadpool.h b/vespamalloc/src/vespamalloc/malloc/threadpool.h
index 30ece02ba29..750833084ca 100644
--- a/vespamalloc/src/vespamalloc/malloc/threadpool.h
+++ b/vespamalloc/src/vespamalloc/malloc/threadpool.h
@@ -9,6 +9,10 @@
namespace vespamalloc {
+constexpr int MMAP_LIMIT_MIN = 0x100000; // 1M
+constexpr int MMAP_LIMIT_DEFAULT = 0x4000000; // 64M
+constexpr int MMAP_LIMIT_MAX = 0x40000000; // 1G
+
template <typename MemBlockPtrT, typename ThreadStatT >
class ThreadPoolT
{
diff --git a/vespamalloc/src/vespamalloc/malloc/threadpool.hpp b/vespamalloc/src/vespamalloc/malloc/threadpool.hpp
index b5a283f6600..e62fa0f2fdf 100644
--- a/vespamalloc/src/vespamalloc/malloc/threadpool.hpp
+++ b/vespamalloc/src/vespamalloc/malloc/threadpool.hpp
@@ -7,8 +7,10 @@
namespace vespamalloc {
namespace {
- constexpr size_t MMAP_LIMIT_MIN = 0x100000; // 1M
- constexpr size_t MMAP_LIMIT_MAX = 0x40000000; // 1G
+ size_t
+ sanitizeMMapThreshold(int threshold) {
+ return std::min(MMAP_LIMIT_MAX, std::max(MMAP_LIMIT_MIN, threshold));
+ }
}
template <typename MemBlockPtrT, typename ThreadStatT>
@@ -112,9 +114,8 @@ ThreadPoolT<MemBlockPtrT, ThreadStatT>::~ThreadPoolT() = default;
template <typename MemBlockPtrT, typename ThreadStatT >
int ThreadPoolT<MemBlockPtrT, ThreadStatT>::mallopt(int param, int value) {
- size_t limit = value;
if (param == M_MMAP_THRESHOLD) {
- _mmapLimit = std::min(MMAP_LIMIT_MAX, std::max(MMAP_LIMIT_MIN, limit));
+ _mmapLimit = sanitizeMMapThreshold(value);
return 1;
}
return 0;