diff options
author | Øyvind Grønnesby <oyving@yahooinc.com> | 2023-09-26 19:05:27 +0200 |
---|---|---|
committer | Øyvind Grønnesby <oyving@yahooinc.com> | 2023-09-26 19:05:27 +0200 |
commit | 49ee4fa2b5230ffe8f3a0b39d9b34880d2191a2a (patch) | |
tree | 807c5afb5ef993fce87b6a90222b92973484a224 | |
parent | 66a1cd6927cdf15e50bb1c477642912ab8d68d6c (diff) | |
parent | a3f1ddde551d7c3092bcbdfb745e4e178da9be0f (diff) |
Merge remote-tracking branch 'origin/master' into ogronnesby/billing-report-customer
169 files changed, 2690 insertions, 1070 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java index 215439aa42a..0276985d6a6 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ApplicationClusterEndpoint.java @@ -2,8 +2,15 @@ package com.yahoo.config.model.api; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.SystemName; + import java.util.List; import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Represents one endpoint for an application cluster @@ -147,21 +154,77 @@ public class ApplicationClusterEndpoint { } - public record DnsName(String name) implements Comparable<DnsName> { + public static class DnsName implements Comparable<DnsName> { + + private static final int MAX_LABEL_LENGTH = 63; + + private final String name; + + private DnsName(String name) { + this.name = name; + } public String value() { return name; } + // TODO(mpolden): Remove when config-models < 8.232 are gone + public static DnsName sharedL4NameFrom(SystemName systemName, ClusterSpec.Id cluster, ApplicationId applicationId, String suffix) { + String name = dnsParts(systemName, cluster, applicationId) + .filter(Objects::nonNull) // remove null values that were "default" + .map(DnsName::sanitize) + .collect(Collectors.joining(".")); + return new DnsName(name + suffix); + } + public static DnsName from(String name) { return new DnsName(name); } + private static Stream<String> dnsParts(SystemName systemName, ClusterSpec.Id cluster, ApplicationId applicationId) { + return Stream.of( + nullIfDefault(cluster.value()), + systemPart(systemName), + nullIfDefault(applicationId.instance().value()), + applicationId.application().value(), + applicationId.tenant().value() + ); + } + + /** + * Remove any invalid characters from the hostnames + */ + private static String sanitize(String id) { + return shortenIfNeeded(id.toLowerCase() + .replace('_', '-') + .replaceAll("[^a-z0-9-]*", "")); + } + + /** + * Truncate the given string at the front so its length does not exceed 63 characters. + */ + private static String shortenIfNeeded(String id) { + return id.substring(Math.max(0, id.length() - MAX_LABEL_LENGTH)); + } + + private static String nullIfDefault(String string) { + return Optional.of(string).filter(s -> !s.equals("default")).orElse(null); + } + + private static String systemPart(SystemName systemName) { + return "cd".equals(systemName.value()) ? systemName.value() : null; + } + + @Override + public String toString() { + return "DnsName{" + + "name='" + name + '\'' + + '}'; + } + @Override public int compareTo(DnsName o) { return name.compareTo(o.name); } - } - } diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index 446c32801e0..57d013ebd01 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -47,6 +47,7 @@ public interface ModelContext { default Optional<? extends Reindexing> reindexing() { return Optional.empty(); } Properties properties(); default Optional<File> appDir() { return Optional.empty();} + OnnxModelCost onnxModelCost(); /** The Docker image repo we want to use for images for this deployment (optional, will use default if empty) */ default Optional<DockerImage> wantedDockerImageRepo() { return Optional.empty(); } @@ -118,6 +119,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"jonmv"}) default boolean useReconfigurableDispatcher() { return false; } @ModelFeatureFlag(owners = {"vekterli"}) default int contentLayerMetadataFeatureLevel() { return 0; } @ModelFeatureFlag(owners = {"bjorncs"}) default boolean dynamicHeapSize() { return false; } + @ModelFeatureFlag(owners = {"hmusum"}) default String unknownConfigDefinition() { return "log"; } } /** Warning: As elsewhere in this package, do not make backwards incompatible changes that will break old config models! */ diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index 422ceba8074..595cd97e6b6 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -4,6 +4,7 @@ package com.yahoo.config.model.api; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; /** @@ -11,7 +12,7 @@ import com.yahoo.config.application.api.DeployLogger; */ public interface OnnxModelCost { - Calculator newCalculator(DeployLogger logger); + Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger); interface Calculator { long aggregatedModelCostInBytes(); @@ -20,7 +21,7 @@ public interface OnnxModelCost { } static OnnxModelCost disabled() { - return (__) -> new Calculator() { + return (__, ___) -> new Calculator() { @Override public long aggregatedModelCostInBytes() { return 0; } @Override public void registerModel(ApplicationFile path) {} @Override public void registerModel(ModelReference ref) {} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java deleted file mode 100644 index 76733872882..00000000000 --- a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.vespa.model; - -import com.yahoo.config.ModelReference; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.DeployLogger; -import com.yahoo.config.model.api.OnnxModelCost; - -import java.io.IOException; -import java.net.URI; -import java.net.http.HttpClient; -import java.net.http.HttpRequest; -import java.net.http.HttpResponse; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.time.Duration; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; -import java.util.logging.Level; - -import static com.yahoo.yolean.Exceptions.uncheck; - -/** - * Aggregates estimated footprint of configured ONNX models. - * - * @author bjorncs - */ -public class DefaultOnnxModelCost implements OnnxModelCost { - - @Override - public Calculator newCalculator(DeployLogger logger) { - return new CalculatorImpl(logger); - } - - private static class CalculatorImpl implements Calculator { - private final DeployLogger log; - - private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>(); - - private CalculatorImpl(DeployLogger log) { - this.log = log; - } - - @Override - public long aggregatedModelCostInBytes() { - return modelCost.values().stream().mapToLong(Long::longValue).sum(); - } - - @Override - public void registerModel(ApplicationFile f) { - String path = f.getPath().getRelative(); - if (alreadyAnalyzed(path)) return; - log.log(Level.FINE, () -> "Register model '%s'".formatted(path)); - deductJvmHeapSizeWithModelCost(f.exists() ? f.getSize() : 0, path); - } - - @Override - public void registerModel(ModelReference ref) { - log.log(Level.FINE, () -> "Register model '%s'".formatted(ref.toString())); - if (ref.path().isPresent()) { - var path = Paths.get(ref.path().get().value()); - var source = path.getFileName().toString(); - if (alreadyAnalyzed(source)) return; - deductJvmHeapSizeWithModelCost(uncheck(() -> Files.exists(path) ? Files.size(path) : 0), source); - } else if (ref.url().isPresent()) deductJvmHeapSizeWithModelCost(URI.create(ref.url().get().value())); - else throw new IllegalStateException(ref.toString()); - } - - private void deductJvmHeapSizeWithModelCost(URI uri) { - if (alreadyAnalyzed(uri.toString())) return; - if (uri.getScheme().equals("http") || uri.getScheme().equals("https")) { - try { - var timeout = Duration.ofSeconds(3); - var httpClient = HttpClient.newBuilder().connectTimeout(timeout).build(); - var request = HttpRequest.newBuilder(uri).timeout(timeout).method("HEAD", HttpRequest.BodyPublishers.noBody()).build(); - var response = httpClient.send(request, HttpResponse.BodyHandlers.discarding()); - var contentLength = response.headers().firstValue("Content-Length").orElse("0"); - log.log(Level.FINE, () -> "Got content length '%s' for '%s'".formatted(contentLength, uri)); - deductJvmHeapSizeWithModelCost(Long.parseLong(contentLength), uri.toString()); - } catch (IllegalArgumentException | InterruptedException | IOException e) { - log.log(Level.INFO, () -> "Failed to get model size for '%s': %s".formatted(uri, e.getMessage()), e); - } - } - } - - private void deductJvmHeapSizeWithModelCost(long size, String source) { - long fallbackModelSize = 1024*1024*1024; - long estimatedCost = Math.max(300*1024*1024, (long) (1.4D * (size > 0 ? size : fallbackModelSize) + 100*1024*1024)); - log.log(Level.FINE, () -> - "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source)); - modelCost.put(source, estimatedCost); - } - - private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); } - - private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); } - } -} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java index 727a18aee2c..269cb2dfa08 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java @@ -7,8 +7,8 @@ import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import ai.vespa.rankingexpression.importer.vespa.VespaImporter; import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; -import com.yahoo.component.annotation.Inject; import com.yahoo.component.Version; +import com.yahoo.component.annotation.Inject; import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.ValidationOverrides; @@ -21,7 +21,6 @@ import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelCreateResult; import com.yahoo.config.model.api.ModelFactory; -import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.application.provider.ApplicationPackageXmlFilesValidator; import com.yahoo.config.model.builder.xml.ConfigModelBuilder; @@ -199,7 +198,7 @@ public class VespaModelFactory implements ModelFactory { .now(clock.instant()) .wantedNodeVespaVersion(modelContext.wantedNodeVespaVersion()) .wantedDockerImageRepo(modelContext.wantedDockerImageRepo()) - .onnxModelCost(modelContext.properties().hostedVespa() ? new DefaultOnnxModelCost() : OnnxModelCost.disabled()); + .onnxModelCost(modelContext.onnxModelCost()); modelContext.previousModel().ifPresent(builder::previousModel); modelContext.reindexing().ifPresent(builder::reindexing); return builder.build(validationParameters); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java index aa3b8b3b821..d10a631fb90 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsProxyContainerCluster.java @@ -159,6 +159,7 @@ public class MetricsProxyContainerCluster extends ContainerCluster<MetricsProxyC builder.consumer.add(toConsumerBuilder(MetricsConsumer.defaultConsumer)); builder.consumer.add(toConsumerBuilder(newDefaultConsumer())); + if (isHostedVespa()) builder.consumer.add(toConsumerBuilder(MetricsConsumer.vespa9)); getAdmin() .map(Admin::getAmendedMetricsConsumers) .map(consumers -> consumers.stream().map(ConsumersConfigGenerator::toConsumerBuilder).toList()) diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricsConsumer.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricsConsumer.java index cfe3c01e03a..987812f11ad 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricsConsumer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricsConsumer.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.model.admin.monitoring; import ai.vespa.metrics.set.Metric; import ai.vespa.metrics.set.MetricSet; +import ai.vespa.metrics.set.Vespa9VespaMetricSet; import ai.vespa.metricsproxy.core.VespaMetrics; import ai.vespa.metricsproxy.http.ValuesFetcher; @@ -41,6 +42,9 @@ public class MetricsConsumer { public static final MetricsConsumer vespaCloud = consumer("vespa-cloud", vespaMetricSet, systemMetricSet, networkMetricSet); + public static final MetricsConsumer vespa9 = + consumer("Vespa9", Vespa9VespaMetricSet.vespa9vespaMetricSet, systemMetricSet, networkMetricSet); + private final String id; private final MetricSet metricSet; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java index 2c5e0db14b9..9e231239521 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidator.java @@ -28,7 +28,7 @@ public class JvmHeapSizeValidator extends Validator { } long jvmModelCost = appCluster.onnxModelCost().aggregatedModelCostInBytes(); if (jvmModelCost > 0) { - int percentLimit = 10; + int percentLimit = 15; if (mp.percentage() < percentLimit) { throw new IllegalArgumentException( ("Allocated percentage of memory of JVM in cluster '%s' is too low (%d%% < %d%%). " + @@ -36,7 +36,7 @@ public class JvmHeapSizeValidator extends Validator { "You may override this validation by specifying 'allocated-memory' (https://docs.vespa.ai/en/performance/container-tuning.html#jvm-heap-size).") .formatted(clusterId, mp.percentage(), percentLimit, jvmModelCost / (1024D * 1024 * 1024))); } - double gbLimit = 0.4; + double gbLimit = 0.6; double availableMemoryGb = mp.availableMemoryGb().getAsDouble(); if (availableMemoryGb < gbLimit) { throw new IllegalArgumentException( diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/UrlConfigValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/UrlConfigValidator.java new file mode 100644 index 00000000000..d9dd3729bd3 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/UrlConfigValidator.java @@ -0,0 +1,50 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.application.validation; + +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; + +/** + * Validates that config using s3:// urls is used in public system and with nodes that are exclusive. + * + * @author hmusum + */ +public class UrlConfigValidator extends Validator { + + @Override + public void validate(VespaModel model, DeployState state) { + if (! state.isHostedTenantApplication(model.getAdmin().getApplicationType())) return; + + model.getContainerClusters().forEach((__, cluster) -> { + var isExclusive = hasExclusiveNodes(model, cluster); + validateS3UlsInConfig(state, cluster, isExclusive); + }); + } + + private static boolean hasExclusiveNodes(VespaModel model, ApplicationContainerCluster cluster) { + return model.hostSystem().getHosts() + .stream() + .flatMap(hostResource -> hostResource.spec().membership().stream()) + .filter(membership -> membership.cluster().id().equals(cluster.id())) + .anyMatch(membership -> membership.cluster().isExclusive()); + } + + private static void validateS3UlsInConfig(DeployState state, ApplicationContainerCluster cluster, boolean isExclusive) { + if (hasS3UrlInConfig(cluster)) { + // TODO: Would be even better if we could add which config/field the url is set for in the error message + String message = "Found s3:// urls in config for container cluster " + cluster.getName(); + if ( ! state.zone().system().isPublic()) + throw new IllegalArgumentException(message + ". This is only supported in public systems"); + else if ( ! isExclusive) + throw new IllegalArgumentException(message + ". Nodes in the cluster need to be 'exclusive'," + + " see https://cloud.vespa.ai/en/reference/services#nodes"); + } + } + + private static boolean hasS3UrlInConfig(ApplicationContainerCluster cluster) { + return cluster.userConfiguredUrls().all().stream() + .anyMatch(url -> url.startsWith("s3://")); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java index b9ecf7c2d22..30aafe67be7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validation.java @@ -87,6 +87,7 @@ public class Validation { new AccessControlFilterExcludeValidator().validate(model, deployState); new CloudUserFilterValidator().validate(model, deployState); new CloudHttpConnectorValidator().validate(model, deployState); + new UrlConfigValidator().validate(model, deployState); new JvmHeapSizeValidator().validate(model, deployState); additionalValidators.forEach(v -> v.validate(model, deployState)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index da6e3387d6a..ac679cc406c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -43,6 +43,7 @@ import com.yahoo.vespa.model.filedistribution.UserConfiguredFiles; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Optional; @@ -104,6 +105,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat private List<ApplicationClusterEndpoint> endpoints = List.of(); + private final UserConfiguredUrls userConfiguredUrls = new UserConfiguredUrls(); + public ApplicationContainerCluster(TreeConfigProducer<?> parent, String configSubId, String clusterId, DeployState deployState) { super(parent, configSubId, clusterId, deployState, true, 10); this.tlsClientAuthority = deployState.tlsClientAuthority(); @@ -130,10 +133,13 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat heapSizePercentageOfAvailableMemory = deployState.featureFlags().heapSizePercentage() > 0 ? Math.min(99, deployState.featureFlags().heapSizePercentage()) : defaultHeapSizePercentageOfAvailableMemory; - onnxModelCost = deployState.onnxModelCost().newCalculator(deployState.getDeployLogger()); + onnxModelCost = deployState.onnxModelCost().newCalculator( + deployState.getApplicationPackage(), deployState.getDeployLogger()); logger = deployState.getDeployLogger(); } + public UserConfiguredUrls userConfiguredUrls() { return userConfiguredUrls; } + @Override protected void doPrepare(DeployState deployState) { super.doPrepare(deployState); @@ -154,7 +160,10 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat if (containers.isEmpty()) return; // Files referenced from user configs to all components. - UserConfiguredFiles files = new UserConfiguredFiles(deployState.getFileRegistry(), deployState.getDeployLogger()); + UserConfiguredFiles files = new UserConfiguredFiles(deployState.getFileRegistry(), + deployState.getDeployLogger(), + deployState.featureFlags(), + userConfiguredUrls); for (Component<?, ?> component : getAllComponents()) { files.register(component); } @@ -382,4 +391,14 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat } } + public static class UserConfiguredUrls { + + private final Set<String> urls = new HashSet<>(); + + public void add(String url) { urls.add(url); } + + public Set<String> all() { return urls; } + + } + } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java b/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java index 8bed5e64bf5..0095dec8079 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java @@ -5,12 +5,14 @@ import com.yahoo.config.FileReference; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.application.api.FileRegistry; +import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.producer.AnyConfigProducer; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.path.Path; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigPayloadBuilder; + import com.yahoo.yolean.Exceptions; import java.io.File; @@ -21,6 +23,8 @@ import java.util.Map; import java.util.Optional; import java.util.logging.Level; +import static com.yahoo.vespa.model.container.ApplicationContainerCluster.UserConfiguredUrls; + /** * Utility methods for registering file distribution of files/paths/urls/models defined by the user. * @@ -30,10 +34,16 @@ public class UserConfiguredFiles implements Serializable { private final FileRegistry fileRegistry; private final DeployLogger logger; + private final UserConfiguredUrls userConfiguredUrls; + private final String unknownConfigDefinition; - public UserConfiguredFiles(FileRegistry fileRegistry, DeployLogger logger) { + public UserConfiguredFiles(FileRegistry fileRegistry, DeployLogger logger, + ModelContext.FeatureFlags featureFlags, + UserConfiguredUrls userConfiguredUrls) { this.fileRegistry = fileRegistry; this.logger = logger; + this.userConfiguredUrls = userConfiguredUrls; + this.unknownConfigDefinition = featureFlags.unknownConfigDefinition(); } /** @@ -56,9 +66,12 @@ public class UserConfiguredFiles implements Serializable { private void register(ConfigPayloadBuilder builder, Map<Path, FileReference> registeredFiles, ConfigDefinitionKey key) { ConfigDefinition configDefinition = builder.getConfigDefinition(); if (configDefinition == null) { - // TODO: throw new IllegalArgumentException("Unable to find config definition for " + builder); - logger.logApplicationPackage(Level.INFO, "Unable to find config definition " + key + - ". Will not register files for file distribution for this config"); + String message = "Unable to find config definition " + key + ". Will not register files for file distribution for this config"; + switch (unknownConfigDefinition) { + case "log" -> logger.logApplicationPackage(Level.INFO, message); + case "warning" -> logger.logApplicationPackage(Level.WARNING, message); + case "fail" -> throw new IllegalArgumentException("Unable to find config definition for " + key); + } return; } @@ -133,7 +146,10 @@ public class UserConfiguredFiles implements Serializable { Path path; if (isModelType) { var modelReference = ModelReference.valueOf(builder.getValue()); - if (modelReference.path().isEmpty()) return; + if (modelReference.path().isEmpty()) { + modelReference.url().ifPresent(url -> userConfiguredUrls.add(url.value())); + return; + } path = Path.fromString(modelReference.path().get().value()); } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java index 7c86267c1b6..39a8e16fad5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -29,6 +29,7 @@ import java.util.Map; public class OnnxModelProbe { private static final String binary = "vespa-analyze-onnx-model"; + private static final ObjectMapper jsonParser = new ObjectMapper(); static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { TensorType outputType = TensorType.empty; @@ -41,8 +42,9 @@ public class OnnxModelProbe { // Otherwise, run vespa-analyze-onnx-model if the model is available if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); - String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); + var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); outputType = outputTypeFromJson(jsonOutput, outputName); + writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput)); if ( ! outputType.equals(TensorType.empty)) { writeProbedOutputType(app, modelPath, contextKey, outputType); } @@ -53,6 +55,16 @@ public class OnnxModelProbe { return outputType; } + private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException { + String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath(); + IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false); + } + + private static Path memoryStatsPath(Path modelPath) { + var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { StringBuilder key = new StringBuilder().append(onnxName).append(":"); inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) @@ -95,9 +107,7 @@ public class OnnxModelProbe { return TensorType.empty; } - private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { - ObjectMapper m = new ObjectMapper(); - JsonNode root = m.readTree(json); + private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException { if ( ! root.isObject() || ! root.has("outputs")) { return TensorType.empty; } @@ -123,7 +133,7 @@ public class OnnxModelProbe { return out.toString(); } - private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { + private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { StringBuilder output = new StringBuilder(); ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); @@ -148,7 +158,16 @@ public class OnnxModelProbe { throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " + "Output: '" + output + "'"); } - return output.toString(); + return jsonParser.readTree(output.toString()); + } + + public record MemoryStats(long vmSize, long vmRss) { + static MemoryStats fromJson(JsonNode json) { + return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong()); + } + JsonNode toJson() { + return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss); + } } } diff --git a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java index af05a144b79..3f5173a3ae9 100644 --- a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java +++ b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java @@ -10,6 +10,7 @@ import com.yahoo.config.model.api.ConfigDefinitionRepo; import com.yahoo.config.model.api.HostProvisioner; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.MockFileRegistry; @@ -84,4 +85,6 @@ public class MockModelContext implements ModelContext { public ExecutorService getExecutor() { return new InThreadExecutorService(); } + + @Override public OnnxModelCost onnxModelCost() { return OnnxModelCost.disabled(); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java index 49019e47bc2..eae4f12f62c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/admin/metricsproxy/MetricsConsumersTest.java @@ -52,11 +52,12 @@ public class MetricsConsumersTest { @Test void consumers_are_set_up_for_hosted() { ConsumersConfig config = consumersConfigFromXml(servicesWithAdminOnly(), hosted); - assertEquals(4, config.consumer().size()); + assertEquals(5, config.consumer().size()); assertEquals(MetricsConsumer.vespa.id(), config.consumer(0).name()); assertEquals(MetricsConsumer.autoscaling.id(), config.consumer(1).name()); assertEquals(MetricsConsumer.defaultConsumer.id(), config.consumer(2).name()); assertEquals(MetricsProxyContainerCluster.NEW_DEFAULT_CONSUMER_ID, config.consumer(3).name()); + assertEquals(MetricsConsumer.vespa9.id(), config.consumer(4).name()); } @Test @@ -124,7 +125,7 @@ public class MetricsConsumersTest { ); VespaModel hostedModel = getModel(services, hosted); ConsumersConfig config = consumersConfigFromModel(hostedModel); - assertEquals(4, config.consumer().size()); + assertEquals(5, config.consumer().size()); // All default metrics are retained ConsumersConfig.Consumer vespaConsumer = config.consumer(0); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 086f2fe778f..447614b8396 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.model.application.validation; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.OnnxModelCost; @@ -34,16 +35,16 @@ class JvmHeapSizeValidatorTest { var deployState = createDeployState(8, 7L * 1024 * 1024 * 1024); var model = new VespaModel(new NullConfigModelRegistry(), deployState); var e = assertThrows(IllegalArgumentException.class, () -> new JvmHeapSizeValidator().validate(model, deployState)); - String expectedMessage = "Allocated percentage of memory of JVM in cluster 'container' is too low (3% < 10%). Estimated cost of ONNX models is 7.00GB"; + String expectedMessage = "Allocated percentage of memory of JVM in cluster 'container' is too low (3% < 15%). Estimated cost of ONNX models is 7.00GB"; assertTrue(e.getMessage().contains(expectedMessage), e.getMessage()); } @Test void fails_on_too_low_heap_size() throws IOException, SAXException { - var deployState = createDeployState(2, 1024L * 1024 * 1024); + var deployState = createDeployState(2.2, 1024L * 1024 * 1024); var model = new VespaModel(new NullConfigModelRegistry(), deployState); var e = assertThrows(IllegalArgumentException.class, () -> new JvmHeapSizeValidator().validate(model, deployState)); - String expectedMessage = "Allocated memory to JVM in cluster 'container' is too low (0.30GB < 0.40GB). Estimated cost of ONNX models is 1.00GB."; + String expectedMessage = "Allocated memory to JVM in cluster 'container' is too low (0.50GB < 0.60GB). Estimated cost of ONNX models is 1.00GB."; assertTrue(e.getMessage().contains(expectedMessage), e.getMessage()); } @@ -112,7 +113,7 @@ class JvmHeapSizeValidatorTest { ModelCostDummy(long modelCost) { this.modelCost = modelCost; } - @Override public Calculator newCalculator(DeployLogger logger) { return this; } + @Override public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) { return this; } @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/UrlConfigValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/UrlConfigValidatorTest.java new file mode 100644 index 00000000000..cef4d8c27dd --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/UrlConfigValidatorTest.java @@ -0,0 +1,107 @@ +package com.yahoo.vespa.model.application.validation; + +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.NullConfigModelRegistry; +import com.yahoo.config.model.api.ConfigDefinitionRepo; +import com.yahoo.config.model.application.provider.MockFileRegistry; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.config.model.test.MockApplicationPackage; +import com.yahoo.config.provision.RegionName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.Zone; +import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.vespa.config.ConfigDefinitionKey; +import com.yahoo.vespa.config.buildergen.ConfigDefinition; +import com.yahoo.vespa.model.VespaModel; +import org.junit.jupiter.api.Test; +import org.xml.sax.SAXException; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static com.yahoo.config.provision.Environment.prod; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class UrlConfigValidatorTest { + + @Test + void failsWhenContainerNodesNotExclusive() throws IOException, SAXException { + runValidatorOnApp(true, SystemName.Public); // Exclusive nodes in public => success + + assertEquals("Found s3:// urls in config for container cluster default. This is only supported in public systems", + assertThrows(IllegalArgumentException.class, + () -> runValidatorOnApp(false, SystemName.main)) + .getMessage()); + + assertEquals("Found s3:// urls in config for container cluster default. This is only supported in public systems", + assertThrows(IllegalArgumentException.class, + () -> runValidatorOnApp(true, SystemName.main)) + .getMessage()); + + assertEquals("Found s3:// urls in config for container cluster default. Nodes in the cluster need to be 'exclusive'," + + " see https://cloud.vespa.ai/en/reference/services#nodes", + assertThrows(IllegalArgumentException.class, + () -> runValidatorOnApp(false, SystemName.Public)) + .getMessage()); + } + + private static String containerXml(boolean isExclusive) { + return """ + <container version='1.0' id='default'> + <component id='transformer' class='ai.vespa.embedding.BertBaseEmbedder' bundle='model-integration'> + <config name='embedding.bert-base-embedder'> + <transformerModel url='s3://models/minilm-l6-v2/sentence_all_MiniLM_L6_v2.onnx' path='foo'/> + <tokenizerVocab url='s3://models/bert-base-uncased.txt'/> + </config> + </component> + <search/> + <document-api/> + <nodes count='2' exclusive='%s' /> + </container> + """.formatted(Boolean.toString(isExclusive)); + } + + private static void runValidatorOnApp(boolean isExclusive, SystemName systemName) throws IOException, SAXException { + String container = containerXml(isExclusive); + String servicesXml = """ + <services version='1.0'> + %s + </services> + """.formatted(container); + ApplicationPackage app = new MockApplicationPackage.Builder() + .withServices(servicesXml) + .build(); + DeployState deployState = createDeployState(app, systemName); + VespaModel model = new VespaModel(new NullConfigModelRegistry(), deployState); + new UrlConfigValidator().validate(model, deployState); + } + + private static DeployState createDeployState(ApplicationPackage app, SystemName systemName) { + boolean isHosted = true; + var builder = new DeployState.Builder() + .applicationPackage(app) + .zone(new Zone(systemName, prod, RegionName.from("us-east-3"))) + .properties(new TestProperties().setHostedVespa(isHosted)) + .fileRegistry(new MockFileRegistry()); + + Map<ConfigDefinitionKey, ConfigDefinition> defs = new HashMap<>(); + defs.put(new ConfigDefinitionKey(BertBaseEmbedderConfig.CONFIG_DEF_NAME, BertBaseEmbedderConfig.CONFIG_DEF_NAMESPACE), + new ConfigDefinition(BertBaseEmbedderConfig.CONFIG_DEF_NAME, BertBaseEmbedderConfig.CONFIG_DEF_SCHEMA)); + builder.configDefinitionRepo(new ConfigDefinitionRepo() { + @Override + public Map<ConfigDefinitionKey, com.yahoo.vespa.config.buildergen.ConfigDefinition> getConfigDefinitions() { + return defs; + } + + @Override + public com.yahoo.vespa.config.buildergen.ConfigDefinition get(ConfigDefinitionKey key) { + return defs.get(key); + } + }); + return builder.build(); + } + +} diff --git a/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java b/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java index bb5ba840c2c..653bdbccf15 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java @@ -7,12 +7,14 @@ import com.yahoo.config.ModelReference; import com.yahoo.config.UrlReference; import com.yahoo.config.application.api.FileRegistry; import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.producer.UserConfigRepo; import com.yahoo.config.model.test.MockRoot; import com.yahoo.vespa.config.ConfigDefinition; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigPayloadBuilder; import com.yahoo.vespa.model.SimpleConfigProducer; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; @@ -68,7 +70,10 @@ public class UserConfiguredFilesTest { } private UserConfiguredFiles userConfiguredFiles() { - return new UserConfiguredFiles(fileRegistry, new BaseDeployLogger()); + return new UserConfiguredFiles(fileRegistry, + new BaseDeployLogger(), + new TestProperties(), + new ApplicationContainerCluster.UserConfiguredUrls()); } @BeforeEach diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/WireguardKeyWithTimestamp.java b/config-provisioning/src/main/java/com/yahoo/config/provision/WireguardKeyWithTimestamp.java new file mode 100644 index 00000000000..ecc1cf71113 --- /dev/null +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/WireguardKeyWithTimestamp.java @@ -0,0 +1,39 @@ +package com.yahoo.config.provision; + +import com.yahoo.jdisc.Timer; + +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Random; + +/** + * @author gjoranv + */ +public record WireguardKeyWithTimestamp(WireguardKey key, Instant timestamp) { + + public static final int KEY_ROTATION_BASE = 60; + public static final int KEY_ROTATION_VARIANCE = 10; + public static final int KEY_EXPIRY = KEY_ROTATION_BASE + KEY_ROTATION_VARIANCE + 5; + + public WireguardKeyWithTimestamp { + if (key == null) throw new IllegalArgumentException("Wireguard key cannot be null"); + if (timestamp == null) timestamp = Instant.EPOCH; + } + + public static WireguardKeyWithTimestamp from(String key, long msTimestamp) { + return new WireguardKeyWithTimestamp(WireguardKey.from(key), Instant.ofEpochMilli(msTimestamp)); + } + + public boolean isDueForRotation(Timer timer, ChronoUnit unit, Random random) { + return timer.currentTime().isAfter(keyRotationDueAt(unit, random)); + } + + public boolean hasExpired(Timer timer, ChronoUnit unit) { + return timer.currentTime().isAfter(timestamp.plus(KEY_EXPIRY, unit)); + } + + private Instant keyRotationDueAt(ChronoUnit unit, Random random) { + return timestamp.plus(KEY_ROTATION_BASE + random.nextInt(KEY_ROTATION_VARIANCE), unit); + } + +} diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/ZoneEndpoint.java b/config-provisioning/src/main/java/com/yahoo/config/provision/ZoneEndpoint.java index 5d5757ec79a..2959815dd28 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/ZoneEndpoint.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/ZoneEndpoint.java @@ -14,9 +14,16 @@ public class ZoneEndpoint { /** * Endpoint service generation. - * Bump this to provision new services, whenever we change regional endpoint names. - * This will cause new endpoint services to be provisioned, with new domain names. - * TODO: wire multiple service IDs to and through the controller. + * <p> + * This is used to transition to a new set of endpoint services, with new domain names. + * The procedure is: + * <ol> + * <li>Start using new endpoint names (in controller code), for <em>all</em> applications.</li> + * <li>Bump the generation counter here; this causes new services to be provisioned.</li> + * <li>Controller configures the new services with the new endpoint names.</li> + * <li>Let users migrate to the new endpoint names.</li> + * <li>Currently missing: clean up obsolete, unused endpoint services.</li> + * </ol> */ public static final int generation = 0; public static final ZoneEndpoint defaultEndpoint = new ZoneEndpoint(true, false, List.of()); diff --git a/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java b/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java index 3b24e1c1b8d..4f8d42e895b 100644 --- a/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java +++ b/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java @@ -111,7 +111,7 @@ public class FlagsHandlerTest { }, { "type": "blacklist", - "dimension": "application", + "dimension": "instance", "values": [ "app1", "app2" ] } ], @@ -127,7 +127,7 @@ public class FlagsHandlerTest { // GET on id2 should now return what was put verifySuccessfulRequest(Method.GET, "/data/" + FLAG2.id(), "", - "{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}"); + "{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}"); // The list of flag data should return id1 and id2 verifySuccessfulRequest(Method.GET, "/data", @@ -153,7 +153,7 @@ public class FlagsHandlerTest { // Get all recursivelly displays all flag data verifySuccessfulRequest(Method.GET, "/data?recursive=true", "", - "{\"flags\":[{\"id\":\"id1\",\"rules\":[{\"value\":false}]},{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}]}"); + "{\"flags\":[{\"id\":\"id1\",\"rules\":[{\"value\":false}]},{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}]}"); // Deleting both flags verifySuccessfulRequest(Method.DELETE, "/data/" + FLAG1.id(), "", ""); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index 9533f04107d..63faf806e9c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -1030,21 +1030,21 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye private Session validateThatLocalSessionIsNotActive(Tenant tenant, long sessionId) { Session session = getLocalSession(tenant, sessionId); if (Session.Status.ACTIVATE.equals(session.getStatus())) { - throw new IllegalArgumentException("Session is active: " + sessionId); + throw new IllegalArgumentException("Session " + sessionId + " for '" + tenant.getName() + "' is active"); } return session; } private Session getLocalSession(Tenant tenant, long sessionId) { Session session = tenant.getSessionRepository().getLocalSession(sessionId); - if (session == null) throw new NotFoundException("Session " + sessionId + " was not found"); + if (session == null) throw new NotFoundException("Local session " + sessionId + " for '" + tenant.getName() + "' was not found"); return session; } private RemoteSession getRemoteSession(Tenant tenant, long sessionId) { RemoteSession session = tenant.getSessionRepository().getRemoteSession(sessionId); - if (session == null) throw new NotFoundException("Session " + sessionId + " was not found"); + if (session == null) throw new NotFoundException("Remote session " + sessionId + " for '" + tenant.getName() + "' was not found"); return session; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/FallbackOnnxModelCostProvider.java b/configserver/src/main/java/com/yahoo/vespa/config/server/FallbackOnnxModelCostProvider.java new file mode 100644 index 00000000000..57cfb1cd43b --- /dev/null +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/FallbackOnnxModelCostProvider.java @@ -0,0 +1,16 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.config.server; + +import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.container.di.componentgraph.Provider; + +/** + * Default provider that provides a disabled {@link OnnxModelCost} instance. + * + * @author bjorncs + */ +public class FallbackOnnxModelCostProvider implements Provider<OnnxModelCost> { + @Override public OnnxModelCost get() { return OnnxModelCost.disabled(); } + @Override public void deconstruct() {} +} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java index 693252da43a..d86e0e3c340 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java @@ -26,6 +26,7 @@ import com.yahoo.vespa.config.server.tenant.TenantRepository; import com.yahoo.vespa.curator.CompletionTimeoutException; import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.curator.Lock; +import com.yahoo.vespa.curator.transaction.CuratorOperations; import com.yahoo.vespa.curator.transaction.CuratorTransaction; import com.yahoo.vespa.flags.FlagSource; import com.yahoo.vespa.flags.ListFlag; @@ -430,11 +431,19 @@ public class TenantApplications implements RequestHandler, HostValidator { public TenantFileSystemDirs getTenantFileSystemDirs() { return tenantFileSystemDirs; } public CompletionWaiter createRemoveApplicationWaiter(ApplicationId applicationId) { - return RemoveApplicationWaiter.createAndInitialize(curator, applicationId, serverId); + var barrierPath = barrierPath(applicationId); + return RemoveApplicationWaiter.createAndInitialize(curator, barrierPath, serverId); } public CompletionWaiter getRemoveApplicationWaiter(ApplicationId applicationId) { - return RemoveApplicationWaiter.create(curator, applicationId, serverId); + var barrierPath = barrierPath(applicationId); + return RemoveApplicationWaiter.create(curator, barrierPath, serverId); + } + + private Path barrierPath(ApplicationId applicationId) { + return TenantRepository.getBarriersPath().append(applicationId.tenant().value()) + .append("delete-application") + .append(applicationId.serializedForm()); } /** @@ -453,14 +462,12 @@ public class TenantApplications implements RequestHandler, HostValidator { private final Duration waitForAll; private final Clock clock = Clock.systemUTC(); - RemoveApplicationWaiter(Curator curator, ApplicationId applicationId, String serverId) { - this(curator, applicationId, serverId, waitForAllDefault); + RemoveApplicationWaiter(Curator curator, Path barrierPath, String serverId) { + this(curator, barrierPath, serverId, waitForAllDefault); } - RemoveApplicationWaiter(Curator curator, ApplicationId applicationId, String serverId, Duration waitForAll) { - this.barrierPath = TenantRepository.getBarriersPath().append(applicationId.tenant().value()) - .append("delete-application") - .append(applicationId.serializedForm()); + RemoveApplicationWaiter(Curator curator, Path barrierPath, String serverId, Duration waitForAll) { + this.barrierPath = barrierPath; this.waiterNode = barrierPath.append(serverId); this.curator = curator; this.waitForAll = waitForAll; @@ -542,34 +549,29 @@ public class TenantApplications implements RequestHandler, HostValidator { @Override public String toString() { return "'" + barrierPath + "', " + barrierMemberCount() + " members"; } - public static CompletionWaiter create(Curator curator, ApplicationId applicationId, String serverId) { - return new RemoveApplicationWaiter(curator, applicationId, serverId); + public static CompletionWaiter create(Curator curator, Path barrierPath, String serverId) { + return new RemoveApplicationWaiter(curator, barrierPath, serverId); } - public static CompletionWaiter create(Curator curator, ApplicationId applicationId, String serverId, Duration waitForAll) { - return new RemoveApplicationWaiter(curator, applicationId, serverId, waitForAll); + public static CompletionWaiter create(Curator curator, Path barrierPath, String serverId, Duration waitForAll) { + return new RemoveApplicationWaiter(curator, barrierPath, serverId, waitForAll); } - public static CompletionWaiter createAndInitialize(Curator curator, ApplicationId applicationId, String serverId) { - return createAndInitialize(curator, applicationId, serverId, waitForAllDefault); + public static CompletionWaiter createAndInitialize(Curator curator, Path barrierPath, String serverId) { + return createAndInitialize(curator, barrierPath, serverId, waitForAllDefault); } - public static CompletionWaiter createAndInitialize(Curator curator, ApplicationId applicationId, String serverId, Duration waitForAll) { - RemoveApplicationWaiter waiter = new RemoveApplicationWaiter(curator, applicationId, serverId, waitForAll); - - // Cleanup and create a new barrier path - Path barrierPath = waiter.barrierPath(); + public static CompletionWaiter createAndInitialize(Curator curator, Path barrierPath, String serverId, Duration waitForAll) { + // Note: Should be done atomically, but unable to that when path may not exist before delete + // and create should be able to create any missing parent paths curator.delete(barrierPath); - curator.create(barrierPath.getParentPath()); - curator.createAtomically(barrierPath); + curator.create(barrierPath); - return waiter; + return new RemoveApplicationWaiter(curator, barrierPath, serverId, waitForAll); } private int barrierMemberCount() { return (curator.zooKeeperEnsembleCount() / 2) + 1; /* majority */ } - private Path barrierPath() { return barrierPath; } - } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 3e33b345437..cdc16c8f82f 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -13,6 +13,7 @@ import com.yahoo.config.model.api.EndpointCertificateSecrets; import com.yahoo.config.model.api.HostProvisioner; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.api.Quota; import com.yahoo.config.model.api.Reindexing; @@ -66,6 +67,7 @@ public class ModelContextImpl implements ModelContext { private final Optional<? extends Reindexing> reindexing; private final ModelContext.Properties properties; private final Optional<File> appDir; + private final OnnxModelCost onnxModelCost; private final Optional<DockerImage> wantedDockerImageRepository; @@ -92,6 +94,7 @@ public class ModelContextImpl implements ModelContext { Provisioned provisioned, ModelContext.Properties properties, Optional<File> appDir, + OnnxModelCost onnxModelCost, Optional<DockerImage> wantedDockerImageRepository, Version modelVespaVersion, Version wantedNodeVespaVersion) { @@ -109,6 +112,7 @@ public class ModelContextImpl implements ModelContext { this.wantedDockerImageRepository = wantedDockerImageRepository; this.modelVespaVersion = modelVespaVersion; this.wantedNodeVespaVersion = wantedNodeVespaVersion; + this.onnxModelCost = onnxModelCost; } @Override @@ -150,6 +154,8 @@ public class ModelContextImpl implements ModelContext { @Override public Optional<File> appDir() { return appDir; } + @Override public OnnxModelCost onnxModelCost() { return onnxModelCost; } + @Override public Optional<DockerImage> wantedDockerImageRepo() { return wantedDockerImageRepository; } @@ -202,6 +208,7 @@ public class ModelContextImpl implements ModelContext { private final boolean useReconfigurableDispatcher; private final int contentLayerMetadataFeatureLevel; private final boolean dynamicHeapSize; + private final String unknownConfigDefinition; public FeatureFlags(FlagSource source, ApplicationId appId, Version version) { this.defaultTermwiseLimit = flagValue(source, appId, version, Flags.DEFAULT_TERM_WISE_LIMIT); @@ -245,6 +252,7 @@ public class ModelContextImpl implements ModelContext { this.useReconfigurableDispatcher = flagValue(source, appId, version, Flags.USE_RECONFIGURABLE_DISPATCHER); this.contentLayerMetadataFeatureLevel = flagValue(source, appId, version, Flags.CONTENT_LAYER_METADATA_FEATURE_LEVEL); this.dynamicHeapSize = flagValue(source, appId, version, Flags.DYNAMIC_HEAP_SIZE); + this.unknownConfigDefinition = flagValue(source, appId, version, Flags.UNKNOWN_CONFIG_DEFINITION); } @Override public int heapSizePercentage() { return heapPercentage; } @@ -296,6 +304,7 @@ public class ModelContextImpl implements ModelContext { @Override public boolean useReconfigurableDispatcher() { return useReconfigurableDispatcher; } @Override public int contentLayerMetadataFeatureLevel() { return contentLayerMetadataFeatureLevel; } @Override public boolean dynamicHeapSize() { return dynamicHeapSize; } + @Override public String unknownConfigDefinition() { return unknownConfigDefinition; } private static <V> V flagValue(FlagSource source, ApplicationId appId, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) { return flag.bindTo(source) diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java index 328bd143d81..d302e0e8008 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.api.ConfigDefinitionRepo; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelFactory; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.application.provider.MockFileRegistry; import com.yahoo.config.provision.ApplicationId; @@ -58,6 +59,7 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { private final FlagSource flagSource; private final SecretStore secretStore; private final ExecutorService executor; + private final OnnxModelCost onnxModelCost; public ActivatedModelsBuilder(TenantName tenant, long applicationGeneration, @@ -72,7 +74,8 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { ConfigserverConfig configserverConfig, Zone zone, ModelFactoryRegistry modelFactoryRegistry, - ConfigDefinitionRepo configDefinitionRepo) { + ConfigDefinitionRepo configDefinitionRepo, + OnnxModelCost onnxModelCost) { super(modelFactoryRegistry, configserverConfig, zone, hostProvisionerProvider, new SilentDeployLogger()); this.tenant = tenant; this.applicationGeneration = applicationGeneration; @@ -84,6 +87,7 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { this.flagSource = flagSource; this.secretStore = secretStore; this.executor = executor; + this.onnxModelCost = onnxModelCost; } @Override @@ -108,6 +112,7 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { provisioned, modelContextProperties, Optional.empty(), + onnxModelCost, wantedDockerImageRepository, modelFactory.version(), wantedNodeVespaVersion); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java index 4faa475fa08..57c766bb9c2 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java @@ -207,11 +207,12 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { builtModelVersions.add(modelVersion); } catch (RuntimeException e) { // allow failure to create old config models if there is a validation override that allow skipping old - // config models or we're manually deploying + // config models, or we're manually deploying if (builtModelVersions.size() > 0 && ( builtModelVersions.get(0).getModel().skipOldConfigModels(now) || zone().environment().isManuallyDeployed())) - log.log(Level.INFO, applicationId + ": Failed to build version " + version + - ", but allow failure due to validation override or manual deployment"); + log.log(Level.WARNING, applicationId + ": Failed to build version " + version + + ", but allow failure due to validation override or manual deployment:" + + Exceptions.toMessageString(e)); else { log.log(Level.SEVERE, applicationId + ": Failed to build version " + version); throw e; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java index af611b131f6..a3f0284890c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java @@ -16,6 +16,7 @@ import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelCreateResult; import com.yahoo.config.model.api.ModelFactory; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.api.ValidationParameters.IgnoreValidationErrors; @@ -69,6 +70,7 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P private final Optional<ApplicationVersions> activeApplicationVersions; private final Curator curator; private final ExecutorService executor; + private final OnnxModelCost onnxModelCost; public PreparedModelsBuilder(ModelFactoryRegistry modelFactoryRegistry, FlagSource flagSource, @@ -85,7 +87,8 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P PrepareParams params, Optional<ApplicationVersions> activeApplicationVersions, ConfigserverConfig configserverConfig, - Zone zone) { + Zone zone, + OnnxModelCost onnxModelCost) { super(modelFactoryRegistry, configserverConfig, zone, hostProvisionerProvider, deployLogger); this.flagSource = flagSource; this.secretStore = secretStore; @@ -98,6 +101,7 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P this.params = params; this.activeApplicationVersions = activeApplicationVersions; this.executor = executor; + this.onnxModelCost = onnxModelCost; } @Override @@ -123,6 +127,7 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P provisioned, createModelContextProperties(modelFactory.version(), applicationPackage), getAppDir(applicationPackage), + onnxModelCost, wantedDockerImageRepository, modelVersion, wantedNodeVespaVersion); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java index aeff97169f4..67872865106 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java @@ -18,6 +18,7 @@ import com.yahoo.config.model.api.ContainerEndpoint; import com.yahoo.config.model.api.EndpointCertificateMetadata; import com.yahoo.config.model.api.EndpointCertificateSecrets; import com.yahoo.config.model.api.FileDistribution; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.Quota; import com.yahoo.config.model.api.TenantSecretStore; import com.yahoo.config.provision.AllocatedHosts; @@ -93,6 +94,7 @@ public class SessionPreparer { private final FlagSource flagSource; private final ExecutorService executor; private final BooleanFlag writeSessionData; + private final OnnxModelCost onnxModelCost; public SessionPreparer(ModelFactoryRegistry modelFactoryRegistry, FileDistributionFactory fileDistributionFactory, @@ -103,7 +105,8 @@ public class SessionPreparer { Curator curator, Zone zone, FlagSource flagSource, - SecretStore secretStore) { + SecretStore secretStore, + OnnxModelCost onnxModelCost) { this.modelFactoryRegistry = modelFactoryRegistry; this.fileDistributionFactory = fileDistributionFactory; this.hostProvisionerProvider = hostProvisionerProvider; @@ -115,6 +118,7 @@ public class SessionPreparer { this.flagSource = flagSource; this.executor = executor; this.writeSessionData = Flags.WRITE_CONFIG_SERVER_SESSION_DATA_AS_ONE_BLOB.bindTo(flagSource); + this.onnxModelCost = onnxModelCost; } ExecutorService getExecutor() { return executor; } @@ -134,7 +138,8 @@ public class SessionPreparer { ApplicationId applicationId = params.getApplicationId(); Preparation preparation = new Preparation(hostValidator, logger, params, activeApplicationVersions, TenantRepository.getTenantPath(applicationId.tenant()), - serverDbSessionDir, applicationPackage, sessionZooKeeperClient); + serverDbSessionDir, applicationPackage, sessionZooKeeperClient, + onnxModelCost); preparation.preprocess(); try { AllocatedHosts allocatedHosts = preparation.buildModels(now); @@ -186,7 +191,7 @@ public class SessionPreparer { Preparation(HostValidator hostValidator, DeployLogger logger, PrepareParams params, Optional<ApplicationVersions> activeApplicationVersions, Path tenantPath, File serverDbSessionDir, ApplicationPackage applicationPackage, - SessionZooKeeperClient sessionZooKeeperClient) { + SessionZooKeeperClient sessionZooKeeperClient, OnnxModelCost onnxModelCost) { this.logger = logger; this.params = params; this.applicationPackage = applicationPackage; @@ -219,7 +224,8 @@ public class SessionPreparer { params, activeApplicationVersions, configserverConfig, - zone); + zone, + onnxModelCost); } void checkTimeout(String step) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 3b57945b21d..eb07e3010c6 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -9,6 +9,7 @@ import com.yahoo.concurrent.StripedExecutor; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.ConfigDefinitionRepo; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.application.provider.DeployData; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.provision.ApplicationId; @@ -118,6 +119,7 @@ public class SessionRepository { private final SessionPreparer sessionPreparer; private final Path sessionsPath; private final TenantName tenantName; + private final OnnxModelCost onnxModelCost; private final SessionCounter sessionCounter; private final SecretStore secretStore; private final HostProvisionerProvider hostProvisionerProvider; @@ -147,8 +149,10 @@ public class SessionRepository { Clock clock, ModelFactoryRegistry modelFactoryRegistry, ConfigDefinitionRepo configDefinitionRepo, - int maxNodeSize) { + int maxNodeSize, + OnnxModelCost onnxModelCost) { this.tenantName = tenantName; + this.onnxModelCost = onnxModelCost; sessionCounter = new SessionCounter(curator, tenantName); this.sessionsPath = TenantRepository.getSessionsPath(tenantName); this.clock = clock; @@ -553,7 +557,8 @@ public class SessionRepository { configserverConfig, zone, modelFactoryRegistry, - configDefinitionRepo); + configDefinitionRepo, + onnxModelCost); return ApplicationVersions.fromList(builder.buildModels(session.getApplicationId(), session.getDockerImageRepository(), session.getVespaVersion(), diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java index 2bc8cb5bc0a..378cd9bdb8c 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionZooKeeperClient.java @@ -118,21 +118,21 @@ public class SessionZooKeeperClient { public long sessionId() { return sessionId; } - public CompletionWaiter createActiveWaiter() { return createCompletionWaiter(getWaiterPath(ACTIVE_BARRIER)); } + public CompletionWaiter createActiveWaiter() { return createCompletionWaiter(barrierPath(ACTIVE_BARRIER)); } - CompletionWaiter createPrepareWaiter() { return createCompletionWaiter(getWaiterPath(PREPARE_BARRIER)); } + CompletionWaiter createPrepareWaiter() { return createCompletionWaiter(barrierPath(PREPARE_BARRIER)); } - CompletionWaiter getPrepareWaiter() { return getCompletionWaiter(getWaiterPath(PREPARE_BARRIER)); } + CompletionWaiter getPrepareWaiter() { return getCompletionWaiter(barrierPath(PREPARE_BARRIER)); } - CompletionWaiter getActiveWaiter() { return getCompletionWaiter(getWaiterPath(ACTIVE_BARRIER)); } + CompletionWaiter getActiveWaiter() { return getCompletionWaiter(barrierPath(ACTIVE_BARRIER)); } - CompletionWaiter getUploadWaiter() { return getCompletionWaiter(getWaiterPath(UPLOAD_BARRIER)); } + CompletionWaiter getUploadWaiter() { return getCompletionWaiter(barrierPath(UPLOAD_BARRIER)); } private static final String PREPARE_BARRIER = "prepareBarrier"; private static final String ACTIVE_BARRIER = "activeBarrier"; private static final String UPLOAD_BARRIER = "uploadBarrier"; - private Path getWaiterPath(String barrierName) { + private Path barrierPath(String barrierName) { return sessionPath.append(barrierName); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java index ba09b3de365..ea53c8aa2bb 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java @@ -11,6 +11,7 @@ import com.yahoo.concurrent.Locks; import com.yahoo.concurrent.StripedExecutor; import com.yahoo.concurrent.ThreadFactoryFactory; import com.yahoo.config.model.api.ConfigDefinitionRepo; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; @@ -119,6 +120,7 @@ public class TenantRepository { new ScheduledThreadPoolExecutor(1, new DaemonThreadFactory("check for removed applications")); private final Curator.DirectoryCache directoryCache; private final ZookeeperServerConfig zookeeperServerConfig; + private final OnnxModelCost onnxModelCost; /** * Creates a new tenant repository @@ -138,7 +140,8 @@ public class TenantRepository { ConfigActivationListener configActivationListener, TenantListener tenantListener, ZookeeperServerConfig zookeeperServerConfig, - FileDirectory fileDirectory) { + FileDirectory fileDirectory, + OnnxModelCost onnxModelCost) { this(hostRegistry, curator, metrics, @@ -157,7 +160,8 @@ public class TenantRepository { configDefinitionRepo, configActivationListener, tenantListener, - zookeeperServerConfig); + zookeeperServerConfig, + onnxModelCost); } public TenantRepository(HostRegistry hostRegistry, @@ -178,7 +182,8 @@ public class TenantRepository { ConfigDefinitionRepo configDefinitionRepo, ConfigActivationListener configActivationListener, TenantListener tenantListener, - ZookeeperServerConfig zookeeperServerConfig) { + ZookeeperServerConfig zookeeperServerConfig, + OnnxModelCost onnxModelCost) { this.hostRegistry = hostRegistry; this.configserverConfig = configserverConfig; this.curator = curator; @@ -201,6 +206,7 @@ public class TenantRepository { this.zookeeperServerConfig = zookeeperServerConfig; // This we should control with a feature flag. this.deployHelperExecutor = createModelBuilderExecutor(); + this.onnxModelCost = onnxModelCost; curator.framework().getConnectionStateListenable().addListener(this::stateChanged); @@ -353,7 +359,8 @@ public class TenantRepository { curator, zone, flagSource, - secretStore); + secretStore, + onnxModelCost); SessionRepository sessionRepository = new SessionRepository(tenantName, applicationRepo, sessionPreparer, @@ -371,7 +378,8 @@ public class TenantRepository { clock, modelFactoryRegistry, configDefinitionRepo, - zookeeperServerConfig.juteMaxBuffer()); + zookeeperServerConfig.juteMaxBuffer(), + onnxModelCost); log.log(Level.FINE, "Adding tenant '" + tenantName + "'" + ", created " + created + ". Bootstrapping in " + Duration.between(start, clock.instant())); Tenant tenant = new Tenant(tenantName, sessionRepository, applicationRepo, created); diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index 02481291213..a1e9bc3054b 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -26,6 +26,7 @@ <component id="com.yahoo.vespa.config.server.tenant.TenantRepository" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.host.HostRegistry" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.ApplicationRepository" bundle="configserver" /> + <component id="com.yahoo.vespa.config.server.FallbackOnnxModelCostProvider" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.HealthCheckerProviderProvider" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.version.VersionState" bundle="configserver" /> <component id="com.yahoo.config.provision.Zone" bundle="config-provisioning" /> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java index 104727cb4f3..333dae94769 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java @@ -605,7 +605,7 @@ public class ApplicationRepositoryTest { long sessionId = result.sessionId(); exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Session is active: 2"); + exceptionRule.expectMessage("Session 2 for 'test1' is active"); applicationRepository.prepare(sessionId, prepareParams()); exceptionRule.expect(IllegalArgumentException.class); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java index f5cd56707b3..fccb6785cb8 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; import com.yahoo.config.model.api.HostProvisioner; import com.yahoo.config.model.api.ModelContext; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.MockFileRegistry; @@ -78,6 +79,7 @@ public class ModelContextImplTest { Optional.empty(), List.of()), Optional.empty(), + OnnxModelCost.disabled(), Optional.empty(), new Version(7), new Version(8)); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java index 765523177a9..88aed6b058c 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/SessionPrepareHandlerTest.java @@ -92,7 +92,7 @@ public class SessionPrepareHandlerTest extends SessionHandlerTest { public void require_error_when_session_id_does_not_exist() throws Exception { // No session with this id exists HttpResponse response = request(HttpRequest.Method.PUT, 9999L); - assertHttpStatusCodeErrorCodeAndMessage(response, NOT_FOUND, HttpErrorResponse.ErrorCode.NOT_FOUND, "Session 9999 was not found"); + assertHttpStatusCodeErrorCodeAndMessage(response, NOT_FOUND, HttpErrorResponse.ErrorCode.NOT_FOUND, "Local session 9999 for 'test' was not found"); } @Test @@ -180,7 +180,7 @@ public class SessionPrepareHandlerTest extends SessionHandlerTest { HttpResponse getResponse = request(HttpRequest.Method.GET, 9999L); assertHttpStatusCodeErrorCodeAndMessage(getResponse, NOT_FOUND, HttpErrorResponse.ErrorCode.NOT_FOUND, - "Session 9999 was not found"); + "Remote session 9999 for 'test' was not found"); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java index 0158aa1961d..6dbb0d72c87 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java @@ -10,6 +10,7 @@ import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ContainerEndpoint; import com.yahoo.config.model.api.EndpointCertificateSecrets; import com.yahoo.config.model.api.ModelContext; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.provision.ApplicationId; @@ -132,7 +133,8 @@ public class SessionPreparerTest { curator, zone, flagSource, - secretStore); + secretStore, + OnnxModelCost.disabled()); } @Test(expected = InvalidApplicationException.class) diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java index 02ee3202475..1417df73cfc 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TenantRepositoryTest.java @@ -6,6 +6,7 @@ import com.yahoo.cloud.config.ZookeeperServerConfig; import com.yahoo.component.Version; import com.yahoo.concurrent.InThreadExecutorService; import com.yahoo.concurrent.StripedExecutor; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationName; @@ -230,7 +231,8 @@ public class TenantRepositoryTest { new TestConfigDefinitionRepo(), new TenantApplicationsTest.MockConfigActivationListener(), new MockTenantListener(), - new ZookeeperServerConfig.Builder().myid(0).build()); + new ZookeeperServerConfig.Builder().myid(0).build(), + OnnxModelCost.disabled()); } @Override diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestTenantRepository.java b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestTenantRepository.java index dd982ccbd72..0419a313dea 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestTenantRepository.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/tenant/TestTenantRepository.java @@ -6,6 +6,7 @@ import com.yahoo.cloud.config.ZookeeperServerConfig; import com.yahoo.concurrent.InThreadExecutorService; import com.yahoo.concurrent.StripedExecutor; import com.yahoo.config.model.api.ConfigDefinitionRepo; +import com.yahoo.config.model.api.OnnxModelCost; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.config.server.ConfigServerDB; import com.yahoo.vespa.config.server.MockSecretStore; @@ -64,7 +65,8 @@ public class TestTenantRepository extends TenantRepository { configDefinitionRepo, configActivationListener, tenantListener, - new ZookeeperServerConfig.Builder().myid(0).build()); + new ZookeeperServerConfig.Builder().myid(0).build(), + OnnxModelCost.disabled()); } public static class Builder { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/ArchiveService.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/ArchiveService.java index ed965f4331e..66cf3eef954 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/ArchiveService.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/ArchiveService.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.archive; +import com.yahoo.component.Version; import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.zone.ZoneId; @@ -10,6 +11,7 @@ import java.net.URI; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Stream; /** * Service that manages archive storage URIs for tenant nodes. @@ -28,4 +30,19 @@ public interface ArchiveService { Optional<String> findEnclaveArchiveBucket(ZoneId zoneId, CloudAccount cloudAccount); URI bucketURI(ZoneId zoneId, String bucketName); + + /** + * @return the version of the template that was used during the last apply for the given cloud account, + * or {@link Version#emptyVersion} if the version tag was not present or invalid, + * or {@link Optional#empty()} if the we have no access to the cloud account (template probably not applied yet) + */ + Optional<Version> getEnclaveTemplateVersion(CloudAccount cloudAccount); + + static Stream<Version> parseVersion(String versionString) { + try { + return Stream.of(Version.fromString(versionString)); + } catch (IllegalArgumentException e) { + return Stream.empty(); + } + } } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/MockArchiveService.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/MockArchiveService.java index 7461d3aa47e..4e6e71ca855 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/MockArchiveService.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/archive/MockArchiveService.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.archive; +import com.yahoo.component.Version; import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.zone.ZoneId; @@ -59,6 +60,11 @@ public class MockArchiveService implements ArchiveService { return URI.create(String.format("s3://%s/", bucketName)); } + @Override + public Optional<Version> getEnclaveTemplateVersion(CloudAccount cloudAccount) { + return Optional.of(new Version(1, 2, 3)); + } + public void setEnclaveArchiveBucket(ZoneId zoneId, CloudAccount cloudAccount, String bucketName) { removeEnclaveArchiveBucket(zoneId, cloudAccount); diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java index e661c88e117..856af9f4132 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java @@ -258,15 +258,15 @@ public class SystemFlagsDataArchive { root = mapper.readTree(fileContent); // TODO (mortent): Remove this after completing migration of APPLICATION_ID dimension // replace "application" with "instance" for all dimension fields -// List<JsonNode> dimensionParents = root.findParents("dimension"); -// for (JsonNode parentNode : dimensionParents) { -// JsonNode dimension = parentNode.get("dimension"); -// if (dimension.isTextual() && "application".equals(dimension.textValue())) { -// ObjectNode parent = (ObjectNode) parentNode; -// parent.remove("dimension"); -// parent.put("dimension", "instance"); -// } -// } + List<JsonNode> dimensionParents = root.findParents("dimension"); + for (JsonNode parentNode : dimensionParents) { + JsonNode dimension = parentNode.get("dimension"); + if (dimension.isTextual() && "application".equals(dimension.textValue())) { + ObjectNode parent = (ObjectNode) parentNode; + parent.remove("dimension"); + parent.put("dimension", "instance"); + } + } } catch (JsonProcessingException e) { throw new FlagValidationException("Invalid JSON: " + e.getMessage()); } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java index 53a3f431de7..0754a5ed49f 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/AthenzTenant.java @@ -8,6 +8,7 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.PropertyId; import com.yahoo.vespa.hosted.controller.api.integration.organization.Contact; import java.time.Instant; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -27,8 +28,9 @@ public class AthenzTenant extends Tenant { * Use {@link #create(TenantName, AthenzDomain, Property, Optional, Instant)}. * */ public AthenzTenant(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId, - Optional<Contact> contact, Instant createdAt, LastLoginInfo lastLoginInfo, Instant tenantRolesLastMaintained) { - super(name, createdAt, lastLoginInfo, contact, tenantRolesLastMaintained); + Optional<Contact> contact, Instant createdAt, LastLoginInfo lastLoginInfo, Instant tenantRolesLastMaintained, + List<CloudAccountInfo> cloudAccounts) { + super(name, createdAt, lastLoginInfo, contact, tenantRolesLastMaintained, cloudAccounts); this.domain = Objects.requireNonNull(domain, "domain must be non-null"); this.property = Objects.requireNonNull(property, "property must be non-null"); this.propertyId = Objects.requireNonNull(propertyId, "propertyId must be non-null"); @@ -62,7 +64,7 @@ public class AthenzTenant extends Tenant { /** Create a new Athenz tenant */ public static AthenzTenant create(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId, Instant createdAt) { - return new AthenzTenant(requireName(name), domain, property, propertyId, Optional.empty(), createdAt, LastLoginInfo.EMPTY, Instant.EPOCH); + return new AthenzTenant(requireName(name), domain, property, propertyId, Optional.empty(), createdAt, LastLoginInfo.EMPTY, Instant.EPOCH, List.of()); } @Override diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudAccountInfo.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudAccountInfo.java new file mode 100644 index 00000000000..430f5770165 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudAccountInfo.java @@ -0,0 +1,19 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.tenant; + +import com.yahoo.component.Version; +import com.yahoo.config.provision.CloudAccount; + +import java.util.Objects; + +/** + * @author freva + */ +public record CloudAccountInfo(CloudAccount cloudAccount, Version templateVersion) { + + public CloudAccountInfo { + Objects.requireNonNull(cloudAccount, "cloudAccount must be non-null"); + Objects.requireNonNull(templateVersion, "templateVersion must be non-null"); + } + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudTenant.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudTenant.java index 4d7aee7b604..173d3e1950e 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudTenant.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/CloudTenant.java @@ -34,8 +34,8 @@ public class CloudTenant extends Tenant { BiMap<PublicKey, SimplePrincipal> developerKeys, TenantInfo info, List<TenantSecretStore> tenantSecretStores, ArchiveAccess archiveAccess, Optional<Instant> invalidateUserSessionsBefore, Instant tenantRoleLastMaintained, - Optional<BillingReference> billingReference) { - super(name, createdAt, lastLoginInfo, Optional.empty(), tenantRoleLastMaintained); + List<CloudAccountInfo> cloudAccounts, Optional<BillingReference> billingReference) { + super(name, createdAt, lastLoginInfo, Optional.empty(), tenantRoleLastMaintained, cloudAccounts); this.creator = creator; this.developerKeys = developerKeys; this.info = Objects.requireNonNull(info); @@ -51,7 +51,8 @@ public class CloudTenant extends Tenant { createdAt, LastLoginInfo.EMPTY, Optional.ofNullable(creator).map(SimplePrincipal::of), - ImmutableBiMap.of(), TenantInfo.empty(), List.of(), new ArchiveAccess(), Optional.empty(), Instant.EPOCH, Optional.empty()); + ImmutableBiMap.of(), TenantInfo.empty(), List.of(), new ArchiveAccess(), Optional.empty(), + Instant.EPOCH, List.of(), Optional.empty()); } /** The user that created the tenant */ diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/DeletedTenant.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/DeletedTenant.java index b58fdf81278..30ce5d5a3b2 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/DeletedTenant.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/DeletedTenant.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.hosted.controller.tenant; import com.yahoo.config.provision.TenantName; import java.time.Instant; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -17,7 +18,7 @@ public class DeletedTenant extends Tenant { private final Instant deletedAt; public DeletedTenant(TenantName name, Instant createdAt, Instant deletedAt) { - super(name, createdAt, LastLoginInfo.EMPTY, Optional.empty(), Instant.EPOCH); + super(name, createdAt, LastLoginInfo.EMPTY, Optional.empty(), Instant.EPOCH, List.of()); this.deletedAt = Objects.requireNonNull(deletedAt, "deletedAt must be non-null"); } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java index a4500991bf2..8b1c6b3ebde 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/tenant/Tenant.java @@ -5,6 +5,7 @@ import com.yahoo.config.provision.TenantName; import com.yahoo.vespa.hosted.controller.api.integration.organization.Contact; import java.time.Instant; +import java.util.List; import java.util.Objects; import java.util.Optional; @@ -20,13 +21,15 @@ public abstract class Tenant { private final LastLoginInfo lastLoginInfo; private final Optional<Contact> contact; private final Instant tenantRolesLastMaintained; + private final List<CloudAccountInfo> cloudAccounts; - Tenant(TenantName name, Instant createdAt, LastLoginInfo lastLoginInfo, Optional<Contact> contact, Instant tenantRolesLastMaintained) { + Tenant(TenantName name, Instant createdAt, LastLoginInfo lastLoginInfo, Optional<Contact> contact, Instant tenantRolesLastMaintained, List<CloudAccountInfo> cloudAccounts) { this.name = name; this.createdAt = createdAt; this.lastLoginInfo = lastLoginInfo; this.contact = contact; this.tenantRolesLastMaintained = tenantRolesLastMaintained; + this.cloudAccounts = cloudAccounts; } /** Name of this tenant */ @@ -53,6 +56,10 @@ public abstract class Tenant { return tenantRolesLastMaintained; } + public List<CloudAccountInfo> cloudAccounts() { + return cloudAccounts; + } + public abstract Type type(); @Override diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java index aba6cfbfeac..373f8ba9de2 100644 --- a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java @@ -245,7 +245,7 @@ public class SystemFlagsDataArchiveTest { "conditions": [ { "type": "whitelist", - "dimension": "application", + "dimension": "instance", "values": [ "f:o:o" ] } ], @@ -287,7 +287,7 @@ public class SystemFlagsDataArchiveTest { { "comment": "bar", "type": "whitelist", - "dimension": "application", + "dimension": "instance", "values": [ "f:o:o" ] } ], @@ -308,7 +308,7 @@ public class SystemFlagsDataArchiveTest { @Test void normalize_json_succeed_on_valid_values() { addFile(Condition.Type.WHITELIST, "application", "a:b:c"); -// addFile(Condition.Type.WHITELIST, "instance", "a:b:c"); + addFile(Condition.Type.WHITELIST, "instance", "a:b:c"); addFile(Condition.Type.WHITELIST, "cloud", "yahoo"); addFile(Condition.Type.WHITELIST, "cloud", "aws"); addFile(Condition.Type.WHITELIST, "cloud", "gcp"); @@ -362,7 +362,7 @@ public class SystemFlagsDataArchiveTest { @Test void normalize_json_fail_on_invalid_values() { - failAddFile(Condition.Type.WHITELIST, "application", "a.b.c", "In file flags/temporary/foo/default.json: Invalid application 'a.b.c' in whitelist condition: Application ids must be on the form tenant:application:instance, but was a.b.c"); + failAddFile(Condition.Type.WHITELIST, "application", "a.b.c", "In file flags/temporary/foo/default.json: Invalid instance 'a.b.c' in whitelist condition: Application ids must be on the form tenant:application:instance, but was a.b.c"); failAddFile(Condition.Type.WHITELIST, "cloud", "foo", "In file flags/temporary/foo/default.json: Unknown cloud: foo"); // cluster-id: any String is valid failAddFile(Condition.Type.WHITELIST, "cluster-type", "foo", "In file flags/temporary/foo/default.json: Invalid cluster-type 'foo' in whitelist condition: Illegal cluster type 'foo'"); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java index 6ec732a3815..7d19acfce80 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/LockedTenant.java @@ -16,6 +16,7 @@ import com.yahoo.vespa.hosted.controller.api.role.SimplePrincipal; import com.yahoo.vespa.hosted.controller.tenant.ArchiveAccess; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import com.yahoo.vespa.hosted.controller.tenant.BillingReference; +import com.yahoo.vespa.hosted.controller.tenant.CloudAccountInfo; import com.yahoo.vespa.hosted.controller.tenant.CloudTenant; import com.yahoo.vespa.hosted.controller.tenant.DeletedTenant; import com.yahoo.vespa.hosted.controller.tenant.LastLoginInfo; @@ -43,12 +44,14 @@ public abstract class LockedTenant { final Instant createdAt; final LastLoginInfo lastLoginInfo; final Instant tenantRolesLastMaintained; + final List<CloudAccountInfo> cloudAccounts; - private LockedTenant(TenantName name, Instant createdAt, LastLoginInfo lastLoginInfo, Instant tenantRolesLastMaintained) { + private LockedTenant(TenantName name, Instant createdAt, LastLoginInfo lastLoginInfo, Instant tenantRolesLastMaintained, List<CloudAccountInfo> cloudAccounts) { this.name = requireNonNull(name); this.createdAt = requireNonNull(createdAt); this.lastLoginInfo = requireNonNull(lastLoginInfo); this.tenantRolesLastMaintained = requireNonNull(tenantRolesLastMaintained); + this.cloudAccounts = requireNonNull(cloudAccounts); } static LockedTenant of(Tenant tenant, Mutex lock) { @@ -66,6 +69,8 @@ public abstract class LockedTenant { public abstract LockedTenant with(Instant tenantRolesLastMaintained); + public abstract LockedTenant withCloudAccounts(List<CloudAccountInfo> cloudAccounts); + public Deleted deleted(Instant deletedAt) { return new Deleted(new DeletedTenant(name, createdAt, deletedAt)); } @@ -85,8 +90,8 @@ public abstract class LockedTenant { private final Optional<Contact> contact; private Athenz(TenantName name, AthenzDomain domain, Property property, Optional<PropertyId> propertyId, - Optional<Contact> contact, Instant createdAt, LastLoginInfo lastLoginInfo, Instant tenantRolesLastMaintained) { - super(name, createdAt, lastLoginInfo, tenantRolesLastMaintained); + Optional<Contact> contact, Instant createdAt, LastLoginInfo lastLoginInfo, Instant tenantRolesLastMaintained, List<CloudAccountInfo> cloudAccounts) { + super(name, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); this.domain = domain; this.property = property; this.propertyId = propertyId; @@ -94,38 +99,43 @@ public abstract class LockedTenant { } private Athenz(AthenzTenant tenant) { - this(tenant.name(), tenant.domain(), tenant.property(), tenant.propertyId(), tenant.contact(), tenant.createdAt(), tenant.lastLoginInfo(), tenant.tenantRolesLastMaintained()); + this(tenant.name(), tenant.domain(), tenant.property(), tenant.propertyId(), tenant.contact(), tenant.createdAt(), tenant.lastLoginInfo(), tenant.tenantRolesLastMaintained(), tenant.cloudAccounts()); } @Override public AthenzTenant get() { - return new AthenzTenant(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new AthenzTenant(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } public Athenz with(AthenzDomain domain) { - return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } public Athenz with(Property property) { - return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } public Athenz with(PropertyId propertyId) { - return new Athenz(name, domain, property, Optional.of(propertyId), contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new Athenz(name, domain, property, Optional.of(propertyId), contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } public Athenz with(Contact contact) { - return new Athenz(name, domain, property, propertyId, Optional.of(contact), createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new Athenz(name, domain, property, propertyId, Optional.of(contact), createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } @Override public LockedTenant with(LastLoginInfo lastLoginInfo) { - return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } @Override public LockedTenant with(Instant tenantRolesLastMaintained) { - return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); + } + + @Override + public LockedTenant withCloudAccounts(List<CloudAccountInfo> cloudAccounts) { + return new Athenz(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); } } @@ -146,8 +156,8 @@ public abstract class LockedTenant { BiMap<PublicKey, SimplePrincipal> developerKeys, TenantInfo info, List<TenantSecretStore> tenantSecretStores, ArchiveAccess archiveAccess, Optional<Instant> invalidateUserSessionsBefore, Instant tenantRolesLastMaintained, - Optional<BillingReference> billingReference) { - super(name, createdAt, lastLoginInfo, tenantRolesLastMaintained); + List<CloudAccountInfo> cloudAccounts, Optional<BillingReference> billingReference) { + super(name, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccounts); this.developerKeys = ImmutableBiMap.copyOf(developerKeys); this.creator = creator; this.info = info; @@ -158,12 +168,12 @@ public abstract class LockedTenant { } private Cloud(CloudTenant tenant) { - this(tenant.name(), tenant.createdAt(), tenant.lastLoginInfo(), tenant.creator(), tenant.developerKeys(), tenant.info(), tenant.tenantSecretStores(), tenant.archiveAccess(), tenant.invalidateUserSessionsBefore(), tenant.tenantRolesLastMaintained(), tenant.billingReference()); + this(tenant.name(), tenant.createdAt(), tenant.lastLoginInfo(), tenant.creator(), tenant.developerKeys(), tenant.info(), tenant.tenantSecretStores(), tenant.archiveAccess(), tenant.invalidateUserSessionsBefore(), tenant.tenantRolesLastMaintained(), tenant.cloudAccounts(), tenant.billingReference()); } @Override public CloudTenant get() { - return new CloudTenant(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new CloudTenant(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withDeveloperKey(PublicKey key, Principal principal) { @@ -174,51 +184,56 @@ public abstract class LockedTenant { if (keys.inverse().containsKey(simplePrincipal)) throw new IllegalArgumentException(principal + " is already associated with key " + KeyUtils.toPem(keys.inverse().get(simplePrincipal))); keys.put(key, simplePrincipal); - return new Cloud(name, createdAt, lastLoginInfo, creator, keys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, keys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withoutDeveloperKey(PublicKey key) { BiMap<PublicKey, SimplePrincipal> keys = HashBiMap.create(developerKeys); keys.remove(key); - return new Cloud(name, createdAt, lastLoginInfo, creator, keys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, keys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withInfo(TenantInfo newInfo) { - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, newInfo, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, newInfo, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } @Override public LockedTenant with(LastLoginInfo lastLoginInfo) { - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withSecretStore(TenantSecretStore tenantSecretStore) { ArrayList<TenantSecretStore> secretStores = new ArrayList<>(tenantSecretStores); secretStores.add(tenantSecretStore); - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, secretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, secretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withoutSecretStore(TenantSecretStore tenantSecretStore) { ArrayList<TenantSecretStore> secretStores = new ArrayList<>(tenantSecretStores); secretStores.remove(tenantSecretStore); - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, secretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, secretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withArchiveAccess(ArchiveAccess archiveAccess) { - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore,tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore,tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud withInvalidateUserSessionsBefore(Instant invalidateUserSessionsBefore) { - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, Optional.of(invalidateUserSessionsBefore), tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, Optional.of(invalidateUserSessionsBefore), tenantRolesLastMaintained, cloudAccounts, billingReference); } @Override public LockedTenant with(Instant tenantRolesLastMaintained) { - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); + } + + @Override + public LockedTenant withCloudAccounts(List<CloudAccountInfo> cloudAccounts) { + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, billingReference); } public Cloud with(BillingReference billingReference) { - return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, Optional.of(billingReference)); + return new Cloud(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccounts, Optional.of(billingReference)); } } @@ -229,7 +244,7 @@ public abstract class LockedTenant { private final Instant deletedAt; private Deleted(DeletedTenant tenant) { - super(tenant.name(), tenant.createdAt(), tenant.lastLoginInfo(), Instant.EPOCH); + super(tenant.name(), tenant.createdAt(), tenant.lastLoginInfo(), Instant.EPOCH, List.of()); this.deletedAt = tenant.deletedAt(); } @@ -247,6 +262,11 @@ public abstract class LockedTenant { public LockedTenant with(Instant tenantRolesLastMaintained) { return this; } + + @Override + public LockedTenant withCloudAccounts(List<CloudAccountInfo> cloudAccounts) { + return this; + } } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java index 091836a1eea..27cd5e7e576 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java @@ -525,7 +525,8 @@ public class RoutingController { } public boolean generatedEndpointsEnabled(ApplicationId instance) { - return randomizedEndpoints.with(FetchVector.Dimension.INSTANCE_ID, instance.serializedForm()).value(); + return randomizedEndpoints.with(FetchVector.Dimension.INSTANCE_ID, instance.serializedForm()) + .with(FetchVector.Dimension.TENANT_ID, instance.tenant().value()).value(); } private static void requireGeneratedEndpoints(GeneratedEndpointList generatedEndpoints, boolean declared) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java index bf2f2ab90eb..d11540b28dd 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/TenantController.java @@ -12,6 +12,7 @@ import com.yahoo.vespa.hosted.controller.persistence.CuratorDb; import com.yahoo.vespa.hosted.controller.security.AccessControl; import com.yahoo.vespa.hosted.controller.security.Credentials; import com.yahoo.vespa.hosted.controller.security.TenantSpec; +import com.yahoo.vespa.hosted.controller.tenant.CloudAccountInfo; import com.yahoo.vespa.hosted.controller.tenant.DeletedTenant; import com.yahoo.vespa.hosted.controller.tenant.LastLoginInfo; import com.yahoo.vespa.hosted.controller.tenant.Tenant; @@ -165,6 +166,14 @@ public class TenantController { } } + public void updateCloudAccounts(TenantName tenantName, List<CloudAccountInfo> cloudAccounts) { + try (Mutex lock = lock(tenantName)) { + var tenant = require(tenantName); + if (tenant.cloudAccounts().equals(cloudAccounts)) return; // no change + curator.writeTenant(LockedTenant.of(tenant, lock).withCloudAccounts(cloudAccounts).get()); + } + } + /** Deletes the given tenant. */ public void delete(TenantName tenant, Optional<Credentials> credentials, boolean forget) { try (Mutex lock = lock(tenant)) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java index e01da00a27e..d661fa189b9 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java @@ -140,10 +140,11 @@ public class EndpointCertificates { } try (NestedTransaction transaction = new NestedTransaction()) { curator.removeUnassignedCertificate(candidate.get(), transaction); - curator.writeAssignedCertificate(new AssignedCertificate(application, instanceName, candidate.get().certificate()), + EndpointCertificate certificate = candidate.get().certificate().withLastRequested(clock.instant().getEpochSecond()); + curator.writeAssignedCertificate(new AssignedCertificate(application, instanceName, certificate), transaction); transaction.commit(); - return candidate.get().certificate(); + return certificate; } } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java index 70eeb2b9f6c..ed383175cc3 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java @@ -69,7 +69,7 @@ public class CertificatePoolMaintainer extends ControllerMaintainer { // Create metric for available certificates in the pool as a fraction of configured size int poolSize = certPoolSize.value(); long available = certificatePool.stream().filter(c -> c.state() == UnassignedCertificate.State.ready).count(); - metric.set(ControllerMetrics.CERTIFICATE_POOL_AVAILABLE.baseName(), (poolSize > 0 ? (available/poolSize) : 1.0), metric.createContext(Map.of())); + metric.set(ControllerMetrics.CERTIFICATE_POOL_AVAILABLE.baseName(), (poolSize > 0 ? ((double)available/poolSize) : 1.0), metric.createContext(Map.of())); if (certificatePool.size() < poolSize) { provisionRandomizedCertificate(); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CloudAccountVerifier.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CloudAccountVerifier.java new file mode 100644 index 00000000000..f0fc8985bdf --- /dev/null +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CloudAccountVerifier.java @@ -0,0 +1,55 @@ +package com.yahoo.vespa.hosted.controller.maintenance; + +import com.yahoo.config.provision.SystemName; +import com.yahoo.vespa.hosted.controller.Controller; +import com.yahoo.vespa.hosted.controller.tenant.CloudAccountInfo; +import com.yahoo.vespa.hosted.controller.tenant.Tenant; + +import java.time.Duration; +import java.util.List; +import java.util.Set; +import java.util.logging.Logger; + +import static java.util.logging.Level.WARNING; + +/** + * Verifies the cloud accounts that may be used by a given user have applied the enclave template + * and extracts the version of the applied template. + * + * All maintainers that operate on external cloud accounts should use the list on the Tenant instance + * maintained by this class rather than the cloud-accounts feature flag. + * + * The template version can be used to determine if new features can be enabled for the cloud account. + * + * @author freva + */ +public class CloudAccountVerifier extends ControllerMaintainer { + + private static final Logger logger = Logger.getLogger(CloudAccountVerifier.class.getName()); + + CloudAccountVerifier(Controller controller, Duration interval) { + super(controller, interval, null, Set.of(SystemName.PublicCd, SystemName.Public)); + } + + @Override + protected double maintain() { + int attempts = 0, failures = 0; + for (Tenant tenant : controller().tenants().asList()) { + try { + attempts++; + List<CloudAccountInfo> cloudAccountInfos = controller().applications().accountsOf(tenant.name()).stream() + .flatMap(account -> controller().serviceRegistry() + .archiveService() + .getEnclaveTemplateVersion(account) + .map(version -> new CloudAccountInfo(account, version)) + .stream()) + .toList(); + controller().tenants().updateCloudAccounts(tenant.name(), cloudAccountInfos); + } catch (RuntimeException e) { + logger.log(WARNING, "Failed to verify cloud accounts for tenant " + tenant.name(), e); + failures++; + } + } + return asSuccessFactorDeviation(attempts, failures); + } +} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java index 6fae732df0a..3dcd8457da6 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java @@ -85,6 +85,7 @@ public class ControllerMaintenance extends AbstractComponent { maintainers.add(new EnclaveAccessMaintainer(controller, intervals.defaultInterval)); maintainers.add(new CertificatePoolMaintainer(controller, metric, intervals.certificatePoolMaintainer)); maintainers.add(new BillingReportMaintainer(controller, intervals.billingReportMaintainer)); + maintainers.add(new CloudAccountVerifier(controller, intervals.cloudAccountVerifier)); } public Upgrader upgrader() { return upgrader; } @@ -147,6 +148,7 @@ public class ControllerMaintenance extends AbstractComponent { private final Duration meteringMonitorMaintainer; private final Duration certificatePoolMaintainer; private final Duration billingReportMaintainer; + private final Duration cloudAccountVerifier; public Intervals(SystemName system) { this.system = Objects.requireNonNull(system); @@ -184,6 +186,7 @@ public class ControllerMaintenance extends AbstractComponent { this.meteringMonitorMaintainer = duration(30, MINUTES); this.certificatePoolMaintainer = duration(15, MINUTES); this.billingReportMaintainer = duration(60, MINUTES); + this.cloudAccountVerifier = duration(10, MINUTES); } private Duration duration(long amount, TemporalUnit unit) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainer.java index 5218da91c46..6c1c4daa1bb 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainer.java @@ -33,7 +33,7 @@ public class EnclaveAccessMaintainer extends ControllerMaintainer { private Set<CloudAccount> externalAccounts() { Set<CloudAccount> accounts = new HashSet<>(); for (Tenant tenant : controller().tenants().asList()) - accounts.addAll(controller().applications().accountsOf(tenant.name())); + tenant.cloudAccounts().forEach(accountInfo -> accounts.add(accountInfo.cloudAccount())); return accounts; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java index a25aa9797ba..dc9c4650191 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/CuratorDb.java @@ -7,6 +7,7 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.concurrent.UncheckedTimeoutException; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.ClusterSpec.Id; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.TenantName; @@ -602,7 +603,7 @@ public class CuratorDb { public List<DnsChallenge> readDnsChallenges(DeploymentId id) { return curator.getChildren(dnsChallengePath(id)).stream() - .map(cluster -> readDnsChallenge(new ClusterId(id, ClusterSpec.Id.from(cluster)))) + .map(cluster -> readDnsChallenge(new ClusterId(id, Id.from(cluster)))) .toList(); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java index e3d61c81667..760fb9b0366 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializer.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.controller.persistence; import com.google.common.collect.BiMap; import com.google.common.collect.ImmutableBiMap; +import com.yahoo.component.Version; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.TenantName; import com.yahoo.security.KeyUtils; import com.yahoo.slime.ArrayTraverser; @@ -20,6 +22,7 @@ import com.yahoo.vespa.hosted.controller.api.role.SimplePrincipal; import com.yahoo.vespa.hosted.controller.tenant.ArchiveAccess; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import com.yahoo.vespa.hosted.controller.tenant.BillingReference; +import com.yahoo.vespa.hosted.controller.tenant.CloudAccountInfo; import com.yahoo.vespa.hosted.controller.tenant.CloudTenant; import com.yahoo.vespa.hosted.controller.tenant.DeletedTenant; import com.yahoo.vespa.hosted.controller.tenant.Email; @@ -85,6 +88,9 @@ public class TenantSerializer { private static final String invalidateUserSessionsBeforeField = "invalidateUserSessionsBefore"; private static final String tenantRolesLastMaintainedField = "tenantRolesLastMaintained"; private static final String billingReferenceField = "billingReference"; + private static final String cloudAccountsField = "cloudAccounts"; + private static final String accountField = "account"; + private static final String templateVersionField = "templateVersion"; private static final String awsIdField = "awsId"; private static final String roleField = "role"; @@ -97,6 +103,7 @@ public class TenantSerializer { tenantObject.setLong(createdAtField, tenant.createdAt().toEpochMilli()); toSlime(tenant.lastLoginInfo(), tenantObject.setObject(lastLoginInfoField)); tenantObject.setLong(tenantRolesLastMaintainedField, tenant.tenantRolesLastMaintained().toEpochMilli()); + cloudAccountsToSlime(tenant.cloudAccounts(), tenantObject.setArray(cloudAccountsField)); switch (tenant.type()) { case athenz: toSlime((AthenzTenant) tenant, tenantObject); break; @@ -162,6 +169,14 @@ public class TenantSerializer { } } + private void cloudAccountsToSlime(List<CloudAccountInfo> cloudAccounts, Cursor cloudAccountsObject) { + cloudAccounts.forEach(cloudAccountInfo -> { + Cursor object = cloudAccountsObject.addObject(); + object.setString(accountField, cloudAccountInfo.cloudAccount().account()); + object.setString(templateVersionField, cloudAccountInfo.templateVersion().toFullString()); + }); + } + public Tenant tenantFrom(Slime slime) { Inspector tenantObject = slime.get(); Tenant.Type type = typeOf(tenantObject.field(typeField).asString()); @@ -183,7 +198,8 @@ public class TenantSerializer { Instant createdAt = SlimeUtils.instant(tenantObject.field(createdAtField)); LastLoginInfo lastLoginInfo = lastLoginInfoFromSlime(tenantObject.field(lastLoginInfoField)); Instant tenantRolesLastMaintained = SlimeUtils.instant(tenantObject.field(tenantRolesLastMaintainedField)); - return new AthenzTenant(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained); + List<CloudAccountInfo> cloudAccountInfos = cloudAccountsFromSlime(tenantObject.field(cloudAccountsField)); + return new AthenzTenant(name, domain, property, propertyId, contact, createdAt, lastLoginInfo, tenantRolesLastMaintained, cloudAccountInfos); } private CloudTenant cloudTenantFrom(Inspector tenantObject) { @@ -197,8 +213,9 @@ public class TenantSerializer { ArchiveAccess archiveAccess = archiveAccessFromSlime(tenantObject); Optional<Instant> invalidateUserSessionsBefore = SlimeUtils.optionalInstant(tenantObject.field(invalidateUserSessionsBeforeField)); Instant tenantRolesLastMaintained = SlimeUtils.instant(tenantObject.field(tenantRolesLastMaintainedField)); + List<CloudAccountInfo> cloudAccountInfos = cloudAccountsFromSlime(tenantObject.field(cloudAccountsField)); Optional<BillingReference> billingReference = billingReferenceFrom(tenantObject.field(billingReferenceField)); - return new CloudTenant(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, billingReference); + return new CloudTenant(name, createdAt, lastLoginInfo, creator, developerKeys, info, tenantSecretStores, archiveAccess, invalidateUserSessionsBefore, tenantRolesLastMaintained, cloudAccountInfos, billingReference); } private DeletedTenant deletedTenantFrom(Inspector tenantObject) { @@ -284,6 +301,14 @@ public class TenantSerializer { return new LastLoginInfo(lastLoginByUserLevel); } + private List<CloudAccountInfo> cloudAccountsFromSlime(Inspector cloudAccountsObject) { + return SlimeUtils.entriesStream(cloudAccountsObject) + .map(inspector -> new CloudAccountInfo( + CloudAccount.from(inspector.field(accountField).asString()), + Version.fromString(inspector.field(templateVersionField).asString()))) + .toList(); + } + void toSlime(TenantInfo info, Cursor parentCursor) { if (info.isEmpty()) return; Cursor infoCursor = parentCursor.setObject("info"); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 46c81fc073f..16d862a66ef 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -2915,6 +2915,15 @@ public class ApplicationApiHandler extends AuditLoggingRequestHandler { } } tenantMetaDataToSlime(tenant, applications, object.setObject("metaData")); + + if (!tenant.cloudAccounts().isEmpty()) { + Cursor cloudAccounts = object.setArray("cloudAccounts"); + tenant.cloudAccounts().forEach(accountInfo -> { + Cursor accountObject = cloudAccounts.addObject(); + accountObject.setString("cloudAccount", accountInfo.cloudAccount().value()); + accountObject.setString("templateVersion", accountInfo.templateVersion().toFullString()); + }); + } } private void toSlime(ArchiveAccess archiveAccess, Cursor object) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/RoutingPolicies.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/RoutingPolicies.java index de25161c461..a608da1d6a9 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/RoutingPolicies.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/RoutingPolicies.java @@ -416,7 +416,7 @@ public class RoutingPolicies { private void setPrivateDns(Endpoint endpoint, LoadBalancer loadBalancer, DeploymentId deploymentId) { if (loadBalancer.service().isEmpty()) return; - // TODO(mpolden): Why is this done? Consider creating private DNS for all auth methods + // TODO(mpolden): Model one service for each endpoint (type), to allow private endpoints with tokens. boolean skipBasedOnAuthMethod = switch (endpoint.authMethod()) { case token -> true; case mtls -> false; @@ -436,10 +436,18 @@ public class RoutingPolicies { }); } + /** Deletes all DNS challenges, and corresponding TXT records, for the given deployment. */ + public void removeDnsChallenges(DeploymentId deploymentId) { + try (Mutex lock = db.lockNameServiceQueue()) { + db.readDnsChallenges(deploymentId).forEach(this::removeDnsChallenge); + } + } + /** Returns true iff. the given deployment has no incomplete DNS challenges, or throws (and cleans up) on errors. */ public boolean processDnsChallenges(DeploymentId deploymentId) { try (Mutex lock = db.lockNameServiceQueue()) { List<DnsChallenge> challenges = new ArrayList<>(db.readDnsChallenges(deploymentId)); + challenges.removeIf(challenge -> challenge.state() == ChallengeState.done); Set<RecordName> pendingRequests = controller.curator().readNameServiceQueue().requests().stream() .map(NameServiceRequest::name) .collect(Collectors.toSet()); @@ -450,14 +458,8 @@ public class RoutingPolicies { challenge = challenge.withState(ChallengeState.ready); } ChallengeState state = controller.serviceRegistry().vpcEndpointService().process(challenge); - if (state == ChallengeState.done) { - removeDnsChallenge(challenge); - return true; - } - else { - db.writeDnsChallenge(challenge.withState(state)); - return false; - } + db.writeDnsChallenge(challenge.withState(state)); + return state == ChallengeState.done; }); return challenges.isEmpty(); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/DeploymentRoutingContext.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/DeploymentRoutingContext.java index df0226176a2..99f60735f6e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/DeploymentRoutingContext.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/routing/context/DeploymentRoutingContext.java @@ -57,6 +57,7 @@ public abstract class DeploymentRoutingContext implements RoutingContext { /** Deactivate routing configuration for the deployment in this context, using given deployment spec */ public final void deactivate(DeploymentSpec deploymentSpec) { routing.policies().refresh(deployment, deploymentSpec, EndpointList.EMPTY); + routing.policies().removeDnsChallenges(deployment); } /** Routing method of this context */ diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificatesTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificatesTest.java index 1cb43453918..a6d3b435dcb 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificatesTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificatesTest.java @@ -306,6 +306,9 @@ public class EndpointCertificatesTest { fail("Expected exception as certificate is not ready"); } catch (IllegalArgumentException ignored) {} + // Advance clock to verify last requested time + clock.advance(Duration.ofDays(3)); + // Certificate is assigned from pool instead. The previously assigned certificate will eventually be cleaned up // by EndpointCertificateMaintainer { // prod @@ -315,6 +318,7 @@ public class EndpointCertificatesTest { assertEquals(certId, cert.get().randomizedId().get()); assertEquals(certId, tester.curator().readAssignedCertificate(TenantAndApplicationId.from(instance.id()), Optional.empty()).get().certificate().randomizedId().get(), "Certificate is assigned at application-level"); assertTrue(tester.controller().curator().readUnassignedCertificate(certId).isEmpty(), "Certificate is removed from pool"); + assertEquals(clock.instant().getEpochSecond(), cert.get().lastRequested()); } { // dev @@ -325,6 +329,7 @@ public class EndpointCertificatesTest { assertEquals(certId, cert.get().randomizedId().get()); assertEquals(certId, tester.curator().readAssignedCertificate(instance.id()).get().certificate().randomizedId().get(), "Certificate is assigned at instance-level"); assertTrue(tester.controller().curator().readUnassignedCertificate(certId).isEmpty(), "Certificate is removed from pool"); + assertEquals(clock.instant().getEpochSecond(), cert.get().lastRequested()); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainerTest.java index 88c5ae9ff06..4257261b09b 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainerTest.java @@ -53,12 +53,4 @@ public class CertificatePoolMaintainerTest { assertEquals(0.0, maintainer.maintain(), 0.0000001); assertEquals(n, tester.curator().readUnassignedCertificates().size()); } - - void old_unassigned_certs_are_refreshed() { - tester.flagSource().withIntFlag(PermanentFlags.CERT_POOL_SIZE.id(), 1); - assertNumCerts(1); - EndpointCertificateProviderMock endpointCertificateProvider = (EndpointCertificateProviderMock) tester.controller().serviceRegistry().endpointCertificateProvider(); - var request = endpointCertificateProvider.listCertificates().get(0); - - } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainerTest.java index 5bfac2866ce..1e1079a3314 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/EnclaveAccessMaintainerTest.java @@ -21,17 +21,20 @@ class EnclaveAccessMaintainerTest { void test() { ControllerTester tester = new ControllerTester(); MockEnclaveAccessService amis = tester.serviceRegistry().enclaveAccessService(); - EnclaveAccessMaintainer sharer = new EnclaveAccessMaintainer(tester.controller(), Duration.ofMinutes(1)); + EnclaveAccessMaintainer sharer = new EnclaveAccessMaintainer(tester.controller(), Duration.ofHours(1)); + CloudAccountVerifier accountVerifier = new CloudAccountVerifier(tester.controller(), Duration.ofHours(1)); assertEquals(Set.of(), amis.currentAccounts()); assertEquals(1, sharer.maintain()); assertEquals(Set.of(), amis.currentAccounts()); tester.createTenant("tanten"); + accountVerifier.maintain(); assertEquals(1, sharer.maintain()); assertEquals(Set.of(), amis.currentAccounts()); tester.flagSource().withListFlag(PermanentFlags.CLOUD_ACCOUNTS.id(), List.of("123123123123", "321321321321"), String.class); + accountVerifier.maintain(); assertEquals(1, sharer.maintain()); assertEquals(Set.of(CloudAccount.from("aws:123123123123"), CloudAccount.from("aws:321321321321")), amis.currentAccounts()); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotificationsDbTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotificationsDbTest.java index bdbbc4b293f..228a61cebc6 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotificationsDbTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotificationsDbTest.java @@ -10,7 +10,6 @@ import com.yahoo.config.provision.zone.ZoneId; import com.yahoo.path.Path; import com.yahoo.test.ManualClock; import com.yahoo.vespa.flags.FlagSource; -import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.flags.PermanentFlags; import com.yahoo.vespa.hosted.controller.api.application.v4.model.ClusterMetrics; @@ -69,6 +68,7 @@ public class NotificationsDbTest { new ArchiveAccess(), Optional.empty(), Instant.EPOCH, + List.of(), Optional.empty()); private static final List<Notification> notifications = List.of( notification(1001, Type.deployment, Level.error, NotificationSource.from(tenant), "tenant msg"), diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotifierTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotifierTest.java index ef1d9cd92e3..15524e2748c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotifierTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/notification/NotifierTest.java @@ -6,7 +6,6 @@ import com.yahoo.config.provision.ApplicationName; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.TenantName; -import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.flags.PermanentFlags; import com.yahoo.vespa.hosted.controller.api.integration.stubs.MockMailer; @@ -47,6 +46,7 @@ public class NotifierTest { new ArchiveAccess(), Optional.empty(), Instant.EPOCH, + List.of(), Optional.empty()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java index dd7afa314ea..4369675ba3e 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/persistence/TenantSerializerTest.java @@ -2,6 +2,8 @@ package com.yahoo.vespa.hosted.controller.persistence;// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. import com.google.common.collect.ImmutableBiMap; +import com.yahoo.component.Version; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.TenantName; import com.yahoo.security.KeyUtils; import com.yahoo.slime.Cursor; @@ -16,6 +18,7 @@ import com.yahoo.vespa.hosted.controller.api.role.SimplePrincipal; import com.yahoo.vespa.hosted.controller.tenant.ArchiveAccess; import com.yahoo.vespa.hosted.controller.tenant.AthenzTenant; import com.yahoo.vespa.hosted.controller.tenant.BillingReference; +import com.yahoo.vespa.hosted.controller.tenant.CloudAccountInfo; import com.yahoo.vespa.hosted.controller.tenant.CloudTenant; import com.yahoo.vespa.hosted.controller.tenant.DeletedTenant; import com.yahoo.vespa.hosted.controller.tenant.Email; @@ -91,7 +94,8 @@ public class TenantSerializerTest { Optional.of(contact()), Instant.EPOCH, lastLoginInfo(321L, 654L, 987L), - Instant.EPOCH); + Instant.EPOCH, + List.of()); AthenzTenant serialized = (AthenzTenant) serializer.tenantFrom(serializer.toSlime(tenant)); assertEquals(tenant.contact(), serialized.contact()); } @@ -109,6 +113,7 @@ public class TenantSerializerTest { new ArchiveAccess(), Optional.empty(), Instant.EPOCH, + List.of(), Optional.empty()); CloudTenant serialized = (CloudTenant) serializer.tenantFrom(serializer.toSlime(tenant)); assertEquals(tenant.name(), serialized.name()); @@ -133,6 +138,7 @@ public class TenantSerializerTest { new ArchiveAccess().withAWSRole("arn:aws:iam::123456789012:role/my-role"), Optional.of(Instant.ofEpochMilli(1234567)), Instant.EPOCH, + List.of(), Optional.empty()); CloudTenant serialized = (CloudTenant) serializer.tenantFrom(serializer.toSlime(tenant)); assertEquals(tenant.info(), serialized.info()); @@ -185,6 +191,8 @@ public class TenantSerializerTest { new ArchiveAccess().withAWSRole("arn:aws:iam::123456789012:role/my-role").withGCPMember("user:foo@example.com"), Optional.empty(), Instant.EPOCH, + List.of(new CloudAccountInfo(CloudAccount.from("aws:123456789012"), Version.fromString("1.2.3")), + new CloudAccountInfo(CloudAccount.from("gcp:my-project"), Version.fromString("3.2.1"))), Optional.empty()); CloudTenant serialized = (CloudTenant) serializer.tenantFrom(serializer.toSlime(tenant)); assertEquals(serialized.archiveAccess().awsRole().get(), "arn:aws:iam::123456789012:role/my-role"); @@ -263,7 +271,8 @@ public class TenantSerializerTest { Optional.of(contact()), Instant.EPOCH, lastLoginInfo(321L, 654L, 987L), - Instant.ofEpochMilli(1_000_000)); + Instant.ofEpochMilli(1_000_000), + List.of()); assertEquals(tenant, serializer.tenantFrom(serializer.toSlime(tenant))); } @@ -281,6 +290,7 @@ public class TenantSerializerTest { new ArchiveAccess().withAWSRole("arn:aws:iam::123456789012:role/my-role").withGCPMember("user:foo@example.com"), Optional.empty(), Instant.EPOCH, + List.of(), Optional.of(reference)); var slime = serializer.toSlime(tenant); var deserialized = serializer.tenantFrom(slime); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java index 4eb6e080737..3b74fea2b9c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java @@ -5,6 +5,7 @@ import ai.vespa.hosted.api.MultiPartStreamer; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationName; +import com.yahoo.config.provision.CloudAccount; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.TenantName; import com.yahoo.restapi.RestApiException; @@ -26,12 +27,14 @@ import com.yahoo.vespa.hosted.controller.restapi.ControllerContainerCloudTest; import com.yahoo.vespa.hosted.controller.security.Auth0Credentials; import com.yahoo.vespa.hosted.controller.security.CloudTenantSpec; import com.yahoo.vespa.hosted.controller.security.Credentials; +import com.yahoo.vespa.hosted.controller.tenant.CloudAccountInfo; import com.yahoo.vespa.hosted.controller.tenant.CloudTenant; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.io.File; import java.util.Collections; +import java.util.List; import java.util.Optional; import java.util.Set; @@ -369,10 +372,10 @@ public class ApplicationApiCloudTest extends ControllerContainerCloudTest { new DeploymentTester(wrapped).newDeploymentContext(ApplicationId.from(tenantName, applicationName, InstanceName.defaultName())) .submit() .deploy(); + tester.controller().tenants().updateCloudAccounts(tenantName, List.of(new CloudAccountInfo(CloudAccount.from("aws:123456789012"), new Version(1, 2, 4)))); tester.assertResponse(request("/application/v4/tenant/scoober", GET).roles(Role.reader(tenantName)), - (response) -> assertFalse(response.getBodyAsString().contains("archiveAccessRole")), - 200); + new File("tenant-cloud.json")); tester.assertResponse(request("/application/v4/tenant/scoober/archive-access/aws", PUT) .data("{\"role\":\"arn:aws:iam::123456789012:role/my-role\"}").roles(Role.administrator(tenantName)), diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index ab70dfd6073..6b377e2069b 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -1372,7 +1372,7 @@ public class ApplicationApiTest extends ControllerContainerTest { // Create legacy tenant name containing underscores tester.controller().curator().writeTenant(new AthenzTenant(TenantName.from("my_tenant"), ATHENZ_TENANT_DOMAIN, - new Property("property1"), Optional.empty(), Optional.empty(), Instant.EPOCH, LastLoginInfo.EMPTY, Instant.EPOCH)); + new Property("property1"), Optional.empty(), Optional.empty(), Instant.EPOCH, LastLoginInfo.EMPTY, Instant.EPOCH, List.of())); // POST (add) a Athenz tenant with dashes duplicates existing one with underscores tester.assertResponse(request("/application/v4/tenant/my-tenant", POST) diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-cloud.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-cloud.json new file mode 100644 index 00000000000..c7258ab3aa6 --- /dev/null +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/responses/tenant-cloud.json @@ -0,0 +1,35 @@ +{ + "tenant": "scoober", + "type": "CLOUD", + "creator": "developer@scoober", + "pemDeveloperKeys": [], + "secretStores": [], + "integrations": { + "aws": { + "tenantRole": "scoober-tenant-role", + "accounts": [] + } + }, + "quota": { + "budgetUsed": 1.304 + }, + "archiveAccess": {}, + "applications": [ + { + "tenant": "scoober", + "application": "albums", + "instance": "default", + "url": "http://localhost:8080/application/v4/tenant/scoober/application/albums/instance/default" + } + ], + "metaData": { + "createdAtMillis": 1600000000000, + "lastSubmissionToProdMillis": 1000 + }, + "cloudAccounts": [ + { + "cloudAccount": "aws:123456789012", + "templateVersion": "1.2.4" + } + ] +} diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json index eb376a95c74..8b76613676c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/controller/responses/maintenance.json @@ -34,6 +34,9 @@ "name": "ChangeRequestMaintainer" }, { + "name": "CloudAccountVerifier" + }, + { "name": "CloudDatabaseMaintainer" }, { @@ -130,7 +133,5 @@ "name": "VersionStatusUpdater" } ], - "inactive": [ - "DeploymentExpirer" - ] + "inactive": ["DeploymentExpirer"] } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/SignatureFilterTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/SignatureFilterTest.java index 581f9704fc5..001e02e1b16 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/SignatureFilterTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/filter/SignatureFilterTest.java @@ -70,17 +70,7 @@ public class SignatureFilterTest { filter = new SignatureFilter(tester.controller()); signer = new RequestSigner(privateKey, id.serializedForm(), tester.clock()); - tester.curator().writeTenant(new CloudTenant(appId.tenant(), - Instant.EPOCH, - LastLoginInfo.EMPTY, - Optional.empty(), - ImmutableBiMap.of(), - TenantInfo.empty(), - List.of(), - new ArchiveAccess(), - Optional.empty(), - Instant.EPOCH, - Optional.empty())); + tester.curator().writeTenant(CloudTenant.create(appId.tenant(), Instant.EPOCH, null)); tester.curator().writeApplication(new Application(appId, tester.clock().instant())); } @@ -129,6 +119,7 @@ public class SignatureFilterTest { new ArchiveAccess(), Optional.empty(), Instant.EPOCH, + List.of(), Optional.empty())); verifySecurityContext(requestOf(signer.signed(request.copy(), Method.POST, () -> new ByteArrayInputStream(hiBytes)), hiBytes), new SecurityContext(new SimplePrincipal("user"), diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java index 779aee73dae..eb3f9daef53 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java @@ -63,7 +63,7 @@ public class UserFlagsSerializerTest { "{\"id\":\"int-id\",\"rules\":[{\"value\":456}]}," + // Default from DB "{\"id\":\"jackson-id\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"tenant\"}],\"value\":{\"integer\":456,\"string\":\"xyz\"}},{\"value\":{\"integer\":123,\"string\":\"abc\"}}]}," + // Resolved for email // Resolved for email, but conditions are empty since this user is not authorized for any tenants - "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\"}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\"}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + + "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\"}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\"}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + "{\"id\":\"string-id\",\"rules\":[{\"value\":\"value1\"}]}]}", // resolved for email flagData, Set.of(), false, email1); @@ -72,7 +72,7 @@ public class UserFlagsSerializerTest { "{\"id\":\"int-id\",\"rules\":[{\"value\":456}]}," + // Default from DB "{\"id\":\"jackson-id\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"tenant\",\"values\":[\"tenant1\"]}],\"value\":{\"integer\":456,\"string\":\"xyz\"}},{\"value\":{\"integer\":123,\"string\":\"abc\"}}]}," + // Resolved for email // Resolved for email, but conditions have filtered out tenant2 - "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + + "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + "{\"id\":\"string-id\",\"rules\":[{\"value\":\"value1\"}]}]}", // resolved for email flagData, Set.of("tenant1"), false, email1); @@ -81,7 +81,7 @@ public class UserFlagsSerializerTest { "{\"id\":\"int-id\",\"rules\":[{\"value\":456}]}," + // Default from DB "{\"id\":\"jackson-id\",\"rules\":[{\"value\":{\"integer\":123,\"string\":\"abc\"}}]}," + // Default from code, no DB values match // Includes last value from DB which is not conditioned on email and the default from code - "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\",\"tenant2:music:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + + "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\",\"tenant2:music:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + "{\"id\":\"string-id\",\"rules\":[{\"value\":\"default value\"}]}]}", // Default from code flagData, Set.of(), true, "operator@domain.tld"); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java index 630de5137bb..1b2fa956763 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java @@ -598,21 +598,28 @@ public class RoutingPoliciesTest { app.deploy(); - // TXT records are cleaned up as we go—the last challenge is the last to go here, and we must flush it ourselves. + // TXT records are cleaned up when deployments are deactivated. + // The last challenge is the last to go here, and we must flush it ourselves. assertEquals(List.of("a.t.aws-us-east-33a.vespa.oath.cloud", "challenge--a.t.aws-us-east-33a.vespa.oath.cloud"), tester.recordNames()); app.flushDnsUpdates(); assertEquals(Set.of(new Record(Type.CNAME, RecordName.from("a.t.aws-us-east-33a.vespa.oath.cloud"), - RecordData.from("lb-0--t.a.default--prod.aws-us-east-33a."))), + RecordData.from("lb-0--t.a.default--prod.aws-us-east-33a.")), + new Record(Type.TXT, + RecordName.from("challenge--a.t.aws-us-east-33a.vespa.oath.cloud"), + RecordData.from("system"))), tester.controllerTester().nameService().records()); + tester.controllerTester().controller().applications().deactivate(app.instanceId(), zone3); + app.flushDnsUpdates(); + assertEquals(Set.of(), + tester.controllerTester().nameService().records()); + // Deployment fails because challenge is not answered (immediately). tester.tester.controllerTester().serviceRegistry().vpcEndpointService().outcomes .put(RecordName.from("challenge--a.t.aws-us-east-33a.vespa.oath.cloud"), ChallengeState.running); - - // Deployment fails because challenge is not answered (immediately). assertEquals("Status of run 2 of production-aws-us-east-33a for t.a ==> expected: <succeeded> but was: <unfinished>", assertThrows(AssertionError.class, () -> app.submit(appPackage).deploy()) diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index 90fe48ab0a5..120c6104793 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -153,7 +153,7 @@ <maven-plugin-api.vespa.version>${maven-core.vespa.version}</maven-plugin-api.vespa.version> <maven-plugin-tools.vespa.version>3.9.0</maven-plugin-tools.vespa.version> <maven-resources-plugin.vespa.version>3.3.1</maven-resources-plugin.vespa.version> - <maven-shade-plugin.vespa.version>3.5.0</maven-shade-plugin.vespa.version> + <maven-shade-plugin.vespa.version>3.5.1</maven-shade-plugin.vespa.version> <maven-site-plugin.vespa.version>3.12.1</maven-site-plugin.vespa.version> <maven-source-plugin.vespa.version>3.3.0</maven-source-plugin.vespa.version> <properties-maven-plugin.vespa.version>1.2.0</properties-maven-plugin.vespa.version> diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index e5b76bedecd..88198e6f00e 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -62,6 +62,7 @@ public class Flags { " latency-amortized-over-requests, latency-amortized-over-time", "Takes effect at redeployment (requires restart)", INSTANCE_ID); + public static final UnboundStringFlag SUMMARY_DECODE_POLICY = defineStringFlag( "summary-decode-policy", "eager", List.of("baldersheim"), "2023-03-30", "2023-12-31", @@ -314,7 +315,7 @@ public class Flags { INSTANCE_ID); public static final UnboundBooleanFlag ENABLE_THE_ONE_THAT_SHOULD_NOT_BE_NAMED = defineFeatureFlag( - "enable-the-one-that-should-not-be-named", false, List.of("hmusum"), "2023-05-08", "2023-10-01", + "enable-the-one-that-should-not-be-named", false, List.of("hmusum"), "2023-05-08", "2023-11-01", "Whether to enable the one program that should not be named", "Takes effect at next host-admin tick"); @@ -340,16 +341,10 @@ public class Flags { public static final UnboundBooleanFlag WRITE_CONFIG_SERVER_SESSION_DATA_AS_ONE_BLOB = defineFeatureFlag( "write-config-server-session-data-as-blob", false, - List.of("hmusum"), "2023-07-19", "2023-10-01", + List.of("hmusum"), "2023-07-19", "2023-11-01", "Whether to write config server session data in one blob or as individual paths", "Takes effect immediately"); - public static final UnboundBooleanFlag READ_CONFIG_SERVER_SESSION_DATA_AS_ONE_BLOB = defineFeatureFlag( - "read-config-server-session-data-as-blob", false, - List.of("hmusum"), "2023-07-19", "2023-10-01", - "Whether to read config server session data from session data blob or from individual paths", - "Takes effect immediately"); - public static final UnboundBooleanFlag MORE_WIREGUARD = defineFeatureFlag( "more-wireguard", false, List.of("andreer"), "2023-08-21", "2023-10-14", @@ -371,12 +366,6 @@ public class Flags { "Takes effect on next host provisioning / run of host-admin", HOSTNAME, CLOUD_ACCOUNT); - public static final UnboundBooleanFlag WRITE_APPLICATION_DATA_AS_JSON = defineFeatureFlag( - "write-application-data-as-json", true, - List.of("hmusum"), "2023-08-27", "2023-10-01", - "Whether to write application data (active session id, last deployed session id etc. ) as json", - "Takes effect immediately"); - public static final UnboundIntFlag MIN_EXCLUSIVE_ADVERTISED_MEMORY_GB = defineIntFlag( "min-exclusive-advertised-memory-gb", 8, List.of("freva"), "2023-09-08", "2023-11-01", @@ -413,6 +402,13 @@ public class Flags { "Takes effect at redeployment", INSTANCE_ID); + public static final UnboundStringFlag UNKNOWN_CONFIG_DEFINITION = defineStringFlag( + "unknown-config-definition", "log", + List.of("hmusum"), "2023-09-25", "2023-11-01", + "How to handle user config referencing unknown config definitions. Valid values are log, warn, fail", + "Takes effect at redeployment", + INSTANCE_ID); + /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners, String createdAt, String expiresAt, String description, diff --git a/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java b/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java index f867daac245..8fb48c8a82f 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java @@ -23,7 +23,7 @@ public class DimensionHelper { serializedDimensions.put(FetchVector.Dimension.CONSOLE_USER_EMAIL, List.of("console-user-email")); serializedDimensions.put(FetchVector.Dimension.ENVIRONMENT, List.of("environment")); serializedDimensions.put(FetchVector.Dimension.HOSTNAME, List.of("hostname")); - serializedDimensions.put(FetchVector.Dimension.INSTANCE_ID, List.of("application", "instance")); + serializedDimensions.put(FetchVector.Dimension.INSTANCE_ID, List.of("instance", "application")); serializedDimensions.put(FetchVector.Dimension.NODE_TYPE, List.of("node-type")); serializedDimensions.put(FetchVector.Dimension.SYSTEM, List.of("system")); serializedDimensions.put(FetchVector.Dimension.TENANT_ID, List.of("tenant")); diff --git a/maven-plugins/allowed-maven-dependencies.txt b/maven-plugins/allowed-maven-dependencies.txt index e1a1adf3b4d..6853632ea40 100644 --- a/maven-plugins/allowed-maven-dependencies.txt +++ b/maven-plugins/allowed-maven-dependencies.txt @@ -35,7 +35,7 @@ org.apache.maven:maven-settings-builder:3.9.4 org.apache.maven.enforcer:enforcer-api:3.4.1 org.apache.maven.enforcer:enforcer-rules:3.4.1 org.apache.maven.plugin-tools:maven-plugin-annotations:3.9.0 -org.apache.maven.plugins:maven-shade-plugin:3.5.0 +org.apache.maven.plugins:maven-shade-plugin:3.5.1 org.apache.maven.resolver:maven-resolver-api:1.9.14 org.apache.maven.resolver:maven-resolver-impl:1.9.14 org.apache.maven.resolver:maven-resolver-named-locks:1.9.14 @@ -63,7 +63,7 @@ org.ow2.asm:asm-tree:9.5 org.slf4j:slf4j-api:1.7.36 org.tukaani:xz:1.9 org.twdata.maven:mojo-executor:2.4.0 -org.vafer:jdependency:2.8.0 +org.vafer:jdependency:2.9.0 #[test-only] # Contains dependencies that are used exclusively in 'test' scope diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index aafb9877c27..4bb7bcc9225 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -27,7 +27,7 @@ import java.util.Arrays; import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; /** - * A ColBERT embedder implementation that maps text to multiple vectors, one vector per subword id. + * A ColBERT embedder implementation that maps text to multiple vectors, one vector per token subword id. * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. * * See col-bert-embedder.def for configurable parameters. @@ -60,10 +60,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { attentionMaskName = config.transformerAttentionMask(); outputName = config.transformerOutput(); maxTransformerTokens = config.transformerMaxTokens(); - if(config.maxDocumentTokens() > maxTransformerTokens) - throw new IllegalArgumentException("maxDocumentTokens must be less than or equal to transformerMaxTokens"); - maxDocumentTokens = config.maxDocumentTokens(); - maxQueryTokens = config.maxQueryTokens(); + maxDocumentTokens = Math.min(config.maxDocumentTokens(), maxTransformerTokens); + maxQueryTokens = Math.min(config.maxQueryTokens(), maxTransformerTokens); startSequenceToken = config.transformerStartSequenceToken(); endSequenceToken = config.transformerEndSequenceToken(); maskSequenceToken = config.transformerMaskToken(); @@ -75,7 +73,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { .setPadding(false); var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { - // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + // Force truncation + // to max length accepted by model if tokenizer.json contains no valid truncation configuration int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() ? info.maxLength() : config.transformerMaxTokens(); @@ -115,8 +114,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String text, Context context, TensorType tensorType) { if(!verifyTensorType(tensorType)) { - throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination." + - "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType.toString()); + throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination. " + + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType); } if (context.getDestination().startsWith("query")) { return embedQuery(text, context, tensorType); @@ -152,6 +151,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { inputIds.add(Q_TOKEN_ID); inputIds.addAll(ids); inputIds.add(endSequenceToken); + int length = inputIds.size(); int padding = maxQueryTokens - length; @@ -177,12 +177,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Token dimensionality does not" + " match indexed dimensionality of " + dims); } - Tensor.Builder builder = Tensor.Builder.of(tensorType); - for (int token = 0; token < result.shape()[0]; token++) - for (int d = 0; d < result.shape()[1]; d++) - builder.cell(TensorAddress.of(token, d), result.get(TensorAddress.of(token, d))); + Tensor resultTensor = toFloatTensor(result, tensorType, inputIds.size()); runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); - return builder.build(); + return resultTensor; } protected Tensor embedDocument(String text, Context context, TensorType tensorType) { @@ -193,7 +190,6 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { List<Long> ids = encoding.ids().stream().filter(token -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList(); - ; if (ids.size() > maxDocumentTokens - 3) ids = ids.subList(0, maxDocumentTokens - 3); @@ -216,29 +212,29 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor tokenEmbeddings = outputs.get(outputName); IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); Tensor contextualEmbeddings; + int retainedTokens = inputIds.size() -1; //Do not retain last PAD if(tensorType.valueType() == TensorType.Value.INT8) { - contextualEmbeddings = toBitTensor(result, tensorType); + contextualEmbeddings = toBitTensor(result, tensorType, retainedTokens); } else { - contextualEmbeddings = toFloatTensor(result, tensorType); + contextualEmbeddings = toFloatTensor(result, tensorType, retainedTokens); } - runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); return contextualEmbeddings; } - public static Tensor toFloatTensor(IndexedTensor result, TensorType type) { + public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { int size = type.indexedSubtype().dimensions().size(); if (size != 1) throw new IllegalArgumentException("Indexed tensor must have one dimension"); - int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDim = (int)result.shape()[1]; - if(resultDim != dims) { - throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDim - + " + dimensions into tensor with " + dims); + int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDimensionality = (int)result.shape()[1]; + if(resultDimensionality != wantedDimensionality) { + throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + + " + dimensions into tensor with " + wantedDimensionality); } Tensor.Builder builder = Tensor.Builder.of(type); - for (int token = 0; token < result.shape()[0]; token++) { - for (int d = 0; d < result.shape()[1]; d++) { + for (int token = 0; token < nTokens; token++) { + for (int d = 0; d < resultDimensionality; d++) { var value = result.get(TensorAddress.of(token, d)); builder.cell(TensorAddress.of(token,d),value); } @@ -246,21 +242,21 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - public static Tensor toBitTensor(IndexedTensor result, TensorType type) { + public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) { if (type.valueType() != TensorType.Value.INT8) throw new IllegalArgumentException("Only a int8 tensor type can be" + " the destination of bit packing"); int size = type.indexedSubtype().dimensions().size(); if (size != 1) throw new IllegalArgumentException("Indexed tensor must have one dimension"); - int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDim = (int)result.shape()[1]; - if(resultDim/8 != dims) { - throw new IllegalArgumentException("Not possible to pack " + resultDim - + " + dimensions into " + dims); + int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDimensionality = (int)result.shape()[1]; + if(resultDimensionality/8 != wantedDimensionality) { + throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + + " + dimensions into " + wantedDimensionality + " dimensions"); } Tensor.Builder builder = Tensor.Builder.of(type); - for (int token = 0; token < result.shape()[0]; token++) { + for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; for (int d = 0; d < result.shape()[1]; d++) { diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index 8516f6e6689..4e398f7245d 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -31,7 +31,7 @@ public class ColBertEmbedderTest { "[1, 1, 1, 1, 1, 1, 1, 1]" + "]", TensorType.fromSpec("tensor<int8>(dt{},x[1])"), - "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}" + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6 ); assertPackedRight( "" + @@ -41,7 +41,7 @@ public class ColBertEmbedderTest { "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + "]", TensorType.fromSpec("tensor<int8>(dt{},x[2])"), - "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}" + "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2 ); } @@ -75,18 +75,18 @@ public class ColBertEmbedderTest { } String text = sb.toString(); Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); - assertEquals(512*128,fullFloat.size()); + assertEquals(511*128,fullFloat.size()); Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); assertEquals(32*128,query.size()); Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); - assertEquals(512*16,binaryRep.size()); + assertEquals(511*16,binaryRep.size()); Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); - // 4 tokens, 16 bytes each = 64 bytes - //because of CLS, special, sequence, SEP - assertEquals(4*16,shortDoc.size());; + // 3 tokens, 16 bytes each = 48 bytes + //CLS [unused1] sequence + assertEquals(3*16,shortDoc.size());; } static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { @@ -100,8 +100,8 @@ public class ColBertEmbedderTest { return result; } - static void assertPackedRight(String numbers, TensorType destination,String expected) { - Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination); + static void assertPackedRight(String numbers, TensorType destination,String expected, int size) { + Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination, size); assertEquals(expected,packed.toString()); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java index 0300d7e92ff..d902fb7b3c4 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java @@ -9,6 +9,7 @@ import com.yahoo.config.provision.DockerImage; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.vespa.hosted.node.admin.task.util.file.DiskSize; import java.net.URI; @@ -73,9 +74,7 @@ public class NodeSpec { private final List<TrustStoreItem> trustStore; - private final Optional<WireguardKey> wireguardPubkey; - - private final Optional<Instant> wireguardKeyTimestamp; + private final Optional<WireguardKeyWithTimestamp> wireguardKeyWithTimestamp; private final boolean wantToRebuild; @@ -112,8 +111,7 @@ public class NodeSpec { Optional<URI> archiveUri, Optional<ApplicationId> exclusiveTo, List<TrustStoreItem> trustStore, - Optional<WireguardKey> wireguardPubkey, - Optional<Instant> wireguardKeyTimestamp, + Optional<WireguardKeyWithTimestamp> wireguardPubkey, boolean wantToRebuild) { if (state == NodeState.active) { @@ -157,8 +155,7 @@ public class NodeSpec { this.archiveUri = Objects.requireNonNull(archiveUri); this.exclusiveTo = Objects.requireNonNull(exclusiveTo); this.trustStore = Objects.requireNonNull(trustStore); - this.wireguardPubkey = Objects.requireNonNull(wireguardPubkey); - this.wireguardKeyTimestamp = Objects.requireNonNull(wireguardKeyTimestamp); + this.wireguardKeyWithTimestamp = Objects.requireNonNull(wireguardPubkey); this.wantToRebuild = wantToRebuild; } @@ -313,9 +310,7 @@ public class NodeSpec { return trustStore; } - public Optional<WireguardKey> wireguardPubkey() { return wireguardPubkey; } - - public Optional<Instant> wireguardKeyTimestamp() { return wireguardKeyTimestamp; } + public Optional<WireguardKeyWithTimestamp> wireguardKeyWithTimestamp() { return wireguardKeyWithTimestamp; } public boolean wantToRebuild() { return wantToRebuild; @@ -358,8 +353,7 @@ public class NodeSpec { Objects.equals(archiveUri, that.archiveUri) && Objects.equals(exclusiveTo, that.exclusiveTo) && Objects.equals(trustStore, that.trustStore) && - Objects.equals(wireguardPubkey, that.wireguardPubkey) && - Objects.equals(wireguardKeyTimestamp, that.wireguardKeyTimestamp) && + Objects.equals(wireguardKeyWithTimestamp, that.wireguardKeyWithTimestamp) && Objects.equals(wantToRebuild, that.wantToRebuild); } @@ -398,8 +392,7 @@ public class NodeSpec { archiveUri, exclusiveTo, trustStore, - wireguardPubkey, - wireguardKeyTimestamp, + wireguardKeyWithTimestamp, wantToRebuild); } @@ -438,8 +431,7 @@ public class NodeSpec { + " archiveUri=" + archiveUri + " exclusiveTo=" + exclusiveTo + " trustStore=" + trustStore - + " wireguardPubkey=" + wireguardPubkey - + " wireguardKeyTimestamp=" + wireguardKeyTimestamp + + " wireguardPubkey=" + wireguardKeyWithTimestamp + " wantToRebuild=" + wantToRebuild + " }"; } @@ -477,8 +469,7 @@ public class NodeSpec { private Optional<URI> archiveUri = Optional.empty(); private Optional<ApplicationId> exclusiveTo = Optional.empty(); private List<TrustStoreItem> trustStore = List.of(); - private Optional<WireguardKey> wireguardPubkey = Optional.empty(); - private Optional<Instant> wireguardKeyTimestamp = Optional.empty(); + private Optional<WireguardKeyWithTimestamp> wireguardPubkey = Optional.empty(); private boolean wantToRebuild = false; public Builder() {} @@ -514,8 +505,7 @@ public class NodeSpec { node.archiveUri.ifPresent(this::archiveUri); node.exclusiveTo.ifPresent(this::exclusiveTo); trustStore(node.trustStore); - node.wireguardPubkey.ifPresent(this::wireguardPubkey); - node.wireguardKeyTimestamp.ifPresent(this::wireguardKeyTimestamp); + node.wireguardKeyWithTimestamp.ifPresent(this::wireguardKeyWithTimestamp); wantToRebuild(node.wantToRebuild); } @@ -704,13 +694,13 @@ public class NodeSpec { return this; } - public Builder wireguardPubkey(WireguardKey wireguardPubKey) { - this.wireguardPubkey = Optional.of(wireguardPubKey); + public Builder wireguardPubkey(WireguardKey wireguardPubkey) { + this.wireguardPubkey = Optional.of(new WireguardKeyWithTimestamp(wireguardPubkey, Instant.EPOCH)); return this; } - public Builder wireguardKeyTimestamp(Instant wireguardKeyTimestamp) { - this.wireguardKeyTimestamp = Optional.of(wireguardKeyTimestamp); + public Builder wireguardKeyWithTimestamp(WireguardKeyWithTimestamp wireguardPubKey) { + this.wireguardPubkey = Optional.of(wireguardPubKey); return this; } @@ -846,7 +836,7 @@ public class NodeSpec { wantedFirmwareCheck, currentFirmwareCheck, modelName, resources, realResources, ipAddresses, additionalIpAddresses, reports, events, parentHostname, archiveUri, exclusiveTo, trustStore, - wireguardPubkey, wireguardKeyTimestamp, wantToRebuild); + wireguardPubkey, wantToRebuild); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java index a9cc2d698e9..17d3b51398f 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java @@ -11,6 +11,7 @@ import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.config.provision.host.FlavorOverrides; import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerApi; import com.yahoo.vespa.hosted.node.admin.configserver.HttpException; @@ -139,26 +140,28 @@ public class RealNodeRepository implements NodeRepository { return response.nodes.stream() .mapMulti((NodeRepositoryNode node, Consumer<WireguardPeer> consumer) -> { - if (node.wireguardPubkey == null || node.wireguardPubkey.isEmpty()) return; - List<VersionedIpAddress> ipAddresses = node.ipAddresses.stream() - .map(InetAddresses::forString) - .filter(address -> !address.isLoopbackAddress() && !address.isLinkLocalAddress() && !address.isSiteLocalAddress()) - .map(VersionedIpAddress::from) - .toList(); - if (ipAddresses.isEmpty()) return; + var keyWithTimestamp = createWireguardKeyWithTimestamp(node.wireguardKeyWithTimestamp, + node.wireguardPubkey, + node.wireguardKeyTimestamp); + if (keyWithTimestamp == null) return; - // Unbox to prevent NPE - long keyTimestamp = node.wireguardKeyTimestamp == null ? 0L : node.wireguardKeyTimestamp; + List<VersionedIpAddress> ipAddresses = getIpAddresses(node); + if (ipAddresses.isEmpty()) return; - consumer.accept(new WireguardPeer(HostName.of(node.hostname), - ipAddresses, - WireguardKey.from(node.wireguardPubkey), - Instant.ofEpochMilli(keyTimestamp))); + consumer.accept(new WireguardPeer(HostName.of(node.hostname), ipAddresses, keyWithTimestamp)); }) .sorted() .toList(); } + private static List<VersionedIpAddress> getIpAddresses(NodeRepositoryNode node) { + return node.ipAddresses.stream() + .map(InetAddresses::forString) + .filter(address -> !address.isLoopbackAddress() && !address.isLinkLocalAddress() && !address.isSiteLocalAddress()) + .map(VersionedIpAddress::from) + .toList(); + } + @Override public List<WireguardPeer> getConfigserverPeers() { GetWireguardResponse response = configServerApi.get("/nodes/v2/wireguard", GetWireguardResponse.class); @@ -246,8 +249,9 @@ public class RealNodeRepository implements NodeRepository { Optional.ofNullable(node.archiveUri).map(URI::create), Optional.ofNullable(node.exclusiveTo).map(ApplicationId::fromSerializedForm), trustStore, - Optional.ofNullable(node.wireguardPubkey).map(WireguardKey::from), - Optional.ofNullable(node.wireguardKeyTimestamp).map(Instant::ofEpochMilli), + Optional.ofNullable(createWireguardKeyWithTimestamp(node.wireguardKeyWithTimestamp, + node.wireguardPubkey, + node.wireguardKeyTimestamp)), node.wantToRebuild); } @@ -364,20 +368,39 @@ public class RealNodeRepository implements NodeRepository { node.trustStore = nodeAttributes.getTrustStore().stream() .map(item -> new NodeRepositoryNode.TrustStoreItem(item.fingerprint(), item.expiry().toEpochMilli())) .toList(); - node.wireguardPubkey = nodeAttributes.getWireguardPubkey().map(WireguardKey::value).orElse(null); + // This is used for patching, and timestamp must only be set on the server side, hence sending EPOCH. + node.wireguardKeyWithTimestamp = nodeAttributes.getWireguardPubkey() + .map(key -> new NodeRepositoryNode.WireguardKeyWithTimestamp(key.value(), 0L)) + .orElse(null); Map<String, JsonNode> reports = nodeAttributes.getReports(); node.reports = reports == null || reports.isEmpty() ? null : new TreeMap<>(reports); + // TODO wg: remove when all nodes are using new key+timestamp format + node.wireguardPubkey = nodeAttributes.getWireguardPubkey().map(WireguardKey::value).orElse(null); return node; } private static WireguardPeer createConfigserverPeer(GetWireguardResponse.Configserver configServer) { - // Unbox to prevent NPE - long keyTimestamp = configServer.wireguardKeyTimestamp == null ? 0L : configServer.wireguardKeyTimestamp; - return new WireguardPeer(HostName.of(configServer.hostname), configServer.ipAddresses.stream().map(VersionedIpAddress::from).toList(), - WireguardKey.from(configServer.wireguardPubkey), - Instant.ofEpochMilli(keyTimestamp)); + createWireguardKeyWithTimestamp(configServer.wireguardKeyWithTimestamp, + configServer.wireguardPubkey, + configServer.wireguardKeyTimestamp)); + } + + private static WireguardKeyWithTimestamp createWireguardKeyWithTimestamp(NodeRepositoryNode.WireguardKeyWithTimestamp wirguardJson, + String oldKeyJson, Long oldTimestampJson) { + if (wirguardJson != null && wirguardJson.key != null && ! wirguardJson.key.isEmpty()) { + return new WireguardKeyWithTimestamp(WireguardKey.from(wirguardJson.key), + Instant.ofEpochMilli(wirguardJson.timestamp)); + // TODO wg: remove when all nodes are using new key+timestamp format + } else if (oldKeyJson != null) { + var timestamp = oldTimestampJson != null ? oldTimestampJson : 0L; + return new WireguardKeyWithTimestamp(WireguardKey.from(oldKeyJson), + Instant.ofEpochMilli(timestamp)); + // TODO END + } else return null; + } + } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/GetWireguardResponse.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/GetWireguardResponse.java index dcbf4cc163f..47903795ef7 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/GetWireguardResponse.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/GetWireguardResponse.java @@ -27,27 +27,23 @@ public class GetWireguardResponse { public static class Configserver { @JsonProperty("hostname") - public final String hostname; + public String hostname; @JsonProperty("ipAddresses") - public final List<String> ipAddresses; + public List<String> ipAddresses; + + @JsonProperty("wireguard") + public NodeRepositoryNode.WireguardKeyWithTimestamp wireguardKeyWithTimestamp; - @JsonProperty("wireguardPubkey") - public final String wireguardPubkey; + // TODO wg: remove when all nodes use new key+timestamp format + @JsonProperty("wireguardPubkey") + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public String wireguardPubkey; @JsonProperty("wireguardKeyTimestamp") - public final Long wireguardKeyTimestamp; - - @JsonCreator - public Configserver(@JsonProperty("hostname") String hostname, - @JsonProperty("ipAddresses") List<String> ipAddresses, - @JsonProperty("wireguardPubkey") String wireguardPubkey, - @JsonProperty("wireguardKeyTimestamp") Long wireguardKeyTimestamp) { - this.hostname = hostname; - this.ipAddresses = ipAddresses; - this.wireguardPubkey = wireguardPubkey; - this.wireguardKeyTimestamp = wireguardKeyTimestamp; - } + @JsonInclude(JsonInclude.Include.NON_EMPTY) + public Long wireguardKeyTimestamp; + } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/NodeRepositoryNode.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/NodeRepositoryNode.java index 3d0d052a877..35ca757ebbe 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/NodeRepositoryNode.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/bindings/NodeRepositoryNode.java @@ -92,6 +92,10 @@ public class NodeRepositoryNode { @JsonProperty("trustStore") @JsonInclude(JsonInclude.Include.NON_EMPTY) public List<TrustStoreItem> trustStore; + @JsonProperty("wireguard") + public WireguardKeyWithTimestamp wireguardKeyWithTimestamp; + + // TODO wg: remove separate key and timestamp when all nodes use new keyWithTimestamp @JsonProperty("wireguardPubkey") @JsonInclude(JsonInclude.Include.NON_EMPTY) public String wireguardPubkey; @@ -141,13 +145,25 @@ public class NodeRepositoryNode { ", exclusiveTo='" + exclusiveTo + '\'' + ", history=" + history + ", trustStore=" + trustStore + - ", wireguardPubkey=" + wireguardPubkey + - ", wireguardKeyTimestamp=" + wireguardKeyTimestamp + + ", wireguard=" + wireguardKeyTimestamp + ", reports=" + reports + '}'; } @JsonIgnoreProperties(ignoreUnknown = true) + public static class WireguardKeyWithTimestamp { + @JsonProperty("key") + public String key; + @JsonProperty("timestamp") + public long timestamp; + + public WireguardKeyWithTimestamp(@JsonProperty("key") String key, @JsonProperty("timestamp") long timestamp) { + this.key = key; + this.timestamp = timestamp; + } + } + + @JsonIgnoreProperties(ignoreUnknown = true) public static class Owner { @JsonProperty("tenant") public String tenant; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeer.java index b5428f57f08..e5ab9a1ce31 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeer.java @@ -1,10 +1,9 @@ package com.yahoo.vespa.hosted.node.admin.wireguard; import com.yahoo.config.provision.HostName; -import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.vespa.hosted.node.admin.task.util.network.VersionedIpAddress; -import java.time.Instant; import java.util.List; /** @@ -15,8 +14,7 @@ import java.util.List; */ public record WireguardPeer(HostName hostname, List<VersionedIpAddress> ipAddresses, - WireguardKey publicKey, - Instant wireguardKeyTimestamp) implements Comparable<WireguardPeer> { + WireguardKeyWithTimestamp keyWithTimestamp) implements Comparable<WireguardPeer> { public WireguardPeer { if (ipAddresses.isEmpty()) throw new IllegalArgumentException("No IP addresses for peer node " + hostname.value()); diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java index 98e65d03f2f..ee3eac22d02 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java @@ -9,6 +9,7 @@ import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.config.provision.host.FlavorOverrides; import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerApi; import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerApiImpl; @@ -140,6 +141,7 @@ public class RealNodeRepositoryTest { var dockerImage = "registry.example.com/repo/image-1:6.2.3"; var wireguardKey = WireguardKey.from("111122223333444455556666777788889999000042c="); var wireguardKeyTimestamp = Instant.ofEpochMilli(123L); // Instant from clock in MockNodeRepository + var keyWithTimestamp = new WireguardKeyWithTimestamp(wireguardKey, wireguardKeyTimestamp); nodeRepositoryApi.updateNodeAttributes( hostname, @@ -151,8 +153,7 @@ public class RealNodeRepositoryTest { NodeSpec hostSpec = nodeRepositoryApi.getOptionalNode(hostname).orElseThrow(); assertEquals(1, hostSpec.currentRestartGeneration().orElseThrow()); assertEquals(dockerImage, hostSpec.currentDockerImage().orElseThrow().asString()); - assertEquals(wireguardKey.value(), hostSpec.wireguardPubkey().orElseThrow().value()); - assertEquals(wireguardKeyTimestamp, hostSpec.wireguardKeyTimestamp().orElseThrow()); + assertEquals(keyWithTimestamp, hostSpec.wireguardKeyWithTimestamp().orElseThrow()); } @Test @@ -215,7 +216,7 @@ public class RealNodeRepositoryTest { assertWireguardPeer(cfgPeers.get(0), "cfg1.yahoo.com", "::201:1", "lololololololololololololololololololololoo=", - Instant.ofEpochMilli(456L)); + 456L); //// Exclave nodes //// @@ -227,16 +228,17 @@ public class RealNodeRepositoryTest { assertWireguardPeer(exclavePeers.get(0), "dockerhost2.yahoo.com", "::101:1", "000011112222333344445555666677778888999900c=", - Instant.ofEpochMilli(123L)); + 123L); } private void assertWireguardPeer(WireguardPeer peer, String hostname, String ipv6, - String publicKey, Instant keyTimestamp) { + String publicKey, long keyTimestamp) { assertEquals(hostname, peer.hostname().value()); assertEquals(1, peer.ipAddresses().size()); assertIp(peer.ipAddresses().get(0), ipv6, 6); - assertEquals(publicKey, peer.publicKey().value()); - assertEquals(keyTimestamp, peer.wireguardKeyTimestamp()); + var expectedKeyWithTimestamp = new WireguardKeyWithTimestamp(WireguardKey.from(publicKey), + Instant.ofEpochMilli(keyTimestamp)); + assertEquals(expectedKeyWithTimestamp, peer.keyWithTimestamp()); } private void assertIp(VersionedIpAddress ip, String expectedIp, int expectedVersion) { diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeerTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeerTest.java index cd76b221c9e..6ee896e3db6 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeerTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/wireguard/WireguardPeerTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.node.admin.wireguard; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.vespa.hosted.node.admin.task.util.network.VersionedIpAddress; import org.junit.jupiter.api.Test; @@ -31,6 +32,7 @@ public class WireguardPeerTest { private static WireguardPeer peer(String hostname) { return new WireguardPeer(HostName.of(hostname), List.of(VersionedIpAddress.from("::1:1")), - WireguardKey.generateRandomForTesting(), Instant.EPOCH); + new WireguardKeyWithTimestamp(WireguardKey.generateRandomForTesting(), Instant.EPOCH)); } + } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java index 24159b88a9b..d5e891a33c7 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/Node.java @@ -10,7 +10,7 @@ import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.TenantName; -import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.config.provision.Zone; import com.yahoo.vespa.hosted.provision.lb.LoadBalancers; import com.yahoo.vespa.hosted.provision.node.Agent; @@ -64,8 +64,7 @@ public final class Node implements Nodelike { private final CloudAccount cloudAccount; /** Only set for configservers and exclave nodes */ - private final Optional<WireguardKey> wireguardPubKey; - private final Optional<Instant> wireguardKeyTimestamp; + private final Optional<WireguardKeyWithTimestamp> wireguardPubKey; /** Record of the last event of each type happening to this node */ private final History history; @@ -96,8 +95,8 @@ public final class Node implements Nodelike { NodeType type, Reports reports, Optional<String> modelName, Optional<TenantName> reservedTo, Optional<ApplicationId> exclusiveToApplicationId, Optional<Duration> hostTTL, Optional<Instant> hostEmptyAt, Optional<ClusterSpec.Type> exclusiveToClusterType, Optional<String> switchHostname, - List<TrustStoreItem> trustStoreItems, CloudAccount cloudAccount, Optional<WireguardKey> wireguardPubKey, - Optional<Instant> wireguardKeyTimestamp) { + List<TrustStoreItem> trustStoreItems, CloudAccount cloudAccount, + Optional<WireguardKeyWithTimestamp> wireguardPubKey) { this.id = Objects.requireNonNull(id, "A node must have an ID"); this.extraId = Objects.requireNonNull(extraId, "Extra ID cannot be null"); this.hostname = requireNonEmptyString(hostname, "A node must have a hostname"); @@ -120,7 +119,6 @@ public final class Node implements Nodelike { this.trustStoreItems = Objects.requireNonNull(trustStoreItems).stream().distinct().toList(); this.cloudAccount = Objects.requireNonNull(cloudAccount); this.wireguardPubKey = Objects.requireNonNull(wireguardPubKey); - this.wireguardKeyTimestamp = Objects.requireNonNull(wireguardKeyTimestamp); if (state == State.active) requireNonEmpty(ipConfig.primary(), "Active node " + hostname + " must have at least one valid IP address"); @@ -264,15 +262,10 @@ public final class Node implements Nodelike { } /** Returns the wireguard public key of this node. Only relevant for enclave nodes. */ - public Optional<WireguardKey> wireguardPubKey() { + public Optional<WireguardKeyWithTimestamp> wireguardPubKey() { return wireguardPubKey; } - /** Returns the timestamp of the wireguard key of this node. Only relevant for enclave nodes. */ - public Optional<Instant> wireguardKeyTimestamp() { - return wireguardKeyTimestamp; - } - /** * Returns a copy of this where wantToFail is set to true and history is updated to reflect this. */ @@ -367,16 +360,14 @@ public final class Node implements Nodelike { public Node with(Status status) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a node with the type assigned to the given value */ public Node with(NodeType type) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a node with the flavor assigned to the given value */ @@ -385,40 +376,35 @@ public final class Node implements Nodelike { History updateHistory = history.with(new History.Event(History.Event.Type.resized, agent, instant)); return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, updateHistory, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this with the reboot generation set to generation */ public Node withReboot(Generation generation) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status.withReboot(generation), state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this with given id set */ public Node withId(String id) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this with model name set to given value */ public Node withModelName(String modelName) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, Optional.of(modelName), reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this with model name cleared */ public Node withoutModelName() { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, Optional.empty(), reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this with a history record saying it was detected to be down at this instant */ @@ -460,24 +446,21 @@ public final class Node implements Nodelike { public Node with(Allocation allocation) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, Optional.of(allocation), history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this node with IP config set to the given value. */ public Node with(IP.Config ipConfig) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this node with the parent hostname assigned to the given value. */ public Node withParentHostname(String parentHostname) { return new Node(id, extraId, ipConfig, hostname, Optional.of(parentHostname), flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node withReservedTo(TenantName tenant) { @@ -485,73 +468,59 @@ public final class Node implements Nodelike { throw new IllegalArgumentException("Only host nodes can be reserved, " + hostname + " has type " + type); return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, Optional.of(tenant), exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } /** Returns a copy of this node which is not reserved to a tenant */ public Node withoutReservedTo() { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, Optional.empty(), exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node withExclusiveToApplicationId(ApplicationId exclusiveTo) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, Optional.ofNullable(exclusiveTo), hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node withExtraId(Optional<String> extraId) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node withHostTTL(Duration hostTTL) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, Optional.ofNullable(hostTTL), hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node withHostEmptyAt(Instant hostEmptyAt) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, Optional.ofNullable(hostEmptyAt), - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node withExclusiveToClusterType(ClusterSpec.Type exclusiveTo) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - Optional.ofNullable(exclusiveTo), switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + Optional.ofNullable(exclusiveTo), switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } - public Node withWireguardPubkey(WireguardKey wireguardPubkey) { + public Node withWireguardPubkey(WireguardKeyWithTimestamp wireguardPubkey) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, Optional.ofNullable(wireguardPubkey), - wireguardKeyTimestamp); - } - - public Node withWireguardKeyTimestamp(Instant wireguardKeyTimestamp) { - return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, - type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - Optional.ofNullable(wireguardKeyTimestamp)); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, + Optional.ofNullable(wireguardPubkey)); } /** Returns a copy of this node with switch hostname set to given value */ public Node withSwitchHostname(String switchHostname) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, Optional.ofNullable(switchHostname), trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, Optional.ofNullable(switchHostname), trustStoreItems, cloudAccount, + wireguardPubKey); } /** Returns a copy of this node with switch hostname unset */ @@ -604,22 +573,19 @@ public final class Node implements Nodelike { public Node with(History history) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node with(Reports reports) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } public Node with(List<TrustStoreItem> trustStoreItems) { return new Node(id, extraId, ipConfig, hostname, parentHostname, flavor, status, state, allocation, history, type, reports, modelName, reservedTo, exclusiveToApplicationId, hostTTL, hostEmptyAt, - exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey, - wireguardKeyTimestamp); + exclusiveToClusterType, switchHostname, trustStoreItems, cloudAccount, wireguardPubKey); } private static Optional<String> requireNonEmptyString(Optional<String> value, String message) { @@ -767,8 +733,7 @@ public final class Node implements Nodelike { private History history; private List<TrustStoreItem> trustStoreItems; private CloudAccount cloudAccount = CloudAccount.empty; - private WireguardKey wireguardPubKey; - private Instant wireguardKeyTimestamp; + private WireguardKeyWithTimestamp wireguardPubKey; private Builder(String id, String hostname, Flavor flavor, State state, NodeType type) { this.id = id; @@ -858,16 +823,11 @@ public final class Node implements Nodelike { return this; } - public Builder wireguardPubKey(WireguardKey wireguardPubKey) { + public Builder wireguardKey(WireguardKeyWithTimestamp wireguardPubKey) { this.wireguardPubKey = wireguardPubKey; return this; } - public Builder wireguardKeyTimestamp(Instant wireguardKeyTimestamp) { - this.wireguardKeyTimestamp = wireguardKeyTimestamp; - return this; - } - public Node build() { return new Node(id, Optional.empty(), Optional.ofNullable(ipConfig).orElse(IP.Config.EMPTY), hostname, Optional.ofNullable(parentHostname), flavor, Optional.ofNullable(status).orElseGet(Status::initial), state, Optional.ofNullable(allocation), @@ -875,7 +835,7 @@ public final class Node implements Nodelike { Optional.ofNullable(modelName), Optional.ofNullable(reservedTo), Optional.ofNullable(exclusiveToApplicationId), Optional.ofNullable(hostTTL), Optional.ofNullable(hostEmptyAt), Optional.ofNullable(exclusiveToClusterType), Optional.ofNullable(switchHostname), Optional.ofNullable(trustStoreItems).orElseGet(List::of), cloudAccount, - Optional.ofNullable(wireguardPubKey), Optional.ofNullable(wireguardKeyTimestamp)); + Optional.ofNullable(wireguardPubKey)); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java index e228d31384c..f42d1ce9bd3 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java @@ -21,7 +21,8 @@ import java.util.Set; public class LoadBalancerInstance { private final Optional<DomainName> hostname; - private final Optional<String> ipAddress; + private final Optional<String> ip4Address; + private final Optional<String> ip6Address; private final Optional<DnsZone> dnsZone; private final Set<Integer> ports; private final Set<String> networks; @@ -30,11 +31,12 @@ public class LoadBalancerInstance { private final List<PrivateServiceId> serviceIds; private final CloudAccount cloudAccount; - public LoadBalancerInstance(Optional<DomainName> hostname, Optional<String> ipAddress, + public LoadBalancerInstance(Optional<DomainName> hostname, Optional<String> ip4Address, Optional<String> ip6Address, Optional<DnsZone> dnsZone, Set<Integer> ports, Set<String> networks, Set<Real> reals, ZoneEndpoint settings, List<PrivateServiceId> serviceIds, CloudAccount cloudAccount) { this.hostname = Objects.requireNonNull(hostname, "hostname must be non-null"); - this.ipAddress = Objects.requireNonNull(ipAddress, "ip must be non-null"); + this.ip4Address = Objects.requireNonNull(ip4Address, "ip4Address must be non-null"); + this.ip6Address = Objects.requireNonNull(ip6Address, "ip6Address must be non-null"); this.dnsZone = Objects.requireNonNull(dnsZone, "dnsZone must be non-null"); this.ports = ImmutableSortedSet.copyOf(requirePorts(ports)); this.networks = ImmutableSortedSet.copyOf(Objects.requireNonNull(networks, "networks must be non-null")); @@ -43,9 +45,9 @@ public class LoadBalancerInstance { this.serviceIds = List.copyOf(Objects.requireNonNull(serviceIds, "private service id must be non-null")); this.cloudAccount = Objects.requireNonNull(cloudAccount, "cloudAccount must be non-null"); - if (hostname.isEmpty() == ipAddress.isEmpty()) { - throw new IllegalArgumentException("Exactly 1 of hostname=%s and ipAddress=%s must be set".formatted( - hostname.map(DomainName::value).orElse("<empty>"), ipAddress.orElse("<empty>"))); + if (hostname.isEmpty() == ip4Address.isEmpty()) { + throw new IllegalArgumentException("Exactly 1 of hostname=%s and ip4Address=%s must be set".formatted( + hostname.map(DomainName::value).orElse("<empty>"), ip4Address.orElse("<empty>"))); } } @@ -54,9 +56,14 @@ public class LoadBalancerInstance { return hostname; } - /** IP address of this (public) load balancer */ - public Optional<String> ipAddress() { - return ipAddress; + /** IPv4 address of this (public) load balancer */ + public Optional<String> ip4Address() { + return ip4Address; + } + + /** IPv6 address of this (public) load balancer */ + public Optional<String> ip6Address() { + return ip6Address; } /** ID of the DNS zone associated with this */ @@ -114,7 +121,7 @@ public class LoadBalancerInstance { public LoadBalancerInstance with(Set<Real> reals, ZoneEndpoint settings, Optional<PrivateServiceId> serviceId) { List<PrivateServiceId> ids = new ArrayList<>(serviceIds); serviceId.filter(id -> ! ids.contains(id)).ifPresent(ids::add); - return new LoadBalancerInstance(hostname, ipAddress, dnsZone, ports, networks, + return new LoadBalancerInstance(hostname, ip4Address, ip6Address, dnsZone, ports, networks, reals, settings, ids, cloudAccount); } @@ -123,7 +130,7 @@ public class LoadBalancerInstance { public LoadBalancerInstance withServiceIds(List<PrivateServiceId> serviceIds) { List<PrivateServiceId> ids = new ArrayList<>(serviceIds); for (PrivateServiceId id : this.serviceIds) if ( ! ids.contains(id)) ids.add(id); - return new LoadBalancerInstance(hostname, ipAddress, dnsZone, ports, networks, reals, settings, ids, cloudAccount); + return new LoadBalancerInstance(hostname, ip4Address, ip6Address, dnsZone, ports, networks, reals, settings, ids, cloudAccount); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java index a79766a577d..c79ccc2aece 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java @@ -57,6 +57,7 @@ public class LoadBalancerServiceMock implements LoadBalancerService { var instance = new LoadBalancerInstance( Optional.of(DomainName.of("lb-" + spec.application().toShortString() + "-" + spec.cluster().value())), Optional.empty(), + Optional.empty(), Optional.of(new DnsZone("zone-id-1")), Collections.singleton(4443), ImmutableSet.of("10.2.3.0/24", "10.4.5.0/24"), diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java index e49d1b302cf..073662b39fe 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java @@ -45,6 +45,7 @@ public class SharedLoadBalancerService implements LoadBalancerService { return new LoadBalancerInstance(Optional.of(DomainName.of(vipHostname)), Optional.empty(), Optional.empty(), + Optional.empty(), Set.of(4443), Set.of(), spec.reals(), diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java index 3c3868bfeb8..e4e08e5a15c 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java @@ -47,6 +47,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; import static com.yahoo.stream.CustomCollectors.toLinkedMap; import static java.util.stream.Collectors.collectingAndThen; @@ -222,7 +223,7 @@ public class CuratorDb { node.type(), node.reports(), node.modelName(), node.reservedTo(), node.exclusiveToApplicationId(), node.hostTTL(), node.hostEmptyAt(), node.exclusiveToClusterType(), node.switchHostname(), node.trustedCertificates(), - node.cloudAccount(), node.wireguardPubKey(), node.wireguardKeyTimestamp()); + node.cloudAccount(), node.wireguardPubKey()); curatorTransaction.add(createOrSet(nodePath(newNode), nodeSerializer.toJson(newNode))); writtenNodes.add(newNode); } @@ -456,7 +457,12 @@ public class CuratorDb { transaction.onCommitted(() -> { for (var lb : loadBalancers) { if (lb.state() == fromState) continue; - Optional<String> target = lb.instance().flatMap(instance -> instance.hostname().map(DomainName::value).or(instance::ipAddress)); + Optional<String> target = lb.instance() + .flatMap(instance -> instance.hostname() + .map(DomainName::value) + .or(() -> Optional.of(Stream.concat(instance.ip4Address().stream(), + instance.ip6Address().stream()) + .collect(Collectors.joining(","))))); if (fromState == null) { log.log(Level.INFO, () -> "Creating " + lb.id() + target.map(t -> " (" + t + ")").orElse("") + " in " + lb.state()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java index b85d96c6b54..d329676f842 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java @@ -45,6 +45,7 @@ public class LoadBalancerSerializer { private static final String idField = "id"; private static final String hostnameField = "hostname"; private static final String lbIpAddressField = "ipAddress"; + private static final String lbIp6AddressField = "ip6Address"; private static final String stateField = "state"; private static final String changedAtField = "changedAt"; private static final String dnsZoneField = "dnsZone"; @@ -69,7 +70,8 @@ public class LoadBalancerSerializer { root.setString(idField, loadBalancer.id().serializedForm()); loadBalancer.instance().flatMap(LoadBalancerInstance::hostname).ifPresent(hostname -> root.setString(hostnameField, hostname.value())); - loadBalancer.instance().flatMap(LoadBalancerInstance::ipAddress).ifPresent(ip -> root.setString(lbIpAddressField, ip)); + loadBalancer.instance().flatMap(LoadBalancerInstance::ip4Address).ifPresent(ip -> root.setString(lbIpAddressField, ip)); + loadBalancer.instance().flatMap(LoadBalancerInstance::ip6Address).ifPresent(ip -> root.setString(lbIp6AddressField, ip)); root.setString(stateField, asString(loadBalancer.state())); root.setLong(changedAtField, loadBalancer.changedAt().toEpochMilli()); loadBalancer.instance().flatMap(LoadBalancerInstance::dnsZone).ifPresent(dnsZone -> root.setString(dnsZoneField, dnsZone.id())); @@ -123,7 +125,8 @@ public class LoadBalancerSerializer { object.field(networksField).traverse((ArrayTraverser) (i, network) -> networks.add(network.asString())); Optional<DomainName> hostname = optionalString(object.field(hostnameField), Function.identity()).filter(s -> !s.isEmpty()).map(DomainName::of); - Optional<String> ipAddress = optionalString(object.field(lbIpAddressField), Function.identity()).filter(s -> !s.isEmpty()); + Optional<String> ip4Address = optionalString(object.field(lbIpAddressField), Function.identity()).filter(s -> !s.isEmpty()); + Optional<String> ip6Address = optionalString(object.field(lbIp6AddressField), Function.identity()).filter(s -> !s.isEmpty()); Optional<DnsZone> dnsZone = optionalString(object.field(dnsZoneField), DnsZone::new); ZoneEndpoint settings = zoneEndpoint(object.field(settingsField)); Optional<PrivateServiceId> serviceId = optionalString(object.field(serviceIdField), PrivateServiceId::of); @@ -131,9 +134,9 @@ public class LoadBalancerSerializer { object.field(serviceIdsField).traverse((ArrayTraverser) (__, serviceIdObject) -> serviceIds.add(PrivateServiceId.of(serviceIdObject.asString()))); if (serviceIds.isEmpty()) serviceId.ifPresent(serviceIds::add); // TODO: remove after winter vacation '23 CloudAccount cloudAccount = optionalString(object.field(cloudAccountField), CloudAccount::from).orElse(CloudAccount.empty); - Optional<LoadBalancerInstance> instance = hostname.isEmpty() && ipAddress.isEmpty() + Optional<LoadBalancerInstance> instance = hostname.isEmpty() && ip4Address.isEmpty() && ip6Address.isEmpty() ? Optional.empty() - : Optional.of(new LoadBalancerInstance(hostname, ipAddress, dnsZone, ports, networks, reals, settings, serviceIds, cloudAccount)); + : Optional.of(new LoadBalancerInstance(hostname, ip4Address, ip6Address, dnsZone, ports, networks, reals, settings, serviceIds, cloudAccount)); return new LoadBalancer(LoadBalancerId.fromSerializedForm(object.field(idField).asString()), instance, diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java index 870e678a250..73531d650d5 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/NodeSerializer.java @@ -16,6 +16,7 @@ import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.config.provision.host.FlavorOverrides; import com.yahoo.config.provision.serialization.NetworkPortsSerializer; import com.yahoo.slime.ArrayTraverser; @@ -188,8 +189,10 @@ public class NodeSerializer { if (!node.cloudAccount().isUnspecified()) { object.setString(cloudAccountKey, node.cloudAccount().value()); } - node.wireguardPubKey().ifPresent(pubKey -> object.setString(wireguardPubKeyKey, pubKey.value())); - node.wireguardKeyTimestamp().ifPresent(timestamp -> object.setLong(wireguardKeyTimestampKey, timestamp.toEpochMilli())); + node.wireguardPubKey().ifPresent(pubKey -> { + object.setString(wireguardPubKeyKey, pubKey.key().value()); + object.setLong(wireguardKeyTimestampKey, pubKey.timestamp().toEpochMilli()); + }); } private void toSlime(Flavor flavor, Cursor object) { @@ -284,8 +287,7 @@ public class NodeSerializer { SlimeUtils.optionalString(object.field(switchHostnameKey)), trustedCertificatesFromSlime(object), SlimeUtils.optionalString(object.field(cloudAccountKey)).map(CloudAccount::from).orElse(CloudAccount.empty), - SlimeUtils.optionalString(object.field(wireguardPubKeyKey)).map(WireguardKey::from), - SlimeUtils.optionalInstant(object.field(wireguardKeyTimestampKey))); + wireguardKeyWithTimestampFromSlime(object.field(wireguardPubKeyKey), object.field(wireguardKeyTimestampKey))); } private Status statusFromSlime(Inspector object) { @@ -397,6 +399,13 @@ public class NodeSerializer { .toList(); } + private Optional<WireguardKeyWithTimestamp> wireguardKeyWithTimestampFromSlime(Inspector keyObject, Inspector timestampObject) { + if ( ! keyObject.valid()) return Optional.empty(); + return SlimeUtils.optionalString(keyObject).map( + key -> new WireguardKeyWithTimestamp(WireguardKey.from(key), + SlimeUtils.optionalInstant(timestampObject).orElse(null))); + } + // ----------------- Enum <-> string mappings ---------------------------------------- /** Returns the event type, or null if this event type should be ignored */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java index 09f947503f6..20aa7d8181e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java @@ -57,7 +57,8 @@ public class LoadBalancersResponse extends SlimeJsonResponse { lbObject.setString("instance", lb.id().application().instance().value()); lbObject.setString("cluster", lb.id().cluster().value()); lb.instance().flatMap(LoadBalancerInstance::hostname).ifPresent(hostname -> lbObject.setString("hostname", hostname.value())); - lb.instance().flatMap(LoadBalancerInstance::ipAddress).ifPresent(ipAddress -> lbObject.setString("ipAddress", ipAddress)); + lb.instance().flatMap(LoadBalancerInstance::ip4Address).ifPresent(ip -> lbObject.setString("ipAddress", ip)); + lb.instance().flatMap(LoadBalancerInstance::ip6Address).ifPresent(ip -> lbObject.setString("ip6Address", ip)); lb.instance().flatMap(LoadBalancerInstance::dnsZone).ifPresent(dnsZone -> lbObject.setString("dnsZone", dnsZone.id())); Cursor networkArray = lbObject.setArray("networks"); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java index 9f1ab3dc3d5..cad034e01aa 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodePatcher.java @@ -11,6 +11,7 @@ import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; import com.yahoo.slime.ObjectTraverser; @@ -108,7 +109,8 @@ public class NodePatcher { "reports", "trustStore", "vespaVersion", - "wireguardPubkey")); + "wireguardPubkey", // TODO wg: remove when all nodes use new key+timestamp format + "wireguard")); if (!disallowedFields.isEmpty()) { throw new IllegalArgumentException("Patching fields not supported: " + disallowedFields); } @@ -271,9 +273,13 @@ public class NodePatcher { return value.type() == Type.NIX ? node.withoutSwitchHostname() : node.withSwitchHostname(value.asString()); case "trustStore": return nodeWithTrustStore(node, value); - case "wireguardPubkey": - return node.withWireguardPubkey(SlimeUtils.optionalString(value).map(WireguardKey::new).orElse(null)) - .withWireguardKeyTimestamp(clock.instant()); + case "wireguard": + // This is where we set the key timestamp. + var key = SlimeUtils.optionalString(value.field("key")).map(WireguardKey::new).orElse(null); + return node.withWireguardPubkey(new WireguardKeyWithTimestamp(key, clock.instant())); + case "wireguardPubkey": // TODO wg: remove when all nodes use new key+timestamp format + var oldKey = SlimeUtils.optionalString(value).map(WireguardKey::new).orElse(null); + return node.withWireguardPubkey(new WireguardKeyWithTimestamp(oldKey, clock.instant())); default: throw new IllegalArgumentException("Could not apply field '" + name + "' on a node: No such modifiable field"); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java index a8f526544d7..05bb0a27d69 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java @@ -8,6 +8,7 @@ import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.DockerImage; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.NodeResources; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.config.provision.serialization.NetworkPortsSerializer; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.restapi.SlimeJsonResponse; @@ -192,8 +193,13 @@ class NodesResponse extends SlimeJsonResponse { if (!node.cloudAccount().isUnspecified()) { object.setString("cloudAccount", node.cloudAccount().value()); } - node.wireguardPubKey().ifPresent(key -> object.setString("wireguardPubkey", key.value())); - node.wireguardKeyTimestamp().ifPresent(timestamp -> object.setLong("wireguardKeyTimestamp", timestamp.toEpochMilli())); + node.wireguardPubKey().ifPresent(key -> toSlime(key, object.setObject("wireguard"))); + + // TODO wg: remove when all nodes have upgraded to new key+timestamp format + node.wireguardPubKey().ifPresent(key -> { + object.setString("wireguardPubkey", key.key().value()); + object.setLong("wireguardKeyTimestamp", key.timestamp().toEpochMilli()); + }); } private Version resolveVersionFlag(StringFlag flag, Node node, Allocation allocation) { @@ -237,6 +243,11 @@ class NodesResponse extends SlimeJsonResponse { } } + static void toSlime(WireguardKeyWithTimestamp keyWithTimestamp, Cursor object) { + object.setString("key", keyWithTimestamp.key().value()); + object.setLong("timestamp", keyWithTimestamp.timestamp().toEpochMilli()); + } + private Optional<DockerImage> currentContainerImage(Node node) { if (node.status().containerImage().isPresent()) { return node.status().containerImage(); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/WireguardResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/WireguardResponse.java index 16e85dfa48a..e29c4f1b87a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/WireguardResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/WireguardResponse.java @@ -1,7 +1,7 @@ package com.yahoo.vespa.hosted.provision.restapi; import com.yahoo.config.provision.NodeType; -import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.restapi.SlimeJsonResponse; import com.yahoo.slime.Cursor; import com.yahoo.vespa.hosted.provision.Node; @@ -10,9 +10,9 @@ import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.node.IP; import java.net.InetAddress; -import java.time.Instant; import java.util.List; -import java.util.Optional; + +import static com.yahoo.vespa.hosted.provision.restapi.NodesResponse.toSlime; /** * A response containing the wireguard peer config for each configserver that has a public key. @@ -36,17 +36,20 @@ public class WireguardResponse extends SlimeJsonResponse { .toList(); if (ipAddresses.isEmpty()) continue; - addConfigserver(cfgArray.addObject(), cfg.hostname(), cfg.wireguardPubKey().get(), - cfg.wireguardKeyTimestamp(), ipAddresses); + addConfigserver(cfgArray.addObject(), cfg.hostname(), cfg.wireguardPubKey().get(), ipAddresses); } } - private void addConfigserver(Cursor cfgEntry, String hostname, WireguardKey key, Optional<Instant> keyTimestamp, + private void addConfigserver(Cursor cfgEntry, String hostname, WireguardKeyWithTimestamp keyWithTimestamp, List<String> ipAddresses) { cfgEntry.setString("hostname", hostname); - cfgEntry.setString("wireguardPubkey", key.value()); - cfgEntry.setLong("wireguardKeyTimestamp", keyTimestamp.orElse(Instant.EPOCH).toEpochMilli()); + + // TODO wg: remove when all nodes are using new key+timestamp format + cfgEntry.setString("wireguardPubkey", keyWithTimestamp.key().value()); + cfgEntry.setLong("wireguardKeyTimestamp", keyWithTimestamp.timestamp().toEpochMilli()); + NodesResponse.ipAddressesToSlime(ipAddresses, cfgEntry.setArray("ipAddresses")); + toSlime(keyWithTimestamp, cfgEntry.setObject("wireguard")); } private static boolean isPublicIp(String ipAddress) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java index 72225763381..2fb549acc11 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/testutils/MockNodeRepository.java @@ -21,6 +21,7 @@ import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.WireguardKey; +import com.yahoo.config.provision.WireguardKeyWithTimestamp; import com.yahoo.config.provision.Zone; import com.yahoo.config.provision.ZoneEndpoint; import com.yahoo.config.provision.ZoneEndpoint.AccessType; @@ -161,8 +162,8 @@ public class MockNodeRepository extends NodeRepository { // Emulate host in tenant account nodes.add(Node.create("dockerhost2", ipConfig(101, 1, 3), "dockerhost2.yahoo.com", flavors.getFlavorOrThrow("large"), NodeType.host) - .wireguardPubKey(WireguardKey.from("000011112222333344445555666677778888999900c=")) - .wireguardKeyTimestamp(Instant.ofEpochMilli(123L)) + .wireguardKey(new WireguardKeyWithTimestamp(WireguardKey.from("000011112222333344445555666677778888999900c="), + Instant.ofEpochMilli(123L))) .cloudAccount(tenantAccount).build()); nodes.add(Node.create("dockerhost3", ipConfig(102, 1, 3), "dockerhost3.yahoo.com", flavors.getFlavorOrThrow("large"), NodeType.host).cloudAccount(defaultCloudAccount).build()); @@ -176,8 +177,8 @@ public class MockNodeRepository extends NodeRepository { // Config servers nodes.add(Node.create("cfg1", ipConfig(201), "cfg1.yahoo.com", flavors.getFlavorOrThrow("default"), NodeType.config) .cloudAccount(defaultCloudAccount) - .wireguardPubKey(WireguardKey.from("lololololololololololololololololololololoo=")) - .wireguardKeyTimestamp(Instant.ofEpochMilli(456L)) + .wireguardKey(new WireguardKeyWithTimestamp(WireguardKey.from("lololololololololololololololololololololoo="), + Instant.ofEpochMilli(456L))) .build()); nodes.add(Node.create("cfg2", ipConfig(202), "cfg2.yahoo.com", flavors.getFlavorOrThrow("default"), NodeType.config) .cloudAccount(defaultCloudAccount) diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java index 6dc681ae5c8..b5257e23d9e 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java @@ -40,6 +40,7 @@ public class LoadBalancerSerializerTest { Optional.of(new LoadBalancerInstance( Optional.of(DomainName.of("lb-host")), Optional.empty(), + Optional.empty(), Optional.of(new DnsZone("zone-id-1")), Set.of(4080, 4443), Set.of("10.2.3.4/24"), @@ -73,6 +74,7 @@ public class LoadBalancerSerializerTest { Optional.of(new LoadBalancerInstance( Optional.empty(), Optional.of("1.2.3.4"), + Optional.of("fd00::1"), Optional.of(new DnsZone("zone-id-1")), Set.of(4443), Set.of("10.2.3.4/24", "12.3.2.1/30"), @@ -86,6 +88,8 @@ public class LoadBalancerSerializerTest { var serialized = LoadBalancerSerializer.fromJson(LoadBalancerSerializer.toJson(loadBalancer)); assertEquals(loadBalancer.id(), serialized.id()); assertEquals(loadBalancer.instance().get().hostname(), serialized.instance().get().hostname()); + assertEquals(loadBalancer.instance().get().ip4Address(), serialized.instance().get().ip4Address()); + assertEquals(loadBalancer.instance().get().ip6Address(), serialized.instance().get().ip6Address()); assertEquals(loadBalancer.instance().get().dnsZone(), serialized.instance().get().dnsZone()); assertEquals(loadBalancer.instance().get().ports(), serialized.instance().get().ports()); assertEquals(loadBalancer.instance().get().networks(), serialized.instance().get().networks()); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/cfg1.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/cfg1.json index 928e91861a2..54a0e7e9757 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/cfg1.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/cfg1.json @@ -119,6 +119,10 @@ ], "additionalIpAddresses": [], "cloudAccount": "aws:111222333444", - "wireguardPubkey":"lololololololololololololololololololololoo=", + "wireguard": { + "key": "lololololololololololololololololololololoo=", + "timestamp": 456 + }, + "wireguardPubkey": "lololololololololololololololololololololoo=", "wireguardKeyTimestamp": 456 } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-node2.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-node2.json index 72b5483d849..d3f1a8082ae 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-node2.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/docker-node2.json @@ -117,6 +117,10 @@ "ipAddresses": ["127.0.101.1", "::101:1"], "additionalIpAddresses": ["::101:2", "::101:3", "::101:4"], "cloudAccount": "aws:777888999000", + "wireguard": { + "key": "000011112222333344445555666677778888999900c=", + "timestamp": 123 + }, "wireguardPubkey": "000011112222333344445555666677778888999900c=", "wireguardKeyTimestamp": 123 } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/node4-wg.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/node4-wg.json index d0d6df71fc1..404cf9a9a80 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/node4-wg.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/node4-wg.json @@ -118,6 +118,10 @@ "ipAddresses": ["127.0.4.1", "::4:1"], "additionalIpAddresses": [], "cloudAccount": "aws:111222333444", + "wireguard": { + "key": "lololololololololololololololololololololoo=", + "timestamp": 123 + }, "wireguardPubkey": "lololololololololololololololololololololoo=", "wireguardKeyTimestamp": 123 } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/wireguard.json b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/wireguard.json index 7bee06adc87..8e9af7f680f 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/wireguard.json +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/restapi/responses/wireguard.json @@ -4,7 +4,11 @@ "hostname": "cfg1.yahoo.com", "wireguardPubkey": "lololololololololololololololololololololoo=", "wireguardKeyTimestamp":456, - "ipAddresses": ["::201:1"] + "ipAddresses": ["::201:1"], + "wireguard": { + "key": "lololololololololololololololololololololoo=", + "timestamp": 456 + } } ] } diff --git a/parent/pom.xml b/parent/pom.xml index 50222402b68..4a84496d9da 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -317,7 +317,7 @@ --> <groupId>org.openrewrite.maven</groupId> <artifactId>rewrite-maven-plugin</artifactId> - <version>5.5.2</version> + <version>5.6.0</version> <configuration> <activeRecipes> <recipe>org.openrewrite.java.testing.junit5.JUnit5BestPractices</recipe> diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp index 575a52d01fb..d098bdde8b6 100644 --- a/searchcore/src/tests/proton/matching/query_test.cpp +++ b/searchcore/src/tests/proton/matching/query_test.cpp @@ -711,7 +711,7 @@ void Test::requireThatQueryGluesEverythingTogether() { EXPECT_EQUAL(1u, md->getNumTermFields()); query.optimize(); - query.fetchPostings(); + query.fetchPostings(requestContext.getDoom()); SearchIterator::UP search = query.createSearch(*md); ASSERT_TRUE(search.get()); } @@ -744,7 +744,7 @@ void checkQueryAddsLocation(const string &loc_in, const string &loc_out) { MatchData::UP md = mdl.createMatchData(); EXPECT_EQUAL(2u, md->getNumTermFields()); - query.fetchPostings(); + query.fetchPostings(requestContext.getDoom()); SearchIterator::UP search = query.createSearch(*md); ASSERT_TRUE(search.get()); if (!EXPECT_NOT_EQUAL(string::npos, search->asString().find(loc_out))) { @@ -966,7 +966,7 @@ Test::requireThatWhiteListBlueprintCanBeUsed() MatchData::UP md = mdl.createMatchData(); query.optimize(); - query.fetchPostings(); + query.fetchPostings(requestContext.getDoom()); SearchIterator::UP search = query.createSearch(*md); SimpleResult exp = SimpleResult().addHit(1).addHit(5).addHit(7).addHit(11); SimpleResult act; diff --git a/searchcore/src/vespa/searchcore/proton/docsummary/docsumcontext.cpp b/searchcore/src/vespa/searchcore/proton/docsummary/docsumcontext.cpp index be1c8941f65..e1820ece0e3 100644 --- a/searchcore/src/vespa/searchcore/proton/docsummary/docsumcontext.cpp +++ b/searchcore/src/vespa/searchcore/proton/docsummary/docsumcontext.cpp @@ -52,22 +52,11 @@ DocsumContext::initState() _docsumState._args.initFromDocsumRequest(req); _docsumState._docsumbuf.clear(); _docsumState._docsumbuf.reserve(req.hits.size()); - for (uint32_t i = 0; i < req.hits.size(); i++) { - _docsumState._docsumbuf.push_back(req.hits[i].docid); + for (const auto & hit : req.hits) { + _docsumState._docsumbuf.push_back(hit.docid); } } -namespace { - -vespalib::Slime::Params -makeSlimeParams(size_t chunkSize) { - Slime::Params params; - params.setChunkSize(chunkSize); - return params; -} - -} - vespalib::Slime::UP DocsumContext::createSlimeReply() { @@ -75,11 +64,11 @@ DocsumContext::createSlimeReply() _docsumState._args.get_fields()); _docsumWriter.initState(_attrMgr, _docsumState, rci); const size_t estimatedChunkSize(std::min(0x200000ul, _docsumState._docsumbuf.size()*0x400ul)); - vespalib::Slime::UP response(std::make_unique<vespalib::Slime>(makeSlimeParams(estimatedChunkSize))); + auto response = std::make_unique<vespalib::Slime>(Slime::Params(estimatedChunkSize)); Cursor & root = response->setObject(); Cursor & array = root.setArray(DOCSUMS); const Symbol docsumSym = response->insert(DOCSUM); - _docsumState._omit_summary_features = (rci.res_class != nullptr) ? rci.res_class->omit_summary_features() : true; + _docsumState._omit_summary_features = (rci.res_class == nullptr) || rci.res_class->omit_summary_features(); uint32_t num_ok(0); for (uint32_t docId : _docsumState._docsumbuf) { if (_request.expired() ) { break; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/attribute_limiter.cpp b/searchcore/src/vespa/searchcore/proton/matching/attribute_limiter.cpp index 1528b327747..b8027bff04a 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/attribute_limiter.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/attribute_limiter.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/fef/matchdatalayout.h> #include <vespa/searchlib/queryeval/searchable.h> #include <vespa/searchlib/queryeval/blueprint.h> +#include <vespa/searchlib/queryeval/irequestcontext.h> #include <vespa/searchlib/query/tree/range.h> #include <vespa/searchlib/query/tree/simplequery.h> @@ -98,7 +99,7 @@ AttributeLimiter::create_search(size_t want_hits, size_t max_group_size, bool st FieldSpecList field; // single field API is protected field.add(FieldSpec(_attribute_name, my_field_id, my_handle)); _blueprint = _searchable_attributes.createBlueprint(_requestContext, field, node); - _blueprint->fetchPostings(ExecuteInfo::create(strictSearch)); + _blueprint->fetchPostings(ExecuteInfo::create(strictSearch, &_requestContext.getDoom())); _estimatedHits.store(_blueprint->getState().estimate().estHits, std::memory_order_relaxed); _blueprint->freeze(); } diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp index 5ae671b88cb..758ef35ebc9 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp @@ -201,9 +201,9 @@ MatchToolsFactory(QueryLimiter & queryLimiter, trace.addEvent(5, "Optimize query execution plan"); _query.optimize(); trace.addEvent(4, "Perform dictionary lookups and posting lists initialization"); - _query.fetchPostings(); + _query.fetchPostings(_requestContext.getDoom()); if (is_search) { - _query.handle_global_filter(searchContext.getDocIdLimit(), + _query.handle_global_filter(_requestContext.getDoom(), searchContext.getDocIdLimit(), _attribute_blueprint_params.global_filter_lower_limit, _attribute_blueprint_params.global_filter_upper_limit, thread_bundle, trace); diff --git a/searchcore/src/vespa/searchcore/proton/matching/query.cpp b/searchcore/src/vespa/searchcore/proton/matching/query.cpp index d0738f1857f..22f6ec9cc88 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/query.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/query.cpp @@ -247,13 +247,14 @@ Query::optimize() } void -Query::fetchPostings() +Query::fetchPostings(const vespalib::Doom & doom) { - _blueprint->fetchPostings(search::queryeval::ExecuteInfo::create(true, 1.0)); + _blueprint->fetchPostings(search::queryeval::ExecuteInfo::create(true, &doom)); } void -Query::handle_global_filter(uint32_t docid_limit, double global_filter_lower_limit, double global_filter_upper_limit, +Query::handle_global_filter(const vespalib::Doom & doom, uint32_t docid_limit, + double global_filter_lower_limit, double global_filter_upper_limit, vespalib::ThreadBundle &thread_bundle, search::engine::Trace& trace) { if (!handle_global_filter(*_blueprint, docid_limit, global_filter_lower_limit, global_filter_upper_limit, thread_bundle, &trace)) { @@ -264,7 +265,7 @@ Query::handle_global_filter(uint32_t docid_limit, double global_filter_lower_lim _blueprint = Blueprint::optimize(std::move(_blueprint)); LOG(debug, "blueprint after handle_global_filter:\n%s\n", _blueprint->asString().c_str()); // strictness may change if optimized order changed: - fetchPostings(); + fetchPostings(doom); } bool diff --git a/searchcore/src/vespa/searchcore/proton/matching/query.h b/searchcore/src/vespa/searchcore/proton/matching/query.h index b0299307e92..1a3136042a7 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/query.h +++ b/searchcore/src/vespa/searchcore/proton/matching/query.h @@ -98,9 +98,10 @@ public: * test to verify the original query without optimization. **/ void optimize(); - void fetchPostings(); + void fetchPostings(const vespalib::Doom & doom); - void handle_global_filter(uint32_t docid_limit, double global_filter_lower_limit, double global_filter_upper_limit, + void handle_global_filter(const vespalib::Doom & doom, uint32_t docid_limit, + double global_filter_lower_limit, double global_filter_upper_limit, vespalib::ThreadBundle &thread_bundle, search::engine::Trace& trace); /** diff --git a/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp b/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp index c7bd0e917f3..c2a39779061 100644 --- a/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp +++ b/searchlib/src/tests/attribute/dfa_fuzzy_matcher/dfa_fuzzy_matcher_test.cpp @@ -8,6 +8,7 @@ #include <vespa/vespalib/fuzzy/levenshtein_dfa.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/text/utf8.h> #include <filesystem> #include <fstream> #include <iostream> @@ -26,13 +27,24 @@ using namespace search::attribute; using namespace search; using vespalib::FuzzyMatcher; using vespalib::datastore::AtomicEntryRef; +using vespalib::datastore::EntryRef; using vespalib::fuzzy::LevenshteinDfa; +using vespalib::Utf8Reader; +using vespalib::Utf8Writer; using StringEnumStore = EnumStoreT<const char*>; using DictionaryEntry = std::pair<std::string, size_t>; using RawDictionary = std::vector<DictionaryEntry>; using StringVector = std::vector<std::string>; +namespace { + +const char* char_from_u8(const char8_t* p) { + return reinterpret_cast<const char*>(p); +} + +} + RawDictionary read_dictionary() { @@ -109,11 +121,11 @@ struct MatchStats { template <bool collect_matches> void -brute_force_fuzzy_match_in_dictionary(std::string_view target, const StringEnumStore& store, MatchStats& stats, StringVector& matched_words) +brute_force_fuzzy_match_in_dictionary(std::string_view target, const StringEnumStore& store, uint32_t prefix_size, bool cased, MatchStats& stats, StringVector& matched_words) { auto view = store.get_dictionary().get_posting_dictionary().getFrozenView(); vespalib::Timer timer; - FuzzyMatcher matcher(target, 2, 0, false); + FuzzyMatcher matcher(target, 2, prefix_size, cased); auto itr = view.begin(); size_t matches = 0; size_t seeks = 0; @@ -133,15 +145,33 @@ brute_force_fuzzy_match_in_dictionary(std::string_view target, const StringEnumS template <bool collect_matches> void -dfa_fuzzy_match_in_dictionary(std::string_view target, const StringEnumStore& store, MatchStats& stats, StringVector& matched_words) +dfa_fuzzy_match_in_dictionary(std::string_view target, const StringEnumStore& store, uint32_t prefix_size, bool cased, MatchStats& stats, StringVector& matched_words) { auto view = store.get_dictionary().get_posting_dictionary().getFrozenView(); vespalib::Timer timer; - DfaFuzzyMatcher matcher(target, 2, false, LevenshteinDfa::DfaType::Explicit); - auto itr = view.begin(); + DfaFuzzyMatcher matcher(target, 2, prefix_size, cased, LevenshteinDfa::DfaType::Explicit); + Utf8Reader reader(vespalib::stringref(target.data(), target.size())); + std::string target_copy; + Utf8Writer<std::string> writer(target_copy); + for (size_t pos = 0; pos < prefix_size && reader.hasMore(); ++pos) { + auto code_point = reader.getChar(); + writer.putChar(code_point); + } + auto prefix_cmp = store.make_folded_comparator_prefix(target_copy.c_str()); + auto itr = prefix_size > 0 ? view.lowerBound(AtomicEntryRef(), prefix_cmp) : view.begin(); + auto itr_end = itr; + if (itr_end.valid()) { + if (prefix_size > 0) { + if (!prefix_cmp.less(EntryRef(), itr_end.getKey().load_relaxed())) { + itr_end.seekPast(AtomicEntryRef(), prefix_cmp); + } + } else { + itr_end.end(); + } + } size_t matches = 0; size_t seeks = 0; - while (itr.valid()) { + while (itr != itr_end) { auto word = store.get_value(itr.getKey().load_relaxed()); if (matcher.is_match(word, itr, store.get_data_store())) { ++itr; @@ -156,10 +186,58 @@ dfa_fuzzy_match_in_dictionary(std::string_view target, const StringEnumStore& st stats.add_sample(matches, seeks, timer.elapsed()); } -struct DfaFuzzyMatcherTest : public ::testing::Test { +template <bool collect_matches> +void +dfa_fuzzy_match_in_dictionary_no_skip(std::string_view target, const StringEnumStore& store, uint32_t prefix_size, bool cased, MatchStats& stats, StringVector& matched_words) +{ + auto view = store.get_dictionary().get_posting_dictionary().getFrozenView(); + vespalib::Timer timer; + DfaFuzzyMatcher matcher(target, 2, prefix_size, cased, LevenshteinDfa::DfaType::Explicit); + auto itr = view.begin(); + size_t matches = 0; + size_t seeks = 0; + for (;itr.valid(); ++itr) { + auto word = store.get_value(itr.getKey().load_relaxed()); + if (matcher.is_match(word)) { + ++matches; + if (collect_matches) { + matched_words.push_back(word); + } + } else { + ++seeks; + } + } + stats.add_sample(matches, seeks, timer.elapsed()); +} + +struct TestParam +{ + vespalib::string _name; + bool _cased; + + TestParam(vespalib::string name, bool cased) + : _name(std::move(name)), + _cased(cased) + { + } + TestParam(const TestParam&); + ~TestParam(); +}; + +TestParam::TestParam(const TestParam&) = default; + +TestParam::~TestParam() = default; + +std::ostream& operator<<(std::ostream& os, const TestParam& param) +{ + os << param._name; + return os; +} + +struct DfaFuzzyMatcherTest : public ::testing::TestWithParam<TestParam> { StringEnumStore store; DfaFuzzyMatcherTest() - : store(true, DictionaryConfig(DictionaryConfig::Type::BTREE, DictionaryConfig::Match::UNCASED)) + : store(true, DictionaryConfig(DictionaryConfig::Type::BTREE, GetParam()._cased ? DictionaryConfig::Match::CASED : DictionaryConfig::Match::UNCASED)) {} void populate_dictionary(const StringVector& words) { auto updater = store.make_batch_updater(); @@ -170,18 +248,31 @@ struct DfaFuzzyMatcherTest : public ::testing::Test { updater.commit(); store.freeze_dictionary(); } - void expect_matches(std::string_view target, const StringVector& exp_matches) { + void expect_prefix_matches(std::string_view target, uint32_t prefix_size, const StringVector& exp_matches) { MatchStats stats; StringVector brute_force_matches; StringVector dfa_matches; - brute_force_fuzzy_match_in_dictionary<true>(target, store, stats, brute_force_matches); - dfa_fuzzy_match_in_dictionary<true>(target, store, stats, dfa_matches); + StringVector dfa_no_skip_matches; + bool cased = GetParam()._cased; + SCOPED_TRACE(target); + brute_force_fuzzy_match_in_dictionary<true>(target, store, prefix_size, cased, stats, brute_force_matches); + dfa_fuzzy_match_in_dictionary<true>(target, store, prefix_size, cased, stats, dfa_matches); + dfa_fuzzy_match_in_dictionary_no_skip<true>(target, store, prefix_size, cased, stats, dfa_no_skip_matches); EXPECT_EQ(exp_matches, brute_force_matches); EXPECT_EQ(exp_matches, dfa_matches); + EXPECT_EQ(exp_matches, dfa_no_skip_matches); + } + void expect_matches(std::string_view target, const StringVector& exp_matches) { + expect_prefix_matches(target, 0, exp_matches); } }; -TEST_F(DfaFuzzyMatcherTest, fuzzy_match_in_dictionary) +INSTANTIATE_TEST_SUITE_P(DfaFuzzyMatcherMultiTest, + DfaFuzzyMatcherTest, + testing::Values(TestParam("uncased", false), TestParam("cased", true)), + testing::PrintToStringParamName()); + +TEST_P(DfaFuzzyMatcherTest, fuzzy_match_in_dictionary) { StringVector words = { "board", "boat", "bob", "door", "food", "foot", "football", "foothill", "for", "forbid", "force", "ford", "forearm", "forecast", "forest" }; @@ -194,23 +285,67 @@ TEST_F(DfaFuzzyMatcherTest, fuzzy_match_in_dictionary) expect_matches("forcecast", {"forecast"}); } +TEST_P(DfaFuzzyMatcherTest, fuzzy_match_in_dictionary_with_prefix_size) +{ + bool cased = GetParam()._cased; + StringVector words = { "board", "boat", "bob", "door", "food", "foot", "football", "foothill", + "for", "forbid", "force", "ford", "forearm", "forecast", "forest", "H", "HA", "h", "ha", char_from_u8(u8"Ørn"), char_from_u8(u8"øre"), char_from_u8(u8"Ås"), char_from_u8(u8"ås")}; + populate_dictionary(words); + expect_prefix_matches("a", 1, {}); + expect_prefix_matches("b", 1, {"bob"}); + expect_prefix_matches("board", 1, {"board", "boat"}); + expect_prefix_matches("c", 1, {}); + expect_prefix_matches("food", 1, {"food", "foot", "for", "ford"}); + expect_prefix_matches("food", 2, {"food", "foot", "for", "ford"}); + expect_prefix_matches("food", 3, {"food", "foot"}); + expect_prefix_matches("foothill", 1, {"football", "foothill"}); + expect_prefix_matches("for", 1, {"food", "foot", "for", "force", "ford"}); + expect_prefix_matches("for", 2, {"food", "foot", "for", "force", "ford"}); + expect_prefix_matches("for", 3, {"for", "force", "ford"}); + expect_prefix_matches("force", 1, {"for", "force", "ford"}); + expect_prefix_matches("forcecast", 1, {"forecast"}); + expect_prefix_matches("forcecast", 4, {}); + expect_prefix_matches("z", 1, {}); + if (cased) { + expect_prefix_matches("h", 1, {"h", "ha"}); + expect_prefix_matches(char_from_u8(u8"Ø"), 1, {char_from_u8(u8"Ørn")}); + expect_prefix_matches(char_from_u8(u8"ø"), 1, {char_from_u8(u8"øre")}); + expect_prefix_matches(char_from_u8(u8"å"), 1, {char_from_u8(u8"ås")}); + /* Corner case: prefix length > target length means exact match */ + expect_prefix_matches("h", 2, {"h"}); + } else { + expect_prefix_matches("h", 1, {"H", "h", "HA", "ha"}); + expect_prefix_matches(char_from_u8(u8"ø"), 1, {char_from_u8(u8"øre"), char_from_u8(u8"Ørn")}); + expect_prefix_matches(char_from_u8(u8"å"), 1, {char_from_u8(u8"Ås"), char_from_u8(u8"ås")}); + /* Corner case: prefix length > target length means exact match */ + expect_prefix_matches("h", 2, {"H", "h"}); + } +} + void -benchmark_fuzzy_match_in_dictionary(const StringEnumStore& store, const RawDictionary& dict, size_t words_to_match, bool dfa_algorithm) +benchmark_fuzzy_match_in_dictionary(const StringEnumStore& store, const RawDictionary& dict, size_t words_to_match, bool cased, bool dfa_algorithm) { MatchStats stats; StringVector dummy; for (size_t i = 0; i < std::min(words_to_match, dict.size()); ++i) { const auto& entry = dict[i]; if (dfa_algorithm) { - dfa_fuzzy_match_in_dictionary<false>(entry.first, store, stats, dummy); + dfa_fuzzy_match_in_dictionary<false>(entry.first, store, 0, cased, stats, dummy); } else { - brute_force_fuzzy_match_in_dictionary<false>(entry.first, store, stats, dummy); + brute_force_fuzzy_match_in_dictionary<false>(entry.first, store, 0, cased, stats, dummy); } } std::cout << (dfa_algorithm ? "DFA:" : "Brute force:") << " samples=" << stats.samples << ", avg_matches=" << stats.avg_matches() << ", avg_seeks=" << stats.avg_seeks() << ", avg_elapsed_ms=" << stats.avg_elapsed_ms() << std::endl; } -TEST_F(DfaFuzzyMatcherTest, benchmark_fuzzy_match_in_dictionary) +using DfaFuzzyMatcherBenchmarkTest = DfaFuzzyMatcherTest; + +INSTANTIATE_TEST_SUITE_P(DfaFuzzyMatcherBenchmarkMultiTest, + DfaFuzzyMatcherBenchmarkTest, + testing::Values(TestParam("uncased", false)), + testing::PrintToStringParamName()); + +TEST_P(DfaFuzzyMatcherBenchmarkTest, benchmark_fuzzy_match_in_dictionary) { if (!benchmarking_enabled()) { GTEST_SKIP() << "benchmarking not enabled"; @@ -219,8 +354,9 @@ TEST_F(DfaFuzzyMatcherTest, benchmark_fuzzy_match_in_dictionary) populate_dictionary(to_string_vector(dict)); std::cout << "Unique words: " << store.get_num_uniques() << std::endl; sort_by_freq(dict); - benchmark_fuzzy_match_in_dictionary(store, dict, dfa_words_to_match, true); - benchmark_fuzzy_match_in_dictionary(store, dict, brute_force_words_to_match, false); + bool cased = GetParam()._cased; + benchmark_fuzzy_match_in_dictionary(store, dict, dfa_words_to_match, cased, true); + benchmark_fuzzy_match_in_dictionary(store, dict, brute_force_words_to_match, cased, false); } int diff --git a/searchlib/src/tests/attribute/document_weight_or_filter_search/document_weight_or_filter_search_test.cpp b/searchlib/src/tests/attribute/document_weight_or_filter_search/document_weight_or_filter_search_test.cpp index b9c70d76934..1fd9dde09c7 100644 --- a/searchlib/src/tests/attribute/document_weight_or_filter_search/document_weight_or_filter_search_test.cpp +++ b/searchlib/src/tests/attribute/document_weight_or_filter_search/document_weight_or_filter_search_test.cpp @@ -24,14 +24,14 @@ class DocumentWeightOrFilterSearchTest : public ::testing::Test { uint32_t _range_end; public: DocumentWeightOrFilterSearchTest(); - ~DocumentWeightOrFilterSearchTest(); + ~DocumentWeightOrFilterSearchTest() override; void inc_generation(); size_t num_trees() const { return _trees.size(); } Iterator get_tree(size_t idx) const { if (idx < _trees.size()) { return _postings.beginFrozen(_trees[idx]); } else { - return Iterator(); + return {}; } } void ensure_tree(size_t idx) { @@ -39,13 +39,13 @@ public: _trees.resize(idx + 1); } } - void add_tree(size_t idx, std::vector<uint32_t> keys) { + void add_tree(size_t idx, const std::vector<uint32_t>& keys) { ensure_tree(idx); std::vector<KeyData> adds; std::vector<uint32_t> removes; adds.reserve(keys.size()); for (auto& key : keys) { - adds.emplace_back(KeyData(key, 1)); + adds.emplace_back(key, 1); } _postings.apply(_trees[idx], adds.data(), adds.data() + adds.size(), removes.data(), removes.data() + removes.size()); } @@ -67,7 +67,7 @@ public: return result; }; - std::vector<uint32_t> eval_daat(SearchIterator &iterator) { + std::vector<uint32_t> eval_daat(SearchIterator &iterator) const { std::vector<uint32_t> result; uint32_t doc_id = _range_start; while (doc_id < _range_end) { @@ -81,7 +81,7 @@ public: return result; } - std::vector<uint32_t> frombv(const BitVector &bv) { + std::vector<uint32_t> frombv(const BitVector &bv) const { std::vector<uint32_t> result; uint32_t doc_id = _range_start; doc_id = bv.getNextTrueBit(doc_id); @@ -93,7 +93,7 @@ public: return result; } - std::unique_ptr<BitVector> tobv(std::vector<uint32_t> values) { + std::unique_ptr<BitVector> tobv(const std::vector<uint32_t> & values) const { auto bv = BitVector::create(_range_start, _range_end); for (auto value : values) { bv->setBit(value); @@ -102,7 +102,7 @@ public: return bv; } - void expect_result(std::vector<uint32_t> exp, std::vector<uint32_t> act) + static void expect_result(const std::vector<uint32_t> & exp, const std::vector<uint32_t> & act) { EXPECT_EQ(exp, act); } @@ -227,7 +227,7 @@ public: } _test.inc_generation(); } - ~Verifier() { + ~Verifier() override { for (uint32_t tree_id = 0; tree_id < _test.num_trees(); ++tree_id) { _test.clear_tree(tree_id); } diff --git a/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp index d7a854e0afc..6c6f05fd5e2 100644 --- a/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp @@ -29,14 +29,14 @@ using namespace search::attribute::test; namespace { void -setupAttributeManager(MockAttributeManager &manager) +setupAttributeManager(MockAttributeManager &manager, bool isFilter) { AttributeVector::DocId docId; { - AttributeVector::SP attr_sp = AttributeFactory::createAttribute("integer", Config(BasicType("int64"))); + AttributeVector::SP attr_sp = AttributeFactory::createAttribute("integer", Config(BasicType("int64")).setIsFilter(isFilter)); manager.addAttribute(attr_sp); - IntegerAttribute *attr = (IntegerAttribute*)(attr_sp.get()); + auto *attr = (IntegerAttribute*)(attr_sp.get()); for (size_t i = 1; i < 10; ++i) { attr->addDoc(docId); assert(i == docId); @@ -45,10 +45,10 @@ setupAttributeManager(MockAttributeManager &manager) } } { - AttributeVector::SP attr_sp = AttributeFactory::createAttribute("string", Config(BasicType("string"))); + AttributeVector::SP attr_sp = AttributeFactory::createAttribute("string", Config(BasicType("string")).setIsFilter(isFilter)); manager.addAttribute(attr_sp); - StringAttribute *attr = (StringAttribute*)(attr_sp.get()); + auto *attr = (StringAttribute*)(attr_sp.get()); for (size_t i = 1; i < 10; ++i) { attr->addDoc(docId); assert(i == docId); @@ -58,9 +58,9 @@ setupAttributeManager(MockAttributeManager &manager) } { AttributeVector::SP attr_sp = AttributeFactory::createAttribute( - "multi", Config(BasicType("int64"), search::attribute::CollectionType("array"))); + "multi", Config(BasicType("int64"), search::attribute::CollectionType("array")).setIsFilter(isFilter)); manager.addAttribute(attr_sp); - IntegerAttribute *attr = (IntegerAttribute*)(attr_sp.get()); + auto *attr = (IntegerAttribute*)(attr_sp.get()); for (size_t i = 1; i < 10; ++i) { attr->addDoc(docId); assert(i == docId); @@ -78,35 +78,43 @@ struct WS { TermFieldHandle handle; std::vector<std::pair<std::string, uint32_t> > tokens; - WS(IAttributeManager & manager) : attribute_manager(manager), layout(), handle(layout.allocTermField(fieldId)), tokens() { + explicit WS(IAttributeManager & manager) + : attribute_manager(manager), + layout(), handle(layout.allocTermField(fieldId)), + tokens() + { MatchData::UP tmp = layout.createMatchData(); ASSERT_TRUE(tmp->resolveTermField(handle)->getFieldId() == fieldId); } WS &add(const std::string &token, uint32_t weight) { - tokens.push_back(std::make_pair(token, weight)); + tokens.emplace_back(token, weight); return *this; } Node::UP createNode() const { - SimpleWeightedSetTerm *node = new SimpleWeightedSetTerm(tokens.size(), "view", 0, Weight(0)); - for (size_t i = 0; i < tokens.size(); ++i) { - node->addTerm(tokens[i].first, Weight(tokens[i].second)); + auto *node = new SimpleWeightedSetTerm(tokens.size(), "view", 0, Weight(0)); + for (const auto & token : tokens) { + node->addTerm(token.first, Weight(token.second)); } return Node::UP(node); } - bool isGenericSearch(Searchable &searchable, const std::string &field, bool strict) const { + SearchIterator::UP + createSearch(Searchable &searchable, const std::string &field, bool strict) const { AttributeContext ac(attribute_manager); FakeRequestContext requestContext(&ac); MatchData::UP md = layout.createMatchData(); Node::UP node = createNode(); FieldSpecList fields; - fields.add(FieldSpec(field, fieldId, handle)); + fields.add(FieldSpec(field, fieldId, handle, ac.getAttribute(field)->getIsFilter())); queryeval::Blueprint::UP bp = searchable.createBlueprint(requestContext, fields, *node); bp->fetchPostings(queryeval::ExecuteInfo::create(strict)); SearchIterator::UP sb = bp->createSearch(*md, strict); - return (dynamic_cast<WeightedSetTermSearch*>(sb.get()) != 0); + return sb; + } + bool isWeightedSetTermSearch(Searchable &searchable, const std::string &field, bool strict) const { + return dynamic_cast<WeightedSetTermSearch *>(createSearch(searchable, field, strict).get()) != nullptr; } FakeResult search(Searchable &searchable, const std::string &field, bool strict) const { @@ -140,23 +148,58 @@ struct WS { } // namespace <unnamed> +void test_tokens(bool isFilter, const std::vector<uint32_t> & docs) { + MockAttributeManager manager; + setupAttributeManager(manager, isFilter); + AttributeBlueprintFactory adapter; + + FakeResult expect = FakeResult(); + WS ws = WS(manager); + for (uint32_t doc : docs) { + auto docS = vespalib::stringify(doc); + int32_t weight = doc * 10; + expect.doc(doc).weight(weight).pos(0); + ws.add(docS, weight); + } + + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "integer", true)); + EXPECT_TRUE(!ws.isWeightedSetTermSearch(adapter, "integer", false)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "string", true)); + EXPECT_TRUE(!ws.isWeightedSetTermSearch(adapter, "string", false)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", true)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", false)); + + EXPECT_EQUAL(expect, ws.search(adapter, "integer", true)); + EXPECT_EQUAL(expect, ws.search(adapter, "integer", false)); + EXPECT_EQUAL(expect, ws.search(adapter, "string", true)); + EXPECT_EQUAL(expect, ws.search(adapter, "string", false)); + EXPECT_EQUAL(expect, ws.search(adapter, "multi", true)); + EXPECT_EQUAL(expect, ws.search(adapter, "multi", false)); +} TEST("attribute_weighted_set_test") { + test_tokens(false, {3, 5, 7}); + test_tokens(true, {3, 5, 7}); + test_tokens(false, {3}); +} + +TEST("attribute_weighted_set_single_token_filter_lifted_out") { MockAttributeManager manager; - setupAttributeManager(manager); + setupAttributeManager(manager, true); AttributeBlueprintFactory adapter; - FakeResult expect = FakeResult() - .doc(3).elem(0).weight(30).pos(0) - .doc(5).elem(0).weight(50).pos(0) - .doc(7).elem(0).weight(70).pos(0); - WS ws = WS(manager).add("7", 70).add("5", 50).add("3", 30); - - EXPECT_TRUE(ws.isGenericSearch(adapter, "integer", true)); - EXPECT_TRUE(!ws.isGenericSearch(adapter, "integer", false)); - EXPECT_TRUE(ws.isGenericSearch(adapter, "string", true)); - EXPECT_TRUE(!ws.isGenericSearch(adapter, "string", false)); - EXPECT_TRUE(ws.isGenericSearch(adapter, "multi", true)); - EXPECT_TRUE(ws.isGenericSearch(adapter, "multi", false)); + FakeResult expect = FakeResult().doc(3).elem(0).weight(30).pos(0); + WS ws = WS(manager).add("3", 30); + + EXPECT_EQUAL("search::FilterAttributeIteratorStrict<search::attribute::SingleNumericSearchContext<long, search::attribute::NumericMatcher<long> > >", + ws.createSearch(adapter, "integer", true)->getClassName()); + EXPECT_EQUAL("search::FilterAttributeIteratorT<search::attribute::SingleNumericSearchContext<long, search::attribute::NumericMatcher<long> > >", + ws.createSearch(adapter, "integer", false)->getClassName()); + EXPECT_EQUAL("search::FilterAttributeIteratorStrict<search::attribute::SingleEnumSearchContext<char const*, search::attribute::StringSearchContext> >", + ws.createSearch(adapter, "string", true)->getClassName()); + EXPECT_EQUAL("search::FilterAttributeIteratorT<search::attribute::SingleEnumSearchContext<char const*, search::attribute::StringSearchContext> >", + ws.createSearch(adapter, "string", false)->getClassName()); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", true)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", false)); EXPECT_EQUAL(expect, ws.search(adapter, "integer", true)); EXPECT_EQUAL(expect, ws.search(adapter, "integer", false)); diff --git a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp index e9d0f8cb736..52329f31ba7 100644 --- a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp +++ b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp @@ -389,8 +389,8 @@ testSingleValue(Attribute & svsa, Config &cfg) TEST("testSingleValue") { EXPECT_EQUAL(24u, sizeof(SearchContext)); - EXPECT_EQUAL(40u, sizeof(StringSearchHelper)); - EXPECT_EQUAL(96u, sizeof(attribute::SingleStringEnumSearchContext)); + EXPECT_EQUAL(48u, sizeof(StringSearchHelper)); + EXPECT_EQUAL(104u, sizeof(attribute::SingleStringEnumSearchContext)); { Config cfg(BasicType::STRING, CollectionType::SINGLE); SingleValueStringAttribute svsa("svsa", cfg); diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index ef0fd56840a..c617db871a7 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -128,9 +128,9 @@ TEST("test And propagates updated histestimate") { const RememberExecuteInfo & child = dynamic_cast<const RememberExecuteInfo &>(bp.getChild(i)); EXPECT_EQUAL((i == 0), child.executeInfo.isStrict()); } - EXPECT_EQUAL(1.0, dynamic_cast<const RememberExecuteInfo &>(bp.getChild(0)).executeInfo.hitRate()); - EXPECT_EQUAL(1.0/250, dynamic_cast<const RememberExecuteInfo &>(bp.getChild(1)).executeInfo.hitRate()); - EXPECT_EQUAL(1.0/(250*25), dynamic_cast<const RememberExecuteInfo &>(bp.getChild(2)).executeInfo.hitRate()); + EXPECT_EQUAL(1.0f, dynamic_cast<const RememberExecuteInfo &>(bp.getChild(0)).executeInfo.hitRate()); + EXPECT_EQUAL(1.0f/250, dynamic_cast<const RememberExecuteInfo &>(bp.getChild(1)).executeInfo.hitRate()); + EXPECT_EQUAL(1.0f/(250*25), dynamic_cast<const RememberExecuteInfo &>(bp.getChild(2)).executeInfo.hitRate()); } TEST("test And Blueprint") { diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 1519bb14554..71ea2a67299 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -337,10 +337,7 @@ public: if (tfmda.size() == 1) { // search in exactly one field fef::TermFieldMatchData &tfmd = *tfmda[0]; - return search::common::create_location_iterator(tfmd, - _attribute.getNumDocs(), - strict, - _location); + return common::create_location_iterator(tfmd, _attribute.getNumDocs(), strict, _location); } else { LOG(debug, "wrong size tfmda: %zu (fallback to old location iterator)\n", tfmda.size()); } @@ -485,6 +482,9 @@ DirectWeightedSetBlueprint<SearchType>::createLeafSearch(const TermFieldMatchDat _attr.create(r.posting_idx, iterators); } bool field_is_filter = getState().fields()[0].isFilter(); + if (field_is_filter && tfmda[0]->isNotNeeded()) { + return attribute::DocumentWeightOrFilterSearch::create(std::move(iterators)); + } return SearchType::create(*tfmda[0], field_is_filter, _weights, std::move(iterators)); } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 108128eeb39..94c560a0dae 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -30,7 +30,7 @@ protected: const attribute::IAttributeVector &attribute() const { return _attr; } public: - UseAttr(const attribute::IAttributeVector & attr) + explicit UseAttr(const attribute::IAttributeVector & attr) : _attr(attr) {} }; @@ -40,7 +40,7 @@ class UseStringEnum : public UseAttr { public: using TokenT = uint32_t; - UseStringEnum(const IAttributeVector & attr) + explicit UseStringEnum(const IAttributeVector & attr) : UseAttr(attr) {} auto mapToken(const ISearchContext &context) const { return attribute().findFoldedEnums(context.queryTerm()->getTerm()); @@ -56,7 +56,7 @@ class UseInteger : public UseAttr { public: using TokenT = uint64_t; - UseInteger(const IAttributeVector & attr) : UseAttr(attr) {} + explicit UseInteger(const IAttributeVector & attr) : UseAttr(attr) {} std::vector<int64_t> mapToken(const ISearchContext &context) const { std::vector<int64_t> result; Int64Range range(context.getAsIntegerTerm()); @@ -157,6 +157,10 @@ AttributeWeightedSetBlueprint::createLeafSearch(const fef::TermFieldMatchDataArr assert(tfmda.size() == 1); assert(getState().numFields() == 1); fef::TermFieldMatchData &tfmd = *tfmda[0]; + bool field_is_filter = getState().fields()[0].isFilter(); + if (field_is_filter && (_contexts.size() == 1)) { + return _contexts[0]->createIterator(&tfmd, strict); + } if (strict) { // use generic weighted set search fef::MatchDataLayout layout; auto handle = layout.allocTermField(tfmd.getFieldId()); @@ -167,7 +171,6 @@ AttributeWeightedSetBlueprint::createLeafSearch(const fef::TermFieldMatchDataArr // TODO: pass ownership with unique_ptr children[i] = _contexts[i]->createIterator(child_tfmd, true).release(); } - bool field_is_filter = getState().fields()[0].isFilter(); return queryeval::WeightedSetTermSearch::create(children, tfmd, field_is_filter, _weights, std::move(match_data)); } else { // use attribute filter optimization bool isString = (_attr.isStringType() && _attr.hasEnum()); @@ -182,18 +185,16 @@ AttributeWeightedSetBlueprint::createLeafSearch(const fef::TermFieldMatchDataArr } queryeval::SearchIterator::UP -AttributeWeightedSetBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) const +AttributeWeightedSetBlueprint::createFilterSearch(bool strict, FilterConstraint) const { - (void) constraint; std::vector<std::unique_ptr<queryeval::SearchIterator>> children; children.reserve(_contexts.size()); for (auto& context : _contexts) { - auto wrapper = std::make_unique<search::queryeval::FilterWrapper>(1); + auto wrapper = std::make_unique<queryeval::FilterWrapper>(1); wrapper->wrap(context->createIterator(wrapper->tfmda()[0], strict)); children.emplace_back(std::move(wrapper)); } - search::queryeval::UnpackInfo unpack_info; - return search::queryeval::OrSearch::create(std::move(children), strict, unpack_info); + return queryeval::OrSearch::create(std::move(children), strict, queryeval::UnpackInfo()); } void diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp index 580c34bd5d0..b16fdc12a9a 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp +++ b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.cpp @@ -1,17 +1,98 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "dfa_fuzzy_matcher.h" +#include <vespa/vespalib/text/utf8.h> +#include <vespa/vespalib/text/lowercase.h> using vespalib::fuzzy::LevenshteinDfa; +using vespalib::LowerCase; +using vespalib::Utf8Reader; +using vespalib::Utf8ReaderForZTS; namespace search::attribute { -DfaFuzzyMatcher::DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, bool cased, LevenshteinDfa::DfaType dfa_type) - : _dfa(vespalib::fuzzy::LevenshteinDfa::build(target, max_edits, (cased ? LevenshteinDfa::Casing::Cased : LevenshteinDfa::Casing::Uncased), dfa_type)), - _successor() +namespace { + +std::vector<uint32_t> +extract_prefix(std::string_view target, uint32_t prefix_size, bool cased) +{ + std::vector<uint32_t> result; + result.reserve(prefix_size); + Utf8Reader reader(vespalib::stringref(target.data(), target.size())); + for (size_t pos = 0; pos < prefix_size && reader.hasMore(); ++pos) { + uint32_t code_point = reader.getChar(); + if (!cased) { + code_point = LowerCase::convert(code_point); + } + result.emplace_back(code_point); + } + return result; +} + +std::string_view +extract_suffix(std::string_view target, uint32_t prefix_size) { + Utf8Reader reader(vespalib::stringref(target.data(), target.size())); + for (size_t pos = 0; pos < prefix_size && reader.hasMore(); ++pos) { + (void) reader.getChar(); + } + std::string_view result = target; + result.remove_prefix(reader.getPos()); + return result; +} + +} + +DfaFuzzyMatcher::DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, uint32_t prefix_size, bool cased, LevenshteinDfa::DfaType dfa_type) + : _dfa(vespalib::fuzzy::LevenshteinDfa::build(extract_suffix(target, prefix_size), max_edits, (cased ? LevenshteinDfa::Casing::Cased : LevenshteinDfa::Casing::Uncased), dfa_type)), + _successor(), + _prefix(extract_prefix(target, prefix_size, cased)), + _successor_suffix(), + _prefix_size(prefix_size), + _cased(cased) +{ + _successor = _prefix; } DfaFuzzyMatcher::~DfaFuzzyMatcher() = default; +const char* +DfaFuzzyMatcher::skip_prefix(const char* word) const +{ + Utf8ReaderForZTS reader(word); + size_t pos = 0; + for (; pos < _prefix.size() && reader.hasMore(); ++pos) { + (void) reader.getChar(); + } + assert(pos == _prefix.size()); + return reader.get_current_ptr(); +} + +bool +DfaFuzzyMatcher::is_match(const char* word) const +{ + if (_prefix_size > 0) { + Utf8ReaderForZTS reader(word); + size_t pos = 0; + for (; pos < _prefix.size() && reader.hasMore(); ++pos) { + uint32_t code_point = reader.getChar(); + if (!_cased) { + code_point = LowerCase::convert(code_point); + } + if (code_point != _prefix[pos]) { + break; + } + } + if (!reader.hasMore() && pos == _prefix.size() && pos < _prefix_size) { + return true; + } + if (pos != _prefix_size) { + return false; + } + word = reader.get_current_ptr(); + } + auto match = _dfa.match(word); + return match.matches(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h index fcba13f85a4..7116b4d8662 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h +++ b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h @@ -5,6 +5,7 @@ #include "dfa_string_comparator.h" #include <vespa/vespalib/datastore/atomic_entry_ref.h> #include <vespa/vespalib/fuzzy/levenshtein_dfa.h> +#include <iostream> namespace search::attribute { @@ -17,22 +18,54 @@ namespace search::attribute { class DfaFuzzyMatcher { private: vespalib::fuzzy::LevenshteinDfa _dfa; - std::vector<uint32_t> _successor; + std::vector<uint32_t> _successor; + std::vector<uint32_t> _prefix; + std::vector<uint32_t> _successor_suffix; + uint32_t _prefix_size; + bool _cased; + const char* skip_prefix(const char* word) const; public: - DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, bool cased, vespalib::fuzzy::LevenshteinDfa::DfaType dfa_type); + DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, uint32_t prefix_size, bool cased, vespalib::fuzzy::LevenshteinDfa::DfaType dfa_type); ~DfaFuzzyMatcher(); + bool is_match(const char *word) const; + + /* + * If prefix size is nonzero then this variant of is_match() + * should only be called with words that starts with the extracted + * prefix of the target word. + * + * Caller must position iterator at right location using lower bound + * functionality in the dictionary. + */ template <typename DictionaryConstIteratorType> bool is_match(const char* word, DictionaryConstIteratorType& itr, const DfaStringComparator::DataStoreType& data_store) { - auto match = _dfa.match(word, _successor); - if (match.matches()) { - return true; + if (_prefix_size > 0) { + word = skip_prefix(word); + if (_prefix.size() < _prefix_size) { + if (*word == '\0') { + return true; + } + _successor.resize(_prefix.size()); + _successor.emplace_back(1); + } else { + auto match = _dfa.match(word, _successor_suffix); + if (match.matches()) { + return true; + } + _successor.resize(_prefix.size()); + _successor.insert(_successor.end(), _successor_suffix.begin(), _successor_suffix.end()); + } } else { - DfaStringComparator cmp(data_store, _successor); - itr.seek(vespalib::datastore::AtomicEntryRef(), cmp); - return false; + auto match = _dfa.match(word, _successor); + if (match.matches()) { + return true; + } } + DfaStringComparator cmp(data_store, _successor); + itr.seek(vespalib::datastore::AtomicEntryRef(), cmp); + return false; } }; diff --git a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp index 3c0bae00047..c840c5cbc91 100644 --- a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp +++ b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp @@ -10,9 +10,10 @@ namespace search::attribute { class DocumentWeightOrFilterSearchImpl : public DocumentWeightOrFilterSearch { AttributeIteratorPack _children; + void seek_all(uint32_t docId); public: - DocumentWeightOrFilterSearchImpl(AttributeIteratorPack&& children); - ~DocumentWeightOrFilterSearchImpl(); + explicit DocumentWeightOrFilterSearchImpl(AttributeIteratorPack&& children); + ~DocumentWeightOrFilterSearchImpl() override; void doSeek(uint32_t docId) override; @@ -32,6 +33,7 @@ public: } std::unique_ptr<BitVector> get_hits(uint32_t begin_id) override { + seek_all(getDocId()); return _children.get_hits(begin_id, getEndId()); } @@ -47,17 +49,29 @@ DocumentWeightOrFilterSearchImpl::DocumentWeightOrFilterSearchImpl(AttributeIter DocumentWeightOrFilterSearchImpl::~DocumentWeightOrFilterSearchImpl() = default; void +DocumentWeightOrFilterSearchImpl::seek_all(uint32_t docId) { + for (uint16_t i = 0; i < _children.size(); ++i) { + uint32_t next = _children.get_docid(i); + if (next < docId) { + _children.seek(i, docId); + } + } +} + +void DocumentWeightOrFilterSearchImpl::doSeek(uint32_t docId) { - if (_children.get_docid(0) < docId) { - _children.seek(0, docId); - } - uint32_t min_doc_id = _children.get_docid(0); - for (uint16_t i = 1; i < _children.size(); ++i) { - if (_children.get_docid(i) < docId) { - _children.seek(i, docId); + uint32_t min_doc_id = endDocId; + for (uint16_t i = 0; i < _children.size(); ++i) { + uint32_t next = _children.get_docid(i); + if (next < docId) { + next = _children.seek(i, docId); + } + if (next == docId) { + setDocId(next); + return; } - min_doc_id = std::min(min_doc_id, _children.get_docid(i)); + min_doc_id = std::min(min_doc_id, next); } setDocId(min_doc_id); } @@ -67,12 +81,14 @@ DocumentWeightOrFilterSearchImpl::doUnpack(uint32_t) { } -std::unique_ptr<search::queryeval::SearchIterator> +std::unique_ptr<queryeval::SearchIterator> DocumentWeightOrFilterSearch::create(std::vector<DocumentWeightIterator>&& children) { if (children.empty()) { return std::make_unique<queryeval::EmptySearch>(); } else { + std::sort(children.begin(), children.end(), + [](const auto & a, const auto & b) { return a.size() > b.size(); }); return std::make_unique<DocumentWeightOrFilterSearchImpl>(AttributeIteratorPack(std::move(children))); } } diff --git a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h index 62be883ab52..c601856573f 100644 --- a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h +++ b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h @@ -9,15 +9,15 @@ namespace search::attribute { * Filter iterator on top of document weight iterators with OR semantics used during * calculation of global filter for weighted set terms, wand terms and dot product terms. */ -class DocumentWeightOrFilterSearch : public search::queryeval::SearchIterator +class DocumentWeightOrFilterSearch : public queryeval::SearchIterator { protected: DocumentWeightOrFilterSearch() - : search::queryeval::SearchIterator() + : queryeval::SearchIterator() { } public: - static std::unique_ptr<search::queryeval::SearchIterator> create(std::vector<DocumentWeightIterator>&& children); + static std::unique_ptr<queryeval::SearchIterator> create(std::vector<DocumentWeightIterator>&& children); }; } diff --git a/searchlib/src/vespa/searchlib/attribute/iterator_pack.cpp b/searchlib/src/vespa/searchlib/attribute/iterator_pack.cpp index 147f56d6d47..ab06fc270bd 100644 --- a/searchlib/src/vespa/searchlib/attribute/iterator_pack.cpp +++ b/searchlib/src/vespa/searchlib/attribute/iterator_pack.cpp @@ -17,9 +17,9 @@ AttributeIteratorPack::or_hits_into(BitVector &result, uint32_t begin_id) { for (size_t i = 0; i < size(); ++i) { uint32_t docId = get_docid(i); if (begin_id > docId) { - seek(i, begin_id); + docId = seek(i, begin_id); } - for (docId = get_docid(i); docId < result.size(); docId = next(i)) { + for (uint32_t limit = result.size(); docId < limit; docId = next(i)) { result.setBit(docId); } } diff --git a/searchlib/src/vespa/searchlib/attribute/iterator_pack.h b/searchlib/src/vespa/searchlib/attribute/iterator_pack.h index e042aab5eae..1753a3d0c2d 100644 --- a/searchlib/src/vespa/searchlib/attribute/iterator_pack.h +++ b/searchlib/src/vespa/searchlib/attribute/iterator_pack.h @@ -41,7 +41,7 @@ public: std::unique_ptr<BitVector> get_hits(uint32_t begin_id, uint32_t end_id); void or_hits_into(BitVector &result, uint32_t begin_id); - size_t size() const { return _children.size(); } + size_t size() const noexcept { return _children.size(); } void initRange(uint32_t begin, uint32_t end) { (void) end; for (auto &child: _children) { diff --git a/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.h b/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.h index 05ccedb39ec..6f148d1d5ba 100644 --- a/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.h +++ b/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.h @@ -320,11 +320,7 @@ StringPostingSearchContext<BaseSC, AttrT, DataT>::use_dictionary_entry(PostingLi ++it; return false; } else if (this->isFuzzy()) { - if (this->getFuzzyMatcher().isMatch(_enumStore.get_value(it.getKey().load_acquire()))) { - return true; - } - ++it; - return false; + return this->is_fuzzy_match(_enumStore.get_value(it.getKey().load_acquire()), it, _enumStore.get_data_store()); } return true; } diff --git a/searchlib/src/vespa/searchlib/attribute/string_matcher.h b/searchlib/src/vespa/searchlib/attribute/string_matcher.h index 05089e1251a..09ba813cefe 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_matcher.h +++ b/searchlib/src/vespa/searchlib/attribute/string_matcher.h @@ -32,6 +32,11 @@ protected: const vespalib::Regex& getRegex() const { return _helper.getRegex(); } const vespalib::FuzzyMatcher& getFuzzyMatcher() const { return _helper.getFuzzyMatcher(); } const QueryTermUCS4* get_query_term_ptr() const noexcept { return _query_term.get(); } + + template <typename DictionaryConstIteratorType> + bool is_fuzzy_match(const char* word, DictionaryConstIteratorType& itr, const DfaStringComparator::DataStoreType& data_store) const { + return _helper.is_fuzzy_match(word, itr, data_store); + } }; } diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp b/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp index 1efe39667b8..aec317926f1 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp +++ b/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "string_search_helper.h" +#include "dfa_fuzzy_matcher.h" +#include "i_enum_store_dictionary.h" #include <vespa/searchlib/query/query_term_ucs4.h> #include <vespa/vespalib/text/lowercase.h> #include <vespa/vespalib/text/utf8.h> @@ -12,6 +14,7 @@ namespace search::attribute { StringSearchHelper::StringSearchHelper(QueryTermUCS4 & term, bool cased, vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm) : _regex(), _fuzzyMatcher(), + _dfa_fuzzy_matcher(), _term(), _termLen(), _isPrefix(term.isPrefix()), @@ -24,12 +27,20 @@ StringSearchHelper::StringSearchHelper(QueryTermUCS4 & term, bool cased, vespali ? vespalib::Regex::from_pattern(term.getTerm(), vespalib::Regex::Options::None) : vespalib::Regex::from_pattern(term.getTerm(), vespalib::Regex::Options::IgnoreCase); } else if (isFuzzy()) { - (void) fuzzy_matching_algorithm; - // TODO: Select implementation based on algorithm. _fuzzyMatcher = std::make_unique<vespalib::FuzzyMatcher>(term.getTerm(), term.getFuzzyMaxEditDistance(), term.getFuzzyPrefixLength(), isCased()); + using FMA = vespalib::FuzzyMatchingAlgorithm; + using LDT = vespalib::fuzzy::LevenshteinDfa::DfaType; + if ((fuzzy_matching_algorithm != FMA::BruteForce) && + (term.getFuzzyMaxEditDistance() <= 2)) { + _dfa_fuzzy_matcher = std::make_unique<DfaFuzzyMatcher>(term.getTerm(), + term.getFuzzyMaxEditDistance(), + term.getFuzzyPrefixLength(), + isCased(), + (fuzzy_matching_algorithm == FMA::DfaImplicit) ? LDT::Implicit : LDT::Explicit); + } } else if (isCased()) { _term = term.getTerm(); _termLen = strlen(_term); @@ -48,7 +59,7 @@ StringSearchHelper::isMatch(const char *src) const noexcept { return getRegex().valid() && getRegex().partial_match(std::string_view(src)); } if (__builtin_expect(isFuzzy(), false)) { - return getFuzzyMatcher().isMatch(src); + return _dfa_fuzzy_matcher ? _dfa_fuzzy_matcher->is_match(src) : getFuzzyMatcher().isMatch(src); } if (__builtin_expect(isCased(), false)) { int res = strncmp(_term, src, _termLen); @@ -67,4 +78,27 @@ StringSearchHelper::isMatch(const char *src) const noexcept { return (_ucs4[j] == 0 && (val == 0 || isPrefix())); } +template <typename DictionaryConstIteratorType> +bool +StringSearchHelper::is_fuzzy_match(const char* word, DictionaryConstIteratorType& itr, const DfaStringComparator::DataStoreType& data_store) const +{ + if (_dfa_fuzzy_matcher) { + return _dfa_fuzzy_matcher->is_match(word, itr, data_store); + } else { + if (_fuzzyMatcher->isMatch(word)) { + return true; + } + ++itr; + return false; + } +} + +template +bool +StringSearchHelper::is_fuzzy_match(const char*, EnumPostingTree::ConstIterator&, const DfaStringComparator::DataStoreType&) const; + +template +bool +StringSearchHelper::is_fuzzy_match(const char*, EnumTree::ConstIterator&, const DfaStringComparator::DataStoreType&) const; + } diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_helper.h b/searchlib/src/vespa/searchlib/attribute/string_search_helper.h index 0e7a116a874..e59291e24a7 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_helper.h +++ b/searchlib/src/vespa/searchlib/attribute/string_search_helper.h @@ -2,6 +2,7 @@ #pragma once +#include "dfa_string_comparator.h" #include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> #include <vespa/vespalib/regex/regex.h> @@ -10,6 +11,8 @@ namespace search { class QueryTermUCS4; } namespace search::attribute { +class DfaFuzzyMatcher; + /** * Helper class for search context when scanning string fields * It handles different search settings like prefix, regex and cased/uncased. @@ -29,11 +32,16 @@ public: bool isCased() const noexcept { return _isCased; } bool isFuzzy() const noexcept { return _isFuzzy; } const vespalib::Regex & getRegex() const noexcept { return _regex; } - const FuzzyMatcher & getFuzzyMatcher() const noexcept { return *_fuzzyMatcher; } + const FuzzyMatcher& getFuzzyMatcher() const noexcept { return *_fuzzyMatcher; } + + template <typename DictionaryConstIteratorType> + bool is_fuzzy_match(const char* word, DictionaryConstIteratorType& itr, const DfaStringComparator::DataStoreType& data_store) const; + private: using ucs4_t = uint32_t; vespalib::Regex _regex; std::unique_ptr<FuzzyMatcher> _fuzzyMatcher; + std::unique_ptr<DfaFuzzyMatcher> _dfa_fuzzy_matcher; std::unique_ptr<ucs4_t[]> _ucs4; const char * _term; uint32_t _termLen; // measured in bytes diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp index 3f6085ef7ff..94d1a4917fd 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp @@ -621,7 +621,7 @@ IntermediateBlueprint::fetchPostings(const ExecuteInfo &execInfo) double nextHitRate = execInfo.hitRate(); for (size_t i = 0; i < _children.size(); ++i) { Blueprint & child = *_children[i]; - child.fetchPostings(ExecuteInfo::create(execInfo.isStrict() && inheritStrict(i), nextHitRate)); + child.fetchPostings(ExecuteInfo::create(execInfo.isStrict() && inheritStrict(i), nextHitRate, execInfo.getDoom())); nextHitRate = computeNextHitRate(child, nextHitRate); } } diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp index 795f5f1424a..4322cafb5c8 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp @@ -66,7 +66,7 @@ DotProductBlueprint::createFilterSearch(bool strict, FilterConstraint constraint void DotProductBlueprint::fetchPostings(const ExecuteInfo &execInfo) { - ExecuteInfo childInfo = ExecuteInfo::create(true, execInfo.hitRate()); + ExecuteInfo childInfo = ExecuteInfo::create(true, execInfo); for (size_t i = 0; i < _terms.size(); ++i) { _terms[i]->fetchPostings(childInfo); } diff --git a/searchlib/src/vespa/searchlib/queryeval/executeinfo.cpp b/searchlib/src/vespa/searchlib/queryeval/executeinfo.cpp index 604e20d2262..e5d20f047f5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/executeinfo.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/executeinfo.cpp @@ -4,17 +4,7 @@ namespace search::queryeval { -const ExecuteInfo ExecuteInfo::TRUE(true, 1.0); -const ExecuteInfo ExecuteInfo::FALSE(false, 1.0); - -ExecuteInfo -ExecuteInfo::create(bool strict) { - return create(strict, 1.0); -} - -ExecuteInfo -ExecuteInfo::create(bool strict, double hitRate) { - return ExecuteInfo(strict, hitRate); -} +const ExecuteInfo ExecuteInfo::TRUE(true, 1.0, nullptr); +const ExecuteInfo ExecuteInfo::FALSE(false, 1.0, nullptr); } diff --git a/searchlib/src/vespa/searchlib/queryeval/executeinfo.h b/searchlib/src/vespa/searchlib/queryeval/executeinfo.h index e161b2bdab7..2dd34284bef 100644 --- a/searchlib/src/vespa/searchlib/queryeval/executeinfo.h +++ b/searchlib/src/vespa/searchlib/queryeval/executeinfo.h @@ -1,8 +1,9 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -// Copyright 2019 Oath inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <vespa/vespalib/util/doom.h> + namespace search::queryeval { /** @@ -11,20 +12,37 @@ namespace search::queryeval { */ class ExecuteInfo { public: - ExecuteInfo() : ExecuteInfo(false, 1.0) { } - bool isStrict() const { return _strict; } - double hitRate() const { return _hitRate; } + ExecuteInfo() noexcept : ExecuteInfo(false, 1.0F, nullptr) { } + bool isStrict() const noexcept { return _strict; } + float hitRate() const noexcept { return _hitRate; } + bool soft_doom() const noexcept { return _doom && _doom->soft_doom(); } + const vespalib::Doom * getDoom() const { return _doom; } static const ExecuteInfo TRUE; static const ExecuteInfo FALSE; - static ExecuteInfo create(bool strict); - static ExecuteInfo create(bool strict, double HitRate); + static ExecuteInfo create(bool strict, const ExecuteInfo & org) noexcept { + return {strict, org._hitRate, org.getDoom()}; + } + static ExecuteInfo create(bool strict, const vespalib::Doom * doom) noexcept { + return create(strict, 1.0F, doom); + } + static ExecuteInfo create(bool strict, float hitRate, const vespalib::Doom * doom) noexcept { + return {strict, hitRate, doom}; + } + static ExecuteInfo create(bool strict) noexcept { + return create(strict, 1.0F); + } + static ExecuteInfo create(bool strict, float hitRate) noexcept { + return create(strict, hitRate, nullptr); + } private: - ExecuteInfo(bool strict, double hitRate_in) - : _hitRate(hitRate_in), + ExecuteInfo(bool strict, float hitRate_in, const vespalib::Doom * doom) noexcept + : _doom(doom), + _hitRate(hitRate_in), _strict(strict) { } - double _hitRate; - bool _strict; + const vespalib::Doom * _doom; + float _hitRate; + bool _strict; }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp index 9c3910b20f9..16461487525 100644 --- a/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/same_element_blueprint.cpp @@ -57,7 +57,7 @@ void SameElementBlueprint::fetchPostings(const ExecuteInfo &execInfo) { for (size_t i = 0; i < _terms.size(); ++i) { - _terms[i]->fetchPostings(ExecuteInfo::create(execInfo.isStrict() && (i == 0), execInfo.hitRate())); + _terms[i]->fetchPostings(ExecuteInfo::create(execInfo.isStrict() && (i == 0), execInfo)); } } diff --git a/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp index 48a09f099a6..eb6241a99d5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp @@ -4,7 +4,6 @@ #include "wand_parts.h" #include "parallel_weak_and_search.h" #include <vespa/searchlib/queryeval/field_spec.hpp> -#include <vespa/searchlib/queryeval/emptysearch.h> #include <vespa/searchlib/queryeval/searchiterator.h> #include <vespa/searchlib/fef/termfieldmatchdata.h> #include <vespa/vespalib/objects/visit.hpp> @@ -77,10 +76,10 @@ ParallelWeakAndBlueprint::createLeafSearch(const search::fef::TermFieldMatchData const State &childState = _terms[i]->getState(); assert(childState.numFields() == 1); // TODO: pass ownership with unique_ptr - terms.push_back(wand::Term(_terms[i]->createSearch(*childrenMatchData, true).release(), - _weights[i], - childState.estimate().estHits, - childState.field(0).resolve(*childrenMatchData))); + terms.emplace_back(_terms[i]->createSearch(*childrenMatchData, true).release(), + _weights[i], + childState.estimate().estHits, + childState.field(0).resolve(*childrenMatchData)); } return SearchIterator::UP (ParallelWeakAndSearch::create(terms, @@ -101,9 +100,9 @@ ParallelWeakAndBlueprint::createFilterSearch(bool strict, FilterConstraint const void ParallelWeakAndBlueprint::fetchPostings(const ExecuteInfo & execInfo) { - ExecuteInfo childInfo = ExecuteInfo::create(true, execInfo.hitRate()); - for (size_t i = 0; i < _terms.size(); ++i) { - _terms[i]->fetchPostings(childInfo); + ExecuteInfo childInfo = ExecuteInfo::create(true, execInfo); + for (const auto & _term : _terms) { + _term->fetchPostings(childInfo); } } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp index 4e06f170253..97f6bc2e6f8 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp @@ -33,10 +33,8 @@ WeightedSetTermMatchingElementsSearch::WeightedSetTermMatchingElementsSearch(con _search() { _tfmda.add(&_tfmd); - auto generic_search = bp.createLeafSearch(_tfmda, false); - auto weighted_set_term_search = dynamic_cast<WeightedSetTermSearch *>(generic_search.get()); - generic_search.release(); - _search.reset(weighted_set_term_search); + _search.reset(static_cast<WeightedSetTermSearch *>(bp.createLeafSearch(_tfmda, false).release())); + } WeightedSetTermMatchingElementsSearch::~WeightedSetTermMatchingElementsSearch() = default; @@ -120,16 +118,16 @@ WeightedSetTermBlueprint::create_matching_elements_search(const MatchingElements if (fields.has_field(_children_field.getName())) { return std::make_unique<WeightedSetTermMatchingElementsSearch>(*this, _children_field.getName(), _terms); } else { - return std::unique_ptr<MatchingElementsSearch>(); + return {}; } } void WeightedSetTermBlueprint::fetchPostings(const ExecuteInfo &execInfo) { - ExecuteInfo childInfo = ExecuteInfo::create(true, execInfo.hitRate()); - for (size_t i = 0; i < _terms.size(); ++i) { - _terms[i]->fetchPostings(childInfo); + ExecuteInfo childInfo = ExecuteInfo::create(true, execInfo); + for (const auto & _term : _terms) { + _term->fetchPostings(childInfo); } } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h index 0e3c82444d7..9c8d6d88329 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h @@ -18,7 +18,7 @@ class WeightedSetTermBlueprint : public ComplexLeafBlueprint std::vector<Blueprint::UP> _terms; public: - WeightedSetTermBlueprint(const FieldSpec &field); + explicit WeightedSetTermBlueprint(const FieldSpec &field); WeightedSetTermBlueprint(const WeightedSetTermBlueprint &) = delete; WeightedSetTermBlueprint &operator=(const WeightedSetTermBlueprint &) = delete; ~WeightedSetTermBlueprint() override; diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp index ee3978705cf..8478a0d3c35 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp @@ -21,7 +21,7 @@ private: struct CmpDocId { const uint32_t *termPos; - CmpDocId(const uint32_t *tp) : termPos(tp) {} + explicit CmpDocId(const uint32_t *tp) : termPos(tp) {} bool operator()(const ref_t &a, const ref_t &b) const { return (termPos[a] < termPos[b]); } @@ -29,7 +29,7 @@ private: struct CmpWeight { const int32_t *weight; - CmpWeight(const int32_t *w) : weight(w) {} + explicit CmpWeight(const int32_t *w) : weight(w) {} bool operator()(const ref_t &a, const ref_t &b) const { return (weight[a] > weight[b]); } @@ -61,7 +61,7 @@ private: } public: - WeightedSetTermSearchImpl(search::fef::TermFieldMatchData &tmd, + WeightedSetTermSearchImpl(fef::TermFieldMatchData &tmd, bool field_is_filter, const std::vector<int32_t> &weights, IteratorPack &&iteratorPack) @@ -180,7 +180,7 @@ WeightedSetTermSearch::create(const std::vector<SearchIterator *> &children, //----------------------------------------------------------------------------- SearchIterator::UP -WeightedSetTermSearch::create(search::fef::TermFieldMatchData &tmd, +WeightedSetTermSearch::create(fef::TermFieldMatchData &tmd, bool field_is_filter, const std::vector<int32_t> &weights, std::vector<DocumentWeightIterator> &&iterators) diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h index e3e12c27f28..b30d3bc3301 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.h @@ -10,12 +10,9 @@ #include <memory> #include <vector> -namespace search { -namespace fef { -class TermFieldMatchData; -} // namespace fef +namespace search::fef { class TermFieldMatchData; } -namespace queryeval { +namespace search::queryeval { class Blueprint; @@ -26,7 +23,7 @@ class Blueprint; class WeightedSetTermSearch : public SearchIterator { protected: - WeightedSetTermSearch() {} + WeightedSetTermSearch() = default; public: // TODO: pass ownership with unique_ptr @@ -47,6 +44,4 @@ public: virtual void find_matching_elements(uint32_t docid, const std::vector<std::unique_ptr<Blueprint>> &child_blueprints, std::vector<uint32_t> &dst) = 0; }; -} // namespace search::queryeval -} // namespace search - +} diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index c1d7e17b457..56dcd9abdf6 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -96,6 +96,7 @@ vespa_define_module( src/tests/fileheader src/tests/floatingpointtype src/tests/fuzzy + src/tests/fuzzy/table_dfa src/tests/gencnt src/tests/growablebytebuffer src/tests/guard diff --git a/vespalib/src/tests/fuzzy/levenshtein_dfa_test.cpp b/vespalib/src/tests/fuzzy/levenshtein_dfa_test.cpp index c235cb99509..02df6edc370 100644 --- a/vespalib/src/tests/fuzzy/levenshtein_dfa_test.cpp +++ b/vespalib/src/tests/fuzzy/levenshtein_dfa_test.cpp @@ -82,7 +82,8 @@ INSTANTIATE_TEST_SUITE_P(AllCasingAndDfaTypes, Combine(Values(LevenshteinDfa::Casing::Uncased, LevenshteinDfa::Casing::Cased), Values(LevenshteinDfa::DfaType::Explicit, - LevenshteinDfa::DfaType::Implicit)), + LevenshteinDfa::DfaType::Implicit, + LevenshteinDfa::DfaType::Table)), LevenshteinDfaTest::stringify_params); // Same as existing non-DFA Levenshtein tests, but with some added instantiations @@ -233,7 +234,8 @@ struct LevenshteinDfaCasingTest : TestWithParam<LevenshteinDfa::DfaType> { INSTANTIATE_TEST_SUITE_P(AllDfaTypes, LevenshteinDfaCasingTest, Values(LevenshteinDfa::DfaType::Explicit, - LevenshteinDfa::DfaType::Implicit), + LevenshteinDfa::DfaType::Implicit, + LevenshteinDfa::DfaType::Table), PrintToStringParamName()); TEST_P(LevenshteinDfaCasingTest, uncased_edge_cases_have_correct_edit_distance) { @@ -315,7 +317,8 @@ INSTANTIATE_TEST_SUITE_P(SupportedMaxEdits, Combine(Values(LevenshteinDfa::Casing::Uncased, LevenshteinDfa::Casing::Cased), Values(LevenshteinDfa::DfaType::Explicit, - LevenshteinDfa::DfaType::Implicit), + LevenshteinDfa::DfaType::Implicit, + LevenshteinDfa::DfaType::Table), Values(1, 2)), LevenshteinDfaSuccessorTest::stringify_params); diff --git a/vespalib/src/tests/fuzzy/table_dfa/CMakeLists.txt b/vespalib/src/tests/fuzzy/table_dfa/CMakeLists.txt new file mode 100644 index 00000000000..1017ac99564 --- /dev/null +++ b/vespalib/src/tests/fuzzy/table_dfa/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_fuzzy_table_dfa_test_app TEST + SOURCES + table_dfa_test.cpp + DEPENDS + vespalib + GTest::GTest + ) +vespa_add_test(NAME vespalib_fuzzy_table_dfa_test_app COMMAND vespalib_fuzzy_table_dfa_test_app) diff --git a/vespalib/src/tests/fuzzy/table_dfa/table_dfa_test.cpp b/vespalib/src/tests/fuzzy/table_dfa/table_dfa_test.cpp new file mode 100644 index 00000000000..7782a39c3c7 --- /dev/null +++ b/vespalib/src/tests/fuzzy/table_dfa/table_dfa_test.cpp @@ -0,0 +1,298 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/fuzzy/table_dfa.hpp> +#include <vespa/vespalib/gtest/gtest.h> +#include <set> + +using namespace ::testing; +using namespace vespalib::fuzzy; + +// test/experiment with low-level concepts underlying the construction +// of the tables used in the table-driven dfa implementation. + +TEST(TableDfaTest, position) { + Position pos1 = Position::start(); + EXPECT_EQ(pos1.index, 0); + EXPECT_EQ(pos1.edits, 0); + Position pos2(2, 3); + EXPECT_EQ(pos2.index, 2); + EXPECT_EQ(pos2.edits, 3); +} + +TEST(TableDfaTest, position_equality) { + Position pos1(0, 0); + Position pos2(0, 1); + Position pos3(1, 0); + EXPECT_TRUE(pos1 == pos1); + EXPECT_FALSE(pos1 == pos2); + EXPECT_FALSE(pos1 == pos2); +} + +TEST(TableDfaTest, position_sort_order) { + std::vector<Position> list; + list.emplace_back(0,1); + list.emplace_back(0,0); + list.emplace_back(1,0); + list.emplace_back(1,1); + std::sort(list.begin(), list.end()); + EXPECT_EQ(list[0].index, 0); + EXPECT_EQ(list[0].edits, 0); + EXPECT_EQ(list[1].index, 1); + EXPECT_EQ(list[1].edits, 0); + EXPECT_EQ(list[2].index, 0); + EXPECT_EQ(list[2].edits, 1); + EXPECT_EQ(list[3].index, 1); + EXPECT_EQ(list[3].edits, 1); +} + +TEST(TableDfaTest, position_subsumption) { + Position pos1(0, 0); + Position pos2(0, 1); + Position pos3(0, 2); + + Position pos4(1, 0); + Position pos5(1, 1); + Position pos6(1, 2); + + Position pos7(2, 0); + Position pos8(2, 1); + Position pos9(2, 2); + + EXPECT_FALSE(pos1.subsumes(pos1)); + EXPECT_TRUE(pos1.subsumes(pos2)); + EXPECT_TRUE(pos1.subsumes(pos3)); + EXPECT_FALSE(pos1.subsumes(pos4)); + EXPECT_TRUE(pos1.subsumes(pos5)); + EXPECT_TRUE(pos1.subsumes(pos6)); + EXPECT_FALSE(pos1.subsumes(pos7)); + EXPECT_FALSE(pos1.subsumes(pos8)); + EXPECT_TRUE(pos1.subsumes(pos9)); + + EXPECT_FALSE(pos5.subsumes(pos1)); + EXPECT_FALSE(pos5.subsumes(pos2)); + EXPECT_TRUE(pos5.subsumes(pos3)); + EXPECT_FALSE(pos5.subsumes(pos4)); + EXPECT_FALSE(pos5.subsumes(pos5)); + EXPECT_TRUE(pos5.subsumes(pos6)); + EXPECT_FALSE(pos5.subsumes(pos7)); + EXPECT_FALSE(pos5.subsumes(pos8)); + EXPECT_TRUE(pos5.subsumes(pos9)); +} + +TEST(TableDfaTest, position_materialization) { + EXPECT_EQ(Position(1,1).materialize(0).index, 0); + EXPECT_EQ(Position(1,1).materialize(1).index, 1); + EXPECT_EQ(Position(1,1).materialize(2).index, 2); + EXPECT_EQ(Position(1,1).materialize(0).edits, 2); + EXPECT_EQ(Position(1,1).materialize(1).edits, 1); + EXPECT_EQ(Position(1,1).materialize(2).edits, 2); +} + +TEST(TableDfaTest, position_to_string) { + Position pos1(0, 0); + Position pos2(1, 2); + Position pos3(2, 3); + EXPECT_EQ(pos1.to_string(), fmt("0#0")); + EXPECT_EQ(pos2.to_string(), fmt("1#2")); + EXPECT_EQ(pos3.to_string(), fmt("2#3")); +} + +TEST(TableDfaTest, state_creation_reorder) { + EXPECT_EQ(State::create<5>({{0,1},{2,0}}).to_string(), fmt("{2#0,0#1}")); + EXPECT_EQ(State::create<5>({{2,0},{0,0}}).to_string(), fmt("{0#0,2#0}")); +} + +TEST(TableDfaTest, state_creation_duplicate_removal) { + EXPECT_EQ(State::create<5>({{0,0},{0,0},{2,1},{2,1}}).to_string(), fmt("{0#0,2#1}")); +} + +TEST(TableDfaTest, state_creation_edit_cutoff) { + EXPECT_EQ(State::create<2>({{0,0},{5,2},{10,3}}).to_string(), fmt("{0#0,5#2}")); +} + +TEST(TableDfaTest, state_creation_subsumption_collapsing) { + EXPECT_EQ(State::create<2>({{0,0},{1,1}}).to_string(), fmt("{0#0}")); + EXPECT_EQ(State::create<2>({{0,1},{1,0}}).to_string(), fmt("{1#0}")); + EXPECT_EQ(State::create<2>({{0,0},{2,2}}).to_string(), fmt("{0#0}")); + EXPECT_EQ(State::create<2>({{0,2},{2,0}}).to_string(), fmt("{2#0}")); +} + +TEST(TableDfaTest, state_normalization) { + auto state1 = State::create<2>({{2,1},{3,1}}); + auto state2 = State::create<2>({{5,0},{3,1}}); + EXPECT_EQ(state1.to_string(), fmt("{2#1,3#1}")); + EXPECT_EQ(state2.to_string(), fmt("{5#0,3#1}")); + EXPECT_EQ(state1.normalize(), 2); + EXPECT_EQ(state2.normalize(), 3); + EXPECT_EQ(state1.to_string(), fmt("{0#1,1#1}")); + EXPECT_EQ(state2.to_string(), fmt("{2#0,0#1}")); +} + +TEST(TableDfaTest, state_repo) { + StateRepo repo; + EXPECT_EQ(repo.state_to_idx(State::failed()), 0); + EXPECT_EQ(repo.state_to_idx(State::start()), 1); + EXPECT_EQ(repo.state_to_idx(State::create<2>({{0,0},{1,0}})), 2); + EXPECT_EQ(repo.state_to_idx(State::create<2>({{0,0},{2,1}})), 3); + EXPECT_EQ(repo.state_to_idx(State::create<2>({{0,0},{1,0}})), 2); + EXPECT_EQ(repo.state_to_idx(State::create<2>({{0,0},{2,1}})), 3); + EXPECT_EQ(repo.size(), 4); + EXPECT_EQ(repo.idx_to_state(0).to_string(), fmt("{}")); + EXPECT_EQ(repo.idx_to_state(1).to_string(), fmt("{0#0}")); + EXPECT_EQ(repo.idx_to_state(2).to_string(), fmt("{0#0,1#0}")); + EXPECT_EQ(repo.idx_to_state(3).to_string(), fmt("{0#0,2#1}")); +} + +TEST(TableDfaTest, expand_bits) { + auto yes = expand_bits<2>(0x1f); + auto no = expand_bits<2>(0x00); + auto odd = expand_bits<2>(0x0a); + auto even = expand_bits<2>(0x15); + ASSERT_EQ(yes.size(), 5); + ASSERT_EQ(no.size(), 5); + ASSERT_EQ(odd.size(), 5); + ASSERT_EQ(even.size(), 5); + for (size_t i = 0; i < 5; ++i) { + EXPECT_TRUE(yes[i]); + EXPECT_FALSE(no[i]); + EXPECT_EQ(odd[i], bool(i % 2 == 1)); + EXPECT_EQ(even[i], bool(i % 2 == 0)); + } +} + +TEST(TableDfaTest, format_bits) { + EXPECT_EQ(format_vector(expand_bits<1>(0)), fmt("[0,0,0]")); + EXPECT_EQ(format_vector(expand_bits<1>(7)), fmt("[1,1,1]")); + EXPECT_EQ(format_vector(expand_bits<1>(5)), fmt("[1,0,1]")); + EXPECT_EQ(format_vector(expand_bits<1>(2)), fmt("[0,1,0]")); + EXPECT_EQ(format_vector(expand_bits<2>(31)), fmt("[1,1,1,1,1]")); + EXPECT_EQ(format_vector(expand_bits<2>(21)), fmt("[1,0,1,0,1]")); + EXPECT_EQ(format_vector(expand_bits<2>(31), true), fmt("11111")); + EXPECT_EQ(format_vector(expand_bits<2>(21), true), fmt("10101")); +} + +template <uint8_t N> +void list_states(bool count_only = false) { + auto repo = make_state_repo<N>(); + EXPECT_EQ(num_states<N>(), repo.size()); + fprintf(stderr, "max_edits: %u, number of states: %zu\n", N, repo.size()); + if (!count_only) { + for (uint32_t i = 0; i < repo.size(); ++i) { + fprintf(stderr, " state %u: %s\n", i, repo.idx_to_state(i).to_string().c_str()); + } + } +} + +TEST(TableDfaTest, list_states_for_max_edits_1) { list_states<1>(); } +TEST(TableDfaTest, list_states_for_max_edits_2) { list_states<2>(); } +TEST(TableDfaTest, count_states_for_max_edits_3) { list_states<3>(true); } + +template <uint8_t N> +void list_edits() { + auto repo = make_state_repo<N>(); + fprintf(stderr, + "per state, listing the minimal number of edits needed\n" + "to reach offsets at and beyond its minimal boundary\n"); + for (uint32_t i = 0; i < repo.size(); ++i) { + const State &state = repo.idx_to_state(i); + fprintf(stderr, "%-23s : %s\n", state.to_string().c_str(), + format_vector(state.make_edit_vector<N>()).c_str()); + } +} + +TEST(TableDfaTest, list_edits_at_input_end_for_max_edits_1) { list_edits<1>(); } +TEST(TableDfaTest, list_edits_at_input_end_for_max_edits_2) { list_edits<2>(); } + +template <uint8_t N> +void list_transitions() { + auto repo = make_state_repo<N>(); + for (uint32_t idx = 0; idx < repo.size(); ++idx) { + const State &state = repo.idx_to_state(idx); + for (uint32_t i = 0; i < num_transitions<N>(); ++i) { + auto bits = expand_bits<N>(i); + State new_state = state.next<N>(bits); + uint32_t step = new_state.normalize(); + uint32_t new_idx = repo.state_to_idx(new_state); + ASSERT_LT(new_idx, repo.size()); + fprintf(stderr, "%u:%s,i --%s--> %u:%s,%s\n", idx, state.to_string().c_str(), + format_vector(bits).c_str(), new_idx, new_state.to_string().c_str(), + (step == 0) ? "i" : fmt("i+%u", step).c_str()); + } + } +} + +TEST(TableDfaTest, list_transitions_for_max_edits_1) { list_transitions<1>(); } + +// Simulate all possible ways we can approach the end of the word we +// are matching. Verify that no transition taken can produce a state +// with a minimal boundary that exceeds the boundary of the word +// itself. Verifying this will enable us to not care about word size +// while simulating the dfa. +template <uint8_t N> +void verify_word_end_boundary() { + auto repo = make_state_repo<N>(); + using StateSet = std::set<uint32_t>; + std::vector<StateSet> active(window_size<N>() + 1); + for (size_t i = 1; i < repo.size(); ++i) { + active[0].insert(i); + } + EXPECT_EQ(active.size(), window_size<N>() + 1); + EXPECT_EQ(active[0].size(), repo.size() - 1); + fprintf(stderr, "verifying word end for max edits %u\n", N); + uint32_t edge_shape = 0; + for (uint32_t active_idx = 0; active_idx < active.size(); ++active_idx) { + fprintf(stderr, " edge shape: %s, max step: %zu, active_states: %zu\n", + format_vector(expand_bits<N>(edge_shape)).c_str(), active.size() - active_idx - 1, active[active_idx].size()); + for (uint32_t idx: active[active_idx]) { + const State &state = repo.idx_to_state(idx); + for (uint32_t i = 0; i < num_transitions<N>(); ++i) { + if ((i & edge_shape) == 0) { + State new_state = state.next<N>(expand_bits<N>(i)); + uint32_t step = new_state.normalize(); + uint32_t new_idx = repo.state_to_idx(new_state); + ASSERT_LT(new_idx, repo.size()); + if (new_idx != 0) { + ASSERT_GT(active.size(), active_idx + step); + active[active_idx + step].insert(new_idx); + } + } + } + } + edge_shape = (edge_shape << 1) + 1; + } + EXPECT_EQ(edge_shape, (1 << (window_size<N>() + 1)) - 1); + while (!active.back().empty()) { + fprintf(stderr, " residue states after word end: %zu\n", active.back().size()); + StateSet residue; + for (uint32_t idx: active.back()) { + const State &state = repo.idx_to_state(idx); + State new_state = state.next<N>(expand_bits<N>(0)); + uint32_t step = new_state.normalize(); + uint32_t new_idx = repo.state_to_idx(new_state); + ASSERT_LT(new_idx, repo.size()); + ASSERT_EQ(step, 0); + if (new_idx != 0) { + residue.insert(new_idx); + } + } + active.back() = std::move(residue); + } +} + +TEST(TableDfaTest, minimal_boundary_will_never_exceed_word_end_with_max_edits_1) { + verify_word_end_boundary<1>(); +} + +TEST(TableDfaTest, minimal_boundary_will_never_exceed_word_end_with_max_edits_2) { + verify_word_end_boundary<2>(); +} + +TEST(TableDfaTest, graphviz_for_food_with_max_edits_1) { + auto dfa = LevenshteinDfa::build("food", 1, LevenshteinDfa::Casing::Cased, LevenshteinDfa::DfaType::Table); + std::ostringstream out; + dfa.dump_as_graphviz(out); + fprintf(stderr, "memory usage: %zu\n", dfa.memory_usage()); + fprintf(stderr, "%s", out.str().c_str()); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/memorydatastore/memorydatastore.cpp b/vespalib/src/tests/memorydatastore/memorydatastore.cpp index 1d49b0af91b..649bd45a541 100644 --- a/vespalib/src/tests/memorydatastore/memorydatastore.cpp +++ b/vespalib/src/tests/memorydatastore/memorydatastore.cpp @@ -6,17 +6,7 @@ using namespace vespalib; -class MemoryDataStoreTest : public vespalib::TestApp -{ -private: - void testMemoryDataStore(); - void testVariableSizeVector(); -public: - int Main() override; -}; - -void -MemoryDataStoreTest::testMemoryDataStore() +TEST("testMemoryDataStore") { MemoryDataStore s(alloc::Alloc::alloc(256)); std::vector<MemoryDataStore::Reference> v; @@ -28,45 +18,9 @@ MemoryDataStoreTest::testMemoryDataStore() v.push_back(s.push_back("mumbo", 5)); EXPECT_EQUAL(52ul, v.size()); EXPECT_NOT_EQUAL(static_cast<const char *>(v[50].data()) + 5, v[51].data()); - for (size_t i(0); i < v.size(); i++) { - EXPECT_EQUAL(0, memcmp("mumbo", v[i].data(), 5)); + for (auto & i : v) { + EXPECT_EQUAL(0, memcmp("mumbo", i.data(), 5)); } } -void -MemoryDataStoreTest::testVariableSizeVector() -{ - VariableSizeVector v(20000, 5*20000); - for (size_t i(0); i < 10000; i++) { - asciistream os; - os << i; - v.push_back(os.str().data(), os.str().size()); - } - for (size_t i(0); i < v.size(); i++) { - asciistream os; - os << i; - EXPECT_EQUAL(os.str().size(), v[i].size()); - EXPECT_EQUAL(0, memcmp(os.str().data(), v[i].data(), os.str().size())); - } - size_t i(0); - for (auto it(v.begin()), mt(v.end()); it != mt; it++, i++) { - asciistream os; - os << i; - EXPECT_EQUAL(os.str().size(), it->size()); - EXPECT_EQUAL(0, memcmp(os.str().data(), (*it).data(), os.str().size())); - } - -} - -int -MemoryDataStoreTest::Main() -{ - TEST_INIT("data_test"); - testMemoryDataStore(); - testVariableSizeVector(); - - TEST_DONE(); -} - -TEST_APPHOOK(MemoryDataStoreTest); - +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/stllike/hash_test.cpp b/vespalib/src/tests/stllike/hash_test.cpp index ae27d2dc58b..dc6e48100a9 100644 --- a/vespalib/src/tests/stllike/hash_test.cpp +++ b/vespalib/src/tests/stllike/hash_test.cpp @@ -494,6 +494,14 @@ TEST("test hash set initializer list - empty") EXPECT_EQUAL(0u, s.size()); } +TEST("empty hash_set can be looked up") +{ + IntHashSet s; + EXPECT_EQUAL(0u, s.size()); + EXPECT_EQUAL(1u, s.capacity()); + EXPECT_TRUE(s.find(1) == s.end()); +} + TEST("test hash set initializer list - 1 element") { IntHashSet s = {1}; diff --git a/vespalib/src/vespa/vespalib/btree/btreeiterator.h b/vespalib/src/vespa/vespalib/btree/btreeiterator.h index 2418da18c23..9480e5880e0 100644 --- a/vespalib/src/vespa/vespalib/btree/btreeiterator.h +++ b/vespalib/src/vespa/vespalib/btree/btreeiterator.h @@ -302,22 +302,22 @@ public: /** * Get key at current iterator location. */ - const KeyType & getKey() const { return _leaf.getKey(); } + const KeyType & getKey() const noexcept { return _leaf.getKey(); } /** * Get data at current iterator location. */ - const DataType & getData() const { return _leaf.getData(); } + const DataType & getData() const noexcept { return _leaf.getData(); } /** * Check if iterator is at a valid element, i.e. not at end. */ - bool valid() const { return _leaf.valid(); } + bool valid() const noexcept{ return _leaf.valid(); } /** * Return the number of elements in the tree. */ - size_t size() const; + size_t size() const noexcept; /** @@ -333,7 +333,7 @@ public: /** * Return if the tree has data or not (e.g. keys and data or only keys). */ - static bool hasData() { return LeafNodeType::hasData(); } + static bool hasData() noexcept { return LeafNodeType::hasData(); } /** * Move the iterator directly to end. Used by findHelper method in BTree. diff --git a/vespalib/src/vespa/vespalib/btree/btreeiterator.hpp b/vespalib/src/vespa/vespalib/btree/btreeiterator.hpp index b7927feaa1a..d6dda0047ce 100644 --- a/vespalib/src/vespa/vespalib/btree/btreeiterator.hpp +++ b/vespalib/src/vespa/vespalib/btree/btreeiterator.hpp @@ -387,15 +387,13 @@ position(uint32_t levels) const res += inode->validLeaves(); for (uint32_t c = elem.getIdx(); c < slots; ++c) { BTreeNode::Ref node = inode->getChild(c); - const InternalNodeType *jnode = - _allocator->mapInternalRef(node); + const InternalNodeType *jnode = _allocator->mapInternalRef(node); res -= jnode->validLeaves(); } } else { for (uint32_t c = 0; c < elem.getIdx(); ++c) { BTreeNode::Ref node = inode->getChild(c); - const InternalNodeType *jnode = - _allocator->mapInternalRef(node); + const InternalNodeType *jnode = _allocator->mapInternalRef(node); res += jnode->validLeaves(); } } @@ -484,7 +482,7 @@ template <typename KeyT, typename DataT, typename AggrT, uint32_t INTERNAL_SLOTS, uint32_t LEAF_SLOTS, uint32_t PATH_SIZE> size_t BTreeIteratorBase<KeyT, DataT, AggrT, INTERNAL_SLOTS, LEAF_SLOTS, PATH_SIZE>:: -size() const +size() const noexcept { if (_pathSize > 0) { return _path[_pathSize - 1].getNode()->validLeaves(); diff --git a/vespalib/src/vespa/vespalib/btree/btreenode.h b/vespalib/src/vespa/vespalib/btree/btreenode.h index 0a77a0b4685..4931021d771 100644 --- a/vespalib/src/vespa/vespalib/btree/btreenode.h +++ b/vespalib/src/vespa/vespalib/btree/btreenode.h @@ -67,14 +67,14 @@ public: using Ref = datastore::EntryRef; using ChildRef = datastore::AtomicEntryRef; - bool isLeaf() const { return _level == 0u; } - bool getFrozen() const { return _isFrozen; } - void freeze() { _isFrozen = true; } - void unFreeze() { _isFrozen = false; } - void setLevel(uint8_t level) { _level = level; } - uint32_t getLevel() const { return _level; } - uint32_t validSlots() const { return _validSlots; } - void setValidSlots(uint16_t validSlots_) { _validSlots = validSlots_; } + bool isLeaf() const noexcept { return _level == 0u; } + bool getFrozen() const noexcept { return _isFrozen; } + void freeze() noexcept { _isFrozen = true; } + void unFreeze() noexcept { _isFrozen = false; } + void setLevel(uint8_t level) noexcept { _level = level; } + uint32_t getLevel() const noexcept { return _level; } + uint32_t validSlots() const noexcept { return _validSlots; } + void setValidSlots(uint16_t validSlots_) noexcept { _validSlots = validSlots_; } }; @@ -358,7 +358,7 @@ public: void insert(uint32_t idx, const KeyT & key, BTreeNode::Ref child) { insert(idx, key, BTreeNode::ChildRef(child)); } - uint32_t validLeaves() const { return _validLeaves; } + uint32_t validLeaves() const noexcept { return _validLeaves; } void setValidLeaves(uint32_t newValidLeaves) { _validLeaves = newValidLeaves; } void incValidLeaves(uint32_t delta) { _validLeaves += delta; } void decValidLeaves(uint32_t delta) { _validLeaves -= delta; } diff --git a/vespalib/src/vespa/vespalib/data/memorydatastore.cpp b/vespalib/src/vespa/vespalib/data/memorydatastore.cpp index 354787690c2..6d483e6ff4e 100644 --- a/vespalib/src/vespa/vespalib/data/memorydatastore.cpp +++ b/vespalib/src/vespa/vespalib/data/memorydatastore.cpp @@ -41,21 +41,4 @@ MemoryDataStore::push_back(const void * data, const size_t sz) return ref; } -VariableSizeVector::VariableSizeVector(size_t initialCount, size_t initialBufferSize) - : _vector(), - _store(Alloc::alloc(initialBufferSize)) -{ - _vector.reserve(initialCount); -} - -VariableSizeVector::~VariableSizeVector() = default; - -VariableSizeVector::Reference -VariableSizeVector::push_back(const void * data, const size_t sz) -{ - MemoryDataStore::Reference ptr(_store.push_back(data, sz)); - _vector.push_back(Reference(ptr.data(), sz)); - return _vector.back(); -} - } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/data/memorydatastore.h b/vespalib/src/vespa/vespalib/data/memorydatastore.h index 7022eb88051..a0280454a91 100644 --- a/vespalib/src/vespa/vespalib/data/memorydatastore.h +++ b/vespalib/src/vespa/vespalib/data/memorydatastore.h @@ -44,83 +44,5 @@ private: std::mutex * _lock; }; -class VariableSizeVector -{ -public: - class Reference { - public: - Reference(void * data_, size_t sz) noexcept : _data(data_), _sz(sz) { } - void * data() noexcept { return _data; } - const char * c_str() const noexcept { return static_cast<const char *>(_data); } - size_t size() const noexcept { return _sz; } - private: - void * _data; - size_t _sz; - }; - class iterator { - public: - iterator(vespalib::Array<Reference> & v, size_t index) noexcept : _vector(&v), _index(index) {} - Reference & operator * () const noexcept { return (*_vector)[_index]; } - Reference * operator -> () const noexcept { return &(*_vector)[_index]; } - iterator & operator ++ () noexcept { - _index++; - return *this; - } - iterator operator ++ (int) noexcept { - iterator prev = *this; - ++(*this); - return prev; - } - bool operator==(const iterator& rhs) const noexcept { return (_index == rhs._index); } - bool operator!=(const iterator& rhs) const noexcept { return (_index != rhs._index); } - private: - vespalib::Array<Reference> * _vector; - size_t _index; - }; - class const_iterator { - public: - const_iterator(const vespalib::Array<Reference> & v, size_t index) noexcept : _vector(&v), _index(index) {} - const Reference & operator * () const noexcept { return (*_vector)[_index]; } - const Reference * operator -> () const noexcept { return &(*_vector)[_index]; } - const_iterator & operator ++ () noexcept { - _index++; - return *this; - } - const_iterator operator ++ (int) noexcept { - const_iterator prev = *this; - ++(*this); - return prev; - } - bool operator==(const const_iterator& rhs) const noexcept { return (_index == rhs._index); } - bool operator!=(const const_iterator& rhs) const noexcept { return (_index != rhs._index); } - private: - const vespalib::Array<Reference> * _vector; - size_t _index; - }; - VariableSizeVector(const VariableSizeVector &) = delete; - VariableSizeVector & operator = (const VariableSizeVector &) = delete; - VariableSizeVector(size_t initialCount, size_t initialBufferSize); - ~VariableSizeVector(); - iterator begin() noexcept { return iterator(_vector, 0); } - iterator end() noexcept { return iterator(_vector, size()); } - const_iterator begin() const noexcept { return const_iterator(_vector, 0); } - const_iterator end() const noexcept { return const_iterator(_vector, size()); } - Reference push_back(const void * data, const size_t sz); - Reference operator [] (uint32_t index) const noexcept { return _vector[index]; } - size_t size() const noexcept { return _vector.size(); } - bool empty() const noexcept { return _vector.empty(); } - void swap(VariableSizeVector & rhs) noexcept { - _vector.swap(rhs._vector); - _store.swap(rhs._store); - } - void clear() { - _vector.clear(); - _store.clear(); - } -private: - vespalib::Array<Reference> _vector; - MemoryDataStore _store; -}; - } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/data/slime/named_symbol_lookup.h b/vespalib/src/vespa/vespalib/data/slime/named_symbol_lookup.h index 44dbf05c9da..fb4bc3943ae 100644 --- a/vespalib/src/vespa/vespalib/data/slime/named_symbol_lookup.h +++ b/vespalib/src/vespa/vespalib/data/slime/named_symbol_lookup.h @@ -20,7 +20,7 @@ private: const Memory &_name; public: - NamedSymbolLookup(const SymbolTable &table, const Memory &name) + NamedSymbolLookup(const SymbolTable &table, const Memory &name) noexcept : _table(table), _name(name) {} Symbol lookup() const override; }; diff --git a/vespalib/src/vespa/vespalib/data/slime/slime.cpp b/vespalib/src/vespa/vespalib/data/slime/slime.cpp index d6fac44b360..5a9ec54584d 100644 --- a/vespalib/src/vespa/vespalib/data/slime/slime.cpp +++ b/vespalib/src/vespa/vespalib/data/slime/slime.cpp @@ -6,16 +6,6 @@ namespace vespalib { -Slime::Params::Params() : Params(std::make_unique<SymbolTable>()) { } -Slime::Params::Params(std::unique_ptr<SymbolTable> symbols) noexcept : _symbols(std::move(symbols)), _chunkSize(4096) { } -Slime::Params::Params(Params &&) noexcept = default; -Slime::Params::~Params() = default; - -std::unique_ptr<slime::SymbolTable> -Slime::Params::detachSymbols() { - return std::move(_symbols); -} - Slime::Slime(Params params) : _names(params.detachSymbols()), _stash(std::make_unique<Stash>(params.getChunkSize())), @@ -33,26 +23,6 @@ Slime::reclaimSymbols(Slime &&rhs) { return std::move(rhs._names); } -size_t -Slime::symbols() const noexcept { - return _names->symbols(); -} - -Memory -Slime::inspect(Symbol symbol) const { - return _names->inspect(symbol); -} - -slime::Symbol -Slime::insert(Memory name) { - return _names->insert(name); -} - -slime::Symbol -Slime::lookup(Memory name) const { - return _names->lookup(name); -} - bool operator == (const Slime & a, const Slime & b) noexcept { return a.get() == b.get(); diff --git a/vespalib/src/vespa/vespalib/data/slime/slime.h b/vespalib/src/vespa/vespalib/data/slime/slime.h index a426f906563..4b789838009 100644 --- a/vespalib/src/vespa/vespalib/data/slime/slime.h +++ b/vespalib/src/vespa/vespalib/data/slime/slime.h @@ -21,6 +21,7 @@ #include "symbol.h" #include "symbol_inserter.h" #include "symbol_lookup.h" +#include "symbol_table.h" #include "type.h" #include "value.h" #include "value_factory.h" @@ -51,32 +52,30 @@ private: using Cursor = slime::Cursor; using Inspector = slime::Inspector; - std::unique_ptr<SymbolTable> _names; - std::unique_ptr<Stash> _stash; - RootValue _root; + std::unique_ptr<SymbolTable> _names; + std::unique_ptr<Stash> _stash; + RootValue _root; public: using UP = std::unique_ptr<Slime>; class Params { private: - std::unique_ptr<SymbolTable> _symbols; - size_t _chunkSize; + std::unique_ptr<SymbolTable> _symbols; + size_t _chunkSize; public: - Params(); - explicit Params(std::unique_ptr<SymbolTable> symbols) noexcept; - Params(Params &&) noexcept; - ~Params(); - Params & setChunkSize(size_t chunkSize) { - _chunkSize = chunkSize; - return *this; - } - size_t getChunkSize() const { return _chunkSize; } - std::unique_ptr<SymbolTable> detachSymbols(); + Params() : Params(4096) {} + explicit Params(size_t chunkSize) : _symbols(std::make_unique<SymbolTable>()), _chunkSize(chunkSize) {} + explicit Params(std::unique_ptr<SymbolTable> symbols) noexcept : _symbols(std::move(symbols)), _chunkSize(4096) {} + Params(Params &&) noexcept = default; + ~Params() = default; + size_t getChunkSize() const noexcept { return _chunkSize; } + std::unique_ptr<SymbolTable> detachSymbols() noexcept { return std::move(_symbols); } }; /** * Construct an initially empty Slime object. **/ - explicit Slime(Params params = Params()); + explicit Slime() : Slime(Params()) {} + explicit Slime(Params params); ~Slime(); @@ -88,13 +87,13 @@ public: static std::unique_ptr<SymbolTable> reclaimSymbols(Slime &&rhs); - size_t symbols() const noexcept; + size_t symbols() const noexcept { return _names->symbols(); } - Memory inspect(Symbol symbol) const; + Memory inspect(Symbol symbol) const { return _names->inspect(symbol); } - Symbol insert(Memory name); + Symbol insert(Memory name) { return _names->insert(name); } - Symbol lookup(Memory name) const; + Symbol lookup(Memory name) const { return _names->lookup(name); } Cursor &get() noexcept { return _root.get(); } diff --git a/vespalib/src/vespa/vespalib/data/slime/symbol.h b/vespalib/src/vespa/vespalib/data/slime/symbol.h index 3bce727fad9..a60a49fda27 100644 --- a/vespalib/src/vespa/vespalib/data/slime/symbol.h +++ b/vespalib/src/vespa/vespalib/data/slime/symbol.h @@ -19,8 +19,8 @@ private: public: Symbol() noexcept : _value(UNDEFINED) {} Symbol(uint32_t v) noexcept : _value(v) {} - bool undefined() const { return (_value == UNDEFINED); } - uint32_t getValue() const { return _value; } + bool undefined() const noexcept { return (_value == UNDEFINED); } + uint32_t getValue() const noexcept { return _value; } bool operator<(const Symbol &rhs) const noexcept { return (_value < rhs._value); } bool operator==(const Symbol &rhs) const noexcept { return (_value == rhs._value); } }; diff --git a/vespalib/src/vespa/vespalib/data/slime/symbol_table.cpp b/vespalib/src/vespa/vespalib/data/slime/symbol_table.cpp index a3313516c64..dffe35707fc 100644 --- a/vespalib/src/vespa/vespalib/data/slime/symbol_table.cpp +++ b/vespalib/src/vespa/vespalib/data/slime/symbol_table.cpp @@ -5,10 +5,13 @@ namespace vespalib::slime { -SymbolTable::SymbolTable(size_t expectedNumSymbols) : - _symbols(3*expectedNumSymbols), - _names(expectedNumSymbols, expectedNumSymbols*16) -{ } +SymbolTable::SymbolTable(size_t expectedNumSymbols) + : _stash(), + _symbols(3*expectedNumSymbols), + _names() +{ + _names.reserve(expectedNumSymbols); +} SymbolTable::~SymbolTable() = default; @@ -16,6 +19,7 @@ void SymbolTable::clear() { _names.clear(); _symbols.clear(); + _stash.clear(); } Symbol @@ -23,17 +27,21 @@ SymbolTable::insert(const Memory &name) { SymbolMap::const_iterator pos = _symbols.find(name); if (pos == _symbols.end()) { Symbol symbol(_names.size()); - SymbolVector::Reference r(_names.push_back(name.data, name.size)); - _symbols.insert(std::make_pair(Memory(r.c_str(), r.size()), symbol)); + char *buf = _stash.alloc(name.size); + memcpy(buf, name.data, name.size); + Memory backed(buf, name.size); + _names.push_back(backed); + _symbols.insert(std::make_pair(backed, symbol)); return symbol; } return pos->second; } + Symbol SymbolTable::lookup(const Memory &name) const { SymbolMap::const_iterator pos = _symbols.find(name); if (pos == _symbols.end()) { - return Symbol(); + return {}; } return pos->second; } diff --git a/vespalib/src/vespa/vespalib/data/slime/symbol_table.h b/vespalib/src/vespa/vespalib/data/slime/symbol_table.h index c5f3cf12fd6..0eae65cead0 100644 --- a/vespalib/src/vespa/vespalib/data/slime/symbol_table.h +++ b/vespalib/src/vespa/vespalib/data/slime/symbol_table.h @@ -4,8 +4,8 @@ #include "symbol.h" #include <vespa/vespalib/data/memory.h> +#include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/stllike/hash_map.h> -#include <vespa/vespalib/data/memorydatastore.h> namespace vespalib::slime { @@ -21,21 +21,24 @@ private: } }; using SymbolMap = hash_map<Memory, Symbol, hasher>; - using SymbolVector = VariableSizeVector; + using SymbolVector = std::vector<Memory>; + Stash _stash; SymbolMap _symbols; SymbolVector _names; public: using UP = std::unique_ptr<SymbolTable>; - SymbolTable(size_t expectedNumSymbols=16); + SymbolTable() : SymbolTable(16) {} + explicit SymbolTable(size_t expectedNumSymbols); + SymbolTable(SymbolTable &&) noexcept = default; + SymbolTable & operator=(SymbolTable &&) noexcept = default; ~SymbolTable(); size_t symbols() const noexcept { return _names.size(); } Memory inspect(const Symbol &symbol) const { if (symbol.getValue() > _names.size()) { return Memory(); } - SymbolVector::Reference r(_names[symbol.getValue()]); - return Memory(r.c_str(), r.size()); + return _names[symbol.getValue()]; } Symbol insert(const Memory &name); Symbol lookup(const Memory &name) const; diff --git a/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt b/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt index 5e8d29980cd..8ccef84d969 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt @@ -7,6 +7,7 @@ vespa_add_library(vespalib_vespalib_fuzzy OBJECT implicit_levenshtein_dfa.cpp levenshtein_dfa.cpp levenshtein_distance.cpp + table_dfa.cpp unicode_utils.cpp DEPENDS ) diff --git a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp index 1caae408176..5f6d0ae9956 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp +++ b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "explicit_levenshtein_dfa.h" #include "implicit_levenshtein_dfa.h" +#include "table_dfa.h" #include "levenshtein_dfa.h" #include "unicode_utils.h" #include <vespa/vespalib/util/stringfmt.h> @@ -53,14 +54,19 @@ LevenshteinDfa LevenshteinDfa::build(std::string_view target_string, uint8_t max } else { // max_edits == 2 return LevenshteinDfa(std::make_unique<ImplicitLevenshteinDfa<FixedMaxEditDistanceTraits<2>>>(std::move(target_string_u32), is_cased)); } - } else { // DfaType::Explicit + } else if(dfa_type == DfaType::Explicit) { if (max_edits == 1) { return ExplicitLevenshteinDfaBuilder<FixedMaxEditDistanceTraits<1>>(std::move(target_string_u32), is_cased).build_dfa(); } else { // max_edits == 2 return ExplicitLevenshteinDfaBuilder<FixedMaxEditDistanceTraits<2>>(std::move(target_string_u32), is_cased).build_dfa(); } + } else { // DfaType::Table + if (max_edits == 1) { + return LevenshteinDfa(std::make_unique<TableDfa<1>>(std::move(target_string_u32), is_cased)); + } else { // max_edits == 2 + return LevenshteinDfa(std::make_unique<TableDfa<2>>(std::move(target_string_u32), is_cased)); + } } - } LevenshteinDfa LevenshteinDfa::build(std::string_view target_string, uint8_t max_edits, Casing casing) { @@ -87,9 +93,11 @@ std::ostream& operator<<(std::ostream& os, const LevenshteinDfa::MatchResult& mo std::ostream& operator<<(std::ostream& os, LevenshteinDfa::DfaType dt) { if (dt == LevenshteinDfa::DfaType::Implicit) { os << "Implicit"; - } else { - assert(dt == LevenshteinDfa::DfaType::Explicit); + } else if (dt == LevenshteinDfa::DfaType::Explicit) { os << "Explicit"; + } else { + assert(dt == LevenshteinDfa::DfaType::Table); + os << "Table"; } return os; } diff --git a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h index c6ca06d4de3..6c2724fbe79 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h +++ b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h @@ -58,8 +58,8 @@ namespace vespalib::fuzzy { * ====== Unicode support ====== * * Matching and successor generation is fully Unicode-aware. All input strings are expected - * to be in UTF-8, and the generated successor is also encoded as UTF-8 (with some caveats; - * see the documentation for match()). + * to be in UTF-8, and the generated successor is encoded as UTF-8 (with some caveats; see + * the documentation for match()) or UTF-32, depending on the chosen `match()` overload. * * Internally, matching is done on UTF-32 code points and the DFA itself is built around * UTF-32. This is unlike Lucene, which converts a UTF-32 DFA to an equivalent UTF-8 DFA. @@ -159,7 +159,7 @@ public: /** * Attempts to match the source string `source` with the target string this DFA was - * built with, emitting a successor string on mismatch if `successor_out` != nullptr. + * built with. * * `source` must not contain any null UTF-8 chars. * @@ -181,11 +181,14 @@ public: * * See `match(source)` for semantics of returned MatchResult. * + * In the case of a _match_, the contents of `successor_out` is unspecified. It may be + * preemptively modified as part of the matching loop itself. + * * In the case of a _mismatch_, the following holds: * - * - `successor_out` is modified to contain the next (in byte-wise ordering) possible - * _matching_ string S so that there exists no other matching string S' that is - * greater than `source` but smaller than S. + * - `successor_out` contains the next (in byte-wise ordering) possible _matching_ + * string S so that there exists no other matching string S' that is greater than + * `source` but smaller than S. * - `successor_out` contains UTF-8 bytes that are within what UTF-8 can legally * encode in bitwise form, but the _code points_ they encode may not be valid. * In particular, surrogate pair ranges and U+10FFFF+1 may be encoded, neither of @@ -203,12 +206,8 @@ public: * is what is passed to the DFA match() function. * * Memory allocation: - * This function does not directly or indirectly allocate any heap memory if either: - * - * - the input string is within the max edit distance, or - * - `successor_out` is nullptr, or - * - `successor_out` has sufficient capacity to hold the generated successor - * + * This function does not directly or indirectly allocate any heap memory if the + * `successor_out` string provided is large enough to fit any generated successor. * By reusing the successor string across many calls, this therefore amortizes memory * allocations down to near zero per invocation. */ @@ -220,7 +219,8 @@ public: * internally, and is therefore expected to be more efficient. * * The code point ordering of the UTF-32 successor string is identical to that its UTF-8 - * equivalent. + * equivalent. This includes the special cases where the successor may contain code points + * outside the legal Unicode range. */ [[nodiscard]] MatchResult match(std::string_view source, std::vector<uint32_t>& successor_out) const; @@ -231,7 +231,8 @@ public: enum class DfaType { Implicit, - Explicit + Explicit, + Table }; /** diff --git a/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp b/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp index fb5ec32abc7..654889d87bf 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp +++ b/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp @@ -136,39 +136,37 @@ struct MatchAlgorithm { * that has nothing in common with the source altogether. * Example: "gp" -> "hfood" (+1 char value case) * - * Performance note: - * Both the input and successor output strings are in UTF-8 format. To avoid doing - * duplicate work, we keep track of the byte length of the string prefix that will be - * part of the successor and simply copy it verbatim instead of building the string - * from converted UTF-32 -> UTF-8 chars as we go. This optimization cannot be used - * when one or more of the prefix characters have been lowercase-transformed. + * Note for cased vs. uncased matching: when uncased matching is specified, we always + * match "as if" both the target and source strings are lowercased. This means that + * successor strings are generated based on this form, _not_ on the original form. + * Example: uncased matching for target "food" with input "FOXX". This generates the + * successor "foyd" (and _not_ "FOyd"), as the latter would imply a completely different + * ordering when compared byte-wise against an implicitly lowercased dictionary. * * TODO let matcher know if source string is pre-normalized (i.e. lowercased). - * TODO consider opportunistically appending prefix as we go instead of only when needed. */ template <DfaMatcher Matcher, typename SuccessorT> static MatchResult match(const Matcher& matcher, std::string_view source, SuccessorT& successor_out) { + successor_out.clear(); // TODO allow for preserving existing prefix + using StateType = typename Matcher::StateType; Utf8Reader u8_reader(source.data(), source.size()); - uint32_t n_prefix_u8_bytes = 0; + uint32_t n_prefix_chars = 0; uint32_t char_after_prefix = 0; StateType last_state_with_higher_out = StateType{}; - bool can_use_raw_prefix = true; StateType state = matcher.start(); while (u8_reader.hasMore()) { - const auto u8_pos_before_char = u8_reader.getPos(); - const uint32_t raw_mch = u8_reader.getChar(); - const uint32_t mch = normalized_match_char(raw_mch, matcher.is_cased()); - if (raw_mch != mch) { - can_use_raw_prefix = false; // FIXME this is pessimistic; considers entire string, not just prefix - } + const auto pos_before_char = static_cast<uint32_t>(successor_out.size()); + const uint32_t raw_mch = u8_reader.getChar(); + const uint32_t mch = normalized_match_char(raw_mch, matcher.is_cased()); + append_utf32_char(successor_out, mch); if (matcher.has_higher_out_edge(state, mch)) { last_state_with_higher_out = state; - n_prefix_u8_bytes = u8_pos_before_char; + n_prefix_chars = pos_before_char; char_after_prefix = mch; } auto maybe_next = matcher.match_input(state, mch); @@ -176,8 +174,7 @@ struct MatchAlgorithm { state = maybe_next; } else { // Can never match; find the successor - emit_successor_prefix(successor_out, source, n_prefix_u8_bytes, - matcher.is_cased() || can_use_raw_prefix); + successor_out.resize(n_prefix_chars); // Always <= successor_out.size() assert(matcher.valid_state(last_state_with_higher_out)); backtrack_and_emit_greater_suffix(matcher, last_state_with_higher_out, char_after_prefix, successor_out); @@ -188,8 +185,7 @@ struct MatchAlgorithm { if (edits <= max_edits()) { return MatchResult::make_match(max_edits(), edits); } - emit_successor_prefix(successor_out, source, source.size(), - matcher.is_cased() || can_use_raw_prefix); + // Successor prefix already filled, just need to emit the suffix emit_smallest_matching_suffix(matcher, state, successor_out); return MatchResult::make_mismatch(max_edits()); } @@ -320,48 +316,6 @@ struct MatchAlgorithm { } } - template <typename T> - static constexpr bool has_8bit_value_type() noexcept { - return sizeof(typename T::value_type) == 1; - } - - /** - * The successor prefix is the prefix of the source string up to (but not including) the - * point where we emit a lexicographically higher character. Ideally we can just copy the - * UTF-8 bytes verbatim from the source into the successor. This is possible when one of - * the following holds: - * - * - DFA uses Cased (i.e. exact) matching, or - * - DFA uses Uncased, but none of the characters in the prefix triggered a lowercase - * transform. This means the prefix is already as-if lowercased, and we can copy it - * verbatim. - * - * In the case that we can't copy verbatim, we currently have to explicitly normalize the - * prefix by converting it to its lowercased form. - * - * Example: Uncased matching for target "food" with input "FOXX". This generates the - * successor "foyd" (and _not_ "FOyd"), as the latter would imply a completely different - * ordering when compared byte-wise against an implicitly lowercased dictionary. - */ - template <typename SuccessorT> - static void emit_successor_prefix(SuccessorT& successor_out, std::string_view source, - uint32_t n_prefix_u8_bytes, bool emit_raw_prefix_u8_bytes) - { - // TODO redesign prefix output wiring - if constexpr (has_8bit_value_type<SuccessorT>()) { - if (emit_raw_prefix_u8_bytes) { - successor_out = source.substr(0, n_prefix_u8_bytes); - return; - } - } - // TODO avoid duplicate work...! :I - successor_out.clear(); - Utf8Reader u8_reader(source.data(), source.size()); - while (u8_reader.getPos() < n_prefix_u8_bytes) { - append_utf32_char(successor_out, LowerCase::convert(u8_reader.getChar())); - } - } - static uint32_t normalized_match_char(uint32_t in_ch, bool is_cased) noexcept { return (is_cased ? in_ch : LowerCase::convert(in_ch)); } diff --git a/vespalib/src/vespa/vespalib/fuzzy/sparse_state.h b/vespalib/src/vespa/vespalib/fuzzy/sparse_state.h index d20cfc07a9a..dfec0bac4a8 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/sparse_state.h +++ b/vespalib/src/vespa/vespalib/fuzzy/sparse_state.h @@ -112,11 +112,11 @@ std::ostream& operator<<(std::ostream& os, const FixedSparseState<MaxEdits>& s) if (i != 0) { os << ", "; } - for (size_t j = last_idx; j < s.indices[i]; ++j) { + for (size_t j = last_idx; j < s.index(i); ++j) { os << "-, "; } - last_idx = s.indices[i] + 1; - os << static_cast<uint32_t>(s.costs[i]); + last_idx = s.index(i) + 1; + os << static_cast<uint32_t>(s.cost(i)); } os << "]"; return os; diff --git a/vespalib/src/vespa/vespalib/fuzzy/table_dfa.cpp b/vespalib/src/vespa/vespalib/fuzzy/table_dfa.cpp new file mode 100644 index 00000000000..943349818fb --- /dev/null +++ b/vespalib/src/vespa/vespalib/fuzzy/table_dfa.cpp @@ -0,0 +1,10 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "table_dfa.hpp" + +namespace vespalib::fuzzy { + +template class TableDfa<1>; +template class TableDfa<2>; + +} diff --git a/vespalib/src/vespa/vespalib/fuzzy/table_dfa.h b/vespalib/src/vespa/vespalib/fuzzy/table_dfa.h new file mode 100644 index 00000000000..dc511a42d04 --- /dev/null +++ b/vespalib/src/vespa/vespalib/fuzzy/table_dfa.h @@ -0,0 +1,64 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "levenshtein_dfa.h" +#include <vector> + +namespace vespalib::fuzzy { + +/** + * This implementation is based on the paper 'Fast string correction + * with Levenshtein automata' from 2002 by Klaus U. Schulz and Stoyan + * Mihov. + * + * Given the maximal distance N, a generic parameterized transition + * table is calculated up-front. When a specific word is given, a + * simple lookup structure is created to enumerate the possible + * characteristic vectors for each position in the given + * word. Together, these structures can be used to simulate the + * traversal of a hypothetical Levenshtein dfa that will never be + * created. + * + * Approaching the end of the word is handled by padding the + * characteristic vectors with 0 bits for everything after the word + * ends. In addition, a unit test verifies that there is no possible + * sequence of events that leads to the minimal boundary of the state + * exceeding the boundary of the word itself. This means that the + * simulated dfa can be stepped freely without checking for word size. + **/ +template <uint8_t N> +class TableDfa final : public LevenshteinDfa::Impl +{ +public: + // characteristic vector for a specific input value indicating how + // it matches the window starting at the minimal boundary. + struct CV { + uint32_t input; + uint32_t match; + CV() noexcept : input(0), match(0) {} + }; + static constexpr size_t window_size() { return 2 * N + 1; } + struct Lookup { + std::array<CV, window_size()> list; + Lookup() noexcept : list() {} + }; + +private: + const void *_tfa; + const std::vector<Lookup> _lookup; + const bool _is_cased; + + static std::vector<Lookup> make_lookup(const std::vector<uint32_t> &str); + +public: + using MatchResult = LevenshteinDfa::MatchResult; + TableDfa(std::vector<uint32_t> str, bool is_cased); + ~TableDfa() override; + [[nodiscard]] MatchResult match(std::string_view source) const override; + [[nodiscard]] MatchResult match(std::string_view source, std::string& successor_out) const override; + [[nodiscard]] MatchResult match(std::string_view source, std::vector<uint32_t>& successor_out) const override; + [[nodiscard]] size_t memory_usage() const noexcept override; + void dump_as_graphviz(std::ostream& os) const override; +}; + +} diff --git a/vespalib/src/vespa/vespalib/fuzzy/table_dfa.hpp b/vespalib/src/vespa/vespalib/fuzzy/table_dfa.hpp new file mode 100644 index 00000000000..0693343007e --- /dev/null +++ b/vespalib/src/vespa/vespalib/fuzzy/table_dfa.hpp @@ -0,0 +1,497 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "table_dfa.h" +#include "match_algorithm.hpp" +#include <vespa/vespalib/util/stringfmt.h> +#include <cassert> +#include <stdexcept> +#include <algorithm> +#include <map> + +namespace vespalib::fuzzy { + +namespace { + +using vespalib::make_string_short::fmt; + +// It is useful to know the number of states compile time to be able +// to pack lookup tables better. +template <uint8_t N> constexpr size_t num_states(); +template <> constexpr size_t num_states<1>() { return 6; } +template <> constexpr size_t num_states<2>() { return 31; } +template <> constexpr size_t num_states<3>() { return 197; } + +template <uint8_t N> constexpr size_t window_size() { return 2 * N + 1; } +template <uint8_t N> constexpr size_t num_transitions() { return 1 << window_size<N>(); } + + +auto diff(auto a, auto b) { return (a > b) ? (a - b) : (b - a); } + +// A Position combines an index into a word being matched with the +// number of edits needed to get there. This maps directly onto a +// specific state in the NFA used to match a word. Note that the sort +// order prefers low edits over low indexs. This is to ensure that a +// position that subsumes another position will always sort before it. +struct Position { + uint32_t index; + uint32_t edits; + Position(uint32_t index_in, uint32_t edits_in) noexcept + : index(index_in), edits(edits_in) {} + static Position start() noexcept { return Position(0,0); } + bool subsumes(const Position &rhs) const noexcept { + if (edits >= rhs.edits) { + return false; + } + return diff(index, rhs.index) <= (rhs.edits - edits); + } + Position materialize(uint32_t target_index) const noexcept { + return Position(target_index, edits + diff(index, target_index)); + } + bool operator==(const Position &rhs) const noexcept { + return (index == rhs.index) && (edits == rhs.edits); + } + bool operator<(const Position &rhs) const noexcept { + return std::tie(edits,index) < std::tie(rhs.edits, rhs.index); + } + template <uint8_t N> + void add_elementary_transitions(const std::vector<bool> &bits, std::vector<Position> &dst) const { + assert(bits.size() > index); + if (!bits[index]) { + dst.emplace_back(index, edits + 1); + dst.emplace_back(index + 1, edits + 1); + } + for (uint32_t e = 0; (edits + e) <= N; ++e) { + assert(bits.size() > (index + e)); + if (bits[index + e]) { + dst.emplace_back(index + e + 1, edits + e); + } + } + } + vespalib::string to_string() const { return fmt("%u#%u", index, edits); } +}; + +// A State is a collection of different Positions that do not subsume +// each other. If the minimal boundary of a state is larger than 0, it +// can be lifted from the state in a normalizing operation that will +// renumber the position indexes such that the minimal boundary of the +// state becomes 0. This is to allow parameterized states where the +// general progress of matching the string (minimal boundary of +// non-normalized state) is untangled from the local competing edit +// alternatives (normalized state). +struct State { + std::vector<Position> list; + State() noexcept : list() {} + static State failed() noexcept { return State(); } + static State start() { + State result; + result.list.push_back(Position::start()); + return result; + } + bool operator<(const State &rhs) const { + return list < rhs.list; + } + uint32_t minimal_boundary() const noexcept { + if (list.empty()) { + return 0; + } + uint32_t min = list[0].index; + for (size_t i = 1; i < list.size(); ++i) { + min = std::min(min, list[i].index); + } + return min; + } + uint32_t normalize() { + uint32_t min = minimal_boundary(); + if (min > 0) { + for (auto &entry: list) { + entry.index -= min; + } + } + return min; + } + template <uint8_t N> + static State create(std::vector<Position> list_in) { + State result; + auto want = [&result](Position pos) { + if (pos.edits > N) { + return false; + } + for (const auto &old_pos: result.list) { + if (old_pos == pos || old_pos.subsumes(pos)) { + return false; + } + } + return true; + }; + std::sort(list_in.begin(), list_in.end()); + for (const auto &pos: list_in) { + if (want(pos)) { + result.list.push_back(pos); + } + } + return result; + } + template <uint8_t N> + State next(const std::vector<bool> &bits) const { + std::vector<Position> tmp; + for (const auto &pos: list) { + pos.add_elementary_transitions<N>(bits, tmp); + } + return create<N>(std::move(tmp)); + } + template <uint8_t N> + std::vector<uint8_t> make_edit_vector() const { + std::vector<uint8_t> result(window_size<N>(), N + 1); + for (const auto &pos: list) { + for (uint32_t i = 0; i < window_size<N>(); ++i) { + result[i] = std::min(result[i], uint8_t(pos.materialize(i).edits)); + } + } + return result; + } + vespalib::string to_string() const { + vespalib::string result = "{"; + for (size_t i = 0; i < list.size(); ++i) { + if (i > 0) { + result.append(","); + } + result.append(list[i].to_string()); + } + result.append("}"); + return result; + } +}; + +// Keeps track of unique states, assigning an integer value to each +// state. Only states with minimal boundary 0 is allowed to be +// inserted into a state repo. Each repo is seeded with the empty +// state (0) and the start state (1). An assigned integer value can be +// mapped back into the originating state. +struct StateRepo { + using Map = std::map<State,uint32_t>; + using Ref = Map::iterator; + Map seen; + std::vector<Ref> refs; + StateRepo() noexcept : seen(), refs() { + auto failed_idx = state_to_idx(State::failed()); + auto start_idx = state_to_idx(State::start()); + assert(failed_idx == 0); + assert(start_idx == 1); + } + ~StateRepo(); + size_t size() const { return seen.size(); } + uint32_t state_to_idx(const State &state) { + assert(state.minimal_boundary() == 0); + uint32_t next = refs.size(); + auto [pos, inserted] = seen.emplace(state, next); + if (inserted) { + refs.push_back(pos); + } + assert(seen.size() == refs.size()); + return pos->second; + } + const State &idx_to_state(uint32_t idx) const { + assert(idx < refs.size()); + return refs[idx]->first; + } +}; +StateRepo::~StateRepo() = default; + +template <uint8_t N> +std::vector<bool> expand_bits(uint32_t value) { + static_assert(N < 10); + std::vector<bool> result(window_size<N>()); + uint32_t look_for = num_transitions<N>(); + assert(value < look_for); + for (size_t i = 0; i < result.size(); ++i) { + look_for >>= 1; + result[i] = (value & look_for); + } + return result; +} + +template <uint8_t N> +[[maybe_unused]] StateRepo make_state_repo() { + StateRepo repo; + for (uint32_t idx = 0; idx < repo.size(); ++idx) { + const State &state = repo.idx_to_state(idx); + for (uint32_t i = 0; i < num_transitions<N>(); ++i) { + State new_state = state.next<N>(expand_bits<N>(i)); + (void) new_state.normalize(); + (void) repo.state_to_idx(new_state); + } + } + return repo; +} + +// this is the result of our efforts +template <uint8_t N> +struct Tfa { + struct Entry { + uint8_t step; + uint8_t state; + }; + + // what happens when following a transition from a state? + std::array<std::array<Entry,num_transitions<N>()>,num_states<N>()> table; + + // how many edits did we use to match the target word? + std::array<std::array<uint8_t,window_size<N>()>,num_states<N>()> edits; +}; + +template <uint8_t N> +std::unique_ptr<Tfa<N>> make_tfa() { + auto tfa = std::make_unique<Tfa<N>>(); + StateRepo repo; + uint32_t state_idx = 0; + for (; state_idx < repo.size(); ++state_idx) { + const State &state = repo.idx_to_state(state_idx); + for (uint32_t i = 0; i < num_transitions<N>(); ++i) { + State new_state = state.next<N>(expand_bits<N>(i)); + uint32_t step = new_state.normalize(); + uint32_t new_state_idx = repo.state_to_idx(new_state); + assert(step < 256); + assert(new_state_idx < 256); + tfa->table[state_idx][i].step = step; + tfa->table[state_idx][i].state = new_state_idx; + } + auto edits = state.make_edit_vector<N>(); + assert(edits.size() == window_size<N>()); + for (uint32_t i = 0; i < window_size<N>(); ++i) { + tfa->edits[state_idx][i] = edits[i]; + } + } + assert(repo.size() == num_states<N>()); + assert(state_idx == num_states<N>()); + return tfa; +} + +template <uint8_t N> +const Tfa<N> *get_tfa() { + static std::unique_ptr<Tfa<N>> tfa = make_tfa<N>(); + return tfa.get(); +} + +template <typename T> +vespalib::string format_vector(const std::vector<T> &vector, bool compact = false) { + vespalib::string str = compact ? "" : "["; + for (size_t i = 0; i < vector.size(); ++i) { + if (i > 0 && !compact) { + str.append(","); + } + str.append(fmt("%u", uint32_t(vector[i]))); + } + if (!compact) { + str.append("]"); + } + return str; +} + +template <uint8_t N> +struct TableMatcher { + struct S { + uint32_t index; + uint32_t state; + // needed by dfa matcher concept (should use std::declval instead) + constexpr S() noexcept : index(0), state(0) {} + constexpr S(uint32_t i, uint32_t s) noexcept : index(i), state(s) {} + S next(const Tfa<N> *tfa, uint32_t bits) noexcept { + auto entry = tfa->table[state][bits]; + return S(index + entry.step, entry.state); + } + constexpr bool is_valid_edge(const Tfa<N> *tfa, uint32_t bits) const noexcept { + return tfa->table[state][bits].state != 0; + } + }; + using StateType = S; + using StateParamType = StateType; + using EdgeType = uint32_t; + + const Tfa<N> *tfa; + const TableDfa<N>::Lookup *lookup; + const uint32_t end; + const bool cased; + + TableMatcher(const Tfa<N> *tfa_in, const TableDfa<N>::Lookup *lookup_in, uint32_t end_in, bool cased_in) + noexcept : tfa(tfa_in), lookup(lookup_in), end(end_in), cased(cased_in) {} + + bool is_cased() const noexcept { return cased; } + static constexpr S start() noexcept { return S(0, 1); } + + uint8_t match_edit_distance(S s) const noexcept { + uint32_t leap = (end - s.index); + return (leap < window_size<N>()) ? tfa->edits[s.state][leap] : N + 1; + } + bool is_match(S s) const noexcept { return match_edit_distance(s) <= N; } + + static constexpr bool can_match(S s) noexcept { return (s.state != 0); } + static constexpr bool valid_state(S s) noexcept { return (s.state != 0); } + + S match_wildcard(S s) const noexcept { return s.next(tfa, 0); } + S match_input(S s, uint32_t c) const noexcept { + const auto *slice = lookup[s.index].list.data(); + for (size_t i = 0; i < window_size<N>() && slice[i].input != 0; ++i) { + if (slice[i].input == c) { + return s.next(tfa, slice[i].match); + } + } + return match_wildcard(s); + } + + bool has_higher_out_edge(S s, uint32_t c) const noexcept { + if (s.is_valid_edge(tfa, 0)) { + return true; + } + const auto *slice = lookup[s.index].list.data(); + for (size_t i = 0; i < window_size<N>() && slice[i].input > c; ++i) { + if (s.is_valid_edge(tfa, slice[i].match)) { + return true; + } + } + return false; + } + + bool has_exact_explicit_out_edge(S s, uint32_t c) const noexcept { + const auto *slice = lookup[s.index].list.data(); + for (size_t i = 0; i < window_size<N>() && slice[i].input >= c; ++i) { + if (slice[i].input == c) { + return s.is_valid_edge(tfa, slice[i].match); + } + } + return false; + } + + uint32_t lowest_higher_explicit_out_edge(S s, uint32_t c) const noexcept { + const auto *slice = lookup[s.index].list.data(); + size_t i = window_size<N>(); + while (i-- > 0) { + if (slice[i].input > c && s.is_valid_edge(tfa, slice[i].match)) { + return slice[i].input; + } + } + return 0; + } + + uint32_t smallest_explicit_out_edge(S s) const noexcept { + const auto *slice = lookup[s.index].list.data(); + size_t i = window_size<N>(); + while (i-- > 0) { + if (slice[i].input != 0 && s.is_valid_edge(tfa, slice[i].match)) { + return slice[i].input; + } + } + return 0; + } + + static constexpr bool valid_edge(uint32_t c) noexcept { return c != 0; } + static constexpr uint32_t edge_to_u32char(uint32_t c) noexcept { return c; } + S edge_to_state(S s, uint32_t c) const noexcept { return match_input(s, c); } + + static constexpr bool implies_exact_match_suffix(S) noexcept { return false; } + static constexpr void emit_exact_match_suffix(S, std::vector<uint32_t> &) {} // not called + static constexpr void emit_exact_match_suffix(S, std::string &) {} // not called +}; + +} // unnamed + +template <uint8_t N> +auto +TableDfa<N>::make_lookup(const std::vector<uint32_t> &str)->std::vector<Lookup> +{ + std::vector<Lookup> result(str.size() + 1); + auto have_already = [&](uint32_t c, size_t i)noexcept{ + for (size_t j = 0; j < window_size(); ++j) { + if (result[i].list[j].input == c) { + return true; + } + } + return false; + }; + auto make_vector = [&](uint32_t c, size_t i)noexcept{ + uint32_t bits = 0; + for (size_t j = 0; j < window_size(); ++j) { + bool found = ((i + j) < str.size()) && (str[i+j] == c); + bits = (bits << 1) + found; + } + return bits; + }; + for (size_t i = 0; i < str.size(); ++i) { + for (size_t j = 0; j < window_size(); ++j) { + assert(result[i].list[j].input == 0); + assert(result[i].list[j].match == 0); + if ((i + j) < str.size()) { + uint32_t c = str[i + j]; + if (!have_already(c, i)) { + result[i].list[j].input = c; + result[i].list[j].match = make_vector(c, i); + } + } + } + std::sort(result[i].list.begin(), result[i].list.end(), + [](const auto &a, const auto &b){ return a.input > b.input; }); + } + return result; +} + +template <uint8_t N> +TableDfa<N>::TableDfa(std::vector<uint32_t> str, bool is_cased) + : _tfa(get_tfa<N>()), + _lookup(make_lookup(str)), + _is_cased(is_cased) +{ +} + +template <uint8_t N> +TableDfa<N>::~TableDfa() = default; + +template <uint8_t N> +LevenshteinDfa::MatchResult +TableDfa<N>::match(std::string_view u8str) const +{ + TableMatcher matcher(static_cast<const Tfa<N>*>(_tfa), _lookup.data(), _lookup.size() - 1, _is_cased); + return MatchAlgorithm<N>::match(matcher, u8str); +} + +template <uint8_t N> +LevenshteinDfa::MatchResult +TableDfa<N>::match(std::string_view u8str, std::string& successor_out) const +{ + TableMatcher matcher(static_cast<const Tfa<N>*>(_tfa), _lookup.data(), _lookup.size() - 1, _is_cased); + return MatchAlgorithm<N>::match(matcher, u8str, successor_out); +} + +template <uint8_t N> +LevenshteinDfa::MatchResult +TableDfa<N>::match(std::string_view u8str, std::vector<uint32_t>& successor_out) const +{ + TableMatcher matcher(static_cast<const Tfa<N>*>(_tfa), _lookup.data(), _lookup.size() - 1, _is_cased); + return MatchAlgorithm<N>::match(matcher, u8str, successor_out); +} + +template <uint8_t N> +size_t +TableDfa<N>::memory_usage() const noexcept +{ + return _lookup.size() * sizeof(Lookup); +} + +template <uint8_t N> +void +TableDfa<N>::dump_as_graphviz(std::ostream &os) const +{ + os << std::dec << "digraph table_dfa {\n"; + for (size_t i = 0; i < _lookup.size(); ++i) { + for (size_t j = 0; j < window_size(); ++j) { + if (_lookup[i].list[j].input != 0) { + std::string as_utf8; + append_utf32_char(as_utf8, _lookup[i].list[j].input); + os << " x" << i << " -> " << _lookup[i].list[j].match << " [label=\"" << as_utf8 << "\"];\n"; + } + } + os << " x" << i << " -> 0 [label=\"*\"];\n"; + } + os << "}\n"; +} + +} diff --git a/vespalib/src/vespa/vespalib/stllike/hashtable.h b/vespalib/src/vespa/vespalib/stllike/hashtable.h index e290d2f626c..fa88bb038b4 100644 --- a/vespalib/src/vespa/vespalib/stllike/hashtable.h +++ b/vespalib/src/vespa/vespalib/stllike/hashtable.h @@ -62,7 +62,7 @@ public: class prime_modulator { public: - prime_modulator(next_t sizeOfHashTable) noexcept : _modulo(sizeOfHashTable) { } + explicit prime_modulator(next_t sizeOfHashTable) noexcept : _modulo(sizeOfHashTable) { } next_t modulo(next_t hash) const noexcept { return hash % _modulo; } next_t getTableSize() const noexcept { return _modulo; } static next_t selectHashTableSize(size_t sz) { return hashtable_base::getModuloStl(sz); } @@ -76,7 +76,7 @@ public: class and_modulator { public: - and_modulator(next_t sizeOfHashTable) noexcept : _mask(sizeOfHashTable-1) { } + explicit and_modulator(next_t sizeOfHashTable) noexcept : _mask(sizeOfHashTable-1) { } next_t modulo(next_t hash) const noexcept { return hash & _mask; } next_t getTableSize() const noexcept { return _mask + 1; } static next_t selectHashTableSize(size_t sz) noexcept { return hashtable_base::getModuloSimple(sz); } @@ -198,7 +198,7 @@ public: using pointer = Value*; using iterator_category = std::forward_iterator_tag; - constexpr iterator(hashtable * hash) noexcept : _current(0), _hashTable(hash) { + constexpr explicit iterator(hashtable * hash) noexcept : _current(0), _hashTable(hash) { if (! _hashTable->_nodes[_current].valid()) { advanceToNextValidHash(); } @@ -242,7 +242,7 @@ public: using pointer = const Value*; using iterator_category = std::forward_iterator_tag; - constexpr const_iterator(const hashtable * hash) noexcept : _current(0), _hashTable(hash) { + constexpr explicit const_iterator(const hashtable * hash) noexcept : _current(0), _hashTable(hash) { if (! _hashTable->_nodes[_current].valid()) { advanceToNextValidHash(); } @@ -282,7 +282,7 @@ public: hashtable & operator = (hashtable &&) noexcept = default; hashtable(const hashtable &); hashtable & operator = (const hashtable &); - hashtable(size_t reservedSpace); + explicit hashtable(size_t reservedSpace); hashtable(size_t reservedSpace, const Hash & hasher, const Equal & equal); virtual ~hashtable(); constexpr iterator begin() noexcept { return iterator(this); } diff --git a/vespalib/src/vespa/vespalib/stllike/hashtable.hpp b/vespalib/src/vespa/vespalib/stllike/hashtable.hpp index 6d2d397a887..040e421f68c 100644 --- a/vespalib/src/vespa/vespalib/stllike/hashtable.hpp +++ b/vespalib/src/vespa/vespalib/stllike/hashtable.hpp @@ -53,8 +53,7 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::hashtable(size_t rese _nodes(createStore<NodeStore>(reservedSpace, _modulator.getTableSize())), _hasher(hasher), _equal(equal) -{ -} +{ } template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator > hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::hashtable(const hashtable &) = default; @@ -130,7 +129,7 @@ hashtable<Key, Value, Hash, Equal, KeyExtract, Modulator>::insert_internal(V && _count++; return insert_result(iterator(this, h), true); } - return insert_internal_cold(std::move(node), h); + return insert_internal_cold(std::forward<V>(node), h); } template< typename Key, typename Value, typename Hash, typename Equal, typename KeyExtract, typename Modulator > diff --git a/vespalib/src/vespa/vespalib/text/utf8.h b/vespalib/src/vespa/vespalib/text/utf8.h index 99e3f8cfe13..489b16b1ed4 100644 --- a/vespalib/src/vespa/vespalib/text/utf8.h +++ b/vespalib/src/vespa/vespalib/text/utf8.h @@ -321,6 +321,7 @@ public: return i; } + const char* get_current_ptr() const noexcept { return _p; } }; diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java index 80646dc5607..1bed85b1c02 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java @@ -166,19 +166,19 @@ public class Curator extends AbstractComponent implements AutoCloseable { private void addLoggingListener() { curatorFramework.getConnectionStateListenable().addListener((curatorFramework, connectionState) -> { switch (connectionState) { - case SUSPENDED: LOG.info("ZK connection state change: SUSPENDED"); break; - case RECONNECTED: LOG.info("ZK connection state change: RECONNECTED"); break; - case LOST: LOG.warning("ZK connection state change: LOST"); break; + case SUSPENDED -> LOG.info("ZK connection state change: SUSPENDED"); + case RECONNECTED -> LOG.info("ZK connection state change: RECONNECTED"); + case LOST -> LOG.warning("ZK connection state change: LOST"); } }); } - public CompletionWaiter getCompletionWaiter(Path waiterPath, String id, Duration waitForAll) { - return CuratorCompletionWaiter.create(this, waiterPath, id, waitForAll); + public CompletionWaiter getCompletionWaiter(Path barrierPath, String id, Duration waitForAll) { + return CuratorCompletionWaiter.create(this, barrierPath, id, waitForAll); } - public CompletionWaiter createCompletionWaiter(Path waiterPath, String id, Duration waitForAll) { - return CuratorCompletionWaiter.createAndInitialize(this, waiterPath, id, waitForAll); + public CompletionWaiter createCompletionWaiter(Path barrierPath, String id, Duration waitForAll) { + return CuratorCompletionWaiter.createAndInitialize(this, barrierPath, id, waitForAll); } /** Creates a listenable cache which keeps in sync with changes to all the immediate children of a path */ diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/CuratorCompletionWaiter.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/CuratorCompletionWaiter.java index 7d918baaf54..9a8b9b5bf60 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/CuratorCompletionWaiter.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/CuratorCompletionWaiter.java @@ -2,6 +2,8 @@ package com.yahoo.vespa.curator; import com.yahoo.path.Path; +import com.yahoo.vespa.curator.transaction.CuratorOperations; +import com.yahoo.vespa.curator.transaction.CuratorTransaction; import java.time.Clock; import java.time.Duration; @@ -120,11 +122,13 @@ class CuratorCompletionWaiter implements CompletionWaiter { return new CuratorCompletionWaiter(curator, barrierPath, id, Clock.systemUTC(), waitForAll); } - public static CompletionWaiter createAndInitialize(Curator curator, Path waiterPath, String id, Duration waitForAll) { - curator.delete(waiterPath); - curator.createAtomically(waiterPath); + public static CompletionWaiter createAndInitialize(Curator curator, Path barrierPath, String id, Duration waitForAll) { + // Note: Should be done atomically, but unable to that when path may not exist before delete + // and create should be able to create any missing parent paths + curator.delete(barrierPath); + curator.create(barrierPath); - return new CuratorCompletionWaiter(curator, waiterPath, id, Clock.systemUTC(), waitForAll); + return new CuratorCompletionWaiter(curator, barrierPath, id, Clock.systemUTC(), waitForAll); } private int barrierMemberCount() { diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java index 592b9fc2a05..e1376fb154b 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCurator.java @@ -82,12 +82,12 @@ public class MockCurator extends Curator { } @Override - public CompletionWaiter getCompletionWaiter(Path parentPath, String id, Duration waitForAll) { + public CompletionWaiter getCompletionWaiter(Path barrierPath, String id, Duration waitForAll) { return mockFramework().createCompletionWaiter(); } @Override - public CompletionWaiter createCompletionWaiter(Path waiterPath, String id, Duration waitForAll) { + public CompletionWaiter createCompletionWaiter(Path barrierPath, String id, Duration waitForAll) { return mockFramework().createCompletionWaiter(); } |