diff options
135 files changed, 2984 insertions, 990 deletions
diff --git a/clustercontroller-core/pom.xml b/clustercontroller-core/pom.xml index 647d8ca4e64..61176c178c6 100644 --- a/clustercontroller-core/pom.xml +++ b/clustercontroller-core/pom.xml @@ -15,14 +15,7 @@ <dependencies> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>annotations</artifactId> - <version>${project.version}</version> - <scope>provided</scope> - </dependency> - <dependency> - <!-- required for bundle-plugin to generate import-package statements for Java's standard library --> - <groupId>com.yahoo.vespa</groupId> - <artifactId>jdisc_core</artifactId> + <artifactId>container-dev</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java index c823c94afd1..60671c96474 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/NodeStateChangeChecker.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.clustercontroller.core; +import ai.vespa.metrics.StorageMetrics; import com.yahoo.lang.MutableBoolean; import com.yahoo.lang.SettableOptional; import com.yahoo.vdslib.distribution.ConfiguredNode; @@ -44,7 +45,7 @@ import static java.util.logging.Level.FINE; public class NodeStateChangeChecker { private static final Logger log = Logger.getLogger(NodeStateChangeChecker.class.getName()); - private static final String BUCKETS_METRIC_NAME = "vds.datastored.bucket_space.buckets_total"; + private static final String BUCKETS_METRIC_NAME = StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_BUCKETS_TOTAL.baseName(); private static final Map<String, String> BUCKETS_METRIC_DIMENSIONS = Map.of("bucketSpace", "default"); private final int requiredRedundancy; diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java index 65c08e67850..09aac786b2f 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.clustercontroller.core.restapiv2.requests; +import ai.vespa.metrics.StorageMetrics; import com.yahoo.vespa.clustercontroller.core.NodeInfo; import com.yahoo.vespa.clustercontroller.core.RemoteClusterControllerTask; import com.yahoo.vespa.clustercontroller.core.hostinfo.Metrics; @@ -46,17 +47,17 @@ public class NodeStateRequest extends Request<Response.NodeResponse> { } private static void fillInMetricValue(String name, Metrics.Value value, Response.NodeResponse result) { - if (name.equals("vds.datastored.alldisks.docs")) { + if (name.equals(StorageMetrics.VDS_DATASTORED_ALLDISKS_DOCS.baseName())) { if (value.getLast() == null) { return; } result.addMetric("unique-document-count", value.getLast()); - } else if (name.equals("vds.datastored.alldisks.bytes")) { + } else if (name.equals(StorageMetrics.VDS_DATASTORED_ALLDISKS_BYTES.baseName())) { if (value.getLast() == null) { return; } result.addMetric("unique-document-total-size", value.getLast()); - } else if (name.equals("vds.datastored.alldisks.buckets")) { + } else if (name.equals(StorageMetrics.VDS_DATASTORED_ALLDISKS_BUCKETS.baseName())) { if (value.getLast() == null) { return; } diff --git a/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java b/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java index a1aa5287d2f..83be05d970e 100644 --- a/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java +++ b/clustercontroller-reindexer/src/main/java/ai/vespa/reindexing/ReindexingMetrics.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.reindexing; +import ai.vespa.metrics.ClusterControllerMetrics; import com.yahoo.documentapi.ProgressToken; import com.yahoo.jdisc.Metric; @@ -27,7 +28,7 @@ class ReindexingMetrics { void dump(Reindexing reindexing) { reindexing.status().forEach((type, status) -> { - metric.set("reindexing.progress", + metric.set(ClusterControllerMetrics.REINDEXING_PROGRESS.baseName(), status.progress().map(ProgressToken::percentFinished).map(percentage -> percentage * 1e-2) .orElse(status.state() == SUCCESSFUL ? 1.0 : 0.0), metric.createContext(Map.of("clusterid", cluster, "documenttype", type.getName()))); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java index 82ad8e5d6e8..47024c1171c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/RedundancyIncreaseValidator.java @@ -35,7 +35,7 @@ public class RedundancyIncreaseValidator implements ChangeValidator { } private int redundancyOf(ContentCluster cluster) { - return cluster.redundancy().finalRedundancy(); + return cluster.getRedundancy().finalRedundancy(); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java index 5228610537f..2be0f0b8422 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/first/RedundancyValidator.java @@ -2,7 +2,6 @@ package com.yahoo.vespa.model.application.validation.first; import com.yahoo.config.application.api.ValidationId; -import com.yahoo.config.application.api.ValidationOverrides; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.vespa.model.VespaModel; @@ -10,7 +9,6 @@ import com.yahoo.vespa.model.application.validation.Validator; import com.yahoo.vespa.model.application.validation.change.ChangeValidator; import com.yahoo.vespa.model.content.cluster.ContentCluster; -import java.time.Instant; import java.util.List; import java.util.stream.Stream; @@ -48,7 +46,7 @@ public class RedundancyValidator extends Validator implements ChangeValidator { } private boolean hasRedundancyOne(ContentCluster cluster) { - return cluster != null && cluster.redundancy().finalRedundancy() == 1 && cluster.redundancy().groups() == 1; + return cluster != null && cluster.getRedundancy().finalRedundancy() == 1 && cluster.getRedundancy().groups() == 1; } private void invalidRedundancy(ContentCluster cluster, DeployState deployState) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java index 64592e75c41..a0a4151daf5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomSearchTuningBuilder.java @@ -298,6 +298,8 @@ public class DomSearchTuningBuilder extends VespaDomBuilder.DomConfigProducerBui for (Element e : XML.getChildren(spec)) { if (equals("concurrency", e)) { sn.feeding.concurrency = asDouble(e); + } else if (equals("niceness", e)) { + sn.feeding.niceness = asDouble(e); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java index 2ca6d5d7155..f7d4fe28c6e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/Container.java @@ -109,6 +109,7 @@ public abstract class Container extends AbstractService implements addChild(new SimpleComponent("com.yahoo.container.jdisc.ConfiguredApplication$ApplicationContext")); appendJvmOptions(jvmOmitStackTraceInFastThrowOption(deployState.featureFlags())); + addEnvironmentVariable("VESPA_MALLOC_MMAP_THRESHOLD","0x200000"); } protected String jvmOmitStackTraceInFastThrowOption(ModelContext.FeatureFlags featureFlags) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java index ec7acaf819f..34ea41384bc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java @@ -270,6 +270,25 @@ public class ContentSearchCluster extends TreeConfigProducer<AnyConfigProducer> clusters.put(sc.getClusterName(), sc); } + /** + * Returns whether the schemas in this cluster use streaming mode. + * + * @return True if this cluster only has schemas with streaming mode, False if it only has schemas + * with indexing, null if it has both or none. + */ + public Boolean isStreaming() { + boolean hasStreaming = false; + boolean hasIndexed = false; + for (var cluster : clusters.values()) { + if (cluster.isStreaming()) + hasStreaming = true; + else + hasIndexed = true; + } + if (hasIndexed == hasStreaming) return null; + return hasStreaming; + } + public List<SearchNode> getSearchNodes() { return hasIndexedCluster() ? getIndexed().getSearchNodes() : nonIndexed; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java index 4aac8bfb647..6f0a03bab60 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/DistributorCluster.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.content; +import ai.vespa.metrics.DistributorMetrics; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.vespa.config.content.core.StorDistributormanagerConfig; import com.yahoo.vespa.config.content.core.StorServerConfig; @@ -135,13 +136,13 @@ public class DistributorCluster extends TreeConfigProducer<Distributor> implemen @Override public void getConfig(MetricsmanagerConfig.Builder builder) { ContentCluster.getMetricBuilder("log", builder). - addedmetrics("vds.distributor.docsstored"). - addedmetrics("vds.distributor.bytesstored"). - addedmetrics("vds.idealstate.delete_bucket.done_ok"). - addedmetrics("vds.idealstate.merge_bucket.done_ok"). - addedmetrics("vds.idealstate.split_bucket.done_ok"). - addedmetrics("vds.idealstate.join_bucket.done_ok"). - addedmetrics("vds.idealstate.buckets_rechecking"); + addedmetrics(DistributorMetrics.VDS_DISTRIBUTOR_DOCSSTORED.baseName()). + addedmetrics(DistributorMetrics.VDS_DISTRIBUTOR_BYTESSTORED.baseName()). + addedmetrics(DistributorMetrics.VDS_IDEALSTATE_DELETE_BUCKET_DONE_OK.baseName()). + addedmetrics(DistributorMetrics.VDS_IDEALSTATE_MERGE_BUCKET_DONE_OK.baseName()). + addedmetrics(DistributorMetrics.VDS_IDEALSTATE_SPLIT_BUCKET_DONE_OK.baseName()). + addedmetrics(DistributorMetrics.VDS_IDEALSTATE_JOIN_BUCKET_DONE_OK.baseName()). + addedmetrics(DistributorMetrics.VDS_IDEALSTATE_BUCKETS_RECHECKING.baseName()); } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java b/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java index 52b2ce06dfe..6078215f9b6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/StorageGroup.java @@ -171,12 +171,11 @@ public class StorageGroup { } @Override - public boolean equals(Object obj) { - if (obj instanceof StorageGroup) { - StorageGroup rhs = (StorageGroup)obj; - return this.index.equals(rhs.index) && - this.name.equals(rhs.name) && - this.partitions.equals(rhs.partitions); + public boolean equals(Object o) { + if (o instanceof StorageGroup other) { + return this.index.equals(other.index) && + this.name.equals(other.name) && + this.partitions.equals(other.partitions); } return false; } @@ -208,9 +207,7 @@ public class StorageGroup { this.context = context; } - public StorageGroup buildRootGroup(DeployState deployState, - RedundancyBuilder redundancyBuilder, - ContentCluster owner) { + public StorageGroup buildRootGroup(DeployState deployState, ContentCluster owner, Boolean isStreaming) { try { if (owner.isHosted()) validateRedundancyAndGroups(deployState.zone().environment()); @@ -229,7 +226,8 @@ public class StorageGroup { ? groupBuilder.buildHosted(deployState, owner, Optional.empty(), context) : groupBuilder.buildNonHosted(deployState, owner, Optional.empty()); - Redundancy redundancy = redundancyBuilder.build(owner.isHosted(), storageGroup.subgroups.size(), + RedundancyBuilder redundancyBuilder = new RedundancyBuilder(clusterElement); + Redundancy redundancy = redundancyBuilder.build(owner.isHosted(), isStreaming, storageGroup.subgroups.size(), storageGroup.getNumberOfLeafGroups(), storageGroup.countNodes(false)); owner.setRedundancy(redundancy); if (storageGroup.partitions.isEmpty() && (redundancy.groups() > 1)) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java index 2592beca6c6..f792ac3a591 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/ContentCluster.java @@ -114,7 +114,6 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem new SearchDefinitionBuilder().build(deployState.getDocumentModel().getDocumentManager(), documentsElement); String routingSelection = new DocumentSelectionBuilder().build(documentsElement); - RedundancyBuilder redundancyBuilder = new RedundancyBuilder(contentElement); Set<NewDocumentType> globallyDistributedDocuments = new GlobalDistributionBuilder(documentDefinitions).build(documentsElement); String clusterId = getClusterId(contentElement); @@ -133,7 +132,7 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem c.persistenceFactory = new EngineFactoryBuilder().build(contentElement, c); c.storageNodes = new StorageCluster.Builder().build(deployState, c, w3cContentElement); c.distributorNodes = new DistributorCluster.Builder(c).build(deployState, c, w3cContentElement); - c.rootGroup = new StorageGroup.Builder(contentElement, context).buildRootGroup(deployState, redundancyBuilder, c); + c.rootGroup = new StorageGroup.Builder(contentElement, context).buildRootGroup(deployState, c, c.search.isStreaming()); c.clusterControllerConfig = createClusterControllerConfig(contentElement, deployState, c, resourceLimits); validateThatGroupSiblingsAreUnique(c.clusterId, c.rootGroup); c.search.handleRedundancy(c.redundancy); @@ -447,7 +446,7 @@ public class ContentCluster extends TreeConfigProducer<AnyConfigProducer> implem public final ContentSearchCluster getSearch() { return search; } - public Redundancy redundancy() { return redundancy; } + public Redundancy getRedundancy() { return redundancy; } public ContentCluster setRedundancy(Redundancy redundancy) { this.redundancy = redundancy; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java index e7bafdf52e4..d310db067a6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/cluster/RedundancyBuilder.java @@ -20,7 +20,7 @@ public class RedundancyBuilder { // Always global (across groups) private Integer globalMinRedundancy = null; - RedundancyBuilder(ModelElement clusterXml) { + public RedundancyBuilder(ModelElement clusterXml) { ModelElement redundancyElement = clusterXml.child("redundancy"); if (redundancyElement != null) { initialRedundancy = redundancyElement.integerAttribute("reply-after"); @@ -47,22 +47,40 @@ public class RedundancyBuilder { throw new IllegalArgumentException("Either <redundancy> or <min-redundancy> must be set"); } - public Redundancy build(boolean isHosted, int subGroups, int leafGroups, int totalNodes) { + /** + * @param isHosted + * @param isStreaming true if this cluster only has schemas with streaming mode, false if it only has schemas + * without streaming, null if it has both + * @param subGroups + * @param leafGroups + * @param totalNodes + * @return + */ + public Redundancy build(boolean isHosted, Boolean isStreaming, int subGroups, int leafGroups, int totalNodes) { if (isHosted) { if (globalMinRedundancy != null && ( finalRedundancy == null || finalRedundancy * leafGroups < globalMinRedundancy )) initialRedundancy = finalRedundancy = (int)Math.ceil((double)globalMinRedundancy / leafGroups); if (readyCopies == null) { - if (leafGroups > 1) - readyCopies = 1; - else - readyCopies = finalRedundancy > 1 ? 2 : 1; + if (isStreaming == Boolean.TRUE) { + readyCopies = finalRedundancy; + } + else { // If isStreaming is null (mixed mode cluster) there are no good options ... + if (leafGroups > 1) + readyCopies = 1; + else + readyCopies = finalRedundancy > 1 ? 2 : 1; + } } return new Redundancy(initialRedundancy, finalRedundancy, readyCopies, leafGroups, totalNodes); } else { if (globalMinRedundancy != null && ( finalRedundancy == null || finalRedundancy < globalMinRedundancy)) initialRedundancy = finalRedundancy = globalMinRedundancy; - if (readyCopies == null) - readyCopies = finalRedundancy > 1 ? Math.max(subGroups, 2) : 1; + if (readyCopies == null) { + if (isStreaming == Boolean.TRUE) + readyCopies = finalRedundancy; + else // If isStreaming is null (mixed mode cluster) there are no good options ... + readyCopies = finalRedundancy > 1 ? Math.max(subGroups, 2) : 1; + } subGroups = Math.max(1, subGroups); IndexedHierarchicDistributionValidator.validateThatLeafGroupsCountIsAFactorOfRedundancy(finalRedundancy, subGroups); IndexedHierarchicDistributionValidator.validateThatReadyCopiesIsCompatibleWithRedundancy(finalRedundancy, readyCopies, subGroups); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java index 2d67a344a17..a1e809098f2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/storagecluster/StorageCluster.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.content.storagecluster; +import ai.vespa.metrics.StorageMetrics; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.vespa.config.content.core.StorIntegritycheckerConfig; import com.yahoo.vespa.config.content.core.StorBucketmoverConfig; @@ -75,24 +76,24 @@ public class StorageCluster extends TreeConfigProducer<StorageNode> @Override public void getConfig(MetricsmanagerConfig.Builder builder) { ContentCluster.getMetricBuilder("fleetcontroller", builder). - addedmetrics("vds.datastored.alldisks.docs"). - addedmetrics("vds.datastored.alldisks.bytes"). - addedmetrics("vds.datastored.alldisks.buckets"). - addedmetrics("vds.datastored.bucket_space.buckets_total"); + addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_DOCS.baseName()). + addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_BYTES.baseName()). + addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_BUCKETS.baseName()). + addedmetrics(StorageMetrics.VDS_DATASTORED_BUCKET_SPACE_BUCKETS_TOTAL.baseName()); ContentCluster.getMetricBuilder("log", builder). addedmetrics("vds.filestor.allthreads.put"). addedmetrics("vds.filestor.allthreads.get"). addedmetrics("vds.filestor.allthreads.remove"). addedmetrics("vds.filestor.allthreads.update"). - addedmetrics("vds.datastored.alldisks.docs"). - addedmetrics("vds.datastored.alldisks.bytes"). - addedmetrics("vds.filestor.queuesize"). - addedmetrics("vds.filestor.averagequeuewait"). - addedmetrics("vds.visitor.cv_queuewaittime"). - addedmetrics("vds.visitor.allthreads.averagequeuewait"). - addedmetrics("vds.visitor.allthreads.averagevisitorlifetime"). - addedmetrics("vds.visitor.allthreads.created"); + addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_DOCS.baseName()). + addedmetrics(StorageMetrics.VDS_DATASTORED_ALLDISKS_BYTES.baseName()). + addedmetrics(StorageMetrics.VDS_FILESTOR_QUEUESIZE.baseName()). + addedmetrics(StorageMetrics.VDS_FILESTOR_AVERAGEQUEUEWAIT.baseName()). + addedmetrics(StorageMetrics.VDS_VISITOR_CV_QUEUEWAITTIME.baseName()). + addedmetrics(StorageMetrics.VDS_VISITOR_ALLTHREADS_AVERAGEQUEUEWAIT.baseName()). + addedmetrics(StorageMetrics.VDS_VISITOR_ALLTHREADS_AVERAGEVISITORLIFETIME.baseName()). + addedmetrics(StorageMetrics.VDS_VISITOR_ALLTHREADS_CREATED.baseName()); } public String getClusterName() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 1984ceadac6..8edd446b209 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -274,7 +274,8 @@ public class OnnxModelInfo { static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { g.writeStartObject(); g.writeStringField("name", valueInfo.getName()); - g.writeStringField("type", onnxValueTypeToString(valueInfo.getType().getTensorType().getElemType())); + var elemType = Onnx.TensorProto.DataType.forNumber(valueInfo.getType().getTensorType().getElemType()); + g.writeStringField("type", onnxValueTypeToString(elemType)); g.writeArrayFieldStart("dim"); for (Onnx.TensorShapeProto.Dimension dim : valueInfo.getType().getTensorType().getShape().getDimList()) { g.writeStartObject(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java b/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java index 93e3a6e7a19..83eccc8697c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/Tuning.java @@ -371,12 +371,16 @@ public class Tuning extends AnyConfigProducer implements ProtonConfig.Producer { public static class Feeding implements ProtonConfig.Producer { public Double concurrency = null; + public Double niceness = null; @Override public void getConfig(ProtonConfig.Builder builder) { if (concurrency != null) { builder.feeding.concurrency(concurrency); } + if (niceness != null) { + builder.feeding.niceness(niceness); + } } } diff --git a/config-model/src/main/protobuf/onnx.proto b/config-model/src/main/protobuf/onnx.proto index dc6542867e0..1d265ae9f28 100644 --- a/config-model/src/main/protobuf/onnx.proto +++ b/config-model/src/main/protobuf/onnx.proto @@ -3,8 +3,8 @@ // -// Copyright (c) Facebook Inc. and Microsoft Corporation. -// Licensed under the MIT license. +// SPDX-License-Identifier: Apache-2.0 + syntax = "proto2"; @@ -20,23 +20,16 @@ package onnx; // // This document describes the syntax of models and their computation graphs, // as well as the standard data types. Together, they are referred to as the ONNX -// Intermediate Representation, or 'IR' for short. +// Intermediate Representation, or 'IR' for short. // // The normative semantic specification of the ONNX IR is found in docs/IR.md. // Definitions of the built-in neural network operators may be found in docs/Operators.md. // Notes // -// Release -// -// We are still in the very early stage of defining ONNX. The current -// version of ONNX is a starting point. While we are actively working -// towards a complete spec, we would like to get the community involved -// by sharing our working version of ONNX. -// // Protobuf compatibility -// -// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf // that is compatible with both protobuf v2 and v3. This means that we do not use any // protobuf features that are only available in one of the two versions. // @@ -60,22 +53,60 @@ enum Version { _START_VERSION = 0; // The version field is always serialized and we will use it to store the // version that the graph is generated from. This helps us set up version - // control. We should use version as - // xx(major) - xx(minor) - xxxx(bugfix) - // and we are starting with 0x00000001 (0.0.1), which was the - // version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x00000001; + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; - // IR_VERSION 0.0.2 published on Oct 30, 2017 + // IR_VERSION 2 published on Oct 30, 2017 // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x00000002; + IR_VERSION_2017_10_30 = 0x0000000000000002; - // IR VERSION 0.0.3 published on Nov 3, 2017 + // IR VERSION 3 published on Nov 3, 2017 // - For operator versioning: // - Added new message OperatorSetIdProto // - Added opset_import in ModelProto // - For vendor extensions, added domain in NodeProto - IR_VERSION = 0x00000003; + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on May 8, 2020 + // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IR_VERSION_2020_5_8 = 0x0000000000000007; + + // IR VERSION 8 published on July 30, 2021 + // Introduce TypeProto.SparseTensor + // Introduce TypeProto.Optional + // Added a list of FunctionProtos local to the model + // Deprecated since_version and operator status from FunctionProto + IR_VERSION_2021_7_30 = 0x0000000000000008; + + // IR VERSION 9 published on May 5, 2023 + // Added AttributeProto to FunctionProto so that default attribute values can be set. + // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. + IR_VERSION = 0x0000000000000009; } // Attributes @@ -95,17 +126,21 @@ message AttributeProto { STRING = 3; TENSOR = 4; GRAPH = 5; + SPARSE_TENSOR = 11; + TYPE_PROTO = 13; FLOATS = 6; INTS = 7; STRINGS = 8; TENSORS = 9; GRAPHS = 10; + SPARSE_TENSORS = 12; + TYPE_PROTOS = 14; } // The name field MUST be present for this version of the IR. optional string name = 1; // namespace Attribute - + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. // In this case, this AttributeProto does not contain data, and it's a reference of attribute // in parent scope. @@ -117,10 +152,10 @@ message AttributeProto { // The type field MUST be present for this version of the IR. // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine + // implementations needed to use has_field heuristics to determine // which value field was in use. For IR_VERSION 0.0.2 or later, this // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. + // change was made to accommodate proto3 implementations. optional AttributeType type = 20; // discriminator that indicates which field below is in use // Exactly ONE of the following fields must be present for this version of the IR @@ -129,14 +164,18 @@ message AttributeProto { optional bytes s = 4; // UTF-8 string optional TensorProto t = 5; // tensor value optional GraphProto g = 6; // graph + optional SparseTensorProto sparse_tensor = 22; // sparse tensor value // Do not use field below, it's deprecated. // optional ValueProto v = 12; // value - subsumes everything but graph + optional TypeProto tp = 14; // type proto repeated float floats = 7; // list of floats repeated int64 ints = 8; // list of ints repeated bytes strings = 9; // list of UTF-8 strings repeated TensorProto tensors = 10; // list of tensors repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors + repeated TypeProto type_protos = 15;// list of type protos } // Defines information on value, including the name, the type, and @@ -144,7 +183,8 @@ message AttributeProto { message ValueInfoProto { // This field MUST be present in this version of the IR. optional string name = 1; // namespace Value - // This field MUST be present in this version of the IR. + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. optional TypeProto type = 2; // A human-readable documentation for this value. Markdown is allowed. optional string doc_string = 3; @@ -155,7 +195,7 @@ message ValueInfoProto { // Computation graphs are made up of a DAG of nodes, which represent what is // commonly called a "layer" or "pipeline stage" in machine learning frameworks. // -// For example, it can be a node of type "Conv" that takes in an image, a filter +// For example, it can be a node of type "Conv" that takes in an image, a filter // tensor and a bias tensor, and produces the convolved output. message NodeProto { repeated string input = 1; // namespace Value @@ -177,12 +217,130 @@ message NodeProto { optional string doc_string = 6; } +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been performed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update steps (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each step. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Thus, no initializer would be changed by default. + optional GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this field contains loss node, gradient node, + // optimizer node, increment of iteration count. + // + // An execution of the training algorithm step is performed by executing the + // graph obtained by combining the inference graph (namely "ModelProto.graph") + // and the "algorithm" graph. That is, the actual + // input/initializer/output/node/value_info/sparse_initializer list of + // the training graph is the concatenation of + // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + // in that order. This combined graph must satisfy the normal ONNX conditions. + // Now, let's provide a visualization of graph combination for clarity. + // Let the inference graph (i.e., "ModelProto.graph") be + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + // and the "algorithm" graph be + // tensor_d -> Add -> tensor_e + // The combination process results + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + // + // Notice that an input of a node in the "algorithm" graph may reference the + // output of a node in the inference graph (but not the other way round). Also, inference + // node cannot reference inputs of "algorithm". With these restrictions, inference graph + // can always be run independently without training information. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Evaluating the default training step never + // update any initializers. + optional GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two + // variables may not have the same name. This ensures that one + // variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + // 4. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + // Models // // ModelProto is a top-level file/container format for bundling a ML model and // associating its computation graph with metadata. // -// The semantics of the model are described by the associated GraphProto. +// The semantics of the model are described by the associated GraphProto's. message ModelProto { // The version of the IR this model targets. See Version enum above. // This field MUST be present. @@ -227,18 +385,58 @@ message ModelProto { // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; + + // A list of function protos local to the model. + // + // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". + // In case of any conflicts the behavior (whether the model local functions are given higher priority, + // or standard operator sets are given higher priotity or this is treated as error) is defined by + // the runtimes. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto and other model local FunctionProtos. + // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto + // or by 2 FunctionProtos then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same for every node in the function body. + // + // One FunctionProto can reference other FunctionProto in the model, however, recursive reference + // is not allowed. + repeated FunctionProto functions = 25; }; // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto { optional string key = 1; - optional string value= 2; + optional string value = 2; }; +message TensorAnnotation { + optional string tensor_name = 1; + // <key, value> pairs to annotate tensor specified by <tensor_name> above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + // Graphs // -// A graph defines the computational logic of a model and is comprised of a parameterized +// A graph defines the computational logic of a model and is comprised of a parameterized // list of nodes that form a directed acyclic graph based on their inputs and outputs. // This is the equivalent of the "network" or "graph" in many deep learning // frameworks. @@ -250,10 +448,14 @@ message GraphProto { optional string name = 2; // namespace Graph // A list of named tensor values, used to specify constant inputs of the graph. - // Each TensorProto entry must have a distinct name (within the list) that - // also appears in the input list. + // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + // The name MUST be unique across both initializer and sparse_initializer, + // but the name MAY also appear in the input list. repeated TensorProto initializer = 5; + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + // A human-readable documentation for this graph. Markdown is allowed. optional string doc_string = 10; @@ -265,13 +467,14 @@ message GraphProto { // must be distinct. It is optional for a value to appear in value_info list. repeated ValueInfoProto value_info = 13; - // DO NOT USE the following fields, they were deprecated from earlier versions. - // repeated string input = 3; - // repeated string output = 4; - // optional int64 ir_version = 6; - // optional int64 producer_version = 7; - // optional string producer_tag = 8; - // optional string domain = 9; + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + reserved 3, 4, 6 to 9; + reserved "ir_version", "producer_version", "producer_tag", "domain"; } // Tensors @@ -291,13 +494,32 @@ message TensorProto { STRING = 8; // string BOOL = 9; // bool - // Advanced types + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. FLOAT16 = 10; + DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero + // Future extensions go here. } @@ -305,7 +527,8 @@ message TensorProto { repeated int64 dims = 1; // The data type of the tensor. - optional DataType data_type = 2; + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; // For very large tensors, we may want to store them in chunks, in which // case the following fields will specify the segment that is stored in @@ -324,17 +547,17 @@ message TensorProto { // For float and complex64 values // Complex64 tensors are encoded as a single array of floats, // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the + // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. repeated float float_data = 4 [packed = true]; - // For int32, uint8, int8, uint16, int16, bool, and float16 values - // float16 values must be bit-wise converted to an uint16_t prior + // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values + // float16 and float8 values must be bit-wise converted to an uint16_t prior // to writing to the buffer. // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ repeated int32 int32_data = 5 [packed = true]; // For strings. @@ -371,10 +594,32 @@ message TensorProto { // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED optional bytes raw_data = 9; + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + // For double - // Complex64 tensors are encoded as a single array of doubles, + // Complex128 tensors are encoded as a single array of doubles, // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the + // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -386,6 +631,30 @@ message TensorProto { repeated uint64 uint64_data = 11 [packed = true]; } +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + // values must have a non-empty name present which serves as a name for SparseTensorProto + // when used in sparse_initializer list. + optional TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + optional TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + // Defines a tensor shape. A dimension can be either an integer value // or a symbolic variable. A symbolic variable represents an unknown // dimension. @@ -398,36 +667,13 @@ message TensorShapeProto { // Standard denotation can optionally be used to denote tensor // dimensions with standard semantic descriptions to ensure // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. optional string denotation = 3; }; repeated Dimension dim = 1; } -// A set of pre-defined constants to be used as values for -// the standard denotation field in TensorShapeProto.Dimension -// for semantic description of the tensor dimension. -message DenotationConstProto { - // Describe a batch number dimension. - optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; - // Describe a channel dimension. - optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; - // Describe a time dimension. - optional string DATA_TIME = 3 [default = "DATA_TIME"]; - // Describe a feature dimension. This is typically a feature - // dimension in RNN and/or spatial dimension in CNN. - optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; - // Describe a filter in-channel dimension. This is the dimension - // that is identical (in size) to the channel dimension of the input - // image feature maps. - optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; - // Describe a filter out channel dimension. This is the dimension - // that is identical (int size) to the channel dimension of the output - // image feature maps. - optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; - // Describe a filter spatial dimension. - optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; -} - // Types // // The standard ONNX data types. @@ -435,8 +681,43 @@ message TypeProto { message Tensor { // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map<K,V> + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional int32 key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + + // wrapper for Tensor, Sequence, or Map + message Optional { + // The type and optional shape of the element wrapped. + // This field MUST be present for this version of the IR. + // Possible values correspond to OptionalProto.DataType enum + optional TypeProto elem_type = 1; + }; + + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value // This field MUST be present for this version of the IR. - optional TensorProto.DataType elem_type = 1; + optional int32 elem_type = 1; optional TensorShapeProto shape = 2; } @@ -445,7 +726,31 @@ message TypeProto { // The type of a tensor. Tensor tensor_type = 1; + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + // The type of an optional. + Optional optional_type = 9; + + + // Type of the sparse tensor + SparseTensor sparse_tensor_type = 8; + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; } // Operator Sets @@ -461,4 +766,70 @@ message OperatorSetIdProto { // The version of the operator set being identified. // This field MUST be present in this version of the IR. optional int64 version = 2; -}
\ No newline at end of file +} + +// Operator/function status. +enum OperatorStatus { + EXPERIMENTAL = 0; + STABLE = 1; +} + +message FunctionProto { + // The name of the function, similar usage of op_type in OperatorProto. + // Combined with FunctionProto.domain, this forms the unique identity of + // the FunctionProto. + optional string name = 1; + + // Deprecated since IR Version 8 + // optional int64 since_version = 2; + reserved 2; + reserved "since_version"; + + // Deprecated since IR Version 8 + // optional OperatorStatus status = 3; + reserved 3; + reserved "status"; + + // The inputs and outputs of the function. + repeated string input = 4; + repeated string output = 5; + + // The attribute parameters of the function. + // It is for function parameters without default values. + repeated string attribute = 6; + + // The attribute protos of the function. + // It is for function attributes with default values. + // A function attribute shall be represented either as + // a string attribute or an AttributeProto, not both. + repeated AttributeProto attribute_proto = 11; + + // The nodes in the function. + repeated NodeProto node = 7; + // A human-readable documentation for this function. Markdown is allowed. + optional string doc_string = 8; + + // The OperatorSets this function body (graph) relies on. + // + // All nodes in the function body (graph) will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. This means at most one version can be relied + // for one domain. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto + // and ModelProto then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same. + + repeated OperatorSetIdProto opset_import = 9; + + // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of + // the FunctionProto. + optional string domain = 10; +} + + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; + diff --git a/config-model/src/main/resources/schema/content.rnc b/config-model/src/main/resources/schema/content.rnc index 5833b575a74..a73236454c6 100644 --- a/config-model/src/main/resources/schema/content.rnc +++ b/config-model/src/main/resources/schema/content.rnc @@ -370,7 +370,8 @@ Tuning = element tuning { element threads { xsd:nonNegativeInteger }? }? & element feeding { - element concurrency { xsd:double { minInclusive = "0.0" maxInclusive = "1.0" } }? + element concurrency { xsd:double { minInclusive = "0.0" maxInclusive = "1.0" } }? & + element niceness { xsd:double { minInclusive = "0.0" maxInclusive = "1.0" } }? }? & element removed-db { element prune { diff --git a/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java b/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java index 84804bc48fa..19fe9e0038d 100644 --- a/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java @@ -1183,9 +1183,9 @@ public class ModelProvisioningTest { ContentCluster cluster = model.getContentClusters().get("bar"); List<StorageGroup> subGroups = cluster.getRootGroup().getSubgroups(); - assertEquals(2*3, cluster.redundancy().effectiveInitialRedundancy()); // Reduced from 3*3 - assertEquals(2*3, cluster.redundancy().effectiveFinalRedundancy()); // Reduced from 3*4 - assertEquals(2*3, cluster.redundancy().effectiveReadyCopies()); // Reduced from 3*3 + assertEquals(2*3, cluster.getRedundancy().effectiveInitialRedundancy()); // Reduced from 3*3 + assertEquals(2*3, cluster.getRedundancy().effectiveFinalRedundancy()); // Reduced from 3*4 + assertEquals(2*3, cluster.getRedundancy().effectiveReadyCopies()); // Reduced from 3*3 assertEquals("2|2|*", cluster.getRootGroup().getPartitions().get()); // Reduced from 4|4|* assertEquals(0, cluster.getRootGroup().getNodes().size()); assertEquals(3, subGroups.size()); @@ -1257,9 +1257,9 @@ public class ModelProvisioningTest { assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size()); ContentCluster cluster = model.getContentClusters().get("bar"); - assertEquals(2, cluster.redundancy().effectiveInitialRedundancy()); - assertEquals(2, cluster.redundancy().effectiveFinalRedundancy()); - assertEquals(2, cluster.redundancy().effectiveReadyCopies()); + assertEquals(2, cluster.getRedundancy().effectiveInitialRedundancy()); + assertEquals(2, cluster.getRedundancy().effectiveFinalRedundancy()); + assertEquals(2, cluster.getRedundancy().effectiveReadyCopies()); assertEquals("1|*", cluster.getRootGroup().getPartitions().get()); assertEquals(0, cluster.getRootGroup().getNodes().size()); assertEquals(2, cluster.getRootGroup().getSubgroups().size()); @@ -1287,9 +1287,9 @@ public class ModelProvisioningTest { ContentCluster cluster = model.getContentClusters().get("bar"); assertEquals(2, cluster.getStorageCluster().getChildren().size()); - assertEquals(1, cluster.redundancy().effectiveInitialRedundancy()); - assertEquals(1, cluster.redundancy().effectiveFinalRedundancy()); - assertEquals(1, cluster.redundancy().effectiveReadyCopies()); + assertEquals(1, cluster.getRedundancy().effectiveInitialRedundancy()); + assertEquals(1, cluster.getRedundancy().effectiveFinalRedundancy()); + assertEquals(1, cluster.getRedundancy().effectiveReadyCopies()); assertEquals(2, cluster.getRootGroup().getNodes().size()); assertEquals(0, cluster.getRootGroup().getSubgroups().size()); } @@ -1324,9 +1324,9 @@ public class ModelProvisioningTest { assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size()); ContentCluster cluster = model.getContentClusters().get("bar"); - assertEquals(4, cluster.redundancy().effectiveInitialRedundancy()); - assertEquals(4, cluster.redundancy().effectiveFinalRedundancy()); - assertEquals(4, cluster.redundancy().effectiveReadyCopies()); + assertEquals(4, cluster.getRedundancy().effectiveInitialRedundancy()); + assertEquals(4, cluster.getRedundancy().effectiveFinalRedundancy()); + assertEquals(4, cluster.getRedundancy().effectiveReadyCopies()); assertEquals(4, cluster.getSearch().getIndexed().getDispatchSpec().getGroups().size()); assertEquals(4, cluster.getSearch().getIndexed().getSearchableCopies()); assertFalse(cluster.getRootGroup().getPartitions().isPresent()); @@ -1368,9 +1368,9 @@ public class ModelProvisioningTest { assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size()); ContentCluster cluster = model.getContentClusters().get("bar"); - assertEquals(1, cluster.redundancy().effectiveInitialRedundancy()); // Reduced from 3*3 - assertEquals(1, cluster.redundancy().effectiveFinalRedundancy()); // Reduced from 3*4 - assertEquals(1, cluster.redundancy().effectiveReadyCopies()); // Reduced from 3*3 + assertEquals(1, cluster.getRedundancy().effectiveInitialRedundancy()); // Reduced from 3*3 + assertEquals(1, cluster.getRedundancy().effectiveFinalRedundancy()); // Reduced from 3*4 + assertEquals(1, cluster.getRedundancy().effectiveReadyCopies()); // Reduced from 3*3 assertFalse(cluster.getRootGroup().getPartitions().isPresent()); // 1 group - > flattened -> no distribution assertEquals(1, cluster.getRootGroup().getNodes().size()); assertEquals(0, cluster.getRootGroup().getSubgroups().size()); @@ -1473,9 +1473,9 @@ public class ModelProvisioningTest { assertEquals(numberOfHosts, model.getRoot().hostSystem().getHosts().size()); ContentCluster cluster = model.getContentClusters().get("bar"); - assertEquals(1, cluster.redundancy().effectiveInitialRedundancy()); - assertEquals(1, cluster.redundancy().effectiveFinalRedundancy()); - assertEquals(1, cluster.redundancy().effectiveReadyCopies()); + assertEquals(1, cluster.getRedundancy().effectiveInitialRedundancy()); + assertEquals(1, cluster.getRedundancy().effectiveFinalRedundancy()); + assertEquals(1, cluster.getRedundancy().effectiveReadyCopies()); assertEquals(1, cluster.getSearch().getIndexed().getDispatchSpec().getGroups().size()); assertFalse(cluster.getRootGroup().getPartitions().isPresent()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java index e3e9fc1a232..db15d7e0a78 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/builder/xml/dom/DomSchemaTuningBuilderTest.java @@ -285,9 +285,13 @@ public class DomSchemaTuningBuilderTest extends DomBuilderTest { void requireThatWeCanParseFeedingTag() { Tuning t = createTuning(parseXml("<feeding>", "<concurrency>0.7</concurrency>", + "<niceness>0.3</niceness>", "</feeding>")); assertEquals(0.7, t.searchNode.feeding.concurrency, DELTA); - assertEquals(getProtonCfg(t).feeding().concurrency(), 0.7, DELTA); + assertEquals(0.3, t.searchNode.feeding.niceness, DELTA); + var cfg = getProtonCfg(t); + assertEquals(cfg.feeding().concurrency(), 0.7, DELTA); + assertEquals(cfg.feeding().niceness(), 0.3, DELTA); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java index 4ce7119f5f7..73bbd6ee464 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java @@ -38,6 +38,7 @@ import com.yahoo.vespa.model.routing.DocumentProtocol; import com.yahoo.vespa.model.routing.Routing; import com.yahoo.vespa.model.test.utils.ApplicationPackageUtils; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithMockPkg; +import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; import java.util.Arrays; @@ -470,7 +471,8 @@ public class ContentClusterTest extends ContentBaseTest { new VespaModelCreatorWithMockPkg(getHosts(), xml, sds).create(); fail("Deploying without redundancy should fail"); } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("Either <redundancy> or <min-redundancy> must be set"), e.getMessage()); + assertEquals("In content cluster 'bar': Either <redundancy> or <min-redundancy> must be set", + Exceptions.toMessageString(e)); } } @@ -478,12 +480,13 @@ public class ContentClusterTest extends ContentBaseTest { void testRedundancyFinalLessThanInitial() { try { parse( - "<content version=\"1.0\" id=\"storage\">\n" + - " <redundancy reply-after=\"4\">2</redundancy>\n" + - " <group>" + - " <node hostalias='node0' distribution-key='0' />" + - " </group>" + - "</content>" + """ + <content version="1.0" id="storage"> + <redundancy reply-after="4">2</redundancy> + <group> + <node hostalias='node0' distribution-key='0' /> + </group> + </content>""" ); fail("no exception thrown"); } catch (Exception e) { /* ignore */ @@ -494,17 +497,18 @@ public class ContentClusterTest extends ContentBaseTest { void testReadyTooHigh() { try { parse( - "<content version=\"1.0\" id=\"storage\">\n" + - " <engine>" + - " <proton>" + - " <searchable-copies>3</searchable-copies>" + - " </proton>" + - " </engine>" + - " <redundancy>2</redundancy>\n" + - " <group>" + - " <node hostalias='node0' distribution-key='0' />" + - " </group>" + - "</content>" + """ + <content version="1.0" id="storage"> + <engine> + <proton> + <searchable-copies>3</searchable-copies> + </proton> + </engine> + <redundancy>2</redundancy> + <group> + <node hostalias='node0' distribution-key='0' /> + </group> + </content>""" ); fail("no exception thrown"); } catch (Exception e) { /* ignore */ @@ -972,15 +976,17 @@ public class ContentClusterTest extends ContentBaseTest { @Test void reserved_document_name_throws_exception() { - String xml = "<content version=\"1.0\" id=\"storage\">" + - " <redundancy>1</redundancy>" + - " <documents>" + - " <document type=\"true\" mode=\"index\"/>" + - " </documents>" + - " <group>" + - " <node distribution-key=\"0\" hostalias=\"mockhost\"/>" + - " </group>" + - "</content>"; + String xml = """ + <content version="1.0" id="storage"> + <redundancy>1</redundancy> + <documents> + <document type="true" mode="index"/> + </documents> + <group> + <node distribution-key="0" hostalias="mockhost"/> + </group> + </content> + """; List<String> sds = ApplicationPackageUtils.generateSchemas("true"); try { @@ -991,6 +997,65 @@ public class ContentClusterTest extends ContentBaseTest { } } + @Test + void default_searchable_copies_indexing() { + String services = """ + <content version="1.0" id="storage"> + <redundancy>3</redundancy> + <documents> + <document type="music" mode="index"/> + </documents> + <group> + <node distribution-key="0" hostalias="mockhost"/> + <node distribution-key="1" hostalias="mockhost"/> + <node distribution-key="2" hostalias="mockhost"/> + </group> + </content> + """; + var model = new VespaModelCreatorWithMockPkg(null, services, ApplicationPackageUtils.generateSchemas("music")).create(); + assertEquals(2, model.getContentClusters().get("storage").getRedundancy().readyCopies()); + } + + @Test + void default_searchable_copies_streaming() { + String services = """ + <content version="1.0" id="storage"> + <redundancy>3</redundancy> + <documents> + <document type="mail" mode="streaming"/> + </documents> + <group> + <node distribution-key="0" hostalias="mockhost"/> + <node distribution-key="1" hostalias="mockhost"/> + <node distribution-key="2" hostalias="mockhost"/> + </group> + </content> + """; + var model = new VespaModelCreatorWithMockPkg(null, services, ApplicationPackageUtils.generateSchemas("mail")).create(); + assertEquals(3, model.getContentClusters().get("storage").getRedundancy().readyCopies()); + } + + /** Here there is no good choice. */ + @Test + void default_searchable_copies_mixed() { + String services = """ + <content version="1.0" id="storage"> + <redundancy>3</redundancy> + <documents> + <document type="music" mode="index"/> + <document type="mail" mode="streaming"/> + </documents> + <group> + <node distribution-key="0" hostalias="mockhost"/> + <node distribution-key="1" hostalias="mockhost"/> + <node distribution-key="2" hostalias="mockhost"/> + </group> + </content> + """; + var model = new VespaModelCreatorWithMockPkg(null, services, ApplicationPackageUtils.generateSchemas("music", "mail")).create(); + assertEquals(2, model.getContentClusters().get("storage").getRedundancy().readyCopies()); + } + private void assertClusterHasBucketSpaceMappings(AllClustersBucketSpacesConfig config, String clusterId, List<String> defaultSpaceTypes, List<String> globalSpaceTypes) { AllClustersBucketSpacesConfig.Cluster cluster = config.cluster(clusterId); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 36f09f989a7..b662179c418 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -24,7 +24,6 @@ import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.DataplaneToken; import com.yahoo.config.provision.DockerImage; import com.yahoo.config.provision.HostName; -import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.vespa.config.server.tenant.SecretStoreExternalIdRetriever; @@ -34,6 +33,7 @@ import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.flags.PermanentFlags; import com.yahoo.vespa.flags.StringFlag; import com.yahoo.vespa.flags.UnboundFlag; + import java.io.File; import java.net.URI; import java.security.cert.X509Certificate; @@ -319,13 +319,7 @@ public class ModelContextImpl implements ModelContext { return flag.bindTo(source) .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm()) .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString()) - .boxedValue(); - } - - private static <V> V flagValue(FlagSource source, TenantName tenant, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) { - return flag.bindTo(source) - .with(FetchVector.Dimension.TENANT_ID, tenant.value()) - .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString()) + .with(FetchVector.Dimension.TENANT_ID, appId.tenant().value()) .boxedValue(); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java index 0532a81617f..b2762b2a3d4 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java @@ -98,7 +98,7 @@ public class ApplicationApiHandler extends SessionHandler { "Unable to parse multipart in deploy from tenant '" + tenantName.value() + "': " + Exceptions.toMessageString(e)); var message = "Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage(); - log.log(INFO, message + ", parts: " + parts, e); + log.log(FINE, message + ", parts: " + parts, e); throw new BadRequestException("Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage()); } } else { diff --git a/container-core/abi-spec.json b/container-core/abi-spec.json index 236586a4132..572d18b02f3 100644 --- a/container-core/abi-spec.json +++ b/container-core/abi-spec.json @@ -1044,6 +1044,7 @@ "public com.yahoo.jdisc.http.ConnectorConfig$Builder requestHeaderSize(int)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder responseHeaderSize(int)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder acceptQueueSize(int)", + "public com.yahoo.jdisc.http.ConnectorConfig$Builder maxContentSize(long)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder reuseAddress(boolean)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder idleTimeout(double)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder tcpKeepAliveEnabled(boolean)", @@ -1413,6 +1414,7 @@ "public int requestHeaderSize()", "public int responseHeaderSize()", "public int acceptQueueSize()", + "public long maxContentSize()", "public boolean reuseAddress()", "public double idleTimeout()", "public boolean tcpKeepAliveEnabled()", diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java index e1ec22bd622..af98e380f2a 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.jdisc.state; +import ai.vespa.metrics.ContainerMetrics; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -289,7 +290,7 @@ public class StateHandler extends AbstractRequestHandler implements CapabilityRe Tuple latencySeconds = new Tuple(NULL_DIMENSIONS, "latencySeconds", null); for (Map.Entry<MetricDimensions, MetricSet> entry : snapshot) { MetricSet metricSet = entry.getValue(); - MetricValue val = metricSet.get("serverTotalSuccessfulResponseLatency"); + MetricValue val = metricSet.get(ContainerMetrics.SERVER_TOTAL_SUCCESFUL_RESPONSE_LATENCY.baseName()); if (val instanceof GaugeMetric gauge) { latencySeconds.add(GaugeMetric.newInstance(gauge.getLast() / 1000, gauge.getMax() / 1000, @@ -297,7 +298,7 @@ public class StateHandler extends AbstractRequestHandler implements CapabilityRe gauge.getSum() / 1000, gauge.getCount())); } - requestsPerSecond.add(metricSet.get("serverNumSuccessfulResponses")); + requestsPerSecond.add(metricSet.get(ContainerMetrics.SERVER_NUM_SUCCESSFUL_RESPONSES.baseName())); } List<Tuple> lst = new ArrayList<>(); if (requestsPerSecond.val != null) { diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java index 2f2c48e0b48..75ef655c60c 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/ServletRequestReader.java @@ -1,16 +1,19 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.http.server.jetty; +import com.yahoo.jdisc.Response; import com.yahoo.jdisc.handler.CompletionHandler; import com.yahoo.jdisc.handler.ContentChannel; import jakarta.servlet.ReadListener; import jakarta.servlet.ServletInputStream; import jakarta.servlet.http.HttpServletRequest; +import org.eclipse.jetty.server.Request; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Objects; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; import java.util.logging.Level; import java.util.logging.Logger; @@ -33,6 +36,7 @@ import java.util.logging.Logger; */ class ServletRequestReader { + private enum State { NOT_STARTED, READING, ALL_DATA_READ, REQUEST_CONTENT_CLOSED } @@ -96,12 +100,15 @@ class ServletRequestReader { private final CompletableFuture<Void> finishedFuture = new CompletableFuture<>(); ServletRequestReader( - HttpServletRequest req, + Request req, ContentChannel requestContentChannel, Janitor janitor, RequestMetricReporter metricReporter) { this.req = Objects.requireNonNull(req); - this.requestContentChannel = Objects.requireNonNull(requestContentChannel); + long maxContentSize = RequestUtils.getConnector(req).connectorConfig().maxContentSize(); + this.requestContentChannel = maxContentSize >= 0 + ? new ByteLimitedContentChannel(Objects.requireNonNull(requestContentChannel), maxContentSize) + : Objects.requireNonNull(requestContentChannel); this.janitor = Objects.requireNonNull(janitor); this.metricReporter = Objects.requireNonNull(metricReporter); } @@ -259,4 +266,30 @@ class ServletRequestReader { } } + private static class ByteLimitedContentChannel implements ContentChannel { + private final long maxContentSize; + private final AtomicLong bytesWritten = new AtomicLong(); + private final ContentChannel delegate; + + ByteLimitedContentChannel(ContentChannel delegate, long maxContentSize) { + this.delegate = delegate; + this.maxContentSize = maxContentSize; + } + + @Override + public void write(ByteBuffer buf, CompletionHandler handler) { + long written = bytesWritten.addAndGet(buf.remaining()); + if (written > maxContentSize) { + handler.failed(new RequestException( + Response.Status.REQUEST_TOO_LONG, + "Request content length %d exceeds limit of %d bytes".formatted(written, maxContentSize))); + return; + } + delegate.write(buf, handler); + } + + @Override public void close(CompletionHandler h) { delegate.close(h); } + @Override public void onError(Throwable t) { delegate.onError(t); } + } + } diff --git a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def index bdcc3f9e40a..3c01012fd9e 100644 --- a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def +++ b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def @@ -22,6 +22,9 @@ responseHeaderSize int default=65536 # The accept queue size (also known as accept backlog). acceptQueueSize int default=0 +# Max content size allowed for requests. Set to -1 to disable. +maxContentSize long default=-1 + # Whether the server socket reuses addresses. reuseAddress bool default=true diff --git a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java index 0a697bd8fb3..6f9c854be64 100644 --- a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java +++ b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/HttpServerTest.java @@ -70,6 +70,7 @@ import static com.yahoo.jdisc.Response.Status.GATEWAY_TIMEOUT; import static com.yahoo.jdisc.Response.Status.INTERNAL_SERVER_ERROR; import static com.yahoo.jdisc.Response.Status.NOT_FOUND; import static com.yahoo.jdisc.Response.Status.OK; +import static com.yahoo.jdisc.Response.Status.REQUEST_TOO_LONG; import static com.yahoo.jdisc.Response.Status.REQUEST_URI_TOO_LONG; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; import static com.yahoo.jdisc.Response.Status.UNSUPPORTED_MEDIA_TYPE; @@ -172,11 +173,11 @@ public class HttpServerTest { driver.client() .newGet("/status.html").addHeader("Host", "localhost").addHeader("Host", "vespa.ai").execute() .expectStatusCode(is(BAD_REQUEST)).expectContent(containsString("reason: Duplicate Host Header")); - assertTrue(driver.close()); var aggregator = ResponseMetricAggregator.getBean(driver.server()); var metric = waitForStatistics(aggregator); assertEquals(400, metric.dimensions.statusCode); assertEquals("GET", metric.dimensions.method); + assertTrue(driver.close()); } @Test @@ -795,6 +796,17 @@ public class HttpServerTest { assertTrue(driver.close()); } + @Test + void exceedingMaxContentSizeReturns413() throws IOException { + JettyTestDriver driver = JettyTestDriver.newConfiguredInstance( + new EchoRequestHandler(), + new ServerConfig.Builder(), + new ConnectorConfig.Builder().maxContentSize(4)); + driver.client().newPost("/").setBinaryContent(new byte[4]).execute().expectStatusCode(is(OK)); + driver.client().newPost("/").setBinaryContent(new byte[5]).execute().expectStatusCode(is(REQUEST_TOO_LONG)); + assertTrue(driver.close()); + } + private static JettyTestDriver createSslWithTlsClientAuthenticationEnforcer(Path certificateFile, Path privateKeyFile) { ConnectorConfig.Builder connectorConfig = new ConnectorConfig.Builder() .tlsClientAuthEnforcer( diff --git a/container-dependencies-enforcer/pom.xml b/container-dependencies-enforcer/pom.xml index d63a1867f9a..a016c44f829 100644 --- a/container-dependencies-enforcer/pom.xml +++ b/container-dependencies-enforcer/pom.xml @@ -60,8 +60,6 @@ <rules> <enforceDependencies implementation="com.yahoo.vespa.maven.plugin.enforcer.EnforceDependencies"> <allowed> - <include>*:*:*:test</include> - <include>com.yahoo.vespa:*:*:*</include> <include>aopalliance:aopalliance:${aopalliance.version}:provided</include> <include>com.fasterxml.jackson.core:jackson-annotations:${jackson2.version}:provided</include> <include>com.fasterxml.jackson.core:jackson-core:${jackson2.version}:provided</include> @@ -89,6 +87,121 @@ <include>org.slf4j:slf4j-api:${slf4j.version}:provided</include> <include>org.slf4j:slf4j-jdk14:${slf4j.version}:provided</include> <include>xml-apis:xml-apis:${xml-apis.version}:provided</include> + + <!-- Vespa provided dependencies --> + <include>com.yahoo.vespa:annotations:*:provided</include> + <include>com.yahoo.vespa:component:*:provided</include> + <include>com.yahoo.vespa:config-bundle:*:provided</include> + <include>com.yahoo.vespa:config-lib:*:provided</include> + <include>com.yahoo.vespa:config:*:provided</include> + <include>com.yahoo.vespa:configdefinitions:*:provided</include> + <include>com.yahoo.vespa:configgen:*:provided</include> + <include>com.yahoo.vespa:container-core:*:provided</include> + <include>com.yahoo.vespa:container-dev:*:provided</include> + <include>com.yahoo.vespa:container-disc:*:provided</include> + <include>com.yahoo.vespa:container-documentapi:*:provided</include> + <include>com.yahoo.vespa:container-messagebus:*:provided</include> + <include>com.yahoo.vespa:container-onnxruntime:*:provided</include> + <include>com.yahoo.vespa:container-search-and-docproc:*:provided</include> + <include>com.yahoo.vespa:container-search:*:provided</include> + <include>com.yahoo.vespa:container:*:provided</include> + <include>com.yahoo.vespa:defaults:*:provided</include> + <include>com.yahoo.vespa:docproc:*:provided</include> + <include>com.yahoo.vespa:document:*:provided</include> + <include>com.yahoo.vespa:documentapi:*:provided</include> + <include>com.yahoo.vespa:fileacquirer:*:provided</include> + <include>com.yahoo.vespa:fsa:*:provided</include> + <include>com.yahoo.vespa:hosted-zone-api:*:provided</include> + <include>com.yahoo.vespa:http-utils:*:provided</include> + <include>com.yahoo.vespa:jdisc_core:*:provided</include> + <include>com.yahoo.vespa:jrt:*:provided</include> + <include>com.yahoo.vespa:linguistics:*:provided</include> + <include>com.yahoo.vespa:messagebus:*:provided</include> + <include>com.yahoo.vespa:metrics:*:provided</include> + <include>com.yahoo.vespa:model-evaluation:*:provided</include> + <include>com.yahoo.vespa:opennlp-linguistics:*:provided</include> + <include>com.yahoo.vespa:predicate-search-core:*:provided</include> + <include>com.yahoo.vespa:provided-dependencies:*:provided</include> + <include>com.yahoo.vespa:searchcore:*:provided</include> + <include>com.yahoo.vespa:searchlib:*:provided</include> + <include>com.yahoo.vespa:security-utils:*:provided</include> + <include>com.yahoo.vespa:vdslib:*:provided</include> + <include>com.yahoo.vespa:vespa-3party-bundles:pom:*:provided</include> + <include>com.yahoo.vespa:vespaclient-container-plugin:*:provided</include> + <include>com.yahoo.vespa:vespajlib:*:provided</include> + <include>com.yahoo.vespa:vespalog:*:provided</include> + + <!-- Vespa test dependencies --> + <include>com.yahoo.vespa:airlift-zstd:*:test</include> + <include>com.yahoo.vespa:application:*:test</include> + <include>com.yahoo.vespa:config-application-package:*:test</include> + <include>com.yahoo.vespa:config-model-api:*:test</include> + <include>com.yahoo.vespa:config-model:*:test</include> + <include>com.yahoo.vespa:config-provisioning:*:test</include> + <include>com.yahoo.vespa:container-apache-http-client-bundle:*:test</include> + <include>com.yahoo.vespa:container-test:*:test</include> + <include>com.yahoo.vespa:indexinglanguage:*:test</include> + <include>com.yahoo.vespa:logd:*:test</include> + <include>com.yahoo.vespa:metrics-proxy:*:test</include> + <include>com.yahoo.vespa:model-integration:*:test</include> + <include>com.yahoo.vespa:searchsummary:*:test</include> + <include>com.yahoo.vespa:standalone-container:*:test</include> + <include>com.yahoo.vespa:storage:*:test</include> + <include>com.yahoo.vespa:vespaclient-core:*:test</include> + <include>com.yahoo.vespa:vsm:*:test</include> + + <!-- 3rd party test dependencies --> + <include>com.google.code.findbugs:jsr305:3.0.2:test</include> + <include>com.google.protobuf:protobuf-java:${protobuf.version}:test</include> + <include>com.ibm.icu:icu4j:70.1:test</include> + <include>com.microsoft.onnxruntime:onnxruntime:${onnxruntime.version}:test</include> + <include>com.thaiopensource:jing:20091111:test</include> + <include>commons-codec:commons-codec:${commons-codec.version}:test</include> + <include>io.airlift:airline:0.9:test</include> + <include>io.prometheus:simpleclient:0.6.0:test</include> + <include>io.prometheus:simpleclient_common:0.6.0:test</include> + <include>junit:junit:4.13.2:test</include> + <include>net.java.dev.jna:jna:5.11.0:test</include> + <include>net.openhft:zero-allocation-hashing:jar:0.16:test</include> + <include>org.antlr:antlr-runtime:3.5.3:test</include> + <include>org.antlr:antlr4-runtime:4.11.1:test</include> + <include>org.apache.commons:commons-exec:1.3:test</include> + <include>org.apache.commons:commons-math3:3.6.1:test</include> + <include>org.apache.felix:org.apache.felix.framework:${felix.version}:test</include> + <include>org.apache.felix:org.apache.felix.framework:${felix.version}:test</include> + <include>org.apache.felix:org.apache.felix.log:1.0.1:test</include> + <include>org.apache.httpcomponents.client5:httpclient5:${apache.httpclient5.version}:test</include> + <include>org.apache.httpcomponents.core5:httpcore5:${apache.httpcore5.version}:test</include> + <include>org.apache.httpcomponents.core5:httpcore5-h2:${apache.httpcore5.version}:test</include> + <include>org.apache.httpcomponents:httpclient:${apache.httpclient.version}:test</include> + <include>org.apache.httpcomponents:httpcore:${apache.httpcore.version}:test</include> + <include>org.apache.httpcomponents:httpmime:${apache.httpclient.version}:test</include> + <include>org.apache.opennlp:opennlp-tools:1.9.3:test</include> + <include>org.bouncycastle:bcpkix-jdk18on:${bouncycastle.version}:test</include> + <include>org.bouncycastle:bcprov-jdk18on:${bouncycastle.version}:test</include> + <include>org.bouncycastle:bcutil-jdk18on:${bouncycastle.version}:test</include> + <include>org.eclipse.jetty.http2:http2-common:${jetty.version}:test</include> + <include>org.eclipse.jetty.http2:http2-hpack:${jetty.version}:test</include> + <include>org.eclipse.jetty.http2:http2-server:${jetty.version}:test</include> + <include>org.eclipse.jetty.toolchain:jetty-jakarta-servlet-api:5.0.2:test</include> + <include>org.eclipse.jetty:jetty-alpn-client:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-alpn-java-server:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-alpn-server:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-client:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-http:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-io:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-jmx:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-security:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-server:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-servlet:${jetty.version}:test</include> + <include>org.eclipse.jetty:jetty-util:${jetty.version}:test</include> + <include>org.hamcrest:hamcrest-core:1.3:test</include> + <include>org.hdrhistogram:HdrHistogram:2.1.12:test</include> + <include>org.json:json:${org.json.version}:test</include> <!-- TODO: Remove on Vespa 9 --> + <include>org.lz4:lz4-java:${org.lz4.version}:test</include> + <include>org.osgi:org.osgi.compendium:4.1.0:test</include> + <include>org.osgi:org.osgi.core:4.1.0:test</include> + <include>xerces:xercesImpl:2.12.2:test</include> </allowed> </enforceDependencies> </rules> diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java index e6af65c0bc8..47050168b80 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java @@ -36,8 +36,8 @@ public class DataplaneProxyService extends AbstractComponent { private final Path root; enum NginxState {INITIALIZING, RUNNING, RELOAD_REQUIRED, STOPPED}; - private NginxState state; - private NginxState wantedState; + private volatile NginxState state; + private volatile NginxState wantedState; private DataplaneProxyConfig cfg; private Path proxyCredentialsCert; @@ -113,35 +113,46 @@ public class DataplaneProxyService extends AbstractComponent { throw new RuntimeException("Error reconfiguring data plane proxy", e); } } - if (wantedState == NginxState.RUNNING) { + NginxState convergeTo = wantedState; + if (convergeTo == NginxState.RUNNING) { boolean nginxRunning = proxyCommands.isRunning(); if (!nginxRunning) { try { proxyCommands.start(nginxConf); - changeState(wantedState); + changeState(convergeTo); } catch (Exception e) { logger.log(Level.INFO, "Failed to start nginx, will retry"); + logger.log(Level.FINE, "Exception from nginx start", e); } - } else if (nginxRunning && state == NginxState.RELOAD_REQUIRED) { - try { - proxyCommands.reload(); - changeState(wantedState); - } catch (Exception e) { - logger.log(Level.INFO, "Failed to reconfigure nginx, will retry."); + } else { + if (state == NginxState.RELOAD_REQUIRED) { + try { + proxyCommands.reload(); + changeState(convergeTo); + } catch (Exception e) { + logger.log(Level.INFO, "Failed to reconfigure nginx, will retry."); + logger.log(Level.FINE, "Exception from nginx reload", e); + } + } else if (state != convergeTo) { + // Already running, but state not updated + changeState(convergeTo); } } - } else if (wantedState == NginxState.STOPPED) { + } else if (convergeTo == NginxState.STOPPED) { if (proxyCommands.isRunning()) { try { proxyCommands.stop(); - changeState(wantedState); - executorService.shutdownNow(); } catch (Exception e) { logger.log(Level.INFO, "Failed to stop nginx, will retry"); + logger.log(Level.FINE, "Exception from nginx stop", e); } } + if (! proxyCommands.isRunning()) { + changeState(convergeTo); + executorService.shutdownNow(); + } } else { - logger.warning("Unknown state " + wantedState); + logger.warning("Unknown state " + convergeTo); } } @@ -150,9 +161,9 @@ public class DataplaneProxyService extends AbstractComponent { super.deconstruct(); wantedState = NginxState.STOPPED; try { - executorService.awaitTermination(5, TimeUnit.MINUTES); + executorService.awaitTermination(30, TimeUnit.SECONDS); } catch (InterruptedException e) { - logger.log(Level.WARNING, "Error shutting down proxy reload thread"); + logger.log(Level.WARNING, "Error shutting down proxy reload thread", e); } } @@ -203,10 +214,12 @@ public class DataplaneProxyService extends AbstractComponent { return template.replaceAll("\\$\\{%s\\}".formatted(key), value); } + // Used for testing NginxState state() { return state; } + // Used for testing NginxState wantedState() { return wantedState; } diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java index 3e6ee3a35a2..501438c8c2e 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java @@ -2,6 +2,7 @@ package com.yahoo.container.jdisc.metric; import ai.vespa.metrics.ContainerMetrics; +import ai.vespa.metrics.ContainerMetrics; import com.yahoo.component.AbstractComponent; import com.yahoo.component.annotation.Inject; import com.yahoo.jdisc.Metric; @@ -141,7 +142,7 @@ public class MetricUpdater extends AbstractComponent { "home", System.getProperty("java.home"), "vendor", System.getProperty("java.vm.vendor"), "arch", System.getProperty("os.arch"))); - metric.set("jdisc.jvm", Runtime.version().feature(), ctx); + metric.set(ContainerMetrics.JDISC_JVM.baseName(), Runtime.version().feature(), ctx); } private void tlsMetrics() { diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java index 947c99adf51..351890e2a3a 100644 --- a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java @@ -22,13 +22,16 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class DataplaneProxyServiceTest { private FileSystem fileSystem = Jimfs.newFileSystem(); - DataplaneProxyService.ProxyCommands proxyCommandsMock = Mockito.mock(DataplaneProxyService.ProxyCommands.class); + DataplaneProxyService.ProxyCommands proxyCommandsMock = mock(DataplaneProxyService.ProxyCommands.class); @Test public void starts_and_reloads_if_no_errors() throws IOException { @@ -122,6 +125,35 @@ public class DataplaneProxyServiceTest { assertFalse(proxyCommands.isRunning()); } + @Test + public void stops_executor_when_nginx_stop_throws() throws IOException, InterruptedException { + DataplaneProxyService.ProxyCommands mockProxyCommands = mock(DataplaneProxyService.ProxyCommands.class); + DataplaneProxyService service = dataplaneProxyService(mockProxyCommands); + service.converge(); + when (mockProxyCommands.isRunning()).thenReturn(true); + assertEquals(DataplaneProxyService.NginxState.RUNNING, service.state()); + + reset(proxyCommandsMock); + + when(mockProxyCommands.isRunning()).thenReturn(true).thenReturn(false); + doThrow(new RuntimeException("Failed to stop proxy")).when(proxyCommandsMock).stop(); + Thread thread = new Thread(service::deconstruct);// deconstruct will block until nginx is stopped + thread.start(); + + // Wait for above thread to set the wanted state to STOPPED + while (service.wantedState() != DataplaneProxyService.NginxState.STOPPED) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + } + } + service.converge(); + assertEquals(service.state(), DataplaneProxyService.NginxState.STOPPED); + thread.join(); + + verify(mockProxyCommands, times(1)).stop(); + } + private DataplaneProxyService dataplaneProxyService(DataplaneProxyService.ProxyCommands proxyCommands) throws IOException { Path root = fileSystem.getPath("/opt/vespa"); diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index e439f7905cc..c41c1c79149 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -5414,6 +5414,8 @@ "public java.lang.Integer getRerankCount()", "public void setKeepRankCount(int)", "public java.lang.Integer getKeepRankCount()", + "public void setRankScoreDropLimit(double)", + "public java.lang.Double getRankScoreDropLimit()", "public com.yahoo.prelude.Location getLocation()", "public void setLocation(com.yahoo.prelude.Location)", "public void setLocation(java.lang.String)", @@ -5449,6 +5451,7 @@ "public static final java.lang.String QUERYCACHE", "public static final java.lang.String RERANKCOUNT", "public static final java.lang.String KEEPRANKCOUNT", + "public static final java.lang.String RANKSCOREDROPLIMIT", "public static final java.lang.String MATCH_PHASE", "public static final java.lang.String DIVERSITY", "public static final java.lang.String SOFTTIMEOUT", diff --git a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java index c9e125ab55f..ecce5ddd740 100644 --- a/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java +++ b/container-search/src/main/java/com/yahoo/search/handler/SearchHandler.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.handler; +import ai.vespa.metrics.ContainerMetrics; import ai.vespa.cloud.ZoneInfo; import ai.vespa.metrics.ContainerMetrics; import com.yahoo.collections.Tuple2; @@ -81,7 +82,7 @@ public class SearchHandler extends LoggingRequestHandler { private static final CompoundName FORCE_TIMESTAMPS = CompoundName.from("trace.timestamps"); /** Event name for number of connections to the search subsystem */ - private static final String SEARCH_CONNECTIONS = "search_connections"; + private static final String SEARCH_CONNECTIONS = ContainerMetrics.SEARCH_CONNECTIONS.baseName(); static final String RENDER_LATENCY_METRIC = ContainerMetrics.JDISC_RENDER_LATENCY.baseName(); static final String MIME_DIMENSION = "mime"; static final String RENDERER_DIMENSION = "renderer"; diff --git a/container-search/src/main/java/com/yahoo/search/query/Ranking.java b/container-search/src/main/java/com/yahoo/search/query/Ranking.java index 5426268d173..ac32bec80ef 100644 --- a/container-search/src/main/java/com/yahoo/search/query/Ranking.java +++ b/container-search/src/main/java/com/yahoo/search/query/Ranking.java @@ -42,6 +42,7 @@ public class Ranking implements Cloneable { public static final String QUERYCACHE = "queryCache"; public static final String RERANKCOUNT = "rerankCount"; public static final String KEEPRANKCOUNT = "keepRankCount"; + public static final String RANKSCOREDROPLIMIT = "rankScoreDropLimit"; public static final String MATCH_PHASE = "matchPhase"; public static final String DIVERSITY = "diversity"; public static final String SOFTTIMEOUT = "softtimeout"; @@ -63,6 +64,7 @@ public class Ranking implements Cloneable { argumentType.addField(new FieldDescription(QUERYCACHE, "boolean")); argumentType.addField(new FieldDescription(RERANKCOUNT, "integer")); argumentType.addField(new FieldDescription(KEEPRANKCOUNT, "integer")); + argumentType.addField(new FieldDescription(RANKSCOREDROPLIMIT, "double")); argumentType.addField(new FieldDescription(MATCH_PHASE, new QueryProfileFieldType(MatchPhase.getArgumentType()), "matchPhase")); argumentType.addField(new FieldDescription(DIVERSITY, new QueryProfileFieldType(Diversity.getArgumentType()))); argumentType.addField(new FieldDescription(SOFTTIMEOUT, new QueryProfileFieldType(SoftTimeout.getArgumentType()))); @@ -94,6 +96,7 @@ public class Ranking implements Cloneable { private Integer rerankCount = null; private Integer keepRankCount = null; + private Double rankScoreDropLimit = null; private RankProperties rankProperties = new RankProperties(); @@ -165,6 +168,11 @@ public class Ranking implements Cloneable { /** Returns the keep-rank-count that will be used, or null if not set */ public Integer getKeepRankCount() { return keepRankCount; } + /** Sets the rank-score-drop-limit that will be used, or null if not set */ + public void setRankScoreDropLimit(double rankScoreDropLimit) { this.rankScoreDropLimit = rankScoreDropLimit; } + /** Returns the rank-score-drop-limit that will be used, or null if not set */ + public Double getRankScoreDropLimit() { return rankScoreDropLimit; } + /** Returns the location of this query, or null if none */ public Location getLocation() { return location; } @@ -241,6 +249,8 @@ public class Ranking implements Cloneable { rankProperties.put("vespa.hitcollector.heapsize", rerankCount); if (keepRankCount != null) rankProperties.put("vespa.hitcollector.arraysize", keepRankCount); + if (rankScoreDropLimit != null) + rankProperties.put("vespa.hitcollector.rankscoredroplimit", rankScoreDropLimit); } private void prepareNow(Freshness freshness) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index da0051c527c..240da3f123f 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -77,6 +77,7 @@ public class QueryProperties extends Properties { if (key.last().equals(Ranking.QUERYCACHE)) return ranking.getQueryCache(); if (key.last().equals(Ranking.RERANKCOUNT)) return ranking.getRerankCount(); if (key.last().equals(Ranking.KEEPRANKCOUNT)) return ranking.getKeepRankCount(); + if (key.last().equals(Ranking.RANKSCOREDROPLIMIT)) return ranking.getRankScoreDropLimit(); if (key.last().equals(Ranking.LIST_FEATURES)) return ranking.getListFeatures(); } else if (key.size() >= 3 && key.get(1).equals(Ranking.MATCH_PHASE)) { @@ -203,6 +204,8 @@ public class QueryProperties extends Properties { ranking.setRerankCount(asInteger(value, null)); else if (key.last().equals(Ranking.KEEPRANKCOUNT)) ranking.setKeepRankCount(asInteger(value, null)); + else if (key.last().equals(Ranking.RANKSCOREDROPLIMIT)) + ranking.setRankScoreDropLimit(asDouble(value, null)); else if (key.last().equals(Ranking.LIST_FEATURES)) ranking.setListFeatures(asBoolean(value,false)); else diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java index 742f4b0f889..f510d68d32e 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ContainerLatencySearcher.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.searchers; +import ai.vespa.metrics.ContainerMetrics; import com.yahoo.component.chain.dependencies.After; import com.yahoo.metrics.simple.Gauge; import com.yahoo.metrics.simple.Point; @@ -21,7 +22,7 @@ public class ContainerLatencySearcher extends Searcher { private final Gauge latencyGauge; public ContainerLatencySearcher(MetricReceiver metrics) { - latencyGauge = metrics.declareGauge("query_container_latency"); + latencyGauge = metrics.declareGauge(ContainerMetrics.QUERY_CONTAINER_LATENCY.baseName()); } @Override diff --git a/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java index 35a3c86f763..846b8881cef 100755 --- a/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/RateLimitingSearcher.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.searchers; +import ai.vespa.metrics.ContainerMetrics; import com.yahoo.component.annotation.Inject; import com.yahoo.cloud.config.ClusterInfoConfig; @@ -60,7 +61,7 @@ public class RateLimitingSearcher extends Searcher { public static final CompoundName idDimensionKey = CompoundName.from("rate.idDimension"); public static final CompoundName dryRunKey = CompoundName.from("rate.dryRun"); - private static final String requestsOverQuotaMetricName = "requestsOverQuota"; + private static final String requestsOverQuotaMetricName = ContainerMetrics.REQUESTS_OVER_QUOTA.baseName(); /** Used to divide quota by nodes. Assumption: All nodes get the same share of traffic. */ private final int nodeCount; diff --git a/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java b/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java index 5537528d36b..76da6407166 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/test/RateLimitingBenchmark.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.searchers.test; +import ai.vespa.metrics.ContainerMetrics; import com.yahoo.cloud.config.ClusterInfoConfig; import com.yahoo.component.chain.Chain; import com.yahoo.metrics.simple.Bucket; @@ -114,7 +115,7 @@ public class RateLimitingBenchmark { private int rejectedRequests(int id) { Point context = metric.pointBuilder().set("id", toClientId(id)).build(); - UntypedMetric rejectedRequestsMetric = metricSnapshot.getMapForMetric("requestsOverQuota").get(context); + UntypedMetric rejectedRequestsMetric = metricSnapshot.getMapForMetric(ContainerMetrics.REQUESTS_OVER_QUOTA.baseName()).get(context); if (rejectedRequestsMetric == null) return 0; return (int)rejectedRequestsMetric.getCount(); } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java index ffaee34e727..d73a7410cc6 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/application/v4/model/InstanceInformation.java @@ -26,18 +26,21 @@ public class InstanceInformation { public URI url; public String scope; public RoutingMethod routingMethod; + public String auth; @JsonCreator public Endpoint(@JsonProperty("cluster") String cluster , @JsonProperty("tls") boolean tls, @JsonProperty("url") URI url, @JsonProperty("scope") String scope, - @JsonProperty("routingMethod") RoutingMethod routingMethod) { + @JsonProperty("routingMethod") RoutingMethod routingMethod, + @JsonProperty("authMethod") String auth) { this.cluster = cluster; this.tls = tls; this.url = url; this.scope = scope; this.routingMethod = routingMethod; + this.auth = auth; } @Override @@ -47,6 +50,7 @@ public class InstanceInformation { ", tls=" + tls + ", url=" + url + ", scope='" + scope + '\'' + + ", authType='" + auth + '\'' + ", routingMethod=" + routingMethod + '}'; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java index 0f3f9479176..68852f90055 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java @@ -129,7 +129,7 @@ public class EndpointCertificates { } private Optional<EndpointCertificateMetadata> getOrProvision(Instance instance, ZoneId zone, DeploymentSpec deploymentSpec) { - if (useRandomizedCert.with(FetchVector.Dimension.APPLICATION_ID, instance.id().toFullString()).value()) { + if (useRandomizedCert.with(FetchVector.Dimension.APPLICATION_ID, instance.id().serializedForm()).value()) { return Optional.of(assignFromPool(instance, zone)); } Optional<AssignedCertificate> assignedCertificate = curator.readAssignedCertificate(TenantAndApplicationId.from(instance.id()), Optional.of(instance.id().instance())); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index aa3f78f1395..693275987c5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -139,7 +139,6 @@ import com.yahoo.yolean.Exceptions; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; -import java.io.UncheckedIOException; import java.net.MalformedURLException; import java.net.URI; import java.net.URISyntaxException; @@ -911,14 +910,17 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler { } private HttpResponse listTokens(String tenant, HttpRequest request) { - List<DataplaneTokenVersions> dataplaneTokenVersions = controller.dataplaneTokenService().listTokens(TenantName.from(tenant)); + var tokens = controller.dataplaneTokenService().listTokens(TenantName.from(tenant)) + .stream().sorted(Comparator.comparing(DataplaneTokenVersions::tokenId)).toList(); Slime slime = new Slime(); Cursor tokensArray = slime.setObject().setArray("tokens"); - for (DataplaneTokenVersions token : dataplaneTokenVersions) { + for (DataplaneTokenVersions token : tokens) { Cursor tokenObject = tokensArray.addObject(); tokenObject.setString("id", token.tokenId().value()); Cursor fingerprintsArray = tokenObject.setArray("versions"); - for (DataplaneTokenVersions.Version tokenVersion : token.tokenVersions()) { + var versions = token.tokenVersions().stream() + .sorted(Comparator.comparing(DataplaneTokenVersions.Version::creationTime)).toList(); + for (var tokenVersion : versions) { Cursor fingerprintObject = fingerprintsArray.addObject(); fingerprintObject.setString("fingerprint", tokenVersion.fingerPrint().value()); fingerprintObject.setString("created", tokenVersion.creationTime().toString()); diff --git a/dist/vespa.spec b/dist/vespa.spec index 091a17e822b..c1a409e057f 100644 --- a/dist/vespa.spec +++ b/dist/vespa.spec @@ -90,7 +90,7 @@ BuildRequires: llvm-devel BuildRequires: vespa-boost-devel >= 1.76.0-1 BuildRequires: vespa-openssl-devel >= 1.1.1o-1 %define _use_vespa_openssl 1 -BuildRequires: vespa-gtest = 1.11.0 +BuildRequires: vespa-gtest = 1.13.0 %define _use_vespa_gtest 1 BuildRequires: vespa-lz4-devel >= 1.9.4-1 BuildRequires: vespa-onnxruntime-devel = 1.13.1 @@ -192,7 +192,7 @@ Requires: unzip Requires: zlib Requires: zstd %if 0%{?el8} -Requires: vespa-gtest = 1.11.0 +Requires: vespa-gtest = 1.13.0 %endif %if 0%{?el9} Requires: gtest diff --git a/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java b/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java index 927672a43f7..97185e9c703 100644 --- a/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java +++ b/metrics/src/main/java/ai/vespa/metrics/HostedNodeAdminMetrics.java @@ -30,9 +30,9 @@ public enum HostedNodeAdminMetrics implements VespaMetrics { NET_IN_BYTES("net.in.bytes", Unit.BYTE, "Network bytes received (rxBytes) (COUNT metric)"), NET_IN_ERROR("net.in.errors", Unit.FAILURE, "Network receive errors (rxErrors)"), NET_IN_DROPPED("net.in.dropped", Unit.PACKET, "Inbound network packets dropped (rxDropped)"), - NET_OUT_BYTES("net.in.bytes", Unit.BYTE, "Network bytes sent (txBytes) (COUNT metric)"), - NET_OUT_ERROR("net.in.errors", Unit.FAILURE, "Network send errors (txErrors)"), - NET_OUT_DROPPED("net.in.dropped", Unit.PACKET, "Outbound network packets dropped (txDropped)"), + NET_OUT_BYTES("net.out.bytes", Unit.BYTE, "Network bytes sent (txBytes) (COUNT metric)"), + NET_OUT_ERROR("net.out.errors", Unit.FAILURE, "Network send errors (txErrors)"), + NET_OUT_DROPPED("net.out.dropped", Unit.PACKET, "Outbound network packets dropped (txDropped)"), BANDWIDTH_LIMIT("bandwidth.limit", Unit.BYTE_PER_SECOND, "Available network bandwidth"); private final String name; @@ -58,4 +58,3 @@ public enum HostedNodeAdminMetrics implements VespaMetrics { } } - diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index f12f60dcc8e..f690b8e8c8a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -32,8 +32,9 @@ class TensorConverter { } private static Values readValuesOf(Onnx.TensorProto tensorProto) { + var elemType = Onnx.TensorProto.DataType.forNumber(tensorProto.getDataType()); if (tensorProto.hasRawData()) { - switch (tensorProto.getDataType()) { + switch (elemType) { case BOOL: return new RawBoolValues(tensorProto); case FLOAT: return new RawFloatValues(tensorProto); case DOUBLE: return new RawDoubleValues(tensorProto); @@ -41,7 +42,7 @@ class TensorConverter { case INT64: return new RawLongValues(tensorProto); } } else { - switch (tensorProto.getDataType()) { + switch (elemType) { case FLOAT: return new FloatValues(tensorProto); case DOUBLE: return new DoubleValues(tensorProto); case INT32: return new IntValues(tensorProto); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 35ec1d8c54a..deac950d324 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -37,7 +37,8 @@ class TypeConverter { static OrderedTensorType typeFrom(Onnx.TypeProto type) { String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(type.getTensorType().getElemType())); + var elemType = Onnx.TensorProto.DataType.forNumber(type.getTensorType().getElemType()); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(elemType)); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); @@ -52,8 +53,8 @@ class TypeConverter { } static OrderedTensorType typeFrom(Onnx.TensorProto tensor) { - return OrderedTensorType.fromDimensionList(toValueType(tensor.getDataType()), - tensor.getDimsList()); + var elemType = Onnx.TensorProto.DataType.forNumber(tensor.getDataType()); + return OrderedTensorType.fromDimensionList(toValueType(elemType), tensor.getDimsList()); } private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { diff --git a/model-integration/src/main/protobuf/onnx.proto b/model-integration/src/main/protobuf/onnx.proto index dc6542867e0..1d265ae9f28 100644 --- a/model-integration/src/main/protobuf/onnx.proto +++ b/model-integration/src/main/protobuf/onnx.proto @@ -3,8 +3,8 @@ // -// Copyright (c) Facebook Inc. and Microsoft Corporation. -// Licensed under the MIT license. +// SPDX-License-Identifier: Apache-2.0 + syntax = "proto2"; @@ -20,23 +20,16 @@ package onnx; // // This document describes the syntax of models and their computation graphs, // as well as the standard data types. Together, they are referred to as the ONNX -// Intermediate Representation, or 'IR' for short. +// Intermediate Representation, or 'IR' for short. // // The normative semantic specification of the ONNX IR is found in docs/IR.md. // Definitions of the built-in neural network operators may be found in docs/Operators.md. // Notes // -// Release -// -// We are still in the very early stage of defining ONNX. The current -// version of ONNX is a starting point. While we are actively working -// towards a complete spec, we would like to get the community involved -// by sharing our working version of ONNX. -// // Protobuf compatibility -// -// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf // that is compatible with both protobuf v2 and v3. This means that we do not use any // protobuf features that are only available in one of the two versions. // @@ -60,22 +53,60 @@ enum Version { _START_VERSION = 0; // The version field is always serialized and we will use it to store the // version that the graph is generated from. This helps us set up version - // control. We should use version as - // xx(major) - xx(minor) - xxxx(bugfix) - // and we are starting with 0x00000001 (0.0.1), which was the - // version we published on Oct 10, 2017. - IR_VERSION_2017_10_10 = 0x00000001; + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; - // IR_VERSION 0.0.2 published on Oct 30, 2017 + // IR_VERSION 2 published on Oct 30, 2017 // - Added type discriminator to AttributeProto to support proto3 users - IR_VERSION_2017_10_30 = 0x00000002; + IR_VERSION_2017_10_30 = 0x0000000000000002; - // IR VERSION 0.0.3 published on Nov 3, 2017 + // IR VERSION 3 published on Nov 3, 2017 // - For operator versioning: // - Added new message OperatorSetIdProto // - Added opset_import in ModelProto // - For vendor extensions, added domain in NodeProto - IR_VERSION = 0x00000003; + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on May 8, 2020 + // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IR_VERSION_2020_5_8 = 0x0000000000000007; + + // IR VERSION 8 published on July 30, 2021 + // Introduce TypeProto.SparseTensor + // Introduce TypeProto.Optional + // Added a list of FunctionProtos local to the model + // Deprecated since_version and operator status from FunctionProto + IR_VERSION_2021_7_30 = 0x0000000000000008; + + // IR VERSION 9 published on May 5, 2023 + // Added AttributeProto to FunctionProto so that default attribute values can be set. + // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. + IR_VERSION = 0x0000000000000009; } // Attributes @@ -95,17 +126,21 @@ message AttributeProto { STRING = 3; TENSOR = 4; GRAPH = 5; + SPARSE_TENSOR = 11; + TYPE_PROTO = 13; FLOATS = 6; INTS = 7; STRINGS = 8; TENSORS = 9; GRAPHS = 10; + SPARSE_TENSORS = 12; + TYPE_PROTOS = 14; } // The name field MUST be present for this version of the IR. optional string name = 1; // namespace Attribute - + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. // In this case, this AttributeProto does not contain data, and it's a reference of attribute // in parent scope. @@ -117,10 +152,10 @@ message AttributeProto { // The type field MUST be present for this version of the IR. // For 0.0.1 versions of the IR, this field was not defined, and - // implementations needed to use has_field hueristics to determine + // implementations needed to use has_field heuristics to determine // which value field was in use. For IR_VERSION 0.0.2 or later, this // field MUST be set and match the f|i|s|t|... field in use. This - // change was made to accomodate proto3 implementations. + // change was made to accommodate proto3 implementations. optional AttributeType type = 20; // discriminator that indicates which field below is in use // Exactly ONE of the following fields must be present for this version of the IR @@ -129,14 +164,18 @@ message AttributeProto { optional bytes s = 4; // UTF-8 string optional TensorProto t = 5; // tensor value optional GraphProto g = 6; // graph + optional SparseTensorProto sparse_tensor = 22; // sparse tensor value // Do not use field below, it's deprecated. // optional ValueProto v = 12; // value - subsumes everything but graph + optional TypeProto tp = 14; // type proto repeated float floats = 7; // list of floats repeated int64 ints = 8; // list of ints repeated bytes strings = 9; // list of UTF-8 strings repeated TensorProto tensors = 10; // list of tensors repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors + repeated TypeProto type_protos = 15;// list of type protos } // Defines information on value, including the name, the type, and @@ -144,7 +183,8 @@ message AttributeProto { message ValueInfoProto { // This field MUST be present in this version of the IR. optional string name = 1; // namespace Value - // This field MUST be present in this version of the IR. + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. optional TypeProto type = 2; // A human-readable documentation for this value. Markdown is allowed. optional string doc_string = 3; @@ -155,7 +195,7 @@ message ValueInfoProto { // Computation graphs are made up of a DAG of nodes, which represent what is // commonly called a "layer" or "pipeline stage" in machine learning frameworks. // -// For example, it can be a node of type "Conv" that takes in an image, a filter +// For example, it can be a node of type "Conv" that takes in an image, a filter // tensor and a bias tensor, and produces the convolved output. message NodeProto { repeated string input = 1; // namespace Value @@ -177,12 +217,130 @@ message NodeProto { optional string doc_string = 6; } +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been performed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update steps (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each step. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Thus, no initializer would be changed by default. + optional GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this field contains loss node, gradient node, + // optimizer node, increment of iteration count. + // + // An execution of the training algorithm step is performed by executing the + // graph obtained by combining the inference graph (namely "ModelProto.graph") + // and the "algorithm" graph. That is, the actual + // input/initializer/output/node/value_info/sparse_initializer list of + // the training graph is the concatenation of + // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + // in that order. This combined graph must satisfy the normal ONNX conditions. + // Now, let's provide a visualization of graph combination for clarity. + // Let the inference graph (i.e., "ModelProto.graph") be + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + // and the "algorithm" graph be + // tensor_d -> Add -> tensor_e + // The combination process results + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + // + // Notice that an input of a node in the "algorithm" graph may reference the + // output of a node in the inference graph (but not the other way round). Also, inference + // node cannot reference inputs of "algorithm". With these restrictions, inference graph + // can always be run independently without training information. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Evaluating the default training step never + // update any initializers. + optional GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two + // variables may not have the same name. This ensures that one + // variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + // 4. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + // Models // // ModelProto is a top-level file/container format for bundling a ML model and // associating its computation graph with metadata. // -// The semantics of the model are described by the associated GraphProto. +// The semantics of the model are described by the associated GraphProto's. message ModelProto { // The version of the IR this model targets. See Version enum above. // This field MUST be present. @@ -227,18 +385,58 @@ message ModelProto { // Named metadata values; keys should be distinct. repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; + + // A list of function protos local to the model. + // + // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". + // In case of any conflicts the behavior (whether the model local functions are given higher priority, + // or standard operator sets are given higher priotity or this is treated as error) is defined by + // the runtimes. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto and other model local FunctionProtos. + // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto + // or by 2 FunctionProtos then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same for every node in the function body. + // + // One FunctionProto can reference other FunctionProto in the model, however, recursive reference + // is not allowed. + repeated FunctionProto functions = 25; }; // StringStringEntryProto follows the pattern for cross-proto-version maps. // See https://developers.google.com/protocol-buffers/docs/proto3#maps message StringStringEntryProto { optional string key = 1; - optional string value= 2; + optional string value = 2; }; +message TensorAnnotation { + optional string tensor_name = 1; + // <key, value> pairs to annotate tensor specified by <tensor_name> above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + // Graphs // -// A graph defines the computational logic of a model and is comprised of a parameterized +// A graph defines the computational logic of a model and is comprised of a parameterized // list of nodes that form a directed acyclic graph based on their inputs and outputs. // This is the equivalent of the "network" or "graph" in many deep learning // frameworks. @@ -250,10 +448,14 @@ message GraphProto { optional string name = 2; // namespace Graph // A list of named tensor values, used to specify constant inputs of the graph. - // Each TensorProto entry must have a distinct name (within the list) that - // also appears in the input list. + // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + // The name MUST be unique across both initializer and sparse_initializer, + // but the name MAY also appear in the input list. repeated TensorProto initializer = 5; + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + // A human-readable documentation for this graph. Markdown is allowed. optional string doc_string = 10; @@ -265,13 +467,14 @@ message GraphProto { // must be distinct. It is optional for a value to appear in value_info list. repeated ValueInfoProto value_info = 13; - // DO NOT USE the following fields, they were deprecated from earlier versions. - // repeated string input = 3; - // repeated string output = 4; - // optional int64 ir_version = 6; - // optional int64 producer_version = 7; - // optional string producer_tag = 8; - // optional string domain = 9; + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + reserved 3, 4, 6 to 9; + reserved "ir_version", "producer_version", "producer_tag", "domain"; } // Tensors @@ -291,13 +494,32 @@ message TensorProto { STRING = 8; // string BOOL = 9; // bool - // Advanced types + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. FLOAT16 = 10; + DOUBLE = 11; UINT32 = 12; UINT64 = 13; COMPLEX64 = 14; // complex with float32 real and imaginary components COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero + // Future extensions go here. } @@ -305,7 +527,8 @@ message TensorProto { repeated int64 dims = 1; // The data type of the tensor. - optional DataType data_type = 2; + // This field MUST have a valid TensorProto.DataType value + optional int32 data_type = 2; // For very large tensors, we may want to store them in chunks, in which // case the following fields will specify the segment that is stored in @@ -324,17 +547,17 @@ message TensorProto { // For float and complex64 values // Complex64 tensors are encoded as a single array of floats, // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the + // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. repeated float float_data = 4 [packed = true]; - // For int32, uint8, int8, uint16, int16, bool, and float16 values - // float16 values must be bit-wise converted to an uint16_t prior + // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values + // float16 and float8 values must be bit-wise converted to an uint16_t prior // to writing to the buffer. // When this field is present, the data_type field MUST be - // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ repeated int32 int32_data = 5 [packed = true]; // For strings. @@ -371,10 +594,32 @@ message TensorProto { // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED optional bytes raw_data = 9; + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + optional DataLocation data_location = 14; + // For double - // Complex64 tensors are encoded as a single array of doubles, + // Complex128 tensors are encoded as a single array of doubles, // with the real components appearing in odd numbered positions, - // and the corresponding imaginary component apparing in the + // and the corresponding imaginary component appearing in the // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] // is encoded as [1.0, 2.0 ,3.0 ,4.0] // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 @@ -386,6 +631,30 @@ message TensorProto { repeated uint64 uint64_data = 11 [packed = true]; } +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + // values must have a non-empty name present which serves as a name for SparseTensorProto + // when used in sparse_initializer list. + optional TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + optional TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + // Defines a tensor shape. A dimension can be either an integer value // or a symbolic variable. A symbolic variable represents an unknown // dimension. @@ -398,36 +667,13 @@ message TensorShapeProto { // Standard denotation can optionally be used to denote tensor // dimensions with standard semantic descriptions to ensure // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. optional string denotation = 3; }; repeated Dimension dim = 1; } -// A set of pre-defined constants to be used as values for -// the standard denotation field in TensorShapeProto.Dimension -// for semantic description of the tensor dimension. -message DenotationConstProto { - // Describe a batch number dimension. - optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; - // Describe a channel dimension. - optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; - // Describe a time dimension. - optional string DATA_TIME = 3 [default = "DATA_TIME"]; - // Describe a feature dimension. This is typically a feature - // dimension in RNN and/or spatial dimension in CNN. - optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; - // Describe a filter in-channel dimension. This is the dimension - // that is identical (in size) to the channel dimension of the input - // image feature maps. - optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; - // Describe a filter out channel dimension. This is the dimension - // that is identical (int size) to the channel dimension of the output - // image feature maps. - optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; - // Describe a filter spatial dimension. - optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; -} - // Types // // The standard ONNX data types. @@ -435,8 +681,43 @@ message TypeProto { message Tensor { // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + optional int32 elem_type = 1; + optional TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + optional TypeProto elem_type = 1; + }; + + // map<K,V> + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + optional int32 key_type = 1; + // This field MUST be present for this version of the IR. + optional TypeProto value_type = 2; + }; + + // wrapper for Tensor, Sequence, or Map + message Optional { + // The type and optional shape of the element wrapped. + // This field MUST be present for this version of the IR. + // Possible values correspond to OptionalProto.DataType enum + optional TypeProto elem_type = 1; + }; + + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value // This field MUST be present for this version of the IR. - optional TensorProto.DataType elem_type = 1; + optional int32 elem_type = 1; optional TensorShapeProto shape = 2; } @@ -445,7 +726,31 @@ message TypeProto { // The type of a tensor. Tensor tensor_type = 1; + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + // The type of an optional. + Optional optional_type = 9; + + + // Type of the sparse tensor + SparseTensor sparse_tensor_type = 8; + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + optional string denotation = 6; } // Operator Sets @@ -461,4 +766,70 @@ message OperatorSetIdProto { // The version of the operator set being identified. // This field MUST be present in this version of the IR. optional int64 version = 2; -}
\ No newline at end of file +} + +// Operator/function status. +enum OperatorStatus { + EXPERIMENTAL = 0; + STABLE = 1; +} + +message FunctionProto { + // The name of the function, similar usage of op_type in OperatorProto. + // Combined with FunctionProto.domain, this forms the unique identity of + // the FunctionProto. + optional string name = 1; + + // Deprecated since IR Version 8 + // optional int64 since_version = 2; + reserved 2; + reserved "since_version"; + + // Deprecated since IR Version 8 + // optional OperatorStatus status = 3; + reserved 3; + reserved "status"; + + // The inputs and outputs of the function. + repeated string input = 4; + repeated string output = 5; + + // The attribute parameters of the function. + // It is for function parameters without default values. + repeated string attribute = 6; + + // The attribute protos of the function. + // It is for function attributes with default values. + // A function attribute shall be represented either as + // a string attribute or an AttributeProto, not both. + repeated AttributeProto attribute_proto = 11; + + // The nodes in the function. + repeated NodeProto node = 7; + // A human-readable documentation for this function. Markdown is allowed. + optional string doc_string = 8; + + // The OperatorSets this function body (graph) relies on. + // + // All nodes in the function body (graph) will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. This means at most one version can be relied + // for one domain. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto + // and ModelProto then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same. + + repeated OperatorSetIdProto opset_import = 9; + + // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of + // the FunctionProto. + optional string domain = 10; +} + + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; + diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 3ef96cdf166..2b707c3beb3 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -775,10 +775,10 @@ public class OnnxOperationsTestCase { Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder(); tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get())); if (tensor.type().valueType() == TensorType.Value.FLOAT) { - builder.setDataType(Onnx.TensorProto.DataType.FLOAT); + builder.setDataType(Onnx.TensorProto.DataType.FLOAT_VALUE); tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue())); } else { - builder.setDataType(Onnx.TensorProto.DataType.DOUBLE); + builder.setDataType(Onnx.TensorProto.DataType.DOUBLE_VALUE); tensor.valueIterator().forEachRemaining(builder::addDoubleData); } Onnx.TensorProto val = builder.build(); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java index cd7eee8ba50..ad067110f59 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminImpl.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.nodeadmin; +import ai.vespa.metrics.ContainerMetrics; import com.yahoo.jdisc.Timer; import com.yahoo.vespa.hosted.node.admin.container.ContainerStats; import com.yahoo.vespa.hosted.node.admin.container.metrics.Counter; @@ -80,9 +81,9 @@ public class NodeAdminImpl implements NodeAdmin { new Dimensions(Map.of("src", "node-agents"))); this.procMeminfoReader = procMeminfoReader; - this.jvmHeapUsed = metrics.declareGauge("mem.heap.used"); - this.jvmHeapFree = metrics.declareGauge("mem.heap.free"); - this.jvmHeapTotal = metrics.declareGauge("mem.heap.total"); + this.jvmHeapUsed = metrics.declareGauge(ContainerMetrics.MEM_HEAP_USED.baseName()); + this.jvmHeapFree = metrics.declareGauge(ContainerMetrics.MEM_HEAP_FREE.baseName()); + this.jvmHeapTotal = metrics.declareGauge(ContainerMetrics.MEM_HEAP_TOTAL.baseName()); this.containerCount = metrics.declareGauge("container.count"); this.metrics = metrics; } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java index dc86daf2c67..9bc18533ddf 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/LockedNodeList.java @@ -17,9 +17,16 @@ import java.util.Objects; */ public final class LockedNodeList extends NodeList { + private final Mutex lock; + public LockedNodeList(List<Node> nodes, Mutex lock) { super(nodes, false); - Objects.requireNonNull(lock, "lock must be non-null"); + this.lock = Objects.requireNonNull(lock, "lock must be non-null"); + } + + /** Returns a new LockedNodeList with the for the same lock. */ + public LockedNodeList childList(List<Node> nodes) { + return new LockedNodeList(nodes, lock); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java index 60fd07951c6..20c246b3ebd 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeMutex.java @@ -28,4 +28,5 @@ public class NodeMutex implements Mutex { return new NodeMutex(updatedNode, mutex); } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java index d6671d41cbd..9da66413b9c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeRepository.java @@ -229,7 +229,7 @@ public class NodeRepository extends AbstractComponent { applicationNodes.asList(), Agent.system, Optional.of("Application is removed"), - transaction.nested()); + transaction); applications.remove(transaction); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java index 8766dea3d61..e300591fbb2 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.maintenance; import com.yahoo.jdisc.Metric; import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.Node.State; +import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.History; @@ -33,8 +35,12 @@ public class DirtyExpirer extends Expirer { @Override protected void expire(List<Node> expired) { - for (Node expiredNode : expired) - nodeRepository().nodes().fail(expiredNode.hostname(), wantToDeprovisionOnExpiry, Agent.DirtyExpirer, "Node is stuck in dirty"); + nodeRepository().nodes().performOn(NodeList.copyOf(expired), + node -> node.state() == State.dirty && isExpired(node), + (node, lock) -> nodeRepository().nodes().fail(node.hostname(), + wantToDeprovisionOnExpiry, + Agent.DirtyExpirer, + "Node is stuck in dirty")); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java index fa3f9435c70..cb0a8005e87 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/FailedExpirer.java @@ -6,13 +6,14 @@ import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.Zone; import com.yahoo.jdisc.Metric; import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.Node.State; import com.yahoo.vespa.hosted.provision.NodeList; +import com.yahoo.vespa.hosted.provision.NodeMutex; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.Agent; -import com.yahoo.vespa.hosted.provision.node.History; +import com.yahoo.vespa.hosted.provision.node.History.Event.Type; import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Optional; import java.util.function.Predicate; @@ -67,55 +68,47 @@ public class FailedExpirer extends NodeRepositoryMaintainer { @Override protected double maintain() { - NodeList allNodes = nodeRepository.nodes().list(); - List<Node> remainingNodes = new ArrayList<>(allNodes.state(Node.State.failed) - .nodeType(NodeType.tenant, NodeType.host) - .asList()); + Predicate<Node> isExpired = node -> node.state() == State.failed + && node.history().hasEventBefore(Type.failed, clock().instant().minus(expiryFor(node))); + NodeList allNodes = nodeRepository.nodes().list(); // Stale snapshot, not critical. - recycleIf(node -> node.allocation().isEmpty(), remainingNodes, allNodes); - recycleIf(node -> !node.allocation().get().membership().cluster().isStateful() && - node.history().hasEventBefore(History.Event.Type.failed, clock().instant().minus(statelessExpiry)), - remainingNodes, - allNodes); - recycleIf(node -> node.allocation().get().membership().cluster().isStateful() && - node.history().hasEventBefore(History.Event.Type.failed, clock().instant().minus(statefulExpiry)), - remainingNodes, - allNodes); + nodeRepository.nodes().performOn(allNodes.nodeType(NodeType.tenant), + isExpired, + (node, lock) -> recycle(node, List.of(), allNodes).get()); + + nodeRepository.nodes().performOnRecursively(allNodes.nodeType(NodeType.host), + nodes -> isExpired.test(nodes.parent().node()), + nodes -> recycle(nodes.parent().node(), + nodes.children().stream().map(NodeMutex::node).toList(), + allNodes) + .map(List::of).orElse(List.of())); return 1.0; } - /** Recycle the nodes matching condition, and remove those nodes from the nodes list. */ - private void recycleIf(Predicate<Node> condition, List<Node> failedNodes, NodeList allNodes) { - List<Node> nodesToRecycle = failedNodes.stream().filter(condition).toList(); - failedNodes.removeAll(nodesToRecycle); - recycle(nodesToRecycle, allNodes); + private Duration expiryFor(Node node) { + return node.allocation().isEmpty() ? Duration.ZERO + : node.allocation().get().membership().cluster().isStateful() ? statefulExpiry + : statelessExpiry; } - /** Move eligible nodes to dirty or parked. This may be a subset of the given nodes */ - private void recycle(List<Node> nodes, NodeList allNodes) { - List<Node> nodesToRecycle = new ArrayList<>(); - for (Node candidate : nodes) { - Optional<String> reason = shouldPark(candidate, allNodes); - if (reason.isPresent()) { - List<String> unparkedChildren = candidate.type().isHost() ? - allNodes.childrenOf(candidate) - .not() - .state(Node.State.parked) - .mapToList(Node::hostname) : - List.of(); - - if (unparkedChildren.isEmpty()) { - nodeRepository.nodes().park(candidate.hostname(), true, Agent.FailedExpirer, - "Parked by FailedExpirer due to " + reason.get()); - } else { - log.info(String.format("Expired failed node %s was not parked because of unparked children: %s", - candidate.hostname(), String.join(", ", unparkedChildren))); - } + private Optional<Node> recycle(Node node, List<Node> children, NodeList allNodes) { + Optional<String> reason = shouldPark(node, allNodes); + if (reason.isPresent()) { + List<String> unparkedChildren = children.stream() + .filter(child -> child.state() != Node.State.parked) + .map(Node::hostname) + .toList(); + if (unparkedChildren.isEmpty()) { + return Optional.of(nodeRepository.nodes().park(node.hostname(), true, Agent.FailedExpirer, + "Parked by FailedExpirer due to " + reason.get())); } else { - nodesToRecycle.add(candidate); + log.info(String.format("Expired failed node %s was not parked because of unparked children: %s", + node.hostname(), String.join(", ", unparkedChildren))); + return Optional.empty(); } + } else { + return Optional.of(nodeRepository.nodes().deallocate(node, Agent.FailedExpirer, "Expired by FailedExpirer")); } - nodeRepository.nodes().deallocate(nodesToRecycle, Agent.FailedExpirer, "Expired by FailedExpirer"); } /** Returns whether the node should be parked instead of recycled */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java index e9e3fd5179a..d70ee825860 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainer.java @@ -72,7 +72,14 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer { protected double maintain() { List<Node> provisionedSnapshot; try { - provisionedSnapshot = provision(nodeRepository().nodes().list()); + NodeList nodes; + // Host and child nodes are written in separate transactions, but both are written while holding the + // unallocated lock. Hold the unallocated lock while reading nodes to ensure we get all the children + // of newly provisioned hosts. + try (Mutex ignored = nodeRepository().nodes().lockUnallocated()) { + nodes = nodeRepository().nodes().list(); + } + provisionedSnapshot = provision(nodes); } catch (NodeAllocationException | IllegalStateException e) { log.log(Level.WARNING, "Failed to allocate preprovisioned capacity and/or find excess hosts: " + e.getMessage()); return 0; // avoid removing excess hosts @@ -85,16 +92,7 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer { } private double markForRemoval(List<Node> provisionedSnapshot) { - // Group nodes by parent; no parent means it's a host. - Map<Optional<String>, List<Node>> nodesByParent = provisionedSnapshot.stream().collect(groupingBy(Node::parentHostname)); - - // Find all hosts that we once thought were empty (first clause), or whose children are now all removable (second clause). - List<Node> emptyHosts = nodesByParent.get(Optional.<String>empty()).stream() - .filter(host -> host.hostEmptyAt().isPresent() - || nodesByParent.getOrDefault(Optional.of(host.hostname()), List.of()) - .stream().allMatch(HostCapacityMaintainer::canDeprovision)) - .toList(); - + List<Node> emptyHosts = findEmptyOrRemovableHosts(provisionedSnapshot); if (emptyHosts.isEmpty()) return 1; int attempts = 0, success = 0; @@ -108,18 +106,16 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer { // Re-read all nodes under lock and compute the candidates for removal. The actual nodes we want // to mark for removal is the intersection with typeEmptyHosts, which excludes the preprovisioned hosts. Map<Optional<String>, List<Node>> currentNodesByParent = nodeRepository().nodes().list().stream().collect(groupingBy(Node::parentHostname)); - List<Node> candidateHosts = new ArrayList<>(currentNodesByParent.get(Optional.<String>empty())); + List<Node> candidateHosts = new ArrayList<>(getHosts(currentNodesByParent)); candidateHosts.retainAll(typeEmptyHosts); for (Node host : candidateHosts) { attempts++; // Any hosts that are no longer empty should be marked as such, and excluded from removal. - if (currentNodesByParent.getOrDefault(Optional.of(host.hostname()), List.of()) - .stream().anyMatch(n -> ! canDeprovision(n))) { - if (host.hostEmptyAt().isPresent()) { - nodeRepository().nodes().write(host.withHostEmptyAt(null), lock); - } + if (currentNodesByParent.getOrDefault(Optional.of(host.hostname()), List.of()).stream().anyMatch(n -> ! canDeprovision(n)) + && host.hostEmptyAt().isPresent()) { + nodeRepository().nodes().write(host.withHostEmptyAt(null), lock); } // If the host is still empty, we can mark it as empty now, or mark it for removal if it has already expired. else { @@ -282,11 +278,38 @@ public class HostCapacityMaintainer extends NodeRepositoryMaintainer { nodeResources, nodeRepository().clock().instant())) .toList(); - } private static NodeResources toNodeResources(ClusterCapacity clusterCapacity) { - return new NodeResources(clusterCapacity.vcpu(), clusterCapacity.memoryGb(), clusterCapacity.diskGb(), - clusterCapacity.bandwidthGbps()); + return new NodeResources(clusterCapacity.vcpu(), + clusterCapacity.memoryGb(), + clusterCapacity.diskGb(), + clusterCapacity.bandwidthGbps(), + NodeResources.DiskSpeed.valueOf(clusterCapacity.diskSpeed()), + NodeResources.StorageType.valueOf(clusterCapacity.storageType()), + NodeResources.Architecture.valueOf(clusterCapacity.architecture())); + } + + private static List<Node> findEmptyOrRemovableHosts(List<Node> provisionedSnapshot) { + // Group nodes by parent; no parent means it's a host. + var nodesByParent = provisionedSnapshot.stream().collect(groupingBy(Node::parentHostname)); + + // Find all hosts that we once thought were empty (first clause), or whose children are now all removable (second clause). + return getHosts(nodesByParent).stream() + .filter(host -> host.hostEmptyAt().isPresent() || allChildrenCanBeDeprovisioned(nodesByParent, host)) + .toList(); + } + + private static List<Node> getHosts(Map<Optional<String>, List<Node>> nodesByParent) { + return nodesByParent.get(Optional.<String>empty()); } + + private static List<Node> getChildren(Map<Optional<String>, List<Node>> nodesByParent, Node host) { + return nodesByParent.getOrDefault(Optional.of(host.hostname()), List.of()); + } + + private static boolean allChildrenCanBeDeprovisioned(Map<Optional<String>, List<Node>> nodesByParent, Node host) { + return getChildren(nodesByParent, host).stream().allMatch(HostCapacityMaintainer::canDeprovision); + } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java index e1624183607..fe89ba17469 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/HostResumeProvisioner.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.maintenance; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.NodeType; import com.yahoo.jdisc.Metric; import com.yahoo.transaction.Mutex; @@ -49,18 +50,15 @@ public class HostResumeProvisioner extends NodeRepositoryMaintainer { NodeList hosts = allNodes.state(Node.State.provisioned).nodeType(NodeType.host, NodeType.confighost, NodeType.controllerhost); int failures = 0; for (Node host : hosts) { - NodeList children = allNodes.childrenOf(host); try { - log.log(Level.INFO, "Provisioning " + host.hostname() + " with " + children.size() + " children"); - HostIpConfig hostIpConfig = hostProvisioner.provision(host, children.asSet()); - setIpConfig(host, children, hostIpConfig); + HostIpConfig hostIpConfig = hostProvisioner.provision(host); + setIpConfig(host, hostIpConfig); } catch (IllegalArgumentException | IllegalStateException e) { - log.log(Level.INFO, "Could not provision " + host.hostname() + " with " + children.size() + " children, will retry in " + + log.log(Level.INFO, "Could not provision " + host.hostname() + ", will retry in " + interval() + ": " + Exceptions.toMessageString(e)); } catch (FatalProvisioningException e) { failures++; - log.log(Level.SEVERE, "Failed to provision " + host.hostname() + " with " + children.size() + - " children, failing out the host recursively", e); + log.log(Level.SEVERE, "Failed to provision " + host.hostname() + ", failing out the host recursively", e); nodeRepository().nodes().failOrMarkRecursively( host.hostname(), Agent.HostResumeProvisioner, "Failed by HostResumeProvisioner due to provisioning failure"); } catch (RuntimeException e) { @@ -75,19 +73,17 @@ public class HostResumeProvisioner extends NodeRepositoryMaintainer { return asSuccessFactorDeviation(hosts.size(), failures); } - private void setIpConfig(Node host, NodeList children, HostIpConfig hostIpConfig) { + private void setIpConfig(Node host, HostIpConfig hostIpConfig) { if (hostIpConfig.isEmpty()) return; - NodeList nodes = NodeList.of(host).and(children); - for (var node : nodes) { - verifyDns(node, hostIpConfig.require(node.hostname())); - } + hostIpConfig.asMap().forEach((hostname, ipConfig) -> + verifyDns(hostname, host.type(), host.cloudAccount(), ipConfig)); nodeRepository().nodes().setIpConfig(hostIpConfig); } /** Verify DNS configuration of given node */ - private void verifyDns(Node node, IP.Config ipConfig) { + private void verifyDns(String hostname, NodeType hostType, CloudAccount cloudAccount, IP.Config ipConfig) { for (String ipAddress : ipConfig.primary()) { - IP.verifyDns(node.hostname(), ipAddress, node.type(), nodeRepository().nameResolver(), node.cloudAccount(), nodeRepository().zone()); + IP.verifyDns(hostname, ipAddress, hostType, nodeRepository().nameResolver(), cloudAccount, nodeRepository().zone()); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java index aa7aac34389..503ac4be86c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/InactiveExpirer.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.maintenance; import com.yahoo.jdisc.Metric; import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.Node.State; +import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.History; @@ -40,9 +42,9 @@ public class InactiveExpirer extends Expirer { @Override protected void expire(List<Node> expired) { - expired.forEach(node -> { - nodeRepository.nodes().deallocate(node, Agent.InactiveExpirer, "Expired by InactiveExpirer"); - }); + nodeRepository.nodes().performOn(NodeList.copyOf(expired), + node -> node.state() == State.inactive && isExpired(node), + (node, lock) -> nodeRepository.nodes().deallocate(node, Agent.InactiveExpirer, "Expired by InactiveExpirer")); } @Override diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java index 6f06a2ac22e..2484f496ece 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/ReservationExpirer.java @@ -25,6 +25,8 @@ public class ReservationExpirer extends Expirer { } @Override - protected void expire(List<Node> expired) { nodeRepository().nodes().deallocate(expired, Agent.ReservationExpirer, "Expired by ReservationExpirer"); } + protected void expire(List<Node> expired) { + nodeRepository().nodes().deallocate(expired, Agent.ReservationExpirer, "Expired by ReservationExpirer"); + } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java index cc7db3c138a..1ff6d2b300d 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/IP.java @@ -113,7 +113,7 @@ public record IP() { * * @throws IllegalArgumentException if there are IP conflicts with existing nodes */ - public static List<Node> verify(List<Node> nodes, LockedNodeList allNodes) { + public static LockedNodeList verify(List<Node> nodes, LockedNodeList allNodes) { NodeList sortedNodes = allNodes.sortedBy(Comparator.comparing(Node::hostname)); for (var node : nodes) { for (var other : sortedNodes) { @@ -135,7 +135,7 @@ public record IP() { other.hostname()); } } - return nodes; + return allNodes.childList(nodes); } /** Returns whether IP address of existing node can be assigned to node */ @@ -152,7 +152,7 @@ public record IP() { } public static Node verify(Node node, LockedNodeList allNodes) { - return verify(List.of(node), allNodes).get(0); + return verify(List.of(node), allNodes).asList().get(0); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java index b10a371e8bd..490e7b9ac33 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java @@ -1,7 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.node; -import com.yahoo.collections.ListMap; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationTransaction; @@ -10,6 +9,7 @@ import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.Zone; +import com.yahoo.time.TimeBudget; import com.yahoo.transaction.Mutex; import com.yahoo.transaction.NestedTransaction; import com.yahoo.vespa.applicationmodel.HostName; @@ -17,6 +17,7 @@ import com.yahoo.vespa.applicationmodel.InfrastructureApplication; import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.NoSuchNodeException; import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.Node.State; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeMutex; import com.yahoo.vespa.hosted.provision.applications.Applications; @@ -31,20 +32,26 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collection; +import java.util.Comparator; import java.util.EnumSet; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; -import java.util.Map; -import java.util.Objects; +import java.util.NavigableSet; import java.util.Optional; import java.util.Set; +import java.util.TreeSet; import java.util.function.BiFunction; +import java.util.function.Function; import java.util.function.Predicate; import java.util.logging.Level; import java.util.logging.Logger; -import java.util.stream.Collectors; -import java.util.stream.Stream; import static com.yahoo.vespa.hosted.provision.restapi.NodePatcher.DROP_DOCUMENTS_REPORT; +import static java.util.Comparator.comparing; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.joining; /** * The nodes in the node repo and their state transitions @@ -148,7 +155,7 @@ public class Nodes { if (existing.isPresent()) throw new IllegalStateException("Cannot add " + node + ": A node with this name already exists"); } - return db.addNodesInState(nodes.asList(), Node.State.reserved, Agent.system); + return db.addNodesInState(nodes, Node.State.reserved, Agent.system); } /** @@ -157,7 +164,8 @@ public class Nodes { * with the history of that node. */ public List<Node> addNodes(List<Node> nodes, Agent agent) { - try (Mutex lock = lockUnallocated()) { + try (NodeMutexes existingNodesLocks = lockAndGetAll(nodes, Optional.empty()); // Locks for any existing nodes we may remove. + Mutex allocationLock = lockUnallocated()) { List<Node> nodesToAdd = new ArrayList<>(); List<Node> nodesToRemove = new ArrayList<>(); for (int i = 0; i < nodes.size(); i++) { @@ -194,7 +202,7 @@ public class Nodes { } NestedTransaction transaction = new NestedTransaction(); db.removeNodes(nodesToRemove, transaction); - List<Node> resultingNodes = db.addNodesInState(IP.Config.verify(nodesToAdd, list(lock)), Node.State.provisioned, agent, transaction); + List<Node> resultingNodes = db.addNodesInState(IP.Config.verify(nodesToAdd, list(allocationLock)), Node.State.provisioned, agent, transaction); transaction.commit(); return resultingNodes; } @@ -218,7 +226,7 @@ public class Nodes { } /** Activate nodes. This method does <b>not</b> lock the node repository. */ - public List<Node> activate(List<Node> nodes, NestedTransaction transaction) { + public List<Node> activate(List<Node> nodes, ApplicationTransaction transaction) { return db.writeTo(Node.State.active, nodes, Agent.application, Optional.empty(), transaction); } @@ -229,8 +237,7 @@ public class Nodes { * @param reusable move the node directly to {@link Node.State#dirty} after removal */ public void setRemovable(NodeList nodes, boolean reusable) { - performOn(nodes, (node, mutex) -> write(node.with(node.allocation().get().removable(true, reusable)), - mutex)); + performOn(nodes, (node, mutex) -> write(node.with(node.allocation().get().removable(true, reusable)), mutex)); } /** @@ -239,7 +246,7 @@ public class Nodes { */ public List<Node> deactivate(List<Node> nodes, ApplicationTransaction transaction) { if ( ! zone.environment().isProduction() || zone.system().isCd()) - return deallocate(nodes, Agent.application, "Deactivated by application", transaction.nested()); + return deallocate(nodes, Agent.application, "Deactivated by application", transaction); NodeList nodeList = NodeList.copyOf(nodes); NodeList stateless = nodeList.stateless(); @@ -247,9 +254,9 @@ public class Nodes { NodeList statefulToInactive = stateful.not().reusable(); NodeList statefulToDirty = stateful.reusable(); List<Node> written = new ArrayList<>(); - written.addAll(deallocate(stateless.asList(), Agent.application, "Deactivated by application", transaction.nested())); - written.addAll(deallocate(statefulToDirty.asList(), Agent.application, "Deactivated by application (recycled)", transaction.nested())); - written.addAll(db.writeTo(Node.State.inactive, statefulToInactive.asList(), Agent.application, Optional.empty(), transaction.nested())); + written.addAll(deallocate(stateless.asList(), Agent.application, "Deactivated by application", transaction)); + written.addAll(deallocate(statefulToDirty.asList(), Agent.application, "Deactivated by application (recycled)", transaction)); + written.addAll(db.writeTo(Node.State.inactive, statefulToInactive.asList(), Agent.application, Optional.empty(), transaction)); return written; } @@ -258,21 +265,9 @@ public class Nodes { * transaction commits. */ public List<Node> fail(List<Node> nodes, ApplicationTransaction transaction) { - return fail(nodes, Agent.application, "Failed by application", transaction.nested()); - } - - public List<Node> fail(List<Node> nodes, Agent agent, String reason) { - NestedTransaction transaction = new NestedTransaction(); - nodes = fail(nodes, agent, reason, transaction); - transaction.commit(); - return nodes; - } - - private List<Node> fail(List<Node> nodes, Agent agent, String reason, NestedTransaction transaction) { - nodes = nodes.stream() - .map(n -> n.withWantToFail(false, agent, clock.instant())) - .toList(); - return db.writeTo(Node.State.failed, nodes, agent, Optional.of(reason), transaction); + return db.writeTo(Node.State.failed, + nodes.stream().map(n -> n.withWantToFail(false, Agent.application, clock.instant())).toList(), + Agent.application, Optional.of("Failed by application"), transaction); } /** Move nodes to the dirty state */ @@ -282,40 +277,48 @@ public class Nodes { public List<Node> deallocateRecursively(String hostname, Agent agent, String reason) { Node nodeToDirty = node(hostname).orElseThrow(() -> new NoSuchNodeException("Could not deallocate " + hostname + ": Node not found")); - - List<Node> nodesToDirty = - (nodeToDirty.type().isHost() ? - Stream.concat(list().childrenOf(hostname).asList().stream(), Stream.of(nodeToDirty)) : - Stream.of(nodeToDirty)).filter(node -> node.state() != Node.State.dirty).toList(); - List<String> hostnamesNotAllowedToDirty = nodesToDirty.stream() - .filter(node -> node.state() != Node.State.provisioned) - .filter(node -> node.state() != Node.State.failed) - .filter(node -> node.state() != Node.State.parked) - .filter(node -> node.state() != Node.State.breakfixed) - .map(Node::hostname).toList(); - if ( ! hostnamesNotAllowedToDirty.isEmpty()) - illegal("Could not deallocate " + nodeToDirty + ": " + - hostnamesNotAllowedToDirty + " are not in states [provisioned, failed, parked, breakfixed]"); - - return nodesToDirty.stream().map(node -> deallocate(node, agent, reason)).toList(); + List<Node> nodesToDirty = new ArrayList<>(); + try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) { + for (NodeMutex child : locked.children()) + if (child.node().state() != Node.State.dirty) + nodesToDirty.add(child.node()); + + if (locked.parent().node().state() != State.dirty) + nodesToDirty.add(locked.parent().node()); + + List<String> hostnamesNotAllowedToDirty = nodesToDirty.stream() + .filter(node -> node.state() != Node.State.provisioned) + .filter(node -> node.state() != Node.State.failed) + .filter(node -> node.state() != Node.State.parked) + .filter(node -> node.state() != Node.State.breakfixed) + .map(Node::hostname).toList(); + if ( ! hostnamesNotAllowedToDirty.isEmpty()) + illegal("Could not deallocate " + nodeToDirty + ": " + + hostnamesNotAllowedToDirty + " are not in states [provisioned, failed, parked, breakfixed]"); + + return nodesToDirty.stream().map(node -> deallocate(node, agent, reason)).toList(); + } } /** - * Set a node dirty or parked, allowed if it is in the provisioned, inactive, failed or parked state. + * Set a node dirty or parked, allowed if it is in the provisioned, inactive, failed or parked state. * Use this to clean newly provisioned nodes or to recycle failed nodes which have been repaired or put on hold. */ public Node deallocate(Node node, Agent agent, String reason) { - NestedTransaction transaction = new NestedTransaction(); - Node deallocated = deallocate(node, agent, reason, transaction); - transaction.commit(); - return deallocated; + try (NodeMutex locked = lockAndGetRequired(node)) { + NestedTransaction transaction = new NestedTransaction(); + Node deallocated = deallocate(locked.node(), agent, reason, transaction); + transaction.commit(); + return deallocated; + } } - public List<Node> deallocate(List<Node> nodes, Agent agent, String reason, NestedTransaction transaction) { - return nodes.stream().map(node -> deallocate(node, agent, reason, transaction)).toList(); + public List<Node> deallocate(List<Node> nodes, Agent agent, String reason, ApplicationTransaction transaction) { + return nodes.stream().map(node -> deallocate(node, agent, reason, transaction.nested())).toList(); } - public Node deallocate(Node node, Agent agent, String reason, NestedTransaction transaction) { + // Be sure to hold the right lock! + private Node deallocate(Node node, Agent agent, String reason, NestedTransaction transaction) { if (parkOnDeallocationOf(node, agent)) { return park(node.hostname(), false, agent, reason, transaction); } else { @@ -339,7 +342,9 @@ public class Nodes { } public Node fail(String hostname, boolean forceDeprovision, Agent agent, String reason) { - return move(hostname, Node.State.failed, agent, forceDeprovision, Optional.of(reason)); + try (NodeMutex lock = lockAndGetRequired(hostname)) { + return move(hostname, Node.State.failed, agent, forceDeprovision, Optional.of(reason), lock); + } } /** @@ -350,14 +355,16 @@ public class Nodes { * @return all the nodes that were changed by this request */ public List<Node> failOrMarkRecursively(String hostname, Agent agent, String reason) { - NodeList children = list().childrenOf(hostname); - List<Node> changed = performOn(children, (node, lock) -> failOrMark(node, agent, reason, lock)); - - if (children.state(Node.State.active).isEmpty()) - changed.add(move(hostname, Node.State.failed, agent, false, Optional.of(reason))); - else - changed.addAll(performOn(NodeList.of(node(hostname).orElseThrow()), (node, lock) -> failOrMark(node, agent, reason, lock))); + List<Node> changed = new ArrayList<>(); + try (RecursiveNodeMutexes nodes = lockAndGetRecursively(hostname, Optional.empty())) { + for (NodeMutex child : nodes.children()) + changed.add(failOrMark(child.node(), agent, reason, child)); + if (changed.stream().noneMatch(child -> child.state() == Node.State.active)) + changed.add(move(hostname, Node.State.failed, agent, false, Optional.of(reason), nodes.parent())); + else + changed.add(failOrMark(nodes.parent().node(), agent, reason, nodes.parent())); + } return changed; } @@ -367,12 +374,14 @@ public class Nodes { write(node, lock); return node; } else { - return move(node.hostname(), Node.State.failed, agent, false, Optional.of(reason)); + return move(node.hostname(), Node.State.failed, agent, false, Optional.of(reason), lock); } } /** Update IP config for nodes in given config */ public void setIpConfig(HostIpConfig hostIpConfig) { + // Ideally this should hold the unallocated lock over the entire method, but unallocated lock must be taken + // after the application lock, making this impossible Predicate<Node> nodeInConfig = (node) -> hostIpConfig.contains(node.hostname()); performOn(nodeInConfig, (node, lock) -> { IP.Config ipConfig = hostIpConfig.require(node.hostname()); @@ -387,10 +396,12 @@ public class Nodes { * @throws NoSuchNodeException if the node is not found */ public Node park(String hostname, boolean forceDeprovision, Agent agent, String reason) { - NestedTransaction transaction = new NestedTransaction(); - Node parked = park(hostname, forceDeprovision, agent, reason, transaction); - transaction.commit(); - return parked; + try (NodeMutex locked = lockAndGetRequired(hostname)) { + NestedTransaction transaction = new NestedTransaction(); + Node parked = park(hostname, forceDeprovision, agent, reason, transaction); + transaction.commit(); + return parked; + } } private Node park(String hostname, boolean forceDeprovision, Agent agent, String reason, NestedTransaction transaction) { @@ -413,36 +424,38 @@ public class Nodes { * @throws NoSuchNodeException if the node is not found */ public Node reactivate(String hostname, Agent agent, String reason) { - return move(hostname, Node.State.active, agent, false, Optional.of(reason)); + try (NodeMutex lock = lockAndGetRequired(hostname)) { + return move(hostname, Node.State.active, agent, false, Optional.of(reason), lock); + } } /** * Moves a host to breakfixed state, removing any children. */ public List<Node> breakfixRecursively(String hostname, Agent agent, String reason) { - Node node = requireNode(hostname); - try (Mutex lock = lockUnallocated()) { - requireBreakfixable(node); + try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) { + requireBreakfixable(locked.parent().node()); NestedTransaction transaction = new NestedTransaction(); - List<Node> removed = removeChildren(node, false, transaction); - removed.add(move(node.hostname(), Node.State.breakfixed, agent, false, Optional.of(reason), transaction)); + removeChildren(locked, false, transaction); + move(hostname, Node.State.breakfixed, agent, false, Optional.of(reason), transaction); transaction.commit(); - return removed; + return locked.nodes().nodes().stream().map(NodeMutex::node).toList(); } } private List<Node> moveRecursively(String hostname, Node.State toState, Agent agent, Optional<String> reason) { - NestedTransaction transaction = new NestedTransaction(); - List<Node> moved = list().childrenOf(hostname).asList().stream() - .map(child -> move(child.hostname(), toState, agent, false, reason, transaction)) - .collect(Collectors.toCollection(ArrayList::new)); - moved.add(move(hostname, toState, agent, false, reason, transaction)); - transaction.commit(); - return moved; + try (RecursiveNodeMutexes locked = lockAndGetRecursively(hostname, Optional.empty())) { + List<Node> moved = new ArrayList<>(); + NestedTransaction transaction = new NestedTransaction(); + for (NodeMutex node : locked.nodes().nodes()) + moved.add(move(node.node().hostname(), toState, agent, false, reason, transaction)); + transaction.commit(); + return moved; + } } /** Move a node to given state */ - private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason) { + private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason, Mutex lock) { NestedTransaction transaction = new NestedTransaction(); Node moved = move(hostname, toState, agent, forceDeprovision, reason, transaction); transaction.commit(); @@ -451,8 +464,7 @@ public class Nodes { /** Move a node to given state as part of a transaction */ private Node move(String hostname, Node.State toState, Agent agent, boolean forceDeprovision, Optional<String> reason, NestedTransaction transaction) { - // TODO: Work out a safe lock acquisition strategy for moves. Lock is only held while adding operations to - // transaction, but lock must also be held while committing + // TODO: This lock is already held here, but we still need to read the node. Perhaps change to requireNode(hostname) later. try (NodeMutex lock = lockAndGetRequired(hostname)) { Node node = lock.node(); if (toState == Node.State.active) { @@ -521,17 +533,18 @@ public class Nodes { } public List<Node> removeRecursively(Node node, boolean force) { - try (Mutex lock = lockUnallocated()) { - requireRemovable(node, false, force); + try (RecursiveNodeMutexes locked = lockAndGetRecursively(node.hostname(), Optional.empty())) { + requireRemovable(locked.parent().node(), false, force); NestedTransaction transaction = new NestedTransaction(); List<Node> removed; - if (!node.type().isHost()) { + if ( ! node.type().isHost()) { removed = List.of(node); db.removeNodes(removed, transaction); - } else { - removed = removeChildren(node, force, transaction); + } + else { + removeChildren(locked, force, transaction); move(node.hostname(), Node.State.deprovisioned, Agent.system, false, Optional.empty(), transaction); - removed.add(node); + removed = locked.nodes().nodes().stream().map(NodeMutex::node).toList(); } transaction.commit(); return removed; @@ -540,20 +553,22 @@ public class Nodes { /** Forgets a deprovisioned node. This removes all traces of the node in the node repository. */ public void forget(Node node) { - if (node.state() != Node.State.deprovisioned) - throw new IllegalArgumentException(node + " must be deprovisioned before it can be forgotten"); - if (node.status().wantToRebuild()) - throw new IllegalArgumentException(node + " is rebuilding and cannot be forgotten"); - NestedTransaction transaction = new NestedTransaction(); - db.removeNodes(List.of(node), transaction); - transaction.commit(); + try (NodeMutex locked = lockAndGetRequired(node.hostname())) { + if (node.state() != Node.State.deprovisioned) + throw new IllegalArgumentException(node + " must be deprovisioned before it can be forgotten"); + if (node.status().wantToRebuild()) + throw new IllegalArgumentException(node + " is rebuilding and cannot be forgotten"); + NestedTransaction transaction = new NestedTransaction(); + db.removeNodes(List.of(node), transaction); + transaction.commit(); + } } - private List<Node> removeChildren(Node node, boolean force, NestedTransaction transaction) { - List<Node> children = list().childrenOf(node).asList(); + private void removeChildren(RecursiveNodeMutexes nodes, boolean force, NestedTransaction transaction) { + if (nodes.children().isEmpty()) return; + List<Node> children = nodes.children().stream().map(NodeMutex::node).toList(); children.forEach(child -> requireRemovable(child, true, force)); db.removeNodes(children, transaction); - return new ArrayList<>(children); } /** @@ -715,8 +730,8 @@ public class Nodes { return db.writeTo(nodes, Agent.system, Optional.empty()); } - private List<Node> performOn(Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) { - return performOn(list().matching(filter), action); + public List<Node> performOn(Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) { + return performOn(list(), filter, action); } /** @@ -725,35 +740,33 @@ public class Nodes { * @param action the action to perform * @return the set of nodes on which the action was performed, as they became as a result of the operation */ - private List<Node> performOn(NodeList nodes, BiFunction<Node, Mutex, Node> action) { - List<Node> unallocatedNodes = new ArrayList<>(); - ListMap<ApplicationId, Node> allocatedNodes = new ListMap<>(); + public List<Node> performOn(NodeList nodes, BiFunction<Node, Mutex, Node> action) { + return performOn(nodes, __ -> true, action); + } - // Group matching nodes by the lock needed - for (Node node : nodes) { - Optional<ApplicationId> applicationId = applicationIdForLock(node); - if (applicationId.isPresent()) - allocatedNodes.put(applicationId.get(), node); - else - unallocatedNodes.add(node); - } + public List<Node> performOn(NodeList nodes, Predicate<Node> filter, BiFunction<Node, Mutex, Node> action) { + List<Node> resultingNodes = new ArrayList<>(); + nodes.stream().collect(groupingBy(Nodes::applicationIdForLock)) + .forEach((applicationId, nodeList) -> { // Grouped only to reduce number of lock acquire/release cycles. + try (NodeMutexes locked = lockAndGetAll(nodeList, Optional.empty())) { + for (NodeMutex node : locked.nodes()) + if (filter.test(node.node())) + resultingNodes.add(action.apply(node.node(), node)); + } + }); + return resultingNodes; + } + + public List<Node> performOnRecursively(NodeList parents, Predicate<RecursiveNodeMutexes> filter, Function<RecursiveNodeMutexes, List<Node>> action) { + for (Node node : parents) + if (node.parentHostname().isPresent()) + throw new IllegalArgumentException(node + " is not a parent host"); - // Perform operation while holding appropriate lock List<Node> resultingNodes = new ArrayList<>(); - try (Mutex lock = lockUnallocated()) { - for (Node node : unallocatedNodes) { - Optional<Node> currentNode = db.readNode(node.hostname()); // Re-read while holding lock - if (currentNode.isEmpty()) continue; - resultingNodes.add(action.apply(currentNode.get(), lock)); - } - } - for (Map.Entry<ApplicationId, List<Node>> applicationNodes : allocatedNodes.entrySet()) { - try (Mutex lock = applications.lock(applicationNodes.getKey())) { - for (Node node : applicationNodes.getValue()) { - Optional<Node> currentNode = db.readNode(node.hostname()); // Re-read while holding lock - if (currentNode.isEmpty()) continue; - resultingNodes.add(action.apply(currentNode.get(), lock)); - } + for (Node parent : parents) { + try (RecursiveNodeMutexes locked = lockAndGetRecursively(parent.hostname(), Optional.empty())) { + if (filter.test(locked)) + resultingNodes.addAll(action.apply(locked)); } } return resultingNodes; @@ -816,9 +829,7 @@ public class Nodes { return Optional.empty(); } - if (node.type() != NodeType.tenant || - Objects.equals(freshNode.get().allocation().map(Allocation::owner), - staleNode.allocation().map(Allocation::owner))) { + if (applicationIdForLock(freshNode.get()).equals(applicationIdForLock(staleNode))) { NodeMutex nodeMutex = new NodeMutex(freshNode.get(), lockToClose); lockToClose = null; return Optional.of(nodeMutex); @@ -879,6 +890,168 @@ public class Nodes { return node(hostname).orElseThrow(() -> new NoSuchNodeException("No node with hostname '" + hostname + "'")); } + /** + * Locks the children of the given node, the node itself, and finally takes the unallocated lock. + * <br> + * When taking multiple locks, it's crucial that we always take them in the same order, to avoid deadlocks. + * We want to take the most contended locks last, so that we don't block other operations for longer than necessary. + * This method does that, by first taking the locks for any children the given node may have, and then the node itself. + * (This is enforced by taking host locks after tenant node locks, in {@link #lockAndGetAll(Collection, Optional)}.) + * Finally, the allocation lock is taken, to ensure no new children are added while we hold this snapshot. + * Unfortunately, since that lock is taken last, we may detect new nodes after taking it, and then we have to retry. + * Closing the returned {@link RecursiveNodeMutexes} will release all the locks, and the locks should not be closed elsewhere. + */ + public RecursiveNodeMutexes lockAndGetRecursively(String hostname, Optional<Duration> timeout) { + TimeBudget budget = TimeBudget.fromNow(clock, timeout.orElse(Duration.ofMinutes(2))); + Set<Node> children = new HashSet<>(list().childrenOf(hostname).asList()); + Optional<Node> node = node(hostname); + + int attempts = 5; // We'll retry locking the whole list of children this many times, in case new children appear. + for (int attempt = 0; attempt < attempts; attempt++) { + NodeMutexes mutexes = null; + Mutex unallocatedLock = null; + try { + // First, we lock all the children, and the host; then we take the allocation lock to ensure our snapshot is valid. + List<Node> nodes = new ArrayList<>(children.size() + 1); + nodes.addAll(children); + node.ifPresent(nodes::add); + mutexes = lockAndGetAll(nodes, budget.timeLeftOrThrow()); + unallocatedLock = db.lockInactive(budget.timeLeftOrThrow().get()); + RecursiveNodeMutexes recursive = new RecursiveNodeMutexes(hostname, mutexes, unallocatedLock); + Set<Node> freshChildren = list().childrenOf(hostname).asSet(); + Optional<Node> freshNode = recursive.parent.map(NodeMutex::node); + if (children.equals(freshChildren) && node.equals(freshNode)) { + // No new nodes have appeared, and none will now, so we have a consistent snapshot. + if (node.isEmpty() && ! children.isEmpty()) + throw new IllegalStateException("node '" + hostname + "' was not found, but it has children: " + children); + + mutexes = null; + unallocatedLock = null; + return recursive; + } + else { + // New nodes have appeared, so we need to let go of the locks and try again with the new set of nodes. + children = freshChildren; + node = freshNode; + } + } + finally { + if (unallocatedLock != null) unallocatedLock.close(); + if (mutexes != null) mutexes.close(); + } + } + throw new IllegalStateException("giving up (after " + attempts + " attempts) fetching an up to " + + "date recursive node set under lock for node " + hostname); + } + + /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes required. */ + public NodeMutexes lockAndRequireAll(Collection<Node> nodes, Optional<Duration> timeout) { + return lockAndGetAll(nodes, timeout, true); + } + + /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes acquired. */ + public NodeMutexes lockAndGetAll(Collection<Node> nodes, Optional<Duration> timeout) { + return lockAndGetAll(nodes, timeout, false); + } + + /** Locks all nodes in the given list, in a universal order, and returns the locks and nodes. */ + private NodeMutexes lockAndGetAll(Collection<Node> nodes, Optional<Duration> timeout, boolean required) { + TimeBudget budget = TimeBudget.fromNow(clock, timeout.orElse(Duration.ofMinutes(2))); + Comparator<Node> universalOrder = (a, b) -> { + Optional<ApplicationId> idA = applicationIdForLock(a); + Optional<ApplicationId> idB = applicationIdForLock(b); + if (idA.isPresent() != idB.isPresent()) return idA.isPresent() ? -1 : 1; // Allocated nodes first. + if (a.type() != b.type()) return a.type().compareTo(b.type()); // Tenant nodes first among those. + if ( ! idA.equals(idB)) return idA.get().compareTo(idB.get()); // Sort primarily by tenant owner id. + return a.hostname().compareTo(b.hostname()); // Sort secondarily by hostname. + }; + NavigableSet<NodeMutex> locked = new TreeSet<>(comparing(NodeMutex::node, universalOrder)); + NavigableSet<Node> unlocked = new TreeSet<>(universalOrder); + unlocked.addAll(nodes); + try { + int attempts = 10; // We'll accept getting the wrong lock at most this many times before giving up. + for (int attempt = 0; attempt < attempts; ) { + if (unlocked.isEmpty()) { + NodeMutexes mutexes = new NodeMutexes(List.copyOf(locked)); + locked.clear(); + return mutexes; + } + + // If the first node is now earlier in lock order than some other locks we have, we need to close those and re-acquire them. + Node next = unlocked.pollFirst(); + Set<NodeMutex> outOfOrder = locked.tailSet(new NodeMutex(next, () -> { }), false); + NodeMutexes.close(outOfOrder.iterator()); + for (NodeMutex node : outOfOrder) unlocked.add(node.node()); + outOfOrder.clear(); + + Mutex lock = lock(next, budget.timeLeftOrThrow()); + try { + Optional<Node> fresh = node(next.hostname()); + if (fresh.isEmpty()) { + if (required) throw new NoSuchNodeException("No node with hostname '" + next.hostname() + "'"); + continue; // Node is gone; skip to close lock. + } + + if (applicationIdForLock(fresh.get()).equals(applicationIdForLock(next))) { + // We held the right lock, so this node is ours now. + locked.add(new NodeMutex(fresh.get(), lock)); + lock = null; + } + else { + // We held the wrong lock, and need to try again. + ++attempt; + unlocked.add(fresh.get()); + } + } + finally { + // If we didn't hold the right lock, we must close the wrong one before we continue. + if (lock != null) lock.close(); + } + } + throw new IllegalStateException("giving up (after " + attempts + " extra attempts) to lock nodes: " + + nodes.stream().map(Node::hostname).collect(joining(", "))); + } + finally { + // If we didn't manage to lock all nodes, we must close the ones we did lock before we throw. + NodeMutexes.close(locked.iterator()); + } + } + + /** A node with their locks, acquired in a universal order. */ + public record NodeMutexes(List<NodeMutex> nodes) implements AutoCloseable { + @Override public void close() { close(nodes.iterator()); } + private static void close(Iterator<NodeMutex> nodes) { + if (nodes.hasNext()) try (NodeMutex node = nodes.next()) { close(nodes); } + } + } + + /** A parent node, all its children, their locks acquired in a universal order, and then the unallocated lock. */ + public static class RecursiveNodeMutexes implements AutoCloseable { + + private final String hostname; + private final NodeMutexes nodes; + private final Mutex unallocatedLock; + private final List<NodeMutex> children; + private final Optional<NodeMutex> parent; + + public RecursiveNodeMutexes(String hostname, NodeMutexes nodes, Mutex unallocatedLock) { + this.hostname = hostname; + this.nodes = nodes; + this.unallocatedLock = unallocatedLock; + this.children = nodes.nodes().stream().filter(node -> ! node.node().hostname().equals(hostname)).toList(); + this.parent = nodes.nodes().stream().filter(node -> node.node().hostname().equals(hostname)).findFirst(); + } + + /** Any children of the node. */ + public List<NodeMutex> children() { return children; } + /** The node itself, or throws if the node was not found. */ + public NodeMutex parent() { return parent.orElseThrow(() -> new NoSuchNodeException("No node with hostname '" + hostname + "'")); } + /** Empty if the node was not found, or the node, and any children. */ + public NodeMutexes nodes() { return nodes; } + /** Closes the allocation lock, and all the node locks. */ + @Override public void close() { try (nodes; unallocatedLock) { } } + } + /** Returns the application ID that should be used for locking when modifying this node */ private static Optional<ApplicationId> applicationIdForLock(Node node) { return switch (node.type()) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java index fc008b7b9dc..037338cb2ed 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java @@ -18,6 +18,7 @@ import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.curator.recipes.CuratorCounter; import com.yahoo.vespa.curator.transaction.CuratorOperations; import com.yahoo.vespa.curator.transaction.CuratorTransaction; +import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.applications.Application; import com.yahoo.vespa.hosted.provision.archive.ArchiveUris; @@ -105,7 +106,7 @@ public class CuratorDb { } /** Adds a set of nodes. Rollbacks/fails transaction if any node is not in the expected state. */ - public List<Node> addNodesInState(List<Node> nodes, Node.State expectedState, Agent agent, NestedTransaction transaction) { + public List<Node> addNodesInState(LockedNodeList nodes, Node.State expectedState, Agent agent, NestedTransaction transaction) { CuratorTransaction curatorTransaction = db.newCuratorTransactionIn(transaction); for (Node node : nodes) { if (node.state() != expectedState) @@ -116,10 +117,10 @@ public class CuratorDb { curatorTransaction.add(CuratorOperations.create(nodePath(node).getAbsolute(), serialized)); } transaction.onCommitted(() -> nodes.forEach(node -> log.log(Level.INFO, "Added " + node))); - return nodes; + return nodes.asList(); } - public List<Node> addNodesInState(List<Node> nodes, Node.State expectedState, Agent agent) { + public List<Node> addNodesInState(LockedNodeList nodes, Node.State expectedState, Agent agent) { NestedTransaction transaction = new NestedTransaction(); List<Node> writtenNodes = addNodesInState(nodes, expectedState, agent, transaction); transaction.commit(); @@ -175,6 +176,7 @@ public class CuratorDb { return writtenNodes; } } + public Node writeTo(Node.State toState, Node node, Agent agent, Optional<String> reason) { return writeTo(toState, Collections.singletonList(node), agent, reason).get(0); } @@ -192,6 +194,12 @@ public class CuratorDb { */ public List<Node> writeTo(Node.State toState, List<Node> nodes, Agent agent, Optional<String> reason, + ApplicationTransaction transaction) { + return writeTo(toState, nodes, agent, reason, transaction.nested()); + } + + public List<Node> writeTo(Node.State toState, List<Node> nodes, + Agent agent, Optional<String> reason, NestedTransaction transaction) { if (nodes.isEmpty()) return nodes; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java index caf936e8aeb..c25f33bc8c2 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Activator.java @@ -88,7 +88,7 @@ class Activator { NodeList activeToRemove = oldActive.matching(node -> ! hostnames.contains(node.hostname())); remove(activeToRemove, transaction); // TODO: Pass activation time in this call and next line - nodeRepository.nodes().activate(newActive.asList(), transaction.nested()); // activate also continued active to update node state + nodeRepository.nodes().activate(newActive.asList(), transaction); // activate also continued active to update node state rememberResourceChange(transaction, generation, activationTime, oldActive.not().retired(), diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java index 2bc5a0719d9..2e9cca21052 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/FlavorConfigBuilder.java @@ -5,12 +5,18 @@ import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provisioning.FlavorsConfig; +import static com.yahoo.config.provision.Flavor.Type.BARE_METAL; +import static com.yahoo.config.provision.Flavor.Type.DOCKER_CONTAINER; import static com.yahoo.config.provision.NodeResources.Architecture; +import static com.yahoo.config.provision.NodeResources.Architecture.arm64; +import static com.yahoo.config.provision.NodeResources.Architecture.x86_64; /** * Simplifies creation of a node-repository config containing flavors. * This is needed because the config builder API is inconvenient. * + * Note: Flavors added will have fast disk and remote storage unless explicitly specified. + * * @author bratseth */ public class FlavorConfigBuilder { @@ -27,7 +33,7 @@ public class FlavorConfigBuilder { double disk, double bandwidth, Flavor.Type type) { - return addFlavor(flavorName, cpu, mem, disk, bandwidth, true, true, type, Architecture.x86_64, 0, 0); + return addFlavor(flavorName, cpu, mem, disk, bandwidth, true, true, type, x86_64, 0, 0); } public FlavorsConfig.Flavor.Builder addFlavor(String flavorName, @@ -69,31 +75,20 @@ public class FlavorConfigBuilder { /** Convenience method which creates a node flavors instance from a list of flavor names */ public static NodeFlavors createDummies(String... flavors) { - - FlavorConfigBuilder flavorConfigBuilder = new FlavorConfigBuilder(); + FlavorConfigBuilder builder = new FlavorConfigBuilder(); for (String flavorName : flavors) { - if (flavorName.equals("docker")) - flavorConfigBuilder.addFlavor(flavorName, 1., 30., 20., 1.5, Flavor.Type.DOCKER_CONTAINER); - else if (flavorName.equals("docker2")) - flavorConfigBuilder.addFlavor(flavorName, 2., 40., 40., 0.5, Flavor.Type.DOCKER_CONTAINER); - else if (flavorName.equals("host")) - flavorConfigBuilder.addFlavor(flavorName, 7., 100., 120., 5, Flavor.Type.BARE_METAL); - else if (flavorName.equals("host2")) - flavorConfigBuilder.addFlavor(flavorName, 16, 24, 100, 1, Flavor.Type.BARE_METAL); - else if (flavorName.equals("host3")) - flavorConfigBuilder.addFlavor(flavorName, 24, 64, 100, 10, Flavor.Type.BARE_METAL); - else if (flavorName.equals("host4")) - flavorConfigBuilder.addFlavor(flavorName, 48, 128, 1000, 10, Flavor.Type.BARE_METAL); - else if (flavorName.equals("devhost")) - flavorConfigBuilder.addFlavor(flavorName, 4., 80., 100, 10, Flavor.Type.BARE_METAL); - else if (flavorName.equals("arm64")) - flavorConfigBuilder.addFlavor(flavorName,2., 30., 20., 3, Flavor.Type.BARE_METAL, Architecture.arm64); - else if (flavorName.equals("gpu")) - flavorConfigBuilder.addFlavor(flavorName,4, 16, 125, 10, true, false, Flavor.Type.BARE_METAL, Architecture.x86_64, 1, 16); - else - flavorConfigBuilder.addFlavor(flavorName, 1., 30., 20., 3, Flavor.Type.BARE_METAL); + switch (flavorName) { + case "docker" -> builder.addFlavor(flavorName, 1., 30., 20., 1.5, DOCKER_CONTAINER); + case "host" -> builder.addFlavor(flavorName, 7., 100., 120., 5, BARE_METAL); + case "host2" -> builder.addFlavor(flavorName, 16, 24, 100, 1, BARE_METAL); + case "host3" -> builder.addFlavor(flavorName, 24, 64, 100, 10, BARE_METAL); + case "host4" -> builder.addFlavor(flavorName, 48, 128, 1000, 10, BARE_METAL); + case "arm64" -> builder.addFlavor(flavorName, 2., 30., 20., 3, BARE_METAL, arm64); + case "gpu" -> builder.addFlavor(flavorName, 4, 16, 125, 10, true, false, BARE_METAL, x86_64, 1, 16); + default -> builder.addFlavor(flavorName, 1., 30., 20., 3, BARE_METAL); + } } - return new NodeFlavors(flavorConfigBuilder.build()); + return new NodeFlavors(builder.build()); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java index 397eb4d7af9..dd838375a59 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/HostProvisioner.java @@ -7,7 +7,6 @@ import com.yahoo.config.provision.NodeAllocationException; import com.yahoo.vespa.hosted.provision.Node; import java.util.List; -import java.util.Set; import java.util.function.Consumer; /** @@ -46,12 +45,11 @@ public interface HostProvisioner { * Continue provisioning of given list of Nodes. * * @param host the host to provision - * @param children list of all the nodes that run on the given host * @return IP config for the provisioned host and its children * @throws FatalProvisioningException if the provisioning has irrecoverably failed and the input nodes * should be deleted from node-repo. */ - HostIpConfig provision(Node host, Set<Node> children) throws FatalProvisioningException; + HostIpConfig provision(Node host) throws FatalProvisioningException; /** * Deprovisions a given host and resources associated with it and its children (such as DNS entries). diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java index 3d5987cd04d..bc10a97068e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockHostProvisioner.java @@ -97,15 +97,13 @@ public class MockHostProvisioner implements HostProvisioner { } @Override - public HostIpConfig provision(Node host, Set<Node> children) throws FatalProvisioningException { + public HostIpConfig provision(Node host) throws FatalProvisioningException { if (behaviour(Behaviour.failProvisioning)) throw new FatalProvisioningException("Failed to provision node(s)"); if (host.state() != Node.State.provisioned) throw new IllegalStateException("Host to provision must be in " + Node.State.provisioned); Map<String, IP.Config> result = new HashMap<>(); result.put(host.hostname(), createIpConfig(host)); - for (var child : children) { - if (child.state() != Node.State.reserved) throw new IllegalStateException("Child to provisioned must be in " + Node.State.reserved); - result.put(child.hostname(), createIpConfig(child)); - } + host.ipConfig().pool().hostnames().forEach(hostname -> + result.put(hostname.value(), IP.Config.ofEmptyPool(nameResolver.resolveAll(hostname.value())))); return new HostIpConfig(result); } @@ -199,8 +197,6 @@ public class MockHostProvisioner implements HostProvisioner { return this; } - public Optional<Flavor> getHostFlavor(ClusterSpec.Type type) { return Optional.ofNullable(hostFlavors.get(type)); } - public MockHostProvisioner addEvent(HostEvent event) { hostEvents.add(event); return this; @@ -230,18 +226,17 @@ public class MockHostProvisioner implements HostProvisioner { } public IP.Config createIpConfig(Node node) { - if (!node.type().isHost()) { - return node.ipConfig().withPrimary(nameResolver.resolveAll(node.hostname())); - } + if (!node.type().isHost()) throw new IllegalArgumentException("Node " + node + " is not a host"); int hostIndex = Integer.parseInt(node.hostname().replaceAll("^[a-z]+|-\\d+$", "")); Set<String> addresses = Set.of("::" + hostIndex + ":0"); Set<String> ipAddressPool = new HashSet<>(); if (!behaviour(Behaviour.failDnsUpdate)) { nameResolver.addRecord(node.hostname(), addresses.iterator().next()); - for (int i = 1; i <= 2; i++) { - String ip = "::" + hostIndex + ":" + i; + int i = 1; + for (HostName hostName : node.ipConfig().pool().hostnames()) { + String ip = "::" + hostIndex + ":" + i++; ipAddressPool.add(ip); - nameResolver.addRecord(node.hostname() + "-" + i, ip); + nameResolver.addRecord(hostName.value(), ip); } } IP.Pool pool = node.ipConfig().pool().withIpAddresses(ipAddressPool); @@ -250,7 +245,7 @@ public class MockHostProvisioner implements HostProvisioner { public enum Behaviour { - /** Fail call to {@link MockHostProvisioner#provision(com.yahoo.vespa.hosted.provision.Node, java.util.Set)} */ + /** Fail call to {@link MockHostProvisioner#provision(com.yahoo.vespa.hosted.provision.Node)} */ failProvisioning, /** Fail call to {@link MockHostProvisioner#provisionHosts(HostProvisionRequest, Consumer)} */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java index 94cb05d20cc..722dc5ef96c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNameResolver.java @@ -6,7 +6,6 @@ import com.yahoo.vespa.hosted.provision.persistence.NameResolver; import java.net.UnknownHostException; import java.util.Arrays; -import java.util.Collections; import java.util.EnumSet; import java.util.HashMap; import java.util.HashSet; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java index b7d6e0a9dd9..714374ccb8a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java @@ -176,7 +176,7 @@ public class MockNodeRepository extends NodeRepository { .build()); // Ready all nodes, except 7 and 55 - nodes = nodes().addNodes(nodes, Agent.system); + nodes = new ArrayList<>(nodes().addNodes(nodes, Agent.system)); nodes.remove(node7); nodes.remove(node55); nodes = nodes().deallocate(nodes, Agent.system, getClass().getSimpleName()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java index 5a9da1e1c3f..d8ad892e210 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockProvisioner.java @@ -18,6 +18,7 @@ import java.util.List; /** * @author freva */ +@SuppressWarnings("unused") // Injected in container from test code (services.xml) public class MockProvisioner implements Provisioner { @Override diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java index 29ebf1789c0..9c843b3eb01 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/RealDataScenarioTest.java @@ -141,7 +141,7 @@ public class RealDataScenarioTest { if (nodeNext.get()) { String json = input.substring(input.indexOf("{\""), input.lastIndexOf('}') + 1); Node node = nodeSerializer.fromJson(json.getBytes(UTF_8)); - nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system); + nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system); nodeNext.set(false); } else { if (!zkNodePathPattern.matcher(input).matches()) return; diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java index f8ec271ce5f..523feeeb303 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/CapacityCheckerTester.java @@ -23,6 +23,7 @@ import com.yahoo.test.ManualClock; import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.curator.mock.MockCurator; import com.yahoo.vespa.flags.InMemoryFlagSource; +import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.autoscale.MemoryMetricsDb; @@ -201,7 +202,7 @@ public class CapacityCheckerTester { nodeRepository.nodes().addNodes(hostsWithChildren.getOrDefault(tenantHostApp, List.of()), Agent.system); hostsWithChildren.forEach((applicationId, nodes) -> { if (applicationId.equals(tenantHostApp)) return; - nodeRepository.database().addNodesInState(nodes, Node.State.active, Agent.system); + nodeRepository.database().addNodesInState(new LockedNodeList(nodes, () -> { }), Node.State.active, Agent.system); }); nodeRepository.nodes().addNodes(createEmptyHosts(numHosts, numEmptyHosts, emptyHostExcessCapacity, emptyHostExcessIps), Agent.system); @@ -322,9 +323,9 @@ public class CapacityCheckerTester { } } - nodeRepository.database().addNodesInState(hosts, Node.State.active, Agent.system); + nodeRepository.database().addNodesInState(new LockedNodeList(hosts, () -> { }), Node.State.active, Agent.system); nodes.forEach((application, applicationNodes) -> { - nodeRepository.database().addNodesInState(applicationNodes, Node.State.active, Agent.system); + nodeRepository.database().addNodesInState(new LockedNodeList(applicationNodes, () -> { }), Node.State.active, Agent.system); }); updateCapacityChecker(); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java index ddd7413567a..262616d5eac 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java @@ -6,6 +6,7 @@ import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; +import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.Allocation; @@ -45,7 +46,7 @@ public class DirtyExpirerTest { false)) .build(); - tester.nodeRepository().database().addNodesInState(List.of(node), node.state(), Agent.system); + tester.nodeRepository().database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system); Duration expiryTimeout = Duration.ofMinutes(30); DirtyExpirer expirer = new DirtyExpirer(tester.nodeRepository(), expiryTimeout, new TestMetric()); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java index 66d4b67c7c2..c16ed47a216 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/HostCapacityMaintainerTest.java @@ -27,6 +27,7 @@ import com.yahoo.test.ManualClock; import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.flags.PermanentFlags; import com.yahoo.vespa.flags.custom.ClusterCapacity; +import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.Node.State; import com.yahoo.vespa.hosted.provision.NodeList; @@ -64,6 +65,9 @@ import java.util.Set; import java.util.function.Supplier; import java.util.stream.Stream; +import static com.yahoo.config.provision.NodeResources.Architecture.arm64; +import static com.yahoo.config.provision.NodeResources.DiskSpeed.fast; +import static com.yahoo.config.provision.NodeResources.StorageType.remote; import static com.yahoo.vespa.hosted.provision.testutils.MockHostProvisioner.Behaviour; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -77,11 +81,13 @@ import static org.junit.Assert.fail; */ public class HostCapacityMaintainerTest { + private DynamicProvisioningTester tester; + @Test public void finds_nodes_that_need_deprovisioning_without_pre_provisioning() { - var tester = new DynamicProvisioningTester().addInitialNodes(); - assertTrue(tester.nodeRepository.nodes().node("host2").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host3").isPresent()); + tester = new DynamicProvisioningTester().addInitialNodes(); + assertNodeExists("host2"); + assertNodeExists("host3"); tester.maintain(); assertSame(State.deprovisioned, tester.nodeRepository.nodes().node("host2").get().state()); @@ -89,31 +95,30 @@ public class HostCapacityMaintainerTest { @Test public void does_not_deprovision_when_preprovisioning_enabled() { - var tester = new DynamicProvisioningTester().addInitialNodes(); - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), List.of(new ClusterCapacity(1, 1.0, 3.0, 2.0, 1.0, "fast", "local", "x86_64")), ClusterCapacity.class); - Optional<Node> failedHost = tester.nodeRepository.nodes().node("host2"); + tester = new DynamicProvisioningTester().addInitialNodes(); + setPreprovisionCapacityFlag(tester, new ClusterCapacity(1, 1.0, 3.0, 2.0, 1.0, "fast", "remote", "x86_64")); + Optional<Node> failedHost = node("host2"); assertTrue(failedHost.isPresent()); tester.maintain(); - assertSame("Failed host is deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node(failedHost.get().hostname()).get().state()); + assertSame("Failed host is deprovisioned", State.deprovisioned, node(failedHost.get().hostname()).get().state()); assertEquals(1, tester.hostProvisioner.deprovisionedHosts()); } @Test public void provision_deficit_and_deprovision_excess() { - var tester = new DynamicProvisioningTester().addInitialNodes(); - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(2, 48.0, 128.0, 1000.0, 10.0, "fast", "local", "x86_64"), - new ClusterCapacity(1, 16.0, 24.0, 100.0, 1.0, "fast", "local", "x86_64")), - ClusterCapacity.class); + tester = new DynamicProvisioningTester().addInitialNodes(); + setPreprovisionCapacityFlag(tester, + new ClusterCapacity(2, 48.0, 128.0, 1000.0, 10.0, "fast", "remote", "x86_64"), + new ClusterCapacity(1, 16.0, 24.0, 100.0, 1.0, "fast", "remote", "x86_64")); assertEquals(0, tester.hostProvisioner.provisionedHosts().size()); assertEquals(9, tester.nodeRepository.nodes().list().size()); - assertTrue(tester.nodeRepository.nodes().node("host2").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host2-1").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host3").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host100").isEmpty()); - assertTrue(tester.nodeRepository.nodes().node("host101").isEmpty()); + assertNodeExists("host2"); + assertNodeExists("host2-1"); + assertNodeExists("host3"); + assertNodeDoesNotExist("host100"); + assertNodeDoesNotExist("host101"); tester.maintain(); @@ -121,37 +126,35 @@ public class HostCapacityMaintainerTest { assertEquals(2, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10))); NodeList nodesAfter = tester.nodeRepository.nodes().list().not().state(State.deprovisioned); assertEquals(9, nodesAfter.size()); // 2 removed, 2 added - assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node("host2").get().state()); - assertTrue("Node on deprovisioned host removed", tester.nodeRepository.nodes().node("host2-1").isEmpty()); - assertTrue("Host satisfying 16-24-100-1 is kept", tester.nodeRepository.nodes().node("host3").isPresent()); - assertTrue("New 48-128-1000-10 host added", tester.nodeRepository.nodes().node("host100").isPresent()); - assertTrue("New 48-128-1000-10 host added", tester.nodeRepository.nodes().node("host101").isPresent()); + assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, node("host2").get().state()); + assertNodeDoesNotExist("Node on deprovisioned host removed", "host2-1"); + assertNodeExists("Host satisfying 16-24-100-1 is kept", "host3"); + assertNodeExists("New 48-128-1000-10 host added", "host100"); + assertNodeExists("New 48-128-1000-10 host added", "host100"); - Instant deprovisionedAt = tester.nodeRepository.nodes().node("host2").get().history().event(History.Event.Type.deprovisioned).get().at(); + Instant deprovisionedAt = node("host2").get().history().event(History.Event.Type.deprovisioned).get().at(); tester.provisioningTester.clock().advance(Duration.ofSeconds(1)); tester.maintain(); assertEquals("Host moves to deprovisioned once", deprovisionedAt, - tester.nodeRepository.nodes().node("host2").get().history() + node("host2").get().history() .event(History.Event.Type.deprovisioned).get().at()); } @Test public void preprovision_with_shared_host() { - var tester = new DynamicProvisioningTester().addInitialNodes(); + tester = new DynamicProvisioningTester().addInitialNodes(); // Makes provisioned hosts 48-128-1000-10 tester.hostProvisioner.setHostFlavor("host4"); - var clusterCapacity = new ClusterCapacity(2, 1.0, 30.0, 20.0, 3.0, "fast", "local", "x86_64"); - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(clusterCapacity), - ClusterCapacity.class); + var clusterCapacity = new ClusterCapacity(2, 1.0, 30.0, 20.0, 3.0, "fast", "remote", "x86_64"); + setPreprovisionCapacityFlag(tester, clusterCapacity); assertEquals(0, tester.hostProvisioner.provisionedHosts().size()); assertEquals(9, tester.nodeRepository.nodes().list().size()); - assertTrue(tester.nodeRepository.nodes().node("host2").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host2-1").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host3").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host100").isEmpty()); + assertTrue(node("host2").isPresent()); + assertTrue(node("host2-1").isPresent()); + assertTrue(node("host3").isPresent()); + assertTrue(node("host100").isEmpty()); // The first cluster will be allocated to host3 and a new host host100. // host100 will be a large shared host specified above. @@ -163,10 +166,7 @@ public class HostCapacityMaintainerTest { verifyFirstMaintain(tester); // Add a second cluster equal to the first. It should fit on existing host3 and host100. - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(clusterCapacity, - clusterCapacity), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, clusterCapacity, clusterCapacity); tester.maintain(); verifyFirstMaintain(tester); @@ -177,25 +177,23 @@ public class HostCapacityMaintainerTest { // in this test, due to skew), so host3 will be deprovisioned when host101 is provisioned. // host3 is a 24-64-100-10 while host100 is 48-128-1000-10. - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(clusterCapacity, - new ClusterCapacity(2, 24.0, 64.0, 100.0, 1.0, "fast", "local", "x86_64")), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, + clusterCapacity, + new ClusterCapacity(2, 24.0, 64.0, 100.0, 1.0, "fast", "remote", "x86_64")); tester.maintain(); assertEquals(2, tester.hostProvisioner.provisionedHosts().size()); assertEquals(2, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10))); assertEquals(8, tester.nodeRepository.nodes().list().not().state(State.deprovisioned).size()); // 3 removed, 2 added - assertSame("preprovision capacity is prefered on shared hosts", State.deprovisioned, tester.nodeRepository.nodes().node("host3").get().state()); - assertTrue(tester.nodeRepository.nodes().node("host100").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host101").isPresent()); + assertSame("preprovision capacity is prefered on shared hosts", State.deprovisioned, node("host3").get().state()); + assertTrue(node("host100").isPresent()); + assertTrue(node("host101").isPresent()); // If the preprovision capacity is reduced, we should see shared hosts deprovisioned. - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(1, 1.0, 30.0, 20.0, 3.0, "fast", "local", "x86_64")), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, + new ClusterCapacity(1, 1.0, 30.0, 20.0, 3.0, "fast", "remote", "x86_64")); tester.maintain(); @@ -203,45 +201,54 @@ public class HostCapacityMaintainerTest { 1, tester.hostProvisioner.provisionedHosts().size()); assertEquals(1, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10))); assertEquals(7, tester.nodeRepository.nodes().list().not().state(State.deprovisioned).size()); // 4 removed, 2 added - if (tester.nodeRepository.nodes().node("host100").isPresent()) { + if (node("host100").isPresent()) { assertSame("host101 is superfluous and should have been deprovisioned", State.deprovisioned, - tester.nodeRepository.nodes().node("host101").get().state()); + node("host101").get().state()); } else { assertTrue("host101 is required for preprovision capacity", - tester.nodeRepository.nodes().node("host101").isPresent()); + node("host101").isPresent()); } + // If a host with another architecture is added to preprovision capacity, a shared host should be added. + setPreprovisionCapacityFlag(tester, + new ClusterCapacity(1, 2.0, 30.0, 20.0, 3.0, "fast", "remote", "x86_64"), + new ClusterCapacity(1, 2.0, 30.0, 20.0, 3.0, "fast", "remote", "arm64")); + tester.hostProvisioner.setHostFlavor("arm64"); + tester.maintain(); + + assertEquals(2, tester.hostProvisioner.provisionedHosts().size()); + assertEquals(1, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10))); + assertEquals(1, tester.provisionedHostsMatching(new NodeResources(2, 30, 20, 3, fast, remote, arm64))); } private void verifyFirstMaintain(DynamicProvisioningTester tester) { - assertEquals(1, tester.hostProvisioner.provisionedHosts().size()); + assertEquals(tester.hostProvisioner.provisionedHosts().toString(), 1, tester.hostProvisioner.provisionedHosts().size()); assertEquals(1, tester.provisionedHostsMatching(new NodeResources(48, 128, 1000, 10))); assertEquals(8, tester.nodeRepository.nodes().list().not().state(State.deprovisioned).size()); // 2 removed, 1 added - assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node("host2").get().state()); - assertTrue("Node on deprovisioned host removed", tester.nodeRepository.nodes().node("host2-1").isEmpty()); - assertTrue("One 1-30-20-3 node fits on host3", tester.nodeRepository.nodes().node("host3").isPresent()); - assertTrue("New 48-128-1000-10 host added", tester.nodeRepository.nodes().node("host100").isPresent()); + assertSame("Failed host 'host2' is deprovisioned", State.deprovisioned, node("host2").get().state()); + assertTrue("Node on deprovisioned host removed", node("host2-1").isEmpty()); + assertTrue("One 1-30-20-3 node fits on host3", node("host3").isPresent()); + assertTrue("New 48-128-1000-10 host added", node("host100").isPresent()); } @Test public void does_not_remove_if_host_provisioner_failed() { - var tester = new DynamicProvisioningTester(); + tester = new DynamicProvisioningTester(); Node host2 = tester.addNode("host2", Optional.empty(), NodeType.host, Node.State.failed, DynamicProvisioningTester.tenantApp); tester.hostProvisioner.with(Behaviour.failDeprovisioning); tester.maintain(); - assertTrue(tester.nodeRepository.nodes().node(host2.hostname()).isPresent()); + assertTrue(node(host2.hostname()).isPresent()); } @Test public void test_minimum_capacity() { - var tester = new DynamicProvisioningTester(); + tester = new DynamicProvisioningTester(); NodeResources resources1 = new NodeResources(24, 64, 100, 10); - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(2, resources1.vcpu(), resources1.memoryGb(), resources1.diskGb(), - resources1.bandwidthGbps(), resources1.diskSpeed().name(), - resources1.storageType().name(), resources1.architecture().name())), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, + new ClusterCapacity(2, resources1.vcpu(), resources1.memoryGb(), resources1.diskGb(), + resources1.bandwidthGbps(), resources1.diskSpeed().name(), + resources1.storageType().name(), resources1.architecture().name())); tester.maintain(); // Hosts are provisioned @@ -252,16 +259,14 @@ public class HostCapacityMaintainerTest { tester.assertNodesUnchanged(); // Pretend shared-host flag has been set to host4's flavor - var sharedHostNodeResources = new NodeResources(48, 128, 1000, 10, NodeResources.DiskSpeed.fast, NodeResources.StorageType.remote); + var sharedHostNodeResources = new NodeResources(48, 128, 1000, 10, fast, remote); tester.hostProvisioner.setHostFlavor("host4"); // Next maintenance run does nothing tester.assertNodesUnchanged(); // Must be able to allocate 2 nodes with "no resource requirement" - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(2, 0.0, 0.0, 0.0, 0.0, null, null, null)), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, new ClusterCapacity(2, 0.0, 0.0, 0.0, 0.0, null, null, null)); // Next maintenance run does nothing tester.assertNodesUnchanged(); @@ -280,62 +285,58 @@ public class HostCapacityMaintainerTest { tester.assertNodesUnchanged(); // Clearing flag does nothing - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), List.of(), ClusterCapacity.class); + setPreprovisionCapacityFlag(tester); tester.assertNodesUnchanged(); // Increasing the capacity provisions additional hosts - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(3, 0.0, 0.0, 0.0, 0.0, null, null, null)), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, new ClusterCapacity(3, 0.0, 0.0, 0.0, 0.0, null, null, null)); assertEquals(0, tester.provisionedHostsMatching(sharedHostNodeResources)); - assertTrue(tester.nodeRepository.nodes().node("host102").isEmpty()); + assertTrue(node("host102").isEmpty()); tester.maintain(); assertEquals(1, tester.provisionedHostsMatching(sharedHostNodeResources)); - assertTrue(tester.nodeRepository.nodes().node("host102").isPresent()); + assertTrue(node("host102").isPresent()); // Next maintenance run does nothing tester.assertNodesUnchanged(); // Requiring >0 capacity does nothing as long as it fits on the 3 hosts - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(3, - resources1.vcpu() - applicationNodeResources.vcpu(), - resources1.memoryGb() - applicationNodeResources.memoryGb(), - resources1.diskGb() - applicationNodeResources.diskGb(), - resources1.bandwidthGbps() - applicationNodeResources.bandwidthGbps(), - resources1.diskSpeed().name(), - resources1.storageType().name(), - resources1.architecture().name())), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, + new ClusterCapacity(3, + resources1.vcpu() - applicationNodeResources.vcpu(), + resources1.memoryGb() - applicationNodeResources.memoryGb(), + resources1.diskGb() - applicationNodeResources.diskGb(), + resources1.bandwidthGbps() - applicationNodeResources.bandwidthGbps(), + resources1.diskSpeed().name(), + resources1.storageType().name(), + resources1.architecture().name())); tester.assertNodesUnchanged(); // But requiring a bit more in the cluster => provisioning of 2 shared hosts. - tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), - List.of(new ClusterCapacity(3, - resources1.vcpu() - applicationNodeResources.vcpu() + 1, - resources1.memoryGb() - applicationNodeResources.memoryGb() + 1, - resources1.diskGb() - applicationNodeResources.diskGb() + 1, - resources1.bandwidthGbps(), - resources1.diskSpeed().name(), - resources1.storageType().name(), - resources1.architecture().name())), - ClusterCapacity.class); + setPreprovisionCapacityFlag(tester, + new ClusterCapacity(3, + resources1.vcpu() - applicationNodeResources.vcpu() + 1, + resources1.memoryGb() - applicationNodeResources.memoryGb() + 1, + resources1.diskGb() - applicationNodeResources.diskGb() + 1, + resources1.bandwidthGbps(), + resources1.diskSpeed().name(), + resources1.storageType().name(), + resources1.architecture().name())); assertEquals(1, tester.provisionedHostsMatching(sharedHostNodeResources)); - assertTrue(tester.nodeRepository.nodes().node("host102").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host103").isEmpty()); - assertTrue(tester.nodeRepository.nodes().node("host104").isEmpty()); + assertTrue(node("host102").isPresent()); + assertTrue(node("host103").isEmpty()); + assertTrue(node("host104").isEmpty()); tester.maintain(); assertEquals(3, tester.provisionedHostsMatching(sharedHostNodeResources)); - assertTrue(tester.nodeRepository.nodes().node("host102").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host103").isPresent()); - assertTrue(tester.nodeRepository.nodes().node("host104").isPresent()); + assertTrue(node("host102").isPresent()); + assertTrue(node("host103").isPresent()); + assertTrue(node("host104").isPresent()); } @Test public void deprovision_empty_confighost() { // cfghost1, cfg1, cfghost2, cfg2, cfghost3, and NOT cfg3. - var tester = new DynamicProvisioningTester(); + tester = new DynamicProvisioningTester(); tester.addCfghost(1, true); tester.addCfghost(2, true); Node cfghost3 = tester.addCfghost(3, false); @@ -475,8 +476,8 @@ public class HostCapacityMaintainerTest { @Test public void custom_cloud_account() { - DynamicProvisioningTester tester = new DynamicProvisioningTester(Cloud.builder().name(CloudName.AWS).dynamicProvisioning(true).allowEnclave(true).account(CloudAccount.from("001122334455")).build(), - new MockNameResolver().mockAnyLookup()); + tester = new DynamicProvisioningTester(Cloud.builder().name(CloudName.AWS).dynamicProvisioning(true).allowEnclave(true).account(CloudAccount.from("001122334455")).build(), + new MockNameResolver().mockAnyLookup()); ProvisioningTester provisioningTester = tester.provisioningTester; ApplicationId applicationId = ApplicationId.from("t1", "a1", "i1"); @@ -513,7 +514,7 @@ public class HostCapacityMaintainerTest { @Test public void deprovision_node_when_no_allocation_and_past_ttl() { - var tester = new DynamicProvisioningTester(); + tester = new DynamicProvisioningTester(); ManualClock clock = (ManualClock) tester.nodeRepository.clock(); tester.hostProvisioner.with(Behaviour.failProvisioning); tester.provisioningTester.makeReadyHosts(2, new NodeResources(1, 1, 1, 1)).activateTenantHosts(); @@ -526,61 +527,61 @@ public class HostCapacityMaintainerTest { // Host is not marked for deprovisioning by maintainer, because child is present tester.maintain(); - assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); - assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); + assertFalse(node(host1.hostname()).get().status().wantToDeprovision()); + assertEquals(Optional.empty(), node(host1.hostname()).get().hostEmptyAt()); // Child is set to deprovision, but turns active tester.nodeRepository.nodes().park(host11.hostname(), true, Agent.system, "not good"); tester.nodeRepository.nodes().reactivate(host11.hostname(), Agent.operator, "all good"); - assertTrue(tester.nodeRepository.nodes().node(host11.hostname()).get().status().wantToDeprovision()); - assertEquals(State.active, tester.nodeRepository.nodes().node(host11.hostname()).get().state()); + assertTrue(node(host11.hostname()).get().status().wantToDeprovision()); + assertEquals(State.active, node(host11.hostname()).get().state()); tester.maintain(); - assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); - assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); + assertFalse(node(host1.hostname()).get().status().wantToDeprovision()); + assertEquals(Optional.empty(), node(host1.hostname()).get().hostEmptyAt()); // Child is parked, to make the host effectively empty tester.nodeRepository.nodes().park(host11.hostname(), true, Agent.system, "not good"); tester.maintain(); - assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); + assertFalse(node(host1.hostname()).get().status().wantToDeprovision()); assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS)), - tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); + node(host1.hostname()).get().hostEmptyAt()); // Some time passes, but not enough for host1 to be deprovisioned clock.advance(Duration.ofDays(1).minusSeconds(1)); tester.maintain(); - assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); + assertFalse(node(host1.hostname()).get().status().wantToDeprovision()); assertEquals(Optional.of(clock.instant().minus(Duration.ofDays(1).minusSeconds(1)).truncatedTo(ChronoUnit.MILLIS)), - tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); - assertTrue(tester.nodeRepository.nodes().node(host2.hostname()).get().status().wantToDeprovision()); - assertTrue(tester.nodeRepository.nodes().node(host2.hostname()).get().status().wantToRetire()); - assertEquals(State.active, tester.nodeRepository.nodes().node(host2.hostname()).get().state()); + node(host1.hostname()).get().hostEmptyAt()); + assertTrue(node(host2.hostname()).get().status().wantToDeprovision()); + assertTrue(node(host2.hostname()).get().status().wantToRetire()); + assertEquals(State.active, node(host2.hostname()).get().state()); assertEquals(Optional.of(clock.instant().minus(Duration.ofDays(1).minusSeconds(1)).truncatedTo(ChronoUnit.MILLIS)), - tester.nodeRepository.nodes().node(host2.hostname()).get().hostEmptyAt()); + node(host2.hostname()).get().hostEmptyAt()); // Some more time passes, but child is reactivated on host1, rendering the host non-empty again clock.advance(Duration.ofDays(1)); tester.nodeRepository.nodes().reactivate(host11.hostname(), Agent.operator, "all good"); tester.maintain(); - assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); - assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); + assertFalse(node(host1.hostname()).get().status().wantToDeprovision()); + assertEquals(Optional.empty(), node(host1.hostname()).get().hostEmptyAt()); // Child is removed, and host is marked as empty tester.nodeRepository.database().writeTo(State.deprovisioned, host11, Agent.operator, Optional.empty()); - tester.nodeRepository.nodes().forget(tester.nodeRepository.nodes().node(host11.hostname()).get()); - assertEquals(Optional.empty(), tester.nodeRepository.nodes().node(host11.hostname())); + tester.nodeRepository.nodes().forget(node(host11.hostname()).get()); + assertEquals(Optional.empty(), node(host11.hostname())); tester.maintain(); - assertFalse(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); + assertFalse(node(host1.hostname()).get().status().wantToDeprovision()); assertEquals(Optional.of(clock.instant().truncatedTo(ChronoUnit.MILLIS)), - tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); + node(host1.hostname()).get().hostEmptyAt()); // Enough time passes for the host to be deprovisioned clock.advance(Duration.ofDays(1)); tester.maintain(); - assertTrue(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToDeprovision()); - assertTrue(tester.nodeRepository.nodes().node(host1.hostname()).get().status().wantToRetire()); - assertEquals(State.active, tester.nodeRepository.nodes().node(host1.hostname()).get().state()); + assertTrue(node(host1.hostname()).get().status().wantToDeprovision()); + assertTrue(node(host1.hostname()).get().status().wantToRetire()); + assertEquals(State.active, node(host1.hostname()).get().state()); assertEquals(Optional.of(clock.instant().minus(Duration.ofDays(1)).truncatedTo(ChronoUnit.MILLIS)), - tester.nodeRepository.nodes().node(host1.hostname()).get().hostEmptyAt()); + node(host1.hostname()).get().hostEmptyAt()); // Let tenant host app redeploy, retiring the obsolete host. tester.provisioningTester.activateTenantHosts(); @@ -599,7 +600,7 @@ public class HostCapacityMaintainerTest { @Test public void deprovision_parked_node_with_allocation() { - var tester = new DynamicProvisioningTester(); + tester = new DynamicProvisioningTester(); tester.hostProvisioner.with(Behaviour.failProvisioning); Node host4 = tester.addNode("host4", Optional.empty(), NodeType.host, Node.State.parked, null, Duration.ofDays(1)); Node host41 = tester.addNode("host4-1", Optional.of("host4"), NodeType.tenant, Node.State.parked, DynamicProvisioningTester.tenantApp); @@ -609,16 +610,16 @@ public class HostCapacityMaintainerTest { // Host and children are marked for deprovisioning, bypassing host TTL. tester.nodeRepository.nodes().deprovision("host4", Agent.operator, Instant.now()); for (var node : List.of(host4, host41, host42, host43)) { - assertTrue(tester.nodeRepository.nodes().node(node.hostname()).map(n -> n.status().wantToDeprovision()).get()); + assertTrue(node(node.hostname()).map(n -> n.status().wantToDeprovision()).get()); } // Host and children remain parked because one child is still active tester.maintain(); for (var node : List.of(host4, host41)) { - assertEquals(Node.State.parked, tester.nodeRepository.nodes().node(node.hostname()).get().state()); + assertEquals(Node.State.parked, node(node.hostname()).get().state()); } - assertEquals(Node.State.active, tester.nodeRepository.nodes().node(host42.hostname()).get().state()); - assertEquals(Node.State.failed, tester.nodeRepository.nodes().node(host43.hostname()).get().state()); + assertEquals(Node.State.active, node(host42.hostname()).get().state()); + assertEquals(Node.State.failed, node(host43.hostname()).get().state()); // Last child is parked tester.nodeRepository.nodes().park(host42.hostname(), false, Agent.system, getClass().getSimpleName()); @@ -627,9 +628,9 @@ public class HostCapacityMaintainerTest { tester.maintain(); for (var node : List.of(host4, host41, host42, host43)) { if (node.type().isHost()) { - assertSame(node.hostname() + " moved to deprovisioned", State.deprovisioned, tester.nodeRepository.nodes().node(node.hostname()).get().state()); + assertSame(node.hostname() + " moved to deprovisioned", State.deprovisioned, node(node.hostname()).get().state()); } else { - assertTrue(node.hostname() + " removed", tester.nodeRepository.nodes().node(node.hostname()).isEmpty()); + assertTrue(node.hostname() + " removed", node(node.hostname()).isEmpty()); } } } @@ -656,7 +657,7 @@ public class HostCapacityMaintainerTest { private void assertCfghost3IsActive(DynamicProvisioningTester tester) { assertEquals(5, tester.nodeRepository.nodes().list(Node.State.active).size()); assertEquals(3, tester.nodeRepository.nodes().list(Node.State.active).nodeType(NodeType.confighost).size()); - Optional<Node> cfghost3 = tester.nodeRepository.nodes().node("cfghost3"); + Optional<Node> cfghost3 = node("cfghost3"); assertTrue(cfghost3.isPresent()); assertEquals(Node.State.active, cfghost3.get().state()); } @@ -664,7 +665,35 @@ public class HostCapacityMaintainerTest { private void assertCfghost3IsDeprovisioned(DynamicProvisioningTester tester) { assertEquals(4, tester.nodeRepository.nodes().list(Node.State.active).size()); assertEquals(2, tester.nodeRepository.nodes().list(Node.State.active).nodeType(NodeType.confighost).size()); - assertSame(State.deprovisioned, tester.nodeRepository.nodes().node("cfghost3").get().state()); + assertSame(State.deprovisioned, node("cfghost3").get().state()); + } + + private static void setPreprovisionCapacityFlag(DynamicProvisioningTester tester, ClusterCapacity... clusterCapacities) { + tester.flagSource.withListFlag(PermanentFlags.PREPROVISION_CAPACITY.id(), List.of(clusterCapacities), ClusterCapacity.class); + } + + private void assertNodeExists(String nodeName) { + assertTrue(nodeExists(nodeName)); + } + + private void assertNodeExists(String message, String nodeName) { + assertTrue(message, nodeExists(nodeName)); + } + + private void assertNodeDoesNotExist(String nodeName) { + assertFalse(nodeExists(nodeName)); + } + + private void assertNodeDoesNotExist(String message, String nodeName) { + assertFalse(message, nodeExists(nodeName)); + } + + private boolean nodeExists(String nodeName) { + return node(nodeName).isPresent(); + } + + private Optional<Node> node(String nodeName) { + return tester.nodeRepository.nodes().node(nodeName); } private static class DynamicProvisioningTester { @@ -673,7 +702,7 @@ public class HostCapacityMaintainerTest { private static final InfraApplication configServerHostApp = new ConfigServerHostApplication(); private static final InfraApplication configServerApp = new ConfigServerApplication(); private static final ApplicationId tenantApp = ApplicationId.from("mytenant", "myapp", "default"); - private static final NodeFlavors flavors = FlavorConfigBuilder.createDummies("default", "docker", "host2", "host3", "host4"); + private static final NodeFlavors flavors = FlavorConfigBuilder.createDummies("default", "docker", "host2", "host3", "host4", "arm64"); private final InMemoryFlagSource flagSource = new InMemoryFlagSource(); @@ -722,7 +751,7 @@ public class HostCapacityMaintainerTest { createNode("host4", Optional.empty(), NodeType.host, Node.State.provisioned, null), createNode("host4-1", Optional.of("host4"), NodeType.tenant, Node.State.reserved, tenantApp), createNode("host4-2", Optional.of("host4"), NodeType.tenant, Node.State.reserved, tenantApp)) - .forEach(node -> nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system)); + .forEach(node -> nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system)); return this; } @@ -744,7 +773,7 @@ public class HostCapacityMaintainerTest { private Node addNode(String hostname, Optional<String> parentHostname, NodeType nodeType, Node.State state, ApplicationId application, Duration hostTTL) { Node node = createNode(hostname, parentHostname, nodeType, state, application, hostTTL); - return nodeRepository.database().addNodesInState(List.of(node), node.state(), Agent.system).get(0); + return nodeRepository.database().addNodesInState(new LockedNodeList(List.of(node), () -> { }), node.state(), Agent.system).get(0); } private Node createNode(String hostname, Optional<String> parentHostname, NodeType nodeType, diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java index de2c060a0eb..83aea78ce58 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.provision.maintenance; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationTransaction; import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.ClusterResources; @@ -10,11 +11,13 @@ import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; +import com.yahoo.config.provision.ProvisionLock; import com.yahoo.jdisc.Metric; import com.yahoo.transaction.Mutex; import com.yahoo.transaction.NestedTransaction; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.ApplicationInstanceReference; +import com.yahoo.vespa.applicationmodel.InfrastructureApplication; import com.yahoo.vespa.curator.stats.LockStats; import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; @@ -182,7 +185,7 @@ public class MetricsReporterTest { @Test public void container_metrics() { - NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker", "docker2"); + NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker"); ProvisioningTester tester = new ProvisioningTester.Builder().flavors(nodeFlavors.getFlavors()).build(); NodeRepository nodeRepository = tester.nodeRepository(); @@ -210,7 +213,8 @@ public class MetricsReporterTest { } NestedTransaction transaction = new NestedTransaction(); - nodeRepository.nodes().activate(nodeRepository.nodes().list().nodeType(NodeType.host).asList(), transaction); + nodeRepository.nodes().activate(nodeRepository.nodes().list().nodeType(NodeType.host).asList(), + new ApplicationTransaction(new ProvisionLock(InfrastructureApplication.TENANT_HOST.id(), () -> { }), transaction)); transaction.commit(); Orchestrator orchestrator = mock(Orchestrator.class); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java index 359f75c27ab..ac1e452d7a5 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ProvisionedExpirerTest.java @@ -9,6 +9,7 @@ import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; +import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester; @@ -45,7 +46,7 @@ public class ProvisionedExpirerTest { var nodes = IntStream.range(0, 15) .mapToObj(i -> Node.create("id-" + i, "host-" + i, new Flavor(NodeResources.unspecified()), Node.State.provisioned, NodeType.host).build()) .toList(); - tester.nodeRepository().database().addNodesInState(nodes, Node.State.provisioned, Agent.system); + tester.nodeRepository().database().addNodesInState(new LockedNodeList(nodes, () -> { }), Node.State.provisioned, Agent.system); } } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java index b54975cbf41..a5ac2be72ee 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/SpareCapacityMaintainerTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.provision.maintenance; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationTransaction; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.DockerImage; @@ -10,6 +11,7 @@ import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; +import com.yahoo.config.provision.ProvisionLock; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.Zone; import com.yahoo.test.ManualClock; @@ -313,7 +315,7 @@ public class SpareCapacityMaintainerTest { } private void allocate(ApplicationId application, ClusterSpec clusterSpec, List<Node> nodes) { - nodes = nodeRepository.nodes().addNodes(nodes, Agent.system); + nodes = new ArrayList<>(nodeRepository.nodes().addNodes(nodes, Agent.system)); for (int i = 0; i < nodes.size(); i++) { Node node = nodes.get(i); ClusterMembership membership = ClusterMembership.from(clusterSpec, i); @@ -322,7 +324,7 @@ public class SpareCapacityMaintainerTest { } nodes = nodeRepository.nodes().reserve(nodes); var transaction = new NestedTransaction(); - nodes = nodeRepository.nodes().activate(nodes, transaction); + nodes = nodeRepository.nodes().activate(nodes, new ApplicationTransaction(new ProvisionLock(application, () -> { }), transaction)); transaction.commit(); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java index 47d34a76dd6..478b201d71b 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicAllocationTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.provision.provisioning; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ApplicationTransaction; import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.ClusterResources; @@ -12,6 +13,7 @@ import com.yahoo.config.provision.HostSpec; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.NodeAllocationException; +import com.yahoo.config.provision.ProvisionLock; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; @@ -540,9 +542,9 @@ public class DynamicAllocationTest { clusterSpec.with(Optional.of(ClusterSpec.Group.from(0))), index); // Need to add group here so that group is serialized in node allocation Node node1aAllocation = node1a.allocate(id, clusterMembership1, node1a.resources(), Instant.now()); - tester.nodeRepository().nodes().addNodes(Collections.singletonList(node1aAllocation), Agent.system); + tester.nodeRepository().nodes().addNodes(List.of(node1aAllocation), Agent.system); NestedTransaction transaction = new NestedTransaction().add(new CuratorTransaction(tester.getCurator())); - tester.nodeRepository().nodes().activate(Collections.singletonList(node1aAllocation), transaction); + tester.nodeRepository().nodes().activate(List.of(node1aAllocation), new ApplicationTransaction(new ProvisionLock(id, () -> { }), transaction)); transaction.commit(); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java index ec8a8f637c8..35d6cdbf5d0 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/HostCapacityTest.java @@ -45,7 +45,7 @@ public class HostCapacityTest { doAnswer(invocation -> ((Flavor)invocation.getArguments()[0]).resources()).when(hostResourcesCalculator).advertisedResourcesOf(any()); // Create flavors - NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker", "docker2"); + NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker"); // Create hosts host1 = Node.create("host1", IP.Config.of(Set.of("::1"), createIps(2, 4), List.of()), "host1", nodeFlavors.getFlavorOrThrow("host"), NodeType.host).build(); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java index 2acbeb00f5f..dd8f97d82de 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java @@ -214,7 +214,7 @@ public class ProvisioningTester { NestedTransaction t = new NestedTransaction(); if (parent.ipConfig().primary().isEmpty()) parent = parent.with(IP.Config.of(Set.of("::" + 0 + ":0"), Set.of("::" + 0 + ":2"))); - nodeRepository.nodes().activate(List.of(parent), t); + nodeRepository.nodes().activate(List.of(parent), new ApplicationTransaction(new ProvisionLock(application, () -> { }), t)); t.commit(); } } diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index d0c1f99af11..28da6013edb 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -35,6 +35,7 @@ using search::fef::MatchData; using search::fef::RankSetup; using search::fef::indexproperties::hitcollector::HeapSize; using search::fef::indexproperties::hitcollector::ArraySize; +using search::fef::indexproperties::hitcollector::RankScoreDropLimit; using search::queryeval::Blueprint; using search::queryeval::SearchIterator; using vespalib::Doom; @@ -239,10 +240,11 @@ Matcher::match(const SearchRequest &request, vespalib::ThreadBundle &threadBundl const Properties & rankProperties = request.propertiesMap.rankProperties(); uint32_t heapSize = HeapSize::lookup(rankProperties, _rankSetup->getHeapSize()); uint32_t arraySize = ArraySize::lookup(rankProperties, _rankSetup->getArraySize()); + search::feature_t rank_score_drop_limit = RankScoreDropLimit::lookup(rankProperties, _rankSetup->getRankScoreDropLimit()); - MatchParams params(searchContext.getDocIdLimit(), heapSize, arraySize, - _rankSetup->getRankScoreDropLimit(), request.offset, request.maxhits, - !_rankSetup->getSecondPhaseRank().empty(), !willNotNeedRanking(request, groupingContext)); + MatchParams params(searchContext.getDocIdLimit(), heapSize, arraySize, rank_score_drop_limit, + request.offset, request.maxhits, !_rankSetup->getSecondPhaseRank().empty(), + !willNotNeedRanking(request, groupingContext)); ResultProcessor rp(attrContext, metaStore, sessionMgr, groupingContext, sessionId, request.sortSpec, params.offset, params.hits); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index f3fe86e261f..7d6f2f8790c 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -946,6 +946,8 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRandom()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorEuclideanDistance()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCosineSimilarity()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()", @@ -1098,6 +1100,8 @@ "public static final int RANDOM", "public static final int L1_NORMALIZE", "public static final int L2_NORMALIZE", + "public static final int EUCLIDEAN_DISTANCE", + "public static final int COSINE_SIMILARITY", "public static final int MATMUL", "public static final int SOFTMAX", "public static final int XW_PLUS_B", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 42b5f2c191a..41647a5ef5b 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -138,6 +138,8 @@ TOKEN : <RANDOM: "random"> | <L1_NORMALIZE: "l1_normalize"> | <L2_NORMALIZE: "l2_normalize"> | + <EUCLIDEAN_DISTANCE: "euclidean_distance"> | + <COSINE_SIMILARITY: "cosine_similarity"> | <MATMUL: "matmul"> | <SOFTMAX: "softmax"> | <XW_PLUS_B: "xw_plus_b"> | @@ -379,6 +381,8 @@ TensorFunctionNode tensorFunction() : tensorExpression = tensorRandom() | tensorExpression = tensorL1Normalize() | tensorExpression = tensorL2Normalize() | + tensorExpression = tensorEuclideanDistance() | + tensorExpression = tensorCosineSimilarity() | tensorExpression = tensorMatmul() | tensorExpression = tensorSoftmax() | tensorExpression = tensorXwPlusB() | @@ -544,6 +548,30 @@ TensorFunctionNode tensorL2Normalize() : { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } +TensorFunctionNode tensorEuclideanDistance() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <EUCLIDEAN_DISTANCE> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new EuclideanDistance(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + dimension)); } +} + +TensorFunctionNode tensorCosineSimilarity() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <COSINE_SIMILARITY> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new CosineSimilarity(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + dimension)); } +} + TensorFunctionNode tensorMatmul() : { ExpressionNode tensor1, tensor2; @@ -701,6 +729,8 @@ String tensorFunctionName() : ( <RANDOM> { return token.image; } ) | ( <L1_NORMALIZE> { return token.image; } ) | ( <L2_NORMALIZE> { return token.image; } ) | + ( <EUCLIDEAN_DISTANCE> { return token.image; } ) | + ( <COSINE_SIMILARITY> { return token.image; } ) | ( <MATMUL> { return token.image; } ) | ( <SOFTMAX> { return token.image; } ) | ( <XW_PLUS_B> { return token.image; } ) | @@ -1041,4 +1071,4 @@ String label() : String string() : {} { <STRING> { return token.image.substring(1, token.image.length() - 1); } -}
\ No newline at end of file +} diff --git a/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp b/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp index 4b01808e855..e3e4f391cc4 100644 --- a/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp +++ b/searchlib/src/tests/attribute/multi_value_mapping/multi_value_mapping_test.cpp @@ -95,7 +95,7 @@ public: ArrayStoreConfig config(max_array_store_type_id, ArrayStoreConfig::AllocSpec(0, RefType::offsetSize(), 8_Ki, ALLOC_GROW_FACTOR)); config.enable_free_lists(enable_free_lists); - _mvMapping = std::make_unique<MvMapping>(config, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats)); + _mvMapping = std::make_unique<MvMapping>(config, ArrayStoreConfig::default_max_buffer_size, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats)); _attr = std::make_unique<AttributeType>(*_mvMapping); _maxSmallArraySize = _mvMapping->get_mapper().get_array_size(max_array_store_type_id); } @@ -103,7 +103,7 @@ public: ArrayStoreConfig config(max_array_store_type_id, ArrayStoreConfig::AllocSpec(min_entries, max_entries, num_entries_for_new_buffer, ALLOC_GROW_FACTOR)); config.enable_free_lists(enable_free_lists); - _mvMapping = std::make_unique<MvMapping>(config, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats)); + _mvMapping = std::make_unique<MvMapping>(config, ArrayStoreConfig::default_max_buffer_size, vespalib::GrowStrategy(), std::make_unique<MemoryAllocatorObserver>(_stats)); _attr = std::make_unique<AttributeType>(*_mvMapping); _maxSmallArraySize = _mvMapping->get_mapper().get_array_size(max_array_store_type_id); } diff --git a/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp b/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp index 2bf2a38b7e6..0d64683b4a6 100644 --- a/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp +++ b/searchlib/src/tests/docstore/logdatastore/logdatastore_test.cpp @@ -565,7 +565,7 @@ TEST("Control static memory usage") { IDocumentStore &ds = vcs.getStore(); vespalib::MemoryUsage usage = ds.getMemoryUsage(); constexpr size_t mutex_size = sizeof(std::mutex) * 2 * (113 + 1); // sizeof(std::mutex) is platform dependent - EXPECT_EQUAL(74572 + mutex_size, usage.allocatedBytes()); + EXPECT_EQUAL(74668 + mutex_size, usage.allocatedBytes()); EXPECT_EQUAL(944u + mutex_size, usage.usedBytes()); } @@ -575,29 +575,29 @@ TEST("test the update cache strategy") { for (size_t i(1); i <= 10; i++) { vcs.write(i); } - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 28)); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241)); vcs.write(8); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241)); vcs.write(7, 17); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 282)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 302)); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 282)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 302)); vcs.remove(8); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 282)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 1, 302)); vcs.remove(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 28)); vcs.write(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 1, 0, 28)); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 241)); vcs.write(7, 17); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 282)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 1, 2, 1, 302)); vcs.recreate(); IDocumentStore & ds2 = vcs.getStore(); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds2.getCacheStats(), 0, 1, 1, 282)); + TEST_DO(verifyCacheStats(ds2.getCacheStats(), 0, 1, 1, 302)); } TEST("test the invalidate cache strategy") { @@ -606,23 +606,23 @@ TEST("test the invalidate cache strategy") { for (size_t i(1); i <= 10; i++) { vcs.write(i); } - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 28)); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241)); vcs.write(8); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 1, 241)); vcs.write(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 1, 0, 28)); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 241)); vcs.remove(8); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 1, 241)); vcs.remove(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 28)); vcs.write(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 2, 0, 28)); vcs.verifyRead(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 3, 1, 221)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 3, 1, 241)); } TEST("test that the integrated visit cache works.") { @@ -631,12 +631,12 @@ TEST("test that the integrated visit cache works.") { for (size_t i(1); i <= 100; i++) { vcs.write(i); } - TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 0)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 0, 0, 28)); for (size_t i(1); i <= 100; i++) { vcs.verifyRead(i); } - constexpr size_t BASE_SZ = 20594; + constexpr size_t BASE_SZ = 20602; TEST_DO(verifyCacheStats(ds.getCacheStats(), 0, 100, 100, BASE_SZ)); for (size_t i(1); i <= 100; i++) { vcs.verifyRead(i); @@ -646,32 +646,32 @@ TEST("test that the integrated visit cache works.") { vcs.verifyVisit({7,9,17,19,67,88}, false); TEST_DO(verifyCacheStats(ds.getCacheStats(), 100, 100, 100, BASE_SZ)); vcs.verifyVisit({7,9,17,19,67,88}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 100, 101, 101, BASE_SZ+557)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 100, 101, 101, BASE_SZ+16)); vcs.verifyVisit({7,9,17,19,67,88}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 101, BASE_SZ+557)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 101, BASE_SZ+16)); vcs.rewrite(8); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 100, BASE_SZ+328)); // From the individual cache. + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 100, BASE_SZ-197)); // From the individual cache. vcs.rewrite(7); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 98, BASE_SZ-442)); // From the both caches. + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 101, 98, BASE_SZ-166)); // From the both caches. vcs.verifyVisit({7,9,17,19,67,88}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 102, 99, BASE_SZ+130)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 102, 99, BASE_SZ-410)); vcs.verifyVisit({7,9,17,19,67,88,89}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 99, BASE_SZ+180)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 99, BASE_SZ-406)); vcs.rewrite(17); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 97, BASE_SZ-671)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 103, 97, BASE_SZ-391)); vcs.verifyVisit({7,9,17,19,67,88,89}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 98, BASE_SZ-20)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 98, BASE_SZ-611)); vcs.remove(17); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 97, BASE_SZ-671)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 104, 97, BASE_SZ-391)); vcs.verifyVisit({7,9,17,19,67,88,89}, {7,9,19,67,88,89}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 105, 98, BASE_SZ-70)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 105, 98, BASE_SZ-611)); vcs.verifyVisit({41, 42}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 106, 99, BASE_SZ+230)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 106, 99, BASE_SZ-611)); vcs.verifyVisit({43, 44}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 107, 100, BASE_SZ+540)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 107, 100, BASE_SZ-611)); vcs.verifyVisit({41, 42, 43, 44}, true); - TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 108, 99, BASE_SZ+360)); + TEST_DO(verifyCacheStats(ds.getCacheStats(), 101, 108, 99, BASE_SZ-611)); } TEST("testWriteRead") { diff --git a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp index 26ef57aab65..87420e8939f 100644 --- a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp +++ b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp @@ -9,6 +9,7 @@ LOG_SETUP("dense_tensor_store_test"); #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/value_type.h> #include <vespa/eval/eval/test/value_compare.h> +#include <vespa/vespalib/util/size_literals.h> using search::tensor::DenseTensorStore; using vespalib::eval::SimpleValue; @@ -90,5 +91,26 @@ TEST("require that array size is calculated correctly") TEST_DO(assertArraySize("tensor<int8>(x[65])", 96)); } +void +assert_max_buffer_entries(const vespalib::string& tensor_type, uint32_t exp_entries) +{ + Fixture f(tensor_type); + EXPECT_EQUAL(exp_entries, f.store.get_max_buffer_entries()); +} + +TEST("require that max entries is calculated correctly") +{ + TEST_DO(assert_max_buffer_entries("tensor(x[1])", 1_Mi)); + TEST_DO(assert_max_buffer_entries("tensor(x[32])", 1_Mi)); + TEST_DO(assert_max_buffer_entries("tensor(x[64])", 512_Ki)); + TEST_DO(assert_max_buffer_entries("tensor(x[1024])", 32_Ki)); + TEST_DO(assert_max_buffer_entries("tensor(x[1024])", 32_Ki)); + TEST_DO(assert_max_buffer_entries("tensor(x[16777216])", 2)); + TEST_DO(assert_max_buffer_entries("tensor(x[33554428])", 2)); + TEST_DO(assert_max_buffer_entries("tensor(x[33554429])", 1)); + TEST_DO(assert_max_buffer_entries("tensor(x[33554432])", 1)); + TEST_DO(assert_max_buffer_entries("tensor(x[303554432])", 1)); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp b/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp index 0a69d8149ed..b42520aa9e0 100644 --- a/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp +++ b/searchlib/src/tests/tensor/tensor_buffer_store/tensor_buffer_store_test.cpp @@ -198,7 +198,7 @@ TEST_F(TensorBufferStoreTest, buffer_handles_range_of_subspaces) auto buffer_id = ref.buffer_id(offset_bits); buffers.insert(buffer_id); } - EXPECT_EQ(156u, buffers.size()); + EXPECT_EQ(119u, buffers.size()); uint32_t x = 0; for (auto ref : refs) { auto tensor = store.get_tensor(ref); diff --git a/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp b/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp index 08c0901de01..fc574ba9b2c 100644 --- a/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp +++ b/searchlib/src/tests/tensor/tensor_buffer_type_mapper/tensor_buffer_type_mapper_test.cpp @@ -15,7 +15,7 @@ const vespalib::string tensor_type_2d_mixed_spec("tensor(x{},y[2])"); const vespalib::string float_tensor_type_spec("tensor<float>(y{})"); const vespalib::string tensor_type_dense_spec("tensor(x[2])"); -constexpr double grow_factor = 1.02; +constexpr double grow_factor = 1.03; struct TestParam { @@ -128,10 +128,10 @@ TensorBufferTypeMapperTest::select_type_ids() INSTANTIATE_TEST_SUITE_P(TensorBufferTypeMapperMultiTest, TensorBufferTypeMapperTest, - testing::Values(TestParam("1d", {8, 16, 32, 40, 64}, {1760, 10880, 76896, 555248, 4020512}, tensor_type_sparse_spec), - TestParam("1dfloat", {4, 12, 20, 28, 36}, {1728, 11104, 79168, 572128, 4143664}, float_tensor_type_spec), - TestParam("2d", {8, 24, 40, 56, 80}, {1600, 9184, 63872, 460416, 3332976}, tensor_type_2d_spec), - TestParam("2dmixed", {8, 24, 48, 64, 96}, {1984, 11472, 79824, 575504, 4166208}, tensor_type_2d_mixed_spec), + testing::Values(TestParam("1d", {8, 16, 32, 40, 64}, {2768, 49712, 950768, 18268976, 351101184}, tensor_type_sparse_spec), + TestParam("1dfloat", {4, 12, 20, 28, 36}, {2688, 48896, 937248, 18009808, 346121248}, float_tensor_type_spec), + TestParam("2d", {8, 24, 40, 56, 80}, {2416, 41392, 790112, 15179616, 291726288}, tensor_type_2d_spec), + TestParam("2dmixed", {8, 24, 48, 64, 96}, {3008, 51728, 987632, 18974512, 364657856}, tensor_type_2d_mixed_spec), TestParam("dense", {8, 24}, {}, tensor_type_dense_spec)), testing::PrintToStringParamName()); @@ -152,7 +152,7 @@ TEST_P(TensorBufferTypeMapperTest, large_arrays_grows_exponentially) TEST_P(TensorBufferTypeMapperTest, avoid_array_size_overflow) { - TensorBufferTypeMapper mapper(400, 2.0, &_ops); + TensorBufferTypeMapper mapper(300, 2.0, &_ops); EXPECT_GE(30, mapper.get_max_type_id(1000)); } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h index 307d1a0d112..ced076dc632 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.h @@ -36,6 +36,7 @@ public: MultiValueMapping(const MultiValueMapping &) = delete; MultiValueMapping & operator = (const MultiValueMapping &) = delete; MultiValueMapping(const vespalib::datastore::ArrayStoreConfig &storeCfg, + size_t max_buffer_size, const vespalib::GrowStrategy &gs, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator); ~MultiValueMapping() override; @@ -79,11 +80,12 @@ public: static vespalib::datastore::ArrayStoreConfig optimizedConfigForHugePage(size_t max_type_id, - size_t hugePageSize, - size_t smallPageSize, - size_t min_num_entries_for_new_buffer, - float allocGrowFactor, - bool enable_free_lists); + size_t hugePageSize, + size_t smallPageSize, + size_t max_buffer_size, + size_t min_num_entries_for_new_buffer, + float allocGrowFactor, + bool enable_free_lists); }; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp index 99808b11e92..64c4777ffda 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping.hpp @@ -9,10 +9,11 @@ namespace search::attribute { template <typename ElemT, typename RefT> MultiValueMapping<ElemT,RefT>::MultiValueMapping(const vespalib::datastore::ArrayStoreConfig &storeCfg, + size_t max_buffer_size, const vespalib::GrowStrategy &gs, std::shared_ptr<vespalib::alloc::MemoryAllocator> memory_allocator) : MultiValueMappingBase(gs, ArrayStore::getGenerationHolderLocation(_store), memory_allocator), - _store(storeCfg, std::move(memory_allocator), ArrayStoreTypeMapper(storeCfg.max_type_id(), array_store_grow_factor)) + _store(storeCfg, std::move(memory_allocator), ArrayStoreTypeMapper(storeCfg.max_type_id(), array_store_grow_factor, max_buffer_size)) { } @@ -66,14 +67,15 @@ MultiValueMapping<ElemT, RefT>::getAddressSpaceUsage() const { template <typename ElemT, typename RefT> vespalib::datastore::ArrayStoreConfig MultiValueMapping<ElemT, RefT>::optimizedConfigForHugePage(size_t max_type_id, - size_t hugePageSize, - size_t smallPageSize, - size_t min_num_entries_for_new_buffer, - float allocGrowFactor, - bool enable_free_lists) + size_t hugePageSize, + size_t smallPageSize, + size_t max_buffer_size, + size_t min_num_entries_for_new_buffer, + float allocGrowFactor, + bool enable_free_lists) { - ArrayStoreTypeMapper mapper(max_type_id, array_store_grow_factor); - auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, min_num_entries_for_new_buffer, allocGrowFactor); + ArrayStoreTypeMapper mapper(max_type_id, array_store_grow_factor, max_buffer_size); + auto result = ArrayStore::optimizedConfigForHugePage(max_type_id, mapper, hugePageSize, smallPageSize, max_buffer_size, min_num_entries_for_new_buffer, allocGrowFactor); result.enable_free_lists(enable_free_lists); return result; } diff --git a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp index d8ada97fa2c..56c6d010582 100644 --- a/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multivalueattribute.hpp @@ -12,6 +12,8 @@ #include <vespa/vespalib/util/memory_allocator.h> #include <vespa/vespalib/util/stash.h> +using vespalib::datastore::ArrayStoreConfig; + namespace search { namespace multivalueattribute { @@ -28,9 +30,11 @@ MultiValueAttribute(const vespalib::string &baseFileName, _mvMapping(MultiValueMapping::optimizedConfigForHugePage(MultiValueMapping::array_store_max_type_id, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, vespalib::alloc::MemoryAllocator::PAGE_SIZE, + ArrayStoreConfig::default_max_buffer_size, 8 * 1024, cfg.getGrowStrategy().getMultiValueAllocGrowFactor(), multivalueattribute::enable_free_lists), + ArrayStoreConfig::default_max_buffer_size, cfg.getGrowStrategy(), this->get_memory_allocator()) { } diff --git a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp index cd9e0508344..0c6dd9c75a8 100644 --- a/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp +++ b/searchlib/src/vespa/searchlib/attribute/raw_buffer_store.cpp @@ -5,6 +5,7 @@ #include <cassert> using vespalib::alloc::MemoryAllocator; +using vespalib::datastore::ArrayStoreConfig; using vespalib::datastore::EntryRef; namespace { @@ -17,11 +18,12 @@ namespace search::attribute { RawBufferStore::RawBufferStore(std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator, uint32_t max_small_buffer_type_id, double grow_factor) : _array_store(ArrayStoreType::optimizedConfigForHugePage(max_small_buffer_type_id, - TypeMapper(max_small_buffer_type_id, grow_factor), + TypeMapper(max_small_buffer_type_id, grow_factor, ArrayStoreConfig::default_max_buffer_size), MemoryAllocator::HUGEPAGE_SIZE, MemoryAllocator::PAGE_SIZE, + ArrayStoreConfig::default_max_buffer_size, 8_Ki, ALLOC_GROW_FACTOR), - std::move(allocator), TypeMapper(max_small_buffer_type_id, grow_factor)) + std::move(allocator), TypeMapper(max_small_buffer_type_id, grow_factor, ArrayStoreConfig::default_max_buffer_size)) { } diff --git a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp index b61cf49c438..322d0eb341b 100644 --- a/searchlib/src/vespa/searchlib/docstore/visitcache.cpp +++ b/searchlib/src/vespa/searchlib/docstore/visitcache.cpp @@ -165,6 +165,7 @@ public: CompressedBlobSet readSet(const KeySet & keys); void removeKey(uint32_t key); vespalib::MemoryUsage getStaticMemoryUsage() const override; + CacheStats get_stats() const override; private: void locateAndInvalidateOtherSubsets(const UniqueLock & cacheGuard, const KeySet & keys); using IdSet = vespalib::hash_set<uint64_t>; @@ -267,7 +268,8 @@ VisitCache::remove(uint32_t key) { CacheStats VisitCache::getCacheStats() const { - return _cache->get_stats(); + CacheStats stats = _cache->get_stats(); + return stats; } VisitCache::Cache::Cache(BackingStore & b, size_t maxBytes) : @@ -306,19 +308,22 @@ VisitCache::Cache::onRemove(const K & key) { vespalib::MemoryUsage VisitCache::Cache::getStaticMemoryUsage() const { vespalib::MemoryUsage usage = Parent::getStaticMemoryUsage(); - auto cacheGuard = getGuard(); size_t baseSelf = sizeof(_lid2Id) + sizeof(_id2KeySet); usage.incAllocatedBytes(baseSelf); - usage.incAllocatedBytes(_lid2Id.capacity() * sizeof(LidUniqueKeySetId::value_type)); - usage.incAllocatedBytes(_id2KeySet.capacity() * sizeof(IdKeySetMap::value_type)); usage.incUsedBytes(baseSelf); - usage.incUsedBytes(_lid2Id.size() * sizeof(LidUniqueKeySetId::value_type)); - usage.incUsedBytes(_id2KeySet.size() * sizeof(IdKeySetMap::value_type)); + return usage; +} + +CacheStats +VisitCache::Cache::get_stats() const { + CacheStats stats = Parent::get_stats(); + auto cacheGuard = getGuard(); + stats.memory_used += _lid2Id.capacity() * sizeof(LidUniqueKeySetId::value_type); + stats.memory_used += _id2KeySet.capacity() * sizeof(IdKeySetMap::value_type); for (const auto & entry: _id2KeySet) { - usage.incAllocatedBytes(entry.second.getKeys().capacity() * sizeof(uint32_t)); - usage.incUsedBytes(entry.second.getKeys().size() * sizeof(uint32_t)); + stats.memory_used = entry.second.getKeys().capacity() * sizeof(uint32_t); } - return usage; + return stats; } } diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp index 26a7be005b7..c993cdeb790 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp @@ -620,7 +620,13 @@ const feature_t RankScoreDropLimit::DEFAULT_VALUE(-std::numeric_limits<feature_t feature_t RankScoreDropLimit::lookup(const Properties &props) { - return lookupDouble(props, NAME, DEFAULT_VALUE); + return lookup(props, DEFAULT_VALUE); +} + +feature_t +RankScoreDropLimit::lookup(const Properties &props, feature_t defaultValue) +{ + return lookupDouble(props, NAME, defaultValue); } } // namspace hitcollector diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.h b/searchlib/src/vespa/searchlib/fef/indexproperties.h index 5ff4ea26bd8..53f1789b2e0 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.h +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.h @@ -507,6 +507,7 @@ namespace hitcollector { static const vespalib::string NAME; static const feature_t DEFAULT_VALUE; static feature_t lookup(const Properties &props); + static feature_t lookup(const Properties &props, feature_t defaultValue); }; diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp index c51d0ec7fd3..638c254aac1 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp @@ -34,6 +34,13 @@ size_t my_align(size_t size, size_t alignment) { return (size - (size % alignment)); } +size_t +cap_max_entries(size_t max_entries, size_t max_buffer_size, size_t entry_size) +{ + size_t dynamic_max_entries = (max_buffer_size + (entry_size - 1)) / entry_size; + return std::min(max_entries, dynamic_max_entries); +} + } DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type) @@ -55,7 +62,7 @@ DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type) } DenseTensorStore::BufferType::BufferType(const TensorSizeCalc &tensorSizeCalc, std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator) - : vespalib::datastore::BufferType<char>(tensorSizeCalc.alignedSize(), MIN_BUFFER_ARRAYS, RefType::offsetSize()), + : vespalib::datastore::BufferType<char>(tensorSizeCalc.alignedSize(), MIN_BUFFER_ARRAYS, cap_max_entries(RefType::offsetSize(), max_dense_tensor_buffer_size, tensorSizeCalc.alignedSize())), _allocator(std::move(allocator)) {} diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h index 0dd483e7f08..b7375c4d2c3 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h @@ -20,9 +20,11 @@ namespace search::tensor { class DenseTensorStore : public TensorStore { public: - using RefType = vespalib::datastore::EntryRefT<22>; + // 4 Ki buffers of 256 MiB each is 1 TiB. + using RefType = vespalib::datastore::EntryRefT<20>; using DataStoreType = vespalib::datastore::DataStoreT<RefType>; using ValueType = vespalib::eval::ValueType; + static constexpr size_t max_dense_tensor_buffer_size = 256_Mi; struct TensorSizeCalc { @@ -90,8 +92,9 @@ public: return VectorBundle(getRawBuffer(ref), 1, _subspace_type); } const SubspaceType& get_subspace_type() const noexcept { return _subspace_type; } - // The following method is meant to be used only for unit tests. + // The following methods are meant to be used only for unit tests. uint32_t getArraySize() const { return _bufferType.getArraySize(); } + uint32_t get_max_buffer_entries() const noexcept { return _bufferType.get_max_entries(); } }; } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 748a747d515..22a33270a27 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -31,6 +31,7 @@ namespace search::tensor { using search::AddressSpaceComponents; using search::StateExplorerUtils; using search::queryeval::GlobalFilter; +using vespalib::datastore::ArrayStoreConfig; using vespalib::datastore::CompactionStrategy; using vespalib::datastore::EntryRef; using vespalib::GenericHeader; @@ -145,25 +146,27 @@ PreparedAddDoc::PreparedAddDoc(PreparedAddDoc&& other) noexcept = default; } template <HnswIndexType type> -vespalib::datastore::ArrayStoreConfig +ArrayStoreConfig HnswIndex<type>::make_default_level_array_store_config() { return LevelArrayStore::optimizedConfigForHugePage(max_level_array_size, - vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, - vespalib::alloc::MemoryAllocator::PAGE_SIZE, + vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, + vespalib::alloc::MemoryAllocator::PAGE_SIZE, + ArrayStoreConfig::default_max_buffer_size, min_num_arrays_for_new_buffer, alloc_grow_factor).enable_free_lists(true); } template <HnswIndexType type> -vespalib::datastore::ArrayStoreConfig +ArrayStoreConfig HnswIndex<type>::make_default_link_array_store_config() { return LinkArrayStore::optimizedConfigForHugePage(max_link_array_size, - vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, - vespalib::alloc::MemoryAllocator::PAGE_SIZE, - min_num_arrays_for_new_buffer, - alloc_grow_factor).enable_free_lists(true); + vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, + vespalib::alloc::MemoryAllocator::PAGE_SIZE, + ArrayStoreConfig::default_max_buffer_size, + min_num_arrays_for_new_buffer, + alloc_grow_factor).enable_free_lists(true); } template <HnswIndexType type> diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp index a78d9cefc64..cf30d62a0b8 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_nodeid_mapping.cpp @@ -49,6 +49,7 @@ HnswNodeidMapping::HnswNodeidMapping() _nodeids(NodeidStore::optimizedConfigForHugePage(max_type_id, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, vespalib::alloc::MemoryAllocator::PAGE_SIZE, + vespalib::datastore::ArrayStoreConfig::default_max_buffer_size, min_num_arrays_for_new_buffer, alloc_grow_factor).enable_free_lists(true), {}), _hold_list(), diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp index 51ebc22c269..1e79ca53e68 100644 --- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp @@ -16,7 +16,8 @@ namespace search::tensor { SerializedFastValueAttribute::SerializedFastValueAttribute(stringref name, const Config &cfg, const NearestNeighborIndexFactory& index_factory) : TensorAttribute(name, cfg, _tensorBufferStore, index_factory), - _tensorBufferStore(cfg.tensorType(), get_memory_allocator(), 400u) + _tensorBufferStore(cfg.tensorType(), get_memory_allocator(), + TensorBufferStore::array_store_max_type_id) { } diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp index ff39c33fc5d..8a7d84010cb 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.cpp @@ -26,8 +26,6 @@ namespace { constexpr float ALLOC_GROW_FACTOR = 0.2; -constexpr double mapper_grow_factor = 1.02; - } TensorBufferStore::TensorBufferStore(const ValueType& tensor_type, std::shared_ptr<MemoryAllocator> allocator, uint32_t max_small_subspaces_type_id) @@ -35,11 +33,12 @@ TensorBufferStore::TensorBufferStore(const ValueType& tensor_type, std::shared_p _tensor_type(tensor_type), _ops(_tensor_type), _array_store(ArrayStoreType::optimizedConfigForHugePage(max_small_subspaces_type_id, - TensorBufferTypeMapper(max_small_subspaces_type_id, mapper_grow_factor, &_ops), + TensorBufferTypeMapper(max_small_subspaces_type_id, array_store_grow_factor, &_ops), MemoryAllocator::HUGEPAGE_SIZE, MemoryAllocator::PAGE_SIZE, + vespalib::datastore::ArrayStoreConfig::default_max_buffer_size, 8_Ki, ALLOC_GROW_FACTOR), - std::move(allocator), TensorBufferTypeMapper(max_small_subspaces_type_id, mapper_grow_factor, &_ops)) + std::move(allocator), TensorBufferTypeMapper(max_small_subspaces_type_id, array_store_grow_factor, &_ops)) { } diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h index 2e86ff5fb67..3342e9f3d27 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h +++ b/searchlib/src/vespa/searchlib/tensor/tensor_buffer_store.h @@ -24,6 +24,10 @@ class TensorBufferStore : public TensorStore TensorBufferOperations _ops; ArrayStoreType _array_store; public: + + static constexpr double array_store_grow_factor = 1.03; + static constexpr uint32_t array_store_max_type_id = 300; + TensorBufferStore(const vespalib::eval::ValueType& tensor_type, std::shared_ptr<vespalib::alloc::MemoryAllocator> allocator, uint32_t max_small_subspaces_type_id); ~TensorBufferStore(); void holdTensor(EntryRef ref) override; diff --git a/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp b/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp index baec5494b36..efa7e18aa33 100644 --- a/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp +++ b/storage/src/vespa/storage/bucketdb/btree_bucket_database.cpp @@ -40,6 +40,7 @@ vespalib::datastore::ArrayStoreConfig make_default_array_store_config() { return ReplicaStore::optimizedConfigForHugePage(1023, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, vespalib::alloc::MemoryAllocator::PAGE_SIZE, + vespalib::datastore::ArrayStoreConfig::default_max_buffer_size, 8_Ki, 0.2).enable_free_lists(true); } diff --git a/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java b/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java index c6d056675a8..aff505f934f 100644 --- a/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java +++ b/vespa-application-maven-plugin/src/main/java/com/yahoo/container/plugin/mojo/ApplicationMojo.java @@ -70,6 +70,7 @@ public class ApplicationMojo extends AbstractMojo { vespaversion = project.getPlugin("com.yahoo.vespa:vespa-application-maven-plugin").getVersion(); Version compileVersion = Version.from(vespaversion); + if (compileVersion.isSnapshot()) return; MavenProject current = project; while (current.getParent() != null && current.getParent().getParentArtifact() != null) diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 676e212f5c6..76d007dd633 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1705,6 +1705,24 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.functions.CosineSimilarity" : { + "superClass" : "com.yahoo.tensor.functions.TensorFunction", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public int hashCode()" + ], + "fields" : [ ] + }, "com.yahoo.tensor.functions.Diag" : { "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces" : [ ], @@ -1740,6 +1758,24 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.functions.EuclideanDistance" : { + "superClass" : "com.yahoo.tensor.functions.TensorFunction", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public int hashCode()" + ], + "fields" : [ ] + }, "com.yahoo.tensor.functions.Expand" : { "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces" : [ ], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java new file mode 100644 index 00000000000..ebb8a11fd8a --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java @@ -0,0 +1,93 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorType.Dimension; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Convenience for cosine similarity between vectors. + * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim)) + * @author arnej + */ +public class CosineSimilarity<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> arg1; + private final TensorFunction<NAMETYPE> arg2; + private final String dimension; + + public CosineSimilarity(TensorFunction<NAMETYPE> argument1, + TensorFunction<NAMETYPE> argument2, + String dimension) + { + this.arg1 = argument1; + this.arg2 = argument2; + this.dimension = dimension; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(arg1, arg2); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("CosineSimilarity must have 2 arguments, got " + arguments.size()); + return new CosineSimilarity<>(arguments.get(0), arguments.get(1), dimension); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + TensorType t1 = arg1.toPrimitive().type(context); + TensorType t2 = arg2.toPrimitive().type(context); + var d1 = t1.dimension(dimension); + var d2 = t2.dimension(dimension); + if (d1.isEmpty() || d2.isEmpty() + || d1.get().type() != Dimension.Type.indexedBound + || d2.get().type() != Dimension.Type.indexedBound + || d1.get().size().get() != d2.get().size().get()) + { + throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '" + + dimension + "' dimension with same size, but input types were " + + t1 + " and " + t2); + } + // Finds the type this produces by first converting it to a primitive function + return toPrimitive().type(context); + } + + /** Evaluates this by first converting it to a primitive function */ + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> a = arg1.toPrimitive(); + TensorFunction<NAMETYPE> b = arg2.toPrimitive(); + var aa = new Join<>(a, a, ScalarFunctions.multiply()); + var ab = new Join<>(a, b, ScalarFunctions.multiply()); + var bb = new Join<>(b, b, ScalarFunctions.multiply()); + var dot_aa = new Reduce<>(aa, Reduce.Aggregator.sum, dimension); + var dot_ab = new Reduce<>(ab, Reduce.Aggregator.sum, dimension); + var dot_bb = new Reduce<>(bb, Reduce.Aggregator.sum, dimension); + var aabb = new Join<>(dot_aa, dot_bb, ScalarFunctions.multiply()); + var sqrt_aabb = new Map<>(aabb, ScalarFunctions.sqrt()); + return new Join<>(dot_ab, sqrt_aabb, ScalarFunctions.divide()); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "cosine_similarity(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + } + + @Override + public int hashCode() { return Objects.hash("cosine_similarity", arg1, arg2, dimension); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java new file mode 100644 index 00000000000..f9fc8e195d3 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -0,0 +1,89 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorType.Dimension; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Convenience for euclidean distance between vectors. + * euclidean_distance(a, b, mydim) == sqrt(sum(pow(a-b, 2), mydim)) + * @author arnej + */ +public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> arg1; + private final TensorFunction<NAMETYPE> arg2; + private final String dimension; + + public EuclideanDistance(TensorFunction<NAMETYPE> argument1, + TensorFunction<NAMETYPE> argument2, + String dimension) + { + this.arg1 = argument1; + this.arg2 = argument2; + this.dimension = dimension; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(arg1, arg2); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("EuclideanDistance must have 2 arguments, got " + arguments.size()); + return new EuclideanDistance<>(arguments.get(0), arguments.get(1), dimension); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + TensorType t1 = arg1.toPrimitive().type(context); + TensorType t2 = arg2.toPrimitive().type(context); + var d1 = t1.dimension(dimension); + var d2 = t2.dimension(dimension); + if (d1.isEmpty() || d2.isEmpty() + || d1.get().type() != Dimension.Type.indexedBound + || d2.get().type() != Dimension.Type.indexedBound + || d1.get().size().get() != d2.get().size().get()) + { + throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" + + dimension + "' dimension with same size, but input types were " + + t1 + " and " + t2); + } + // Finds the type this produces by first converting it to a primitive function + return toPrimitive().type(context); + } + + /** Evaluates this by first converting it to a primitive function */ + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive(); + TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive(); + // this should match the C++ optimized "l2_distance" + var diffs = new Join<>(primitive1, primitive2, ScalarFunctions.subtract()); + var squaredDiffs = new Map<>(diffs, ScalarFunctions.square()); + var sumOfSquares = new Reduce<>(squaredDiffs, Reduce.Aggregator.sum, dimension); + return new Map<>(sumOfSquares, ScalarFunctions.sqrt()); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + } + + @Override + public int hashCode() { return Objects.hash("euclidean_distance", arg1, arg2, dimension); } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java new file mode 100644 index 00000000000..b303e2c1739 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java @@ -0,0 +1,66 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class CosineSimilarityTestCase { + + @Test + public void testVectorSimilarity() { + var a = Tensor.from("tensor(x[3]):[ 2.0, 3.0, 6.0]"); + var b = Tensor.from("tensor(x[3]):[-2.0, 0.0, 0.0]"); + var c = Tensor.from("tensor(x[3]):[ 0.0, 4.0, 3.0]"); + var op = new CosineSimilarity<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + assertEquals((-2.0 / 7.0), result.asDouble(), 0.000001); + op = new CosineSimilarity<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals((-2.0 / 7.0), result.asDouble(), 0.000001); + op = new CosineSimilarity<>(new ConstantTensor<>(a), new ConstantTensor<>(c), "x"); + result = op.evaluate(); + assertEquals((30.0 / 35.0), result.asDouble(), 0.000001); + op = new CosineSimilarity<>(new ConstantTensor<>(b), new ConstantTensor<>(c), "x"); + result = op.evaluate(); + assertEquals(0.0, result.asDouble(), 0.000001); + } + + @Test + public void testSimilarityInMixed() { + var a = Tensor.from("tensor(c{},yy[3]):{foo:[3.0, 4.0, 0.0],bar:[0.0, -4.0, 3.0]}"); + var b = Tensor.from("tensor(c{},yy[3]):{foo:[0.0, 4.0, -3.0],bar:[4.0, 0.0, -3.0]}"); + var op = new CosineSimilarity<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "yy"); + Tensor result = op.evaluate(); + var expect = Tensor.from("tensor(c{}):{foo:0.64,bar:-0.36}"); + assertEquals(expect, result); + } + + @Test + public void testExpansion() { + var tType = TensorType.fromSpec("tensor(vecdim[128])"); + var a = new VariableTensor<>("left", tType); + var b = new VariableTensor<>("right", tType); + var op = new CosineSimilarity<>(a, b, "vecdim"); + assertEquals("join(" + + ( "reduce(join(left, right, f(a,b)(a * b)), sum, vecdim), " + + "map(" + + ( "join(" + + ( "reduce(join(left, left, f(a,b)(a * b)), sum, vecdim), " + + "reduce(join(right, right, f(a,b)(a * b)), sum, vecdim), " + + "f(a,b)(a * b)), " ) + + "f(a)(sqrt(a))), " ) + + "f(a,b)(a / b)" ) + + ")", + op.toPrimitive().toString()); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java new file mode 100644 index 00000000000..4fae432b3ca --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -0,0 +1,54 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class EuclideanDistanceTestCase { + + @Test + public void testVectorDistances() { + var a = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]"); + var b = Tensor.from("tensor(x[3]):[4.0, 2.0, 7.0]"); + var c = Tensor.from("tensor(x[3]):[1.0, 6.0, 6.0]"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(c), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + } + + @Test + public void testDistancesInMixed() { + var a = Tensor.from("tensor(c{},x[3]):{foo:[1.0, 2.0, 3.0],bar:[0.0, 0.0, 0.0]}"); + var b = Tensor.from("tensor(c{},x[3]):{foo:[4.0, 2.0, 7.0],bar:[12.0, 0.0, 5.0]}"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + var expect = Tensor.from("tensor(c{}):{foo:5.0,bar:13.0}"); + assertEquals(expect, result); + } + + @Test + public void testExpansion() { + var tType = TensorType.fromSpec("tensor(vecdim[128])"); + var a = new VariableTensor<>("left", tType); + var b = new VariableTensor<>("right", tType); + var op = new EuclideanDistance<>(a, b, "vecdim"); + assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))", + op.toPrimitive().toString()); + } + +} diff --git a/vespalib/src/tests/datastore/array_store/array_store_test.cpp b/vespalib/src/tests/datastore/array_store/array_store_test.cpp index a259fcaa4dc..6e433e48d88 100644 --- a/vespalib/src/tests/datastore/array_store/array_store_test.cpp +++ b/vespalib/src/tests/datastore/array_store/array_store_test.cpp @@ -32,7 +32,7 @@ constexpr float ALLOC_GROW_FACTOR = 0.2; template <typename ElemT> class MyArrayStoreSimpleTypeMapper : public ArrayStoreSimpleTypeMapper<ElemT> { public: - MyArrayStoreSimpleTypeMapper(uint32_t, double) + MyArrayStoreSimpleTypeMapper(uint32_t, double, size_t) : ArrayStoreSimpleTypeMapper<ElemT>() { } @@ -62,7 +62,7 @@ struct ArrayStoreTest : public TestT bool add_using_allocate; double type_mapper_grow_factor; ArrayStoreTest(uint32_t max_type_id = 3, bool enable_free_lists = true, bool add_using_allocate_in = false, double type_mapper_grow_factor_in = 2.0) - : type_mapper(max_type_id, type_mapper_grow_factor_in), + : type_mapper(max_type_id, type_mapper_grow_factor_in, ArrayStoreConfig::default_max_buffer_size), store(ArrayStoreConfig(max_type_id, ArrayStoreConfig::AllocSpec(16, RefT::offsetSize(), 8_Ki, ALLOC_GROW_FACTOR)).enable_free_lists(enable_free_lists), @@ -74,7 +74,7 @@ struct ArrayStoreTest : public TestT type_mapper_grow_factor(type_mapper_grow_factor_in) {} explicit ArrayStoreTest(const ArrayStoreConfig &storeCfg) - : type_mapper(storeCfg.max_type_id(), 2.0), + : type_mapper(storeCfg.max_type_id(), 2.0, ArrayStoreConfig::default_max_buffer_size), store(storeCfg, std::make_unique<MemoryAllocatorObserver>(stats), TypeMapperType(type_mapper)), refStore(), generation(1), @@ -280,7 +280,7 @@ TYPED_TEST(NumberStoreTest, control_static_sizes) { EXPECT_EQ(202140u, usage.allocatedBytes()); EXPECT_EQ(197680u, usage.usedBytes()); } else { - EXPECT_EQ(202388u, usage.allocatedBytes()); + EXPECT_EQ(202328u, usage.allocatedBytes()); EXPECT_EQ(197568u, usage.usedBytes()); } } @@ -564,10 +564,10 @@ TYPED_TEST(NumberStoreTest, address_space_usage_is_ratio_between_used_arrays_and * allocated elements = 256 / sizeof(int) = 64. * limit = 64 / 3 = 21. * - * For dynamic buffer 3, we have 16 * 5 * sizeof(int) => 320 -> 512 - * limit = 512 / (5 * 4) = 25 + * For dynamic buffer 3, we have 16 * 5 * sizeof(int) => 320 -> 512 - 64 + * limit = (512 -64) / (5 * 4) = 22 */ - size_t type_id_3_entries = this->simple_buffers() ? 21 : 25; + size_t type_id_3_entries = this->simple_buffers() ? 21 : 22; size_t expLimit = fourgig - 4 * TestFixture::EntryRefType::offsetSize() + 3 * 16 + type_id_3_entries; EXPECT_EQ(static_cast<double>(2)/ expLimit, this->store.addressSpaceUsage().usage()); EXPECT_EQ(expLimit, this->store.addressSpaceUsage().limit()); @@ -578,6 +578,7 @@ struct ByteStoreTest : public ArrayStoreTest<testing::Test, uint8_t, EntryRefT<1 optimizedConfigForHugePage(1023, vespalib::alloc::MemoryAllocator::HUGEPAGE_SIZE, vespalib::alloc::MemoryAllocator::PAGE_SIZE, + ArrayStoreConfig::default_max_buffer_size, 8_Ki, ALLOC_GROW_FACTOR)) {} }; diff --git a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp index 71c1341ae74..3bcc130052d 100644 --- a/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp +++ b/vespalib/src/tests/datastore/array_store_config/array_store_config_test.cpp @@ -22,15 +22,20 @@ struct Fixture Fixture(uint32_t max_type_id, size_t hugePageSize, size_t smallPageSize, + size_t max_buffer_size, size_t min_num_entries_for_new_buffer) : cfg(ArrayStoreConfig::optimizeForHugePage(max_type_id, [](size_t type_id) noexcept { return type_id * sizeof(int); }, hugePageSize, smallPageSize, EntryRefType::offsetSize(), + max_buffer_size, min_num_entries_for_new_buffer, ALLOC_GROW_FACTOR)) { } void assertSpec(uint32_t type_id, uint32_t num_entries_for_new_buffer) { - assertSpec(type_id, AllocSpec(0, EntryRefType::offsetSize(), + assertSpec(type_id, EntryRefType::offsetSize(), num_entries_for_new_buffer); + } + void assertSpec(uint32_t type_id, uint32_t max_entries, uint32_t num_entries_for_new_buffer) { + assertSpec(type_id, AllocSpec(0, max_entries, num_entries_for_new_buffer, ALLOC_GROW_FACTOR)); } void assertSpec(uint32_t type_id, const AllocSpec &expSpec) { @@ -50,9 +55,6 @@ makeSpec(size_t min_entries_in_buffer, return AllocSpec(min_entries_in_buffer, max_entries_in_buffer, num_entries_for_new_buffer, ALLOC_GROW_FACTOR); } -constexpr size_t KB = 1024; -constexpr size_t MB = KB * KB; - TEST_F("require that default allocation spec is given for all array sizes", Fixture(3, makeSpec(4, 32, 8))) { EXPECT_EQUAL(3u, f.cfg.max_type_id()); @@ -62,26 +64,54 @@ TEST_F("require that default allocation spec is given for all array sizes", Fixt TEST_DO(f.assertSpec(3, makeSpec(4, 32, 8))); } -TEST_F("require that we can generate config optimized for a given huge page", Fixture(1024, - 2 * MB, - 4 * KB, - 8 * KB)) +struct BigBuffersFixture : public Fixture { + BigBuffersFixture() : Fixture(1023, 2_Mi, 4_Ki, 1024_Gi, 8_Ki) { } +}; + +TEST_F("require that we can generate config optimized for a given huge page without capped buffer sizes", BigBuffersFixture()) +{ + EXPECT_EQUAL(1023u, f.cfg.max_type_id()); + TEST_DO(f.assertSpec(0, 8_Ki)); // large arrays + TEST_DO(f.assertSpec(1, 256_Ki)); + TEST_DO(f.assertSpec(2, 256_Ki)); + TEST_DO(f.assertSpec(3, 168_Ki)); + TEST_DO(f.assertSpec(4, 128_Ki)); + TEST_DO(f.assertSpec(5, 100_Ki)); + TEST_DO(f.assertSpec(6, 84_Ki)); + + TEST_DO(f.assertSpec(32, 16_Ki)); + TEST_DO(f.assertSpec(33, 12_Ki)); + TEST_DO(f.assertSpec(42, 12_Ki)); + TEST_DO(f.assertSpec(43, 8_Ki)); + TEST_DO(f.assertSpec(1022, 8_Ki)); + TEST_DO(f.assertSpec(1023, 8_Ki)); +} + +struct CappedBuffersFixture : public Fixture { + CappedBuffersFixture() : Fixture(1023, 2_Mi, 4_Ki, 256_Mi, 8_Ki) { } + size_t max_entries(size_t array_size) { + auto entry_size = array_size * sizeof(int); + return (256_Mi + entry_size - 1) / entry_size; + } +}; + +TEST_F("require that we can generate config optimized for a given huge page with capped buffer sizes", CappedBuffersFixture()) { - EXPECT_EQUAL(1_Ki, f.cfg.max_type_id()); - TEST_DO(f.assertSpec(0, 8 * KB)); // large arrays - TEST_DO(f.assertSpec(1, 256 * KB)); - TEST_DO(f.assertSpec(2, 256 * KB)); - TEST_DO(f.assertSpec(3, 168 * KB)); - TEST_DO(f.assertSpec(4, 128 * KB)); - TEST_DO(f.assertSpec(5, 100 * KB)); - TEST_DO(f.assertSpec(6, 84 * KB)); + EXPECT_EQUAL(1023u, f.cfg.max_type_id()); + TEST_DO(f.assertSpec(0, f.max_entries(1023), 8_Ki)); // large arrays + TEST_DO(f.assertSpec(1, 256_Ki)); + TEST_DO(f.assertSpec(2, 256_Ki)); + TEST_DO(f.assertSpec(3, 168_Ki)); + TEST_DO(f.assertSpec(4, 128_Ki)); + TEST_DO(f.assertSpec(5, 100_Ki)); + TEST_DO(f.assertSpec(6, 84_Ki)); - TEST_DO(f.assertSpec(32, 16 * KB)); - TEST_DO(f.assertSpec(33, 12 * KB)); - TEST_DO(f.assertSpec(42, 12 * KB)); - TEST_DO(f.assertSpec(43, 8 * KB)); - TEST_DO(f.assertSpec(1022, 8 * KB)); - TEST_DO(f.assertSpec(1023, 8 * KB)); + TEST_DO(f.assertSpec(32, 16_Ki)); + TEST_DO(f.assertSpec(33, 12_Ki)); + TEST_DO(f.assertSpec(42, 12_Ki)); + TEST_DO(f.assertSpec(43, 8_Ki)); + TEST_DO(f.assertSpec(1022, f.max_entries(1022), 8_Ki)); + TEST_DO(f.assertSpec(1023, f.max_entries(1023), 8_Ki)); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp b/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp index 7ead0b97269..f53f5a8ff22 100644 --- a/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp +++ b/vespalib/src/tests/datastore/array_store_dynamic_type_mapper/array_store_dynamic_type_mapper_test.cpp @@ -2,10 +2,18 @@ #include <vespa/vespalib/datastore/array_store_dynamic_type_mapper.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/size_literals.h> +#include <limits> using vespalib::datastore::ArrayStoreDynamicTypeMapper; +namespace { + constexpr double default_grow_factor = 1.03; +constexpr size_t default_max_buffer_size = 256_Mi; +constexpr size_t small_max_buffer_size = 256_Ki; +constexpr size_t max_max_buffer_size = std::numeric_limits<uint32_t>::max(); +} template <typename ElemT> class TestBase : public testing::Test @@ -19,13 +27,13 @@ protected: std::vector<size_t> get_large_array_sizes(uint32_t num_large_arrays); void select_type_ids(std::vector<size_t> array_sizes); void setup_mapper(uint32_t max_buffer_type_id, double grow_factor); - static uint32_t calc_max_buffer_type_id(double grow_factor); + static uint32_t calc_max_buffer_type_id(double grow_factor, size_t max_buffer_size = default_max_buffer_size); }; template <typename ElemT> TestBase<ElemT>::TestBase() : testing::Test(), - _mapper(5, default_grow_factor) + _mapper(5, default_grow_factor, default_max_buffer_size) { } @@ -36,7 +44,7 @@ template <typename ElemT> void TestBase<ElemT>::setup_mapper(uint32_t max_buffer_type_id, double grow_factor) { - _mapper = ArrayStoreDynamicTypeMapper<ElemT>(max_buffer_type_id, grow_factor); + _mapper = ArrayStoreDynamicTypeMapper<ElemT>(max_buffer_type_id, grow_factor, default_max_buffer_size); } template <typename ElemT> @@ -108,9 +116,9 @@ TestBase<ElemT>::select_type_ids(std::vector<size_t> array_sizes) template <typename ElemT> uint32_t -TestBase<ElemT>::calc_max_buffer_type_id(double grow_factor) +TestBase<ElemT>::calc_max_buffer_type_id(double grow_factor, size_t max_buffer_size) { - ArrayStoreDynamicTypeMapper<ElemT> mapper(1000, grow_factor); + ArrayStoreDynamicTypeMapper<ElemT> mapper(1000, grow_factor, max_buffer_size); return mapper.get_max_type_id(1000); } @@ -139,11 +147,13 @@ TEST_F(ArrayStoreDynamicTypeMapperCharTest, large_arrays_grows_exponentially) TEST_F(ArrayStoreDynamicTypeMapperCharTest, avoid_entry_size_overflow) { - EXPECT_EQ(32, calc_max_buffer_type_id(2.0)); - EXPECT_EQ(410, calc_max_buffer_type_id(1.05)); - EXPECT_EQ(507, calc_max_buffer_type_id(1.04)); - EXPECT_EQ(661, calc_max_buffer_type_id(1.03)); - EXPECT_EQ(968, calc_max_buffer_type_id(1.02)); + EXPECT_EQ(29, calc_max_buffer_type_id(2.0)); + EXPECT_EQ(367, calc_max_buffer_type_id(1.05)); + EXPECT_EQ(454, calc_max_buffer_type_id(1.04)); + EXPECT_EQ(591, calc_max_buffer_type_id(1.03)); + EXPECT_EQ(357, calc_max_buffer_type_id(1.03, small_max_buffer_size)); + EXPECT_EQ(661, calc_max_buffer_type_id(1.03, max_max_buffer_size)); + EXPECT_EQ(863, calc_max_buffer_type_id(1.02)); } using ArrayStoreDynamicTypeMapperInt32Test = TestBase<int32_t>; @@ -159,11 +169,13 @@ TEST_F(ArrayStoreDynamicTypeMapperInt32Test, array_sizes_are_calculated) TEST_F(ArrayStoreDynamicTypeMapperInt32Test, avoid_entry_size_overflow) { - EXPECT_EQ(30, calc_max_buffer_type_id(2.0)); - EXPECT_EQ(395, calc_max_buffer_type_id(1.05)); - EXPECT_EQ(487, calc_max_buffer_type_id(1.04)); - EXPECT_EQ(636, calc_max_buffer_type_id(1.03)); - EXPECT_EQ(930, calc_max_buffer_type_id(1.02)); + EXPECT_EQ(27, calc_max_buffer_type_id(2.0)); + EXPECT_EQ(337, calc_max_buffer_type_id(1.05)); + EXPECT_EQ(409, calc_max_buffer_type_id(1.04)); + EXPECT_EQ(525, calc_max_buffer_type_id(1.03)); + EXPECT_EQ(291, calc_max_buffer_type_id(1.03, small_max_buffer_size)); + EXPECT_EQ(596, calc_max_buffer_type_id(1.03, max_max_buffer_size)); + EXPECT_EQ(744, calc_max_buffer_type_id(1.02)); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp b/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp index a703d9b18eb..9279aff46b9 100644 --- a/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp +++ b/vespalib/src/tests/datastore/dynamic_array_buffer_type/dynamic_array_buffer_type_test.cpp @@ -5,6 +5,7 @@ #include <ostream> using vespalib::datastore::ArrayStoreConfig; +using vespalib::datastore::AtomicEntryRef; using vespalib::datastore::BufferTypeBase; using vespalib::datastore::DynamicArrayBufferType; using vespalib::datastore::EntryCount; @@ -120,20 +121,23 @@ protected: BufferType _buffer_type; size_t _entry_size; + size_t _buffer_underflow_size; size_t _buf_size; - std::unique_ptr<char[]> _buf; + std::unique_ptr<char[]> _buf_alloc; + char* _buf; }; DynamicArrayBufferTypeTest::DynamicArrayBufferTypeTest() : testing::Test(), _buffer_type(3, ArrayStoreConfig::AllocSpec(0, 10, 0, 0.2), {}), _entry_size(_buffer_type.entry_size()), + _buffer_underflow_size(_buffer_type.buffer_underflow_size()), _buf_size(2 * _entry_size), - _buf(std::make_unique<char[]>(_buf_size)) + _buf_alloc(std::make_unique<char[]>(_buf_size + _buffer_underflow_size)), + _buf(_buf_alloc.get() + _buffer_underflow_size) { // Call initialize_reserved_entries to force construction of empty element - _buffer_type.initialize_reserved_entries(_buf.get(), 1); - memset(_buf.get(), 55, _buf_size); + _buffer_type.initialize_reserved_entries(_buf, 1); // Reset counts after empty element has been constructed counts = Counts(); } @@ -178,7 +182,7 @@ DynamicArrayBufferTypeTest::get_max_vector(const void* buffer, uint32_t offset) void DynamicArrayBufferTypeTest::write_entry1() { - auto e1 = BufferType::get_entry(_buf.get(), 1, _entry_size); + auto e1 = BufferType::get_entry(_buf, 1, _entry_size); BufferType::set_dynamic_array_size(e1, 2); new (static_cast<void *>(e1)) WrapInt32(42); new (static_cast<void *>(e1 + 1)) WrapInt32(47); @@ -200,50 +204,58 @@ TEST_F(DynamicArrayBufferTypeTest, entry_size_is_calculated) EXPECT_EQ(16, get_entry_size<int64_t>(1)); EXPECT_EQ(24, get_entry_size<int64_t>(2)); EXPECT_EQ(20, get_entry_size<WrapInt32>(4)); + + EXPECT_EQ(1028, get_entry_size<WrapInt32>(256)); + EXPECT_EQ(1028, get_entry_size<AtomicEntryRef>(256)); + EXPECT_EQ(1088, get_entry_size<int32_t>(256)); + EXPECT_EQ(1088, get_entry_size<int64_t>(128)); + EXPECT_EQ(1088, get_entry_size<float>(256)); + EXPECT_EQ(1088, get_entry_size<double>(128)); } TEST_F(DynamicArrayBufferTypeTest, initialize_reserved_entries) { - _buffer_type.initialize_reserved_entries(_buf.get(), 2); - EXPECT_EQ((std::vector<int>{}), get_vector(_buf.get(), 0)); - EXPECT_EQ((std::vector<int>{}), get_vector(_buf.get(), 1)); - EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf.get(), 0)); - EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf.get(), 1)); + _buffer_type.initialize_reserved_entries(_buf, 2); + EXPECT_EQ((std::vector<int>{}), get_vector(_buf, 0)); + EXPECT_EQ((std::vector<int>{}), get_vector(_buf, 1)); + EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf, 0)); + EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(_buf, 1)); EXPECT_EQ(Counts(0, 0, 6, 0, 0), counts); } TEST_F(DynamicArrayBufferTypeTest, fallback_copy) { - _buffer_type.initialize_reserved_entries(_buf.get(), 1); + _buffer_type.initialize_reserved_entries(_buf, 1); write_entry1(); EXPECT_EQ(Counts(0, 3, 3, 0, 0), counts); - auto buf2 = std::make_unique<char[]>(_buf_size); - _buffer_type.fallback_copy(buf2.get(), _buf.get(), 2); - EXPECT_EQ((std::vector<int>{}), get_vector(buf2.get(), 0)); - EXPECT_EQ((std::vector<int>{42, 47}), get_vector(buf2.get(), 1)); - EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(buf2.get(), 0)); - EXPECT_EQ((std::vector<int>{42, 47, 49}), get_max_vector(buf2.get(), 1)); + auto buf2_alloc = std::make_unique<char[]>(_buf_size + _buffer_underflow_size); + char* buf2 = buf2_alloc.get() + _buffer_underflow_size; + _buffer_type.fallback_copy(buf2, _buf, 2); + EXPECT_EQ((std::vector<int>{}), get_vector(buf2, 0)); + EXPECT_EQ((std::vector<int>{42, 47}), get_vector(buf2, 1)); + EXPECT_EQ((std::vector<int>{0, 0, 0}), get_max_vector(buf2, 0)); + EXPECT_EQ((std::vector<int>{42, 47, 49}), get_max_vector(buf2, 1)); EXPECT_EQ(Counts(0, 3, 9, 0, 0), counts); } TEST_F(DynamicArrayBufferTypeTest, destroy_entries) { - _buffer_type.initialize_reserved_entries(_buf.get(), 2); + _buffer_type.initialize_reserved_entries(_buf, 2); write_entry1(); - _buffer_type.destroy_entries(_buf.get(), 2); + _buffer_type.destroy_entries(_buf, 2); EXPECT_EQ(Counts(0, 3, 6, 6, 0), counts); } TEST_F(DynamicArrayBufferTypeTest, clean_hold) { - _buffer_type.initialize_reserved_entries(_buf.get(), 1); + _buffer_type.initialize_reserved_entries(_buf, 1); write_entry1(); MyCleanContext clean_context; - _buffer_type.clean_hold(_buf.get(), 1, 1, clean_context); - EXPECT_EQ((std::vector<int>{0, 0}), get_vector(_buf.get(), 1)); - EXPECT_EQ((std::vector<int>{0, 0, 49}), get_max_vector(_buf.get(), 1)); + _buffer_type.clean_hold(_buf, 1, 1, clean_context); + EXPECT_EQ((std::vector<int>{0, 0}), get_vector(_buf, 1)); + EXPECT_EQ((std::vector<int>{0, 0, 49}), get_max_vector(_buf, 1)); EXPECT_EQ(Counts(0, 3, 3, 0, 2), counts); - _buffer_type.clean_hold(_buf.get(), 0, 2, clean_context); + _buffer_type.clean_hold(_buf, 0, 2, clean_context); EXPECT_EQ(Counts(0, 3, 3, 0, 4), counts); } diff --git a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp index c6f0581bceb..afd18a13d2e 100644 --- a/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp +++ b/vespalib/src/tests/rw_spin_lock/rw_spin_lock_test.cpp @@ -11,6 +11,7 @@ #include <ranges> #include <random> #include <array> +#include <algorithm> using namespace vespalib; using namespace vespalib::test; @@ -23,6 +24,71 @@ size_t state_loop = 1; //----------------------------------------------------------------------------- +/** + * Estimates the 80th percentile by throwing away the 2 best samples + * in each set of 10 samples, using the best remaining sample as a + * representative for the set. Representatives are hierarchically + * matched against representatives from other sample sets. Result + * extraction is simplified in that it does not try to estimate the + * actual 80th percentile, but rather tries to drop the best samples + * if possible. + * + * The goal is to have a more robust way of combining repeated + * micro-benchmark samples than simply using minimum time. With simple + * single-threaded CPU-bound tasks, minimum time is a good measure of + * how expensive something is, but when we start benchmarking + * operations that may conflict with themselves, we do not want to + * account for being super lucky. However, we still want to account + * for the benchmark conditions being as good as possible. + **/ +struct Est80P { + struct Level { + int cnt; + std::array<double,3> data; + Level(double value) noexcept + : cnt(1), data{value, 0.0, 0.0} {} + bool empty() const noexcept { return (cnt == 0); } + bool full() const noexcept { return (cnt == 10); } + void add(double value) noexcept { + assert(!full()); + if (cnt < 3 || data[2] > value) { + size_t i = std::min(cnt, 2); + while (i > 0 && data[i - 1] > value) { + data[i] = data[i - 1]; + --i; + } + data[i] = value; + } + ++cnt; + } + double get() const noexcept { + assert(!empty()); + return data[std::min(2, cnt - 1)]; + } + void clear() noexcept { + cnt = 0; + } + }; + std::vector<Level> levels; + void add_sample(double value) { + for (auto &level: levels) { + level.add(value); + if (!level.full()) [[likely]] { + return; + } + value = level.get(); + level.clear(); + } + levels.emplace_back(value); + } + double get_result() { + assert(!levels.empty()); + return levels.back().get(); + } +}; + +//----------------------------------------------------------------------------- + struct DummyLock { constexpr DummyLock() noexcept {} // BasicLockable @@ -158,47 +224,31 @@ double measure_ns(auto &work) { struct BenchmarkResult { double cost_ns; - double range_ns; - size_t threads; - BenchmarkResult(size_t num_threads) - : cost_ns(std::numeric_limits<double>::max()), range_ns(0.0), threads(num_threads) {} + BenchmarkResult(double cost_ns_in) : cost_ns(cost_ns_in) {} void report(vespalib::string desc) { - if (threads == 1) { - fprintf(stderr, "%s: cost_ns: %g\n", - desc.c_str(), cost_ns); - } else { - fprintf(stderr, "%s: cost_ns: %g, range_ns: %g (%zu threads)\n", - desc.c_str(), cost_ns, range_ns, threads); - } + fprintf(stderr, "%s: cost_ns: %g\n", desc.c_str(), cost_ns); } void report(vespalib::string name, vespalib::string desc) { report(name + "(" + desc + ")"); } }; -struct Meets { - vespalib::test::ThreadMeets::Avg avg; - vespalib::test::ThreadMeets::Range<double> range; - Meets(size_t num_threads) : avg(num_threads), range(num_threads) {} -}; - BenchmarkResult benchmark_ns(auto &&work, size_t num_threads = 1) { - Meets meets(num_threads); + Est80P collector; + vespalib::test::ThreadMeets::Avg avg(num_threads); auto entry = [&](Nexus &ctx) { Timer timer; BenchmarkResult result(ctx.num_threads()); for (bool once_more = true; ctx.vote(once_more); once_more = (timer.elapsed() < budget)) { - auto my_ns = measure_ns(work); - auto cost_ns = meets.avg(my_ns); - auto range_ns = meets.range(my_ns); - if (cost_ns < result.cost_ns) { - result.cost_ns = cost_ns; - result.range_ns = range_ns; + auto cost_ns = avg(measure_ns(work)); + if (ctx.is_main()) { + collector.add_sample(cost_ns); } } - return result; }; - return Nexus::run(num_threads, entry); + Nexus::run(num_threads, entry); + auto result = collector.get_result(); + return {result}; } //----------------------------------------------------------------------------- @@ -224,7 +274,7 @@ void estimate_cost() { //----------------------------------------------------------------------------- template <typename T> -void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, Meets &meets, int read_bp) { +void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, auto &max, int read_bp) { Rnd rnd(ctx.thread_id()); size_t write_cnt = 0; size_t bad_reads = 0; @@ -247,16 +297,11 @@ void thread_safety_loop(Nexus &ctx, T &lock, MyState &state, Meets &meets, int r } } } - auto t1 = steady_clock::now(); - ctx.barrier(); - auto t2 = steady_clock::now(); - auto my_ms = count_ns(t1 - t0) / 1'000'000.0; - auto total_ms = count_ns(t2 - t0) / 1'000'000.0; - auto cost_ms = meets.avg(my_ms); - auto range_ms = meets.range(my_ms); - if (ctx.thread_id() == 0) { - fprintf(stderr, "---> %s with %2zu threads (%5d bp r): avg: %10.2f ms, range: %10.2f ms, max: %10.2f ms\n", - getClassName(lock).c_str(), ctx.num_threads(), read_bp, cost_ms, range_ms, total_ms); + auto my_ms = count_ns(steady_clock::now() - t0) / 1'000'000.0; + auto actual_ms = max(my_ms); + if (ctx.is_main()) { + fprintf(stderr, "---> %s with %2zu threads (%5d bp r): time: %10.2f ms\n", + getClassName(lock).c_str(), ctx.num_threads(), read_bp, actual_ms); } state.commit_inconsistent_reads(bad_reads); state.commit_expected_writes(write_cnt); @@ -290,9 +335,9 @@ void benchmark_lock() { for (size_t bp: {10000, 9999, 5000, 0}) { for (size_t num_threads: {8, 4, 2, 1}) { if (bench || (bp == 9999 && num_threads == 8)) { - Meets meets(num_threads); + vespalib::test::ThreadMeets::Max<double> max(num_threads); Nexus::run(num_threads, [&](Nexus &ctx) { - thread_safety_loop(ctx, *lock, *state, meets, bp); + thread_safety_loop(ctx, *lock, *state, max, bp); }); } } diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.h b/vespalib/src/vespa/vespalib/datastore/array_store.h index f5c30c90c5b..7ee63be3848 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store.h @@ -126,7 +126,7 @@ public: return get_dynamic_array<typename TypeMapper::DynamicBufferType>(bufferAndMeta.get_buffer_acquire(), internalRef.offset(), bufferAndMeta.get_entry_size()); } } - return getSmallArray(internalRef, bufferAndMeta.getArraySize()); + return getSmallArray(internalRef, bufferAndMeta.get_array_size()); } else { return getLargeArray(internalRef); } @@ -196,6 +196,7 @@ public: static ArrayStoreConfig optimizedConfigForHugePage(uint32_t max_type_id, size_t hugePageSize, size_t smallPageSize, + size_t max_buffer_size, size_t min_num_entries_for_new_buffer, float allocGrowFactor); @@ -203,6 +204,7 @@ public: const TypeMapper& mapper, size_t hugePageSize, size_t smallPageSize, + size_t max_buffer_size, size_t min_num_entries_for_new_buffer, float allocGrowFactor); }; diff --git a/vespalib/src/vespa/vespalib/datastore/array_store.hpp b/vespalib/src/vespa/vespalib/datastore/array_store.hpp index 211176b8ad0..bfd4ff0430a 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store.hpp +++ b/vespalib/src/vespa/vespalib/datastore/array_store.hpp @@ -252,6 +252,7 @@ ArrayStoreConfig ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id, size_t hugePageSize, size_t smallPageSize, + size_t max_buffer_size, size_t min_num_entries_for_new_buffer, float allocGrowFactor) { @@ -260,6 +261,7 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_ty mapper, hugePageSize, smallPageSize, + max_buffer_size, min_num_entries_for_new_buffer, allocGrowFactor); } @@ -267,17 +269,19 @@ ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_ty template <typename ElemT, typename RefT, typename TypeMapperT> ArrayStoreConfig ArrayStore<ElemT, RefT, TypeMapperT>::optimizedConfigForHugePage(uint32_t max_type_id, - const TypeMapper& mapper, - size_t hugePageSize, - size_t smallPageSize, - size_t min_num_entries_for_new_buffer, - float allocGrowFactor) + const TypeMapper& mapper, + size_t hugePageSize, + size_t smallPageSize, + size_t max_buffer_size, + size_t min_num_entries_for_new_buffer, + float allocGrowFactor) { return ArrayStoreConfig::optimizeForHugePage(mapper.get_max_type_id(max_type_id), [&](uint32_t type_id) noexcept { return mapper.get_entry_size(type_id); }, hugePageSize, smallPageSize, RefT::offsetSize(), + max_buffer_size, min_num_entries_for_new_buffer, allocGrowFactor); } diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp index c7f0b69a85e..37f6fab96dc 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp +++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "array_store_config.h" +#include <algorithm> #include <cassert> namespace vespalib::datastore { @@ -42,6 +43,13 @@ alignToSmallPageSize(size_t value, size_t minLimit, size_t smallPageSize) return ((value - minLimit) / smallPageSize) * smallPageSize + minLimit; } +size_t +cap_max_entries(size_t max_entries, size_t max_buffer_size, size_t entry_size) +{ + size_t dynamic_max_entries = (max_buffer_size + (entry_size - 1)) / entry_size; + return std::min(max_entries, dynamic_max_entries); +} + } ArrayStoreConfig @@ -50,17 +58,21 @@ ArrayStoreConfig::optimizeForHugePage(uint32_t max_type_id, size_t hugePageSize, size_t smallPageSize, size_t maxEntryRefOffset, + size_t max_buffer_size, size_t min_num_entries_for_new_buffer, float allocGrowFactor) { AllocSpecVector allocSpecs; - allocSpecs.emplace_back(0, maxEntryRefOffset, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec; + auto entry_size = type_id_to_entry_size(max_type_id); + auto capped_max_entries = cap_max_entries(maxEntryRefOffset, max_buffer_size, entry_size); + allocSpecs.emplace_back(0, capped_max_entries, min_num_entries_for_new_buffer, allocGrowFactor); // large array spec; for (uint32_t type_id = 1; type_id <= max_type_id; ++type_id) { - size_t entry_size = type_id_to_entry_size(type_id); + entry_size = type_id_to_entry_size(type_id); + capped_max_entries = cap_max_entries(maxEntryRefOffset, max_buffer_size, entry_size); size_t num_entries_for_new_buffer = hugePageSize / entry_size; - num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, maxEntryRefOffset); + num_entries_for_new_buffer = capToLimits(num_entries_for_new_buffer, min_num_entries_for_new_buffer, capped_max_entries); num_entries_for_new_buffer = alignToSmallPageSize(num_entries_for_new_buffer, min_num_entries_for_new_buffer, smallPageSize); - allocSpecs.emplace_back(0, maxEntryRefOffset, num_entries_for_new_buffer, allocGrowFactor); + allocSpecs.emplace_back(0, capped_max_entries, num_entries_for_new_buffer, allocGrowFactor); } return ArrayStoreConfig(allocSpecs); } diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_config.h b/vespalib/src/vespa/vespalib/datastore/array_store_config.h index 3b62609d0f1..3967996c64d 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_config.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store_config.h @@ -2,6 +2,7 @@ #pragma once +#include <vespa/vespalib/util/size_literals.h> #include <cstddef> #include <cstdint> #include <functional> @@ -39,6 +40,8 @@ public: using AllocSpecVector = std::vector<AllocSpec>; + static constexpr size_t default_max_buffer_size = 256_Mi; + private: AllocSpecVector _allocSpecs; bool _enable_free_lists; @@ -77,6 +80,7 @@ public: size_t hugePageSize, size_t smallPageSize, size_t maxEntryRefOffset, + size_t max_buffer_size, size_t min_num_entries_for_new_buffer, float allocGrowFactor); }; diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h index 73c998e82a5..6797b2a79b4 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h +++ b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.h @@ -36,9 +36,9 @@ public: using LargeBufferType = vespalib::datastore::LargeArrayBufferType<ElemT>; ArrayStoreDynamicTypeMapper(); - ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor); + ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size); ~ArrayStoreDynamicTypeMapper(); - void setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor); + void setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size); size_t get_entry_size(uint32_t type_id) const; bool is_dynamic_buffer(uint32_t type_id) const noexcept { return type_id > _max_static_array_buffer_type_id; } uint32_t count_dynamic_buffer_types(uint32_t max_type_id) const noexcept { return (max_type_id > _max_static_array_buffer_type_id) ? (max_type_id - _max_static_array_buffer_type_id) : 0u; } diff --git a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp index e74cd92e6aa..48de5cf5332 100644 --- a/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp +++ b/vespalib/src/vespa/vespalib/datastore/array_store_dynamic_type_mapper.hpp @@ -18,16 +18,16 @@ ArrayStoreDynamicTypeMapper<ElemT>::ArrayStoreDynamicTypeMapper() } template <typename ElemT> -ArrayStoreDynamicTypeMapper<ElemT>::ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor) +ArrayStoreDynamicTypeMapper<ElemT>::ArrayStoreDynamicTypeMapper(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size) : ArrayStoreTypeMapper(), _max_static_array_buffer_type_id(0) { - setup_array_sizes(max_buffer_type_id, grow_factor); + setup_array_sizes(max_buffer_type_id, grow_factor, max_buffer_size); } template <typename ElemT> void -ArrayStoreDynamicTypeMapper<ElemT>::setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor) +ArrayStoreDynamicTypeMapper<ElemT>::setup_array_sizes(uint32_t max_buffer_type_id, double grow_factor, size_t max_buffer_size) { _array_sizes.clear(); _array_sizes.reserve(max_buffer_type_id + 1); @@ -49,7 +49,8 @@ ArrayStoreDynamicTypeMapper<ElemT>::setup_array_sizes(uint32_t max_buffer_type_i entry_size = array_size * sizeof(ElemT); } } - if (entry_size > std::numeric_limits<uint32_t>::max()) { + if (entry_size > std::numeric_limits<uint32_t>::max() || + entry_size >= 2 * max_buffer_size) { break; } _array_sizes.emplace_back(array_size); diff --git a/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp b/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp index bb34c5d0f9d..1b087a01c58 100644 --- a/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp +++ b/vespalib/src/vespa/vespalib/datastore/buffer_type.cpp @@ -27,6 +27,7 @@ BufferTypeBase::CleanContext::extraBytesCleaned(size_t value) } BufferTypeBase::BufferTypeBase(uint32_t entry_size_in, + uint32_t buffer_underflow_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries, @@ -34,6 +35,7 @@ BufferTypeBase::BufferTypeBase(uint32_t entry_size_in, float allocGrowFactor) noexcept : _entry_size(entry_size_in), _arraySize(arraySize), + _buffer_underflow_size(buffer_underflow_size_in), _min_entries(std::min(min_entries, max_entries)), _max_entries(max_entries), _num_entries_for_new_buffer(std::min(num_entries_for_new_buffer, max_entries)), @@ -46,10 +48,11 @@ BufferTypeBase::BufferTypeBase(uint32_t entry_size_in, } BufferTypeBase::BufferTypeBase(uint32_t entry_size_in, + uint32_t buffer_underflow_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept - : BufferTypeBase(entry_size_in, arraySize, min_entries, max_entries, 0u, DEFAULT_ALLOC_GROW_FACTOR) + : BufferTypeBase(entry_size_in, buffer_underflow_size_in, arraySize, min_entries, max_entries, 0u, DEFAULT_ALLOC_GROW_FACTOR) { } @@ -117,6 +120,12 @@ BufferTypeBase::get_memory_allocator() const return nullptr; } +bool +BufferTypeBase::is_dynamic_array_buffer_type() const noexcept +{ + return false; +} + void BufferTypeBase::clamp_max_entries(uint32_t max_entries) { diff --git a/vespalib/src/vespa/vespalib/datastore/buffer_type.h b/vespalib/src/vespa/vespalib/datastore/buffer_type.h index 3370bd47fad..7b23a238ba2 100644 --- a/vespalib/src/vespa/vespalib/datastore/buffer_type.h +++ b/vespalib/src/vespa/vespalib/datastore/buffer_type.h @@ -39,8 +39,8 @@ public: BufferTypeBase & operator=(const BufferTypeBase &rhs) = delete; BufferTypeBase(BufferTypeBase &&rhs) noexcept = default; BufferTypeBase & operator=(BufferTypeBase &&rhs) noexcept = default; - BufferTypeBase(uint32_t entry_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept; - BufferTypeBase(uint32_t entry_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries, + BufferTypeBase(uint32_t entry_size_in, uint32_t buffer_underflow_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept; + BufferTypeBase(uint32_t entry_size_in, uint32_t buffer_underflow_size_in, uint32_t arraySize, uint32_t min_entries, uint32_t max_entries, uint32_t num_entries_for_new_buffer, float allocGrowFactor) noexcept; virtual ~BufferTypeBase(); virtual void destroy_entries(void *buffer, EntryCount num_entries) = 0; @@ -57,6 +57,7 @@ public: */ virtual void initialize_reserved_entries(void *buffer, EntryCount reserved_entries) = 0; size_t entry_size() const noexcept { return _entry_size; } + uint32_t buffer_underflow_size() const noexcept { return _buffer_underflow_size; } virtual void clean_hold(void *buffer, size_t offset, EntryCount num_entries, CleanContext cleanCtx) = 0; size_t getArraySize() const noexcept { return _arraySize; } virtual void on_active(uint32_t bufferId, std::atomic<EntryCount>* used_entries, std::atomic<EntryCount>* dead_entries, void* buffer); @@ -64,6 +65,7 @@ public: virtual void on_free(EntryCount used_entries); void resume_primary_buffer(uint32_t buffer_id, std::atomic<EntryCount>* used_entries, std::atomic<EntryCount>* dead_entries); virtual const alloc::MemoryAllocator* get_memory_allocator() const; + virtual bool is_dynamic_array_buffer_type() const noexcept; /** * Calculate number of entries to allocate for new buffer given how many free entries are needed. @@ -114,6 +116,16 @@ protected: uint32_t _entry_size; // Number of bytes in an allocation unit uint32_t _arraySize; // Number of elements in an allocation unit + + /* + * Buffer underflow size is the size of an area before the start + * of the logical buffer that is safe to access (part of the same + * memory alloation as the buffer itself). This allows for data + * belonging to an entry to be placed at the end of what is normally + * the last part of the previos entry (e.g. dynamic array size + * for the dynamic array buffer type). + */ + uint32_t _buffer_underflow_size; uint32_t _min_entries; // Minimum number of entries to allocate in a buffer uint32_t _max_entries; // Maximum number of entries to allocate in a buffer // Number of entries needed before allocating a new buffer instead of just resizing the first one diff --git a/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp b/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp index 375c832d9fb..00d642be9bc 100644 --- a/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp +++ b/vespalib/src/vespa/vespalib/datastore/buffer_type.hpp @@ -8,13 +8,13 @@ namespace vespalib::datastore { template <typename ElemT, typename EmptyT> BufferType<ElemT, EmptyT>::BufferType(uint32_t arraySize, uint32_t min_entries, uint32_t max_entries) noexcept - : BufferTypeBase(arraySize * sizeof(ElemT), arraySize, min_entries, max_entries) + : BufferTypeBase(arraySize * sizeof(ElemT), 0u, arraySize, min_entries, max_entries) { } template <typename ElemT, typename EmptyT> BufferType<ElemT, EmptyT>::BufferType(uint32_t arraySize, uint32_t min_entries, uint32_t max_entries, uint32_t num_entries_for_new_buffer, float allocGrowFactor) noexcept - : BufferTypeBase(arraySize * sizeof(ElemT), arraySize, min_entries, max_entries, num_entries_for_new_buffer, allocGrowFactor) + : BufferTypeBase(arraySize * sizeof(ElemT), 0u, arraySize, min_entries, max_entries, num_entries_for_new_buffer, allocGrowFactor) { } template <typename ElemT, typename EmptyT> diff --git a/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp b/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp index f312596d6f7..e7832a1c4e2 100644 --- a/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp +++ b/vespalib/src/vespa/vespalib/datastore/bufferstate.cpp @@ -64,13 +64,14 @@ calc_allocation(uint32_t bufferId, { size_t alloc_entries = typeHandler.calc_entries_to_alloc(bufferId, free_entries_needed, resizing); size_t entry_size = typeHandler.entry_size(); - size_t allocBytes = roundUpToMatchAllocator(alloc_entries * entry_size); - size_t maxAllocBytes = typeHandler.get_max_entries() * entry_size; + auto buffer_underflow_size = typeHandler.buffer_underflow_size(); + size_t allocBytes = roundUpToMatchAllocator(alloc_entries * entry_size + buffer_underflow_size); + size_t maxAllocBytes = typeHandler.get_max_entries() * entry_size + buffer_underflow_size; if (allocBytes > maxAllocBytes) { // Ensure that allocated bytes does not exceed the maximum handled by this type. allocBytes = maxAllocBytes; } - size_t adjusted_alloc_entries = allocBytes / entry_size; + size_t adjusted_alloc_entries = (allocBytes - buffer_underflow_size) / entry_size; return AllocResult(adjusted_alloc_entries, allocBytes); } @@ -102,7 +103,8 @@ BufferState::on_active(uint32_t bufferId, uint32_t typeId, _buffer = (allocator != nullptr) ? Alloc::alloc_with_allocator(allocator) : Alloc::alloc(0, MemoryAllocator::HUGEPAGE_SIZE); _buffer.create(alloc.bytes).swap(_buffer); assert(_buffer.get() != nullptr || alloc.entries == 0u); - buffer.store(_buffer.get(), std::memory_order_release); + auto buffer_underflow_size = typeHandler->buffer_underflow_size(); + buffer.store(get_buffer(buffer_underflow_size), std::memory_order_release); _stats.set_alloc_entries(alloc.entries); _typeHandler.store(typeHandler, std::memory_order_release); assert(typeId <= std::numeric_limits<uint16_t>::max()); @@ -117,28 +119,30 @@ void BufferState::onHold(uint32_t buffer_id) { assert(getState() == State::ACTIVE); - assert(getTypeHandler() != nullptr); + auto type_handler = getTypeHandler(); + assert(type_handler != nullptr); _state.store(State::HOLD, std::memory_order_release); _compacting = false; assert(_stats.dead_entries() <= size()); assert(_stats.hold_entries() <= (size() - _stats.dead_entries())); _stats.set_dead_entries(0); _stats.set_hold_entries(size()); - getTypeHandler()->on_hold(buffer_id, &_stats.used_entries_ref(), &_stats.dead_entries_ref()); + type_handler->on_hold(buffer_id, &_stats.used_entries_ref(), &_stats.dead_entries_ref()); _free_list.disable(); } void BufferState::onFree(std::atomic<void*>& buffer) { - assert(buffer.load(std::memory_order_relaxed) == _buffer.get()); assert(getState() == State::HOLD); - assert(_typeHandler != nullptr); + auto type_handler = getTypeHandler(); + assert(type_handler != nullptr); + assert(buffer.load(std::memory_order_relaxed) == get_buffer(type_handler->buffer_underflow_size())); assert(_stats.dead_entries() <= size()); assert(_stats.hold_entries() == (size() - _stats.dead_entries())); - getTypeHandler()->destroy_entries(buffer, size()); + type_handler->destroy_entries(buffer, size()); Alloc::alloc().swap(_buffer); - getTypeHandler()->on_free(size()); + type_handler->on_free(size()); buffer.store(nullptr, std::memory_order_release); _stats.clear(); _state.store(State::FREE, std::memory_order_release); @@ -200,9 +204,11 @@ BufferState::free_entries(EntryRef ref, size_t num_entries, size_t ref_offset) } _stats.inc_dead_entries(num_entries); _stats.dec_hold_entries(num_entries); - getTypeHandler()->clean_hold(_buffer.get(), ref_offset, num_entries, - BufferTypeBase::CleanContext(_stats.extra_used_bytes_ref(), - _stats.extra_hold_bytes_ref())); + auto type_handler = getTypeHandler(); + auto buffer_underflow_size = type_handler->buffer_underflow_size(); + type_handler->clean_hold(get_buffer(buffer_underflow_size), ref_offset, num_entries, + BufferTypeBase::CleanContext(_stats.extra_used_bytes_ref(), + _stats.extra_hold_bytes_ref())); } void @@ -212,17 +218,19 @@ BufferState::fallback_resize(uint32_t bufferId, Alloc &holdBuffer) { assert(getState() == State::ACTIVE); - assert(_typeHandler != nullptr); + auto type_handler = getTypeHandler(); + assert(type_handler != nullptr); assert(holdBuffer.get() == nullptr); - AllocResult alloc = calc_allocation(bufferId, *_typeHandler, free_entries_needed, true); + auto buffer_underflow_size = type_handler->buffer_underflow_size(); + AllocResult alloc = calc_allocation(bufferId, *type_handler, free_entries_needed, true); assert(alloc.entries >= size() + free_entries_needed); assert(alloc.entries > capacity()); Alloc newBuffer = _buffer.create(alloc.bytes); - getTypeHandler()->fallback_copy(newBuffer.get(), buffer.load(std::memory_order_relaxed), size()); + type_handler->fallback_copy(get_buffer(newBuffer, buffer_underflow_size), buffer.load(std::memory_order_relaxed), size()); holdBuffer.swap(_buffer); std::atomic_thread_fence(std::memory_order_release); _buffer = std::move(newBuffer); - buffer.store(_buffer.get(), std::memory_order_release); + buffer.store(get_buffer(buffer_underflow_size), std::memory_order_release); _stats.set_alloc_entries(alloc.entries); } diff --git a/vespalib/src/vespa/vespalib/datastore/bufferstate.h b/vespalib/src/vespa/vespalib/datastore/bufferstate.h index 289be32e19b..01439586f5b 100644 --- a/vespalib/src/vespa/vespalib/datastore/bufferstate.h +++ b/vespalib/src/vespa/vespalib/datastore/bufferstate.h @@ -48,6 +48,8 @@ private: bool _disable_entry_hold_list : 1; bool _compacting : 1; + static void *get_buffer(Alloc& buffer, uint32_t buffer_underflow_size) noexcept { return static_cast<char *>(buffer.get()) + buffer_underflow_size; } + void *get_buffer(uint32_t buffer_underflow_size) noexcept { return get_buffer(_buffer, buffer_underflow_size); } public: /** * TODO: Check if per-buffer free lists are useful, or if @@ -137,24 +139,28 @@ public: void* get_buffer_relaxed() noexcept { return _buffer.load(std::memory_order_relaxed); } const void* get_buffer_acquire() const noexcept { return _buffer.load(std::memory_order_acquire); } uint32_t getTypeId() const { return _typeId; } - uint32_t getArraySize() const { return _arraySize; } + uint32_t get_array_size() const { return _array_size; } BufferState * get_state_relaxed() { return _state.load(std::memory_order_relaxed); } const BufferState * get_state_acquire() const { return _state.load(std::memory_order_acquire); } - uint32_t get_entry_size() const { return get_state_acquire()->getTypeHandler()->entry_size(); } + uint32_t get_entry_size() const noexcept { return _entry_size; } void setTypeId(uint32_t typeId) { _typeId = typeId; } - void setArraySize(uint32_t arraySize) { _arraySize = arraySize; } + void set_array_size(uint32_t arraySize) { _array_size = arraySize; } + void set_entry_size(uint32_t entry_size) noexcept { _entry_size = entry_size; } void set_state(BufferState * state) { _state.store(state, std::memory_order_release); } private: BufferAndMeta(void* buffer, BufferState * state, uint32_t typeId, uint32_t arraySize) : _buffer(buffer), _state(state), _typeId(typeId), - _arraySize(arraySize) + _array_size(arraySize) { } std::atomic<void*> _buffer; std::atomic<BufferState*> _state; uint32_t _typeId; - uint32_t _arraySize; + union { + uint32_t _array_size; // Valid unless buffer type is dynamic array buffer type + uint32_t _entry_size; // Valid if buffer type is dynamic array buffer type + }; }; } diff --git a/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp b/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp index 75ffe855a32..5c88900ae92 100644 --- a/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp +++ b/vespalib/src/vespa/vespalib/datastore/datastorebase.cpp @@ -59,7 +59,8 @@ DataStoreBase::FallbackHold::FallbackHold(size_t bytesSize, BufferState::Alloc & DataStoreBase::FallbackHold::~FallbackHold() { - _typeHandler->destroy_entries(_buffer.get(), _used_entries); + auto buffer_underflow_size = _typeHandler->buffer_underflow_size(); + _typeHandler->destroy_entries(static_cast<char *>(_buffer.get()) + buffer_underflow_size, _used_entries); } class DataStoreBase::BufferHold : public GenerationHeldBase { @@ -415,7 +416,11 @@ DataStoreBase::on_active(uint32_t bufferId, uint32_t typeId, size_t entries_need assert(state->isFree()); state->on_active(bufferId, typeId, _typeHandlers[typeId], entries_needed, bufferMeta.get_atomic_buffer()); bufferMeta.setTypeId(typeId); - bufferMeta.setArraySize(state->getArraySize()); + if (_typeHandlers[typeId]->is_dynamic_array_buffer_type()) { + bufferMeta.set_entry_size(_typeHandlers[typeId]->entry_size()); + } else { + bufferMeta.set_array_size(state->getArraySize()); + } if (_freeListsEnabled && state->isActive() && !state->getCompacting()) { state->enable_free_list(_free_lists[state->getTypeId()]); } diff --git a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h index fbd3c2361d1..62b311f6aee 100644 --- a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h +++ b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.h @@ -7,6 +7,7 @@ #include "array_store_config.h" #include <algorithm> #include <memory> +#include <type_traits> namespace vespalib::datastore { @@ -26,12 +27,14 @@ class DynamicArrayBufferType : public BufferTypeBase { using AllocSpec = ArrayStoreConfig::AllocSpec; std::shared_ptr<alloc::MemoryAllocator> _memory_allocator; + public: using ElemType = ElemT; static constexpr size_t entry_min_align = std::max(alignof(uint32_t), alignof(ElemT)); using EntryMinAligner = Aligner<entry_min_align>; - static constexpr size_t entry_bias = EntryMinAligner::align(sizeof(uint32_t)); + static constexpr uint32_t dynamic_array_buffer_underflow_size = 64u; + static constexpr bool align_for_simd = std::disjunction_v<std::is_same<ElemT, double>,std::is_same<ElemT, float>, std::is_same<ElemT, int64_t>, std::is_same<ElemT, int32_t>>; protected: static const ElemType& empty_entry() noexcept; ElemType* get_entry(void *buffer, size_t offset) noexcept { return get_entry(buffer, offset, entry_size()); } @@ -55,10 +58,11 @@ public: const vespalib::alloc::MemoryAllocator* get_memory_allocator() const override; static size_t calc_entry_size(size_t array_size) noexcept; static size_t calc_array_size(size_t entry_size) noexcept; - static ElemType* get_entry(void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<ElemType*>(static_cast<char*>(buffer) + offset * entry_size + entry_bias); } - static const ElemType* get_entry(const void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<const ElemType*>(static_cast<const char*>(buffer) + offset * entry_size + entry_bias); } + static ElemType* get_entry(void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<ElemType*>(static_cast<char*>(buffer) + offset * entry_size); } + static const ElemType* get_entry(const void* buffer, size_t offset, uint32_t entry_size) noexcept { return reinterpret_cast<const ElemType*>(static_cast<const char*>(buffer) + offset * entry_size); } static uint32_t get_dynamic_array_size(const ElemType* buffer) noexcept { return *(reinterpret_cast<const uint32_t*>(buffer) - 1); } static void set_dynamic_array_size(ElemType* buffer, uint32_t array_size) noexcept { *(reinterpret_cast<uint32_t*>(buffer) - 1) = array_size; } + bool is_dynamic_array_buffer_type() const noexcept override; }; extern template class DynamicArrayBufferType<char>; diff --git a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp index bf3235a6b97..b208fa627e4 100644 --- a/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp +++ b/vespalib/src/vespa/vespalib/datastore/dynamic_array_buffer_type.hpp @@ -9,7 +9,7 @@ namespace vespalib::datastore { template <typename ElemT> DynamicArrayBufferType<ElemT>::DynamicArrayBufferType(uint32_t array_size, const AllocSpec& spec, std::shared_ptr<alloc::MemoryAllocator> memory_allocator) noexcept - : BufferTypeBase(calc_entry_size(array_size), array_size, spec.min_entries_in_buffer, spec.max_entries_in_buffer, spec.num_entries_for_new_buffer, spec.allocGrowFactor), + : BufferTypeBase(calc_entry_size(array_size), dynamic_array_buffer_underflow_size, array_size, spec.min_entries_in_buffer, spec.max_entries_in_buffer, spec.num_entries_for_new_buffer, spec.allocGrowFactor), _memory_allocator(std::move(memory_allocator)) { } @@ -24,14 +24,18 @@ template <typename ElemT> size_t DynamicArrayBufferType<ElemT>::calc_entry_size(size_t array_size) noexcept { - return EntryMinAligner::align(sizeof(ElemType) * array_size + entry_bias); + auto entry_size = EntryMinAligner::align(sizeof(ElemType) * array_size + sizeof(uint32_t)); + if (align_for_simd && entry_size >= 512) { + entry_size = Aligner<64>::align(entry_size); + } + return entry_size; } template <typename ElemT> size_t DynamicArrayBufferType<ElemT>::calc_array_size(size_t entry_size) noexcept { - return (entry_size - entry_bias) / sizeof(ElemType); + return (entry_size - sizeof(uint32_t)) / sizeof(ElemType); } template <typename ElemT> @@ -116,4 +120,11 @@ DynamicArrayBufferType<ElemT>::get_memory_allocator() const return _memory_allocator.get(); } +template <typename ElemT> +bool +DynamicArrayBufferType<ElemT>::is_dynamic_array_buffer_type() const noexcept +{ + return true; +} + } diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h index d3348950891..102808f7629 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_string_allocator.h @@ -117,7 +117,7 @@ public: auto &state = _store.getBufferMeta(iRef.bufferId()); auto type_id = state.getTypeId(); if (type_id != 0) { - return *reinterpret_cast<const UniqueStoreEntryBase *>(_store.template getEntryArray<char>(iRef, state.getArraySize())); + return *reinterpret_cast<const UniqueStoreEntryBase *>(_store.template getEntryArray<char>(iRef, state.get_array_size())); } else { return *_store.template getEntry<WrappedExternalEntryType>(iRef); } @@ -127,7 +127,7 @@ public: auto &state = _store.getBufferMeta(iRef.bufferId()); auto type_id = state.getTypeId(); if (type_id != 0) { - return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, state.getArraySize()))->value(); + return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, state.get_array_size()))->value(); } else { return _store.template getEntry<WrappedExternalEntryType>(iRef)->value().c_str(); } diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h index e507132a085..e71dcd3aafb 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h @@ -28,7 +28,7 @@ protected: const auto &meta = _store.getBufferMeta(iRef.bufferId()); auto type_id = meta.getTypeId(); if (type_id != 0) { - return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, meta.getArraySize()))->value(); + return reinterpret_cast<const UniqueStoreSmallStringEntry *>(_store.template getEntryArray<char>(iRef, meta.get_array_size()))->value(); } else { return _store.template getEntry<WrappedExternalEntryType>(iRef)->value().c_str(); } diff --git a/vespalib/src/vespa/vespalib/stllike/cache.h b/vespalib/src/vespa/vespalib/stllike/cache.h index f7456cda197..97ba23aba65 100644 --- a/vespalib/src/vespa/vespalib/stllike/cache.h +++ b/vespalib/src/vespa/vespalib/stllike/cache.h @@ -110,7 +110,7 @@ public: */ bool hasKey(const K & key) const; - CacheStats get_stats() const; + virtual CacheStats get_stats() const; size_t getHit() const { return _hit.load(std::memory_order_relaxed); } size_t getMiss() const { return _miss.load(std::memory_order_relaxed); } diff --git a/vespalib/src/vespa/vespalib/stllike/cache.hpp b/vespalib/src/vespa/vespalib/stllike/cache.hpp index 4e7736c9e5f..f767cf9812c 100644 --- a/vespalib/src/vespa/vespalib/stllike/cache.hpp +++ b/vespalib/src/vespa/vespalib/stllike/cache.hpp @@ -61,9 +61,7 @@ cache<P>::getStaticMemoryUsage() const { MemoryUsage usage; auto cacheGuard = getGuard(); usage.incAllocatedBytes(sizeof(*this)); - usage.incAllocatedBytes(Lru::capacity()*sizeof(typename Lru::value_type)); usage.incUsedBytes(sizeof(*this)); - usage.incUsedBytes(Lru::size()*sizeof(typename Lru::value_type)); return usage; } diff --git a/vespalib/src/vespa/vespalib/test/nexus.h b/vespalib/src/vespa/vespalib/test/nexus.h index 659b563e43c..7be35d42463 100644 --- a/vespalib/src/vespa/vespalib/test/nexus.h +++ b/vespalib/src/vespa/vespalib/test/nexus.h @@ -36,6 +36,7 @@ public: Nexus &operator=(const Nexus &) = delete; size_t num_threads() const noexcept { return _vote.size(); } size_t thread_id() const noexcept { return _thread_id; } + bool is_main() const noexcept { return _thread_id == 0; } bool vote(bool my_vote) { return _vote(my_vote); } void barrier() { REQUIRE_EQ(_vote(true), true); } struct select_thread_0 {}; diff --git a/vespalib/src/vespa/vespalib/test/thread_meets.h b/vespalib/src/vespa/vespalib/test/thread_meets.h index 7ef4dcb9921..26ca560641d 100644 --- a/vespalib/src/vespa/vespalib/test/thread_meets.h +++ b/vespalib/src/vespa/vespalib/test/thread_meets.h @@ -38,8 +38,8 @@ struct ThreadMeets { explicit Sum(size_t N) : vespalib::Rendezvous<T,T>(N) {} T operator()(T value) { return rendezvous(value); } void mingle() override { - T acc{}; - for (size_t i = 0; i < size(); ++i) { + T acc = in(0); + for (size_t i = 1; i < size(); ++i) { acc += in(i); } for (size_t i = 0; i < size(); ++i) { @@ -47,6 +47,48 @@ struct ThreadMeets { } } }; + // maximum of values across all threads + template <typename T> + struct Max : vespalib::Rendezvous<T,T> { + using vespalib::Rendezvous<T,T>::in; + using vespalib::Rendezvous<T,T>::out; + using vespalib::Rendezvous<T,T>::size; + using vespalib::Rendezvous<T,T>::rendezvous; + explicit Max(size_t N) : vespalib::Rendezvous<T,T>(N) {} + T operator()(T value) { return rendezvous(value); } + void mingle() override { + T max = in(0); + for (size_t i = 1; i < size(); ++i) { + if (in(i) > max) { + max = in(i); + } + } + for (size_t i = 0; i < size(); ++i) { + out(i) = max; + } + } + }; + // minimum of values across all threads + template <typename T> + struct Min : vespalib::Rendezvous<T,T> { + using vespalib::Rendezvous<T,T>::in; + using vespalib::Rendezvous<T,T>::out; + using vespalib::Rendezvous<T,T>::size; + using vespalib::Rendezvous<T,T>::rendezvous; + explicit Min(size_t N) : vespalib::Rendezvous<T,T>(N) {} + T operator()(T value) { return rendezvous(value); } + void mingle() override { + T min = in(0); + for (size_t i = 1; i < size(); ++i) { + if (in(i) < min) { + min = in(i); + } + } + for (size_t i = 0; i < size(); ++i) { + out(i) = min; + } + } + }; // range of values across all threads template <typename T> struct Range : vespalib::Rendezvous<T,T> { diff --git a/vespamalloc/src/vespamalloc/malloc/threadlist.hpp b/vespamalloc/src/vespamalloc/malloc/threadlist.hpp index 743090a4e12..e22b93c48fe 100644 --- a/vespamalloc/src/vespamalloc/malloc/threadlist.hpp +++ b/vespamalloc/src/vespamalloc/malloc/threadlist.hpp @@ -2,9 +2,14 @@ #pragma once #include "threadlist.h" +#include <malloc.h> namespace vespamalloc { +namespace { + const char * VESPA_MALLOC_MMAP_THRESHOLD = "VESPA_MALLOC_MMAP_THRESHOLD"; +} + template <typename MemBlockPtrT, typename ThreadStatT> ThreadListT<MemBlockPtrT, ThreadStatT>::ThreadListT(AllocPool & allocPool, MMapPool & mmapPool) : _isThreaded(false), @@ -13,8 +18,14 @@ ThreadListT<MemBlockPtrT, ThreadStatT>::ThreadListT(AllocPool & allocPool, MMapP _allocPool(allocPool), _mmapPool(mmapPool) { + const char * mmapThresholdS = getenv(VESPA_MALLOC_MMAP_THRESHOLD); + int mmapThreshold = (mmapThresholdS != nullptr) + ? strtol(mmapThresholdS, nullptr, 0) + : MMAP_LIMIT_DEFAULT; for (size_t i = 0; i < getMaxNumThreads(); i++) { - _threadVector[i].setPool(_allocPool, _mmapPool); + auto & thread = _threadVector[i]; + thread.setPool(_allocPool, _mmapPool); + thread.mallopt(M_MMAP_THRESHOLD, mmapThreshold); } } diff --git a/vespamalloc/src/vespamalloc/malloc/threadpool.h b/vespamalloc/src/vespamalloc/malloc/threadpool.h index 30ece02ba29..750833084ca 100644 --- a/vespamalloc/src/vespamalloc/malloc/threadpool.h +++ b/vespamalloc/src/vespamalloc/malloc/threadpool.h @@ -9,6 +9,10 @@ namespace vespamalloc { +constexpr int MMAP_LIMIT_MIN = 0x100000; // 1M +constexpr int MMAP_LIMIT_DEFAULT = 0x4000000; // 64M +constexpr int MMAP_LIMIT_MAX = 0x40000000; // 1G + template <typename MemBlockPtrT, typename ThreadStatT > class ThreadPoolT { diff --git a/vespamalloc/src/vespamalloc/malloc/threadpool.hpp b/vespamalloc/src/vespamalloc/malloc/threadpool.hpp index b5a283f6600..e62fa0f2fdf 100644 --- a/vespamalloc/src/vespamalloc/malloc/threadpool.hpp +++ b/vespamalloc/src/vespamalloc/malloc/threadpool.hpp @@ -7,8 +7,10 @@ namespace vespamalloc { namespace { - constexpr size_t MMAP_LIMIT_MIN = 0x100000; // 1M - constexpr size_t MMAP_LIMIT_MAX = 0x40000000; // 1G + size_t + sanitizeMMapThreshold(int threshold) { + return std::min(MMAP_LIMIT_MAX, std::max(MMAP_LIMIT_MIN, threshold)); + } } template <typename MemBlockPtrT, typename ThreadStatT> @@ -112,9 +114,8 @@ ThreadPoolT<MemBlockPtrT, ThreadStatT>::~ThreadPoolT() = default; template <typename MemBlockPtrT, typename ThreadStatT > int ThreadPoolT<MemBlockPtrT, ThreadStatT>::mallopt(int param, int value) { - size_t limit = value; if (param == M_MMAP_THRESHOLD) { - _mmapLimit = std::min(MMAP_LIMIT_MAX, std::max(MMAP_LIMIT_MIN, limit)); + _mmapLimit = sanitizeMMapThreshold(value); return 1; } return 0; |