diff options
97 files changed, 2075 insertions, 569 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java index 2b5dadf5512..c791fea3a56 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java @@ -17,7 +17,7 @@ import java.net.URI; * @author mortent */ public class IdentityProvider extends SimpleComponent implements IdentityConfig.Producer { - public static final String CLASS = "com.yahoo.vespa.athenz.identityprovider.client.AthenzIdentityProviderImpl"; + public static final String CLASS = "com.yahoo.vespa.athenz.identityprovider.client.AthenzIdentityProviderProvider"; public static final String BUNDLE = "vespa-athenz"; private final AthenzDomain domain; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 36d34b99223..f74c218a906 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.xml; +import com.yahoo.component.ComponentId; import com.yahoo.config.provision.ClusterInfo; import com.yahoo.config.provision.IntRange; import com.yahoo.component.ComponentSpecification; @@ -1169,6 +1170,9 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { ztsUrl, zoneDnsSuffix, zone); + + // Replace AthenzIdentityProviderProvider + cluster.removeComponent(ComponentId.fromString("com.yahoo.container.jdisc.AthenzIdentityProviderProvider")); cluster.addComponent(identityProvider); cluster.getContainers().forEach(container -> { diff --git a/container-disc/abi-spec.json b/container-disc/abi-spec.json index dd681e4124f..75246a77e03 100644 --- a/container-disc/abi-spec.json +++ b/container-disc/abi-spec.json @@ -19,7 +19,8 @@ "public abstract java.util.List getIdentityCertificate()", "public abstract java.security.cert.X509Certificate getRoleCertificate(java.lang.String, java.lang.String)", "public abstract java.security.PrivateKey getPrivateKey()", - "public abstract java.nio.file.Path trustStorePath()" + "public abstract java.nio.file.Path trustStorePath()", + "public abstract void deconstruct()" ], "fields" : [ ] }, diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java index f04e2291ee8..9d2e06ed9da 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java @@ -89,6 +89,9 @@ public class AthenzIdentityProviderProvider implements Provider<AthenzIdentityPr public Path trustStorePath() { throw new UnsupportedOperationException(message); } + + @Override + public void deconstruct() {} } } diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java index af5133eceac..46803988b20 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java @@ -24,5 +24,5 @@ public interface AthenzIdentityProvider { X509Certificate getRoleCertificate(String domain, String role); PrivateKey getPrivateKey(); Path trustStorePath(); - + void deconstruct(); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java index 216da0d9045..c83ddbd8045 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java @@ -199,7 +199,7 @@ public class ControllerMaintenance extends AbstractComponent { public SuccessFactorBaseline(SystemName system) { Objects.requireNonNull(system); this.defaultSuccessFactorBaseline = 1.0; - this.deploymentMetricsMaintainerBaseline = 0.95; + this.deploymentMetricsMaintainerBaseline = 0.90; this.trafficFractionUpdater = system.isCd() ? 0.5 : 0.65; } } diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 0a5cd1aacac..d3eef8d24a9 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -48,13 +48,6 @@ public class Flags { private static volatile TreeMap<FlagId, FlagDefinition> flags = new TreeMap<>(); - public static final UnboundBooleanFlag RECONFIGURE_ALB_TARGETS = defineFeatureFlag( - "reconfigure-alb-targets", false, - List.of("bjormel"), "2023-03-24", "2023-04-30", - "Reconfigure ALB targets", - "Takes effect on next config server container start", - ZONE_ID); - public static final UnboundBooleanFlag DROP_CACHES = defineFeatureFlag( "drop-caches", false, List.of("hakonhall", "baldersheim"), "2023-03-06", "2023-06-05", @@ -390,7 +383,7 @@ public class Flags { List.of("olaa"), "2023-04-12", "2023-06-12", "Whether AthenzCredentialsMaintainer in node-admin should create tenant service identity certificate", "Takes effect on next tick", - ZONE_ID, HOSTNAME + ZONE_ID, HOSTNAME, VESPA_VERSION, APPLICATION_ID ); public static final UnboundBooleanFlag ENABLE_CROWDSTRIKE = defineFeatureFlag( @@ -410,6 +403,12 @@ public class Flags { "Takes effect at redeployment", ZONE_ID); + public static final UnboundBooleanFlag NEW_IDDOC_LAYOUT = defineFeatureFlag( + "new_iddoc_layout", false, List.of("tokle", "bjorncs", "olaa"), "2023-04-24", "2023-05-31", + "Whether to use new identity document lauoyt", + "Takes effect on node reboot", + HOSTNAME); + /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners, String createdAt, String expiresAt, String description, diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java index 3fb9c73367d..3ab1fdf211b 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.maintenance.identity; +import com.yahoo.component.Version; +import com.yahoo.config.provision.ApplicationId; import com.yahoo.security.KeyAlgorithm; import com.yahoo.security.KeyUtils; import com.yahoo.security.Pkcs10Csr; @@ -13,6 +15,7 @@ import com.yahoo.vespa.athenz.client.zts.ZtsClient; import com.yahoo.vespa.athenz.client.zts.ZtsClientException; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.client.CsrGenerator; @@ -76,6 +79,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { private final ServiceIdentityProvider hostIdentityProvider; private final IdentityDocumentClient identityDocumentClient; private final BooleanFlag tenantServiceIdentityFlag; + private final BooleanFlag useNewIdentityDocumentLayout; // Used as an optimization to ensure ZTS is not DDoS'ed on continuously failing refresh attempts private final Map<ContainerName, Instant> lastRefreshAttempt = new ConcurrentHashMap<>(); @@ -97,13 +101,20 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { new AthenzIdentityVerifier(Set.of(configServerInfo.getConfigServerIdentity()))); this.clock = clock; this.tenantServiceIdentityFlag = Flags.NODE_ADMIN_TENANT_SERVICE_REGISTRY.bindTo(flagSource); + this.useNewIdentityDocumentLayout = Flags.NEW_IDDOC_LAYOUT.bindTo(flagSource); } public boolean converge(NodeAgentContext context) { var modified = false; modified |= maintain(context, NODE); + + if (context.zone().getSystemName().isPublic()) + return modified; + if (shouldWriteTenantServiceIdentity(context)) modified |= maintain(context, TENANT); + else + modified |= deleteTenantCredentials(context); return modified; } @@ -114,7 +125,10 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { context.log(logger, Level.FINE, "Checking certificate"); ContainerPath siaDirectory = context.paths().of(CONTAINER_SIA_DIRECTORY, context.users().vespa()); ContainerPath identityDocumentFile = siaDirectory.resolve(identityType.getIdentityDocument()); - AthenzIdentity athenzIdentity = getAthenzIdentity(context, identityType, identityDocumentFile); + Optional<AthenzIdentity> optionalAthenzIdentity = getAthenzIdentity(context, identityType, identityDocumentFile); + if (optionalAthenzIdentity.isEmpty()) + return false; + AthenzIdentity athenzIdentity = optionalAthenzIdentity.get(); ContainerPath privateKeyFile = (ContainerPath) SiaUtils.getPrivateKeyFile(siaDirectory, athenzIdentity); ContainerPath certificateFile = (ContainerPath) SiaUtils.getCertificateFile(siaDirectory, athenzIdentity); if (!Files.exists(privateKeyFile) || !Files.exists(certificateFile) || !Files.exists(identityDocumentFile)) { @@ -130,7 +144,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { Instant now = clock.instant(); Instant expiry = certificate.getNotAfter().toInstant(); var doc = EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile); - if (doc.outdated()) { + if (refreshIdentityDocument(doc, context)) { context.log(logger, "Identity document is outdated (version=%d)", doc.documentVersion()); registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType, athenzIdentity); return true; @@ -150,7 +164,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { return false; } else { lastRefreshAttempt.put(context.containerName(), now); - refreshIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, doc, identityType, athenzIdentity); + refreshIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, doc.identityDocument(), identityType, athenzIdentity); return true; } } @@ -161,6 +175,11 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { } } + private boolean refreshIdentityDocument(SignedIdentityDocument signedIdentityDocument, NodeAgentContext context) { + int expectedVersion = documentVersion(context); + return signedIdentityDocument.outdated() || signedIdentityDocument.documentVersion() != expectedVersion; + } + public void clearCredentials(NodeAgentContext context) { FileFinder.files(context.paths().of(CONTAINER_SIA_DIRECTORY)) .deleteRecursively(context); @@ -187,6 +206,23 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { return "node-certificate"; } + private boolean deleteTenantCredentials(NodeAgentContext context) { + var siaDirectory = context.paths().of(CONTAINER_SIA_DIRECTORY, context.users().vespa()); + var identityDocumentFile = siaDirectory.resolve(TENANT.getIdentityDocument()); + if (!Files.exists(identityDocumentFile)) return false; + return getAthenzIdentity(context, TENANT, identityDocumentFile).map(athenzIdentity -> { + var privateKeyFile = (ContainerPath) SiaUtils.getPrivateKeyFile(siaDirectory, athenzIdentity); + var certificateFile = (ContainerPath) SiaUtils.getCertificateFile(siaDirectory, athenzIdentity); + try { + return Files.deleteIfExists(identityDocumentFile) || + Files.deleteIfExists(privateKeyFile) || + Files.deleteIfExists(certificateFile); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }).orElse(false); + } + private boolean shouldRefreshCredentials(Duration age) { return age.compareTo(REFRESH_PERIOD) >= 0; } @@ -200,7 +236,8 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { private void registerIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile, ContainerPath identityDocumentFile, IdentityType identityType, AthenzIdentity identity) { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); - SignedIdentityDocument doc = signedIdentityDocument(context, identityType); + SignedIdentityDocument signedDoc = signedIdentityDocument(context, identityType); + IdentityDocument doc = signedDoc.identityDocument(); CsrGenerator csrGenerator = new CsrGenerator(certificateDnsSuffix, doc.providerService().getFullName()); Pkcs10Csr csr = csrGenerator.generateInstanceCsr( identity, doc.providerUniqueId(), doc.ipAddresses(), doc.clusterType(), keyPair); @@ -212,9 +249,9 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { ztsClient.registerInstance( doc.providerService(), identity, - EntityBindingsMapper.toAttestationData(doc), + EntityBindingsMapper.toAttestationData(signedDoc), csr); - EntityBindingsMapper.writeSignedIdentityDocumentToFile(identityDocumentFile, doc); + EntityBindingsMapper.writeSignedIdentityDocumentToFile(identityDocumentFile, signedDoc); writePrivateKeyAndCertificate(privateKeyFile, keyPair.getPrivate(), certificateFile, instanceIdentity.certificate()); context.log(logger, "Instance successfully registered and credentials written to file"); } @@ -223,14 +260,14 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { /** * Return zts url from identity document, fallback to ztsEndpoint */ - private URI ztsEndpoint(SignedIdentityDocument doc) { + private URI ztsEndpoint(IdentityDocument doc) { return Optional.ofNullable(doc.ztsUrl()) .filter(s -> !s.isBlank()) .map(URI::create) .orElse(ztsEndpoint); } private void refreshIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile, - ContainerPath identityDocumentFile, SignedIdentityDocument doc, IdentityType identityType, AthenzIdentity identity) { + ContainerPath identityDocumentFile, IdentityDocument doc, IdentityType identityType, AthenzIdentity identity) { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); CsrGenerator csrGenerator = new CsrGenerator(certificateDnsSuffix, doc.providerService().getFullName()); Pkcs10Csr csr = csrGenerator.generateInstanceCsr( @@ -291,32 +328,48 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer { private SignedIdentityDocument signedIdentityDocument(NodeAgentContext context, IdentityType identityType) { return switch (identityType) { - case NODE -> identityDocumentClient.getNodeIdentityDocument(context.hostname().value()); - case TENANT -> identityDocumentClient.getTenantIdentityDocument(context.hostname().value()); + case NODE -> identityDocumentClient.getNodeIdentityDocument(context.hostname().value(), documentVersion(context)); + case TENANT -> identityDocumentClient.getTenantIdentityDocument(context.hostname().value(), documentVersion(context)).get(); }; } - private AthenzIdentity getAthenzIdentity(NodeAgentContext context, IdentityType identityType, ContainerPath identityDocumentFile) { + private Optional<AthenzIdentity> getAthenzIdentity(NodeAgentContext context, IdentityType identityType, ContainerPath identityDocumentFile) { return switch (identityType) { - case NODE -> context.identity(); + case NODE -> Optional.of(context.identity()); case TENANT -> getTenantIdentity(context, identityDocumentFile); }; } - private AthenzIdentity getTenantIdentity(NodeAgentContext context, ContainerPath identityDocumentFile) { + private Optional<AthenzIdentity> getTenantIdentity(NodeAgentContext context, ContainerPath identityDocumentFile) { if (Files.exists(identityDocumentFile)) { - return EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile).serviceIdentity(); + return Optional.of(EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile).identityDocument().serviceIdentity()); } else { - return identityDocumentClient.getTenantIdentityDocument(context.hostname().value()).serviceIdentity(); + return identityDocumentClient.getTenantIdentityDocument(context.hostname().value(), documentVersion(context)) + .map(doc -> doc.identityDocument().serviceIdentity()); } } private boolean shouldWriteTenantServiceIdentity(NodeAgentContext context) { + var version = context.node().currentVespaVersion() + .orElse(context.node().wantedVespaVersion().orElse(Version.emptyVersion)); + var appId = context.node().owner().orElse(ApplicationId.defaultId()); return tenantServiceIdentityFlag - .with(FetchVector.Dimension.HOSTNAME, context.hostname().value()) + .with(FetchVector.Dimension.VESPA_VERSION, version.toFullString()) + .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm()) .value(); } + /* + Get the document version to ask for + */ + private int documentVersion(NodeAgentContext context) { + return useNewIdentityDocumentLayout + .with(FetchVector.Dimension.HOSTNAME, context.hostname().value()) + .value() + ? SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION + : SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION; + } + enum IdentityType { NODE("vespa-node-identity-document.json"), TENANT("vespa-tenant-identity-document.json"); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 7c84afc8397..025a04a15d6 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -506,14 +506,6 @@ public class NodeAgentImpl implements NodeAgent { storageMaintainer.cleanDiskIfFull(context); storageMaintainer.handleCoreDumpsForContainer(context, container, false); - // TODO: this is a workaround for restarting wireguard as early as possible after host-admin has been down. - var runOrdinaryWireguardTasks = true; - if (container.isPresent() && container.get().state().isRunning()) { - Optional<Container> finalContainer = container; - wireguardTasks.forEach(task -> task.converge(context, finalContainer.get().id())); - runOrdinaryWireguardTasks = false; - } - if (downloadImageIfNeeded(context, container)) { context.log(logger, "Waiting for image to download " + context.node().wantedDockerImage().get().asString()); return; @@ -525,16 +517,13 @@ public class NodeAgentImpl implements NodeAgent { containerState = STARTING; container = Optional.of(startContainer(context)); containerState = UNKNOWN; - runOrdinaryWireguardTasks = true; } else { container = Optional.of(updateContainerIfNeeded(context, container.get())); } aclMaintainer.ifPresent(maintainer -> maintainer.converge(context)); - if (runOrdinaryWireguardTasks) { - Optional<Container> finalContainer = container; - wireguardTasks.forEach(task -> task.converge(context, finalContainer.get().id())); - } + final Optional<Container> finalContainer = container; + wireguardTasks.forEach(task -> task.converge(context, finalContainer.get().id())); startServicesIfNeeded(context); resumeNodeIfNeeded(context); if (healthChecker.isPresent()) { diff --git a/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp b/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp index 2027ad56768..fe7303692ba 100644 --- a/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp +++ b/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp @@ -20,7 +20,6 @@ ProtonConfig make_proton_config(double concurrency, uint32_t indexing_threads = 1) { ProtonConfigBuilder builder; - // This setup requires a minimum of 4 shared threads. builder.documentdb.push_back(ProtonConfig::Documentdb()); builder.documentdb.push_back(ProtonConfig::Documentdb()); builder.flush.maxconcurrent = 1; @@ -48,8 +47,10 @@ expect_field_writer_threads(uint32_t exp_threads, uint32_t cpu_cores, uint32_t i TEST(SharedThreadingServiceConfigTest, shared_threads_are_derived_from_cpu_cores_and_feeding_concurrency) { - expect_shared_threads(4, 1); - expect_shared_threads(4, 6); + expect_shared_threads(2, 1); + expect_shared_threads(2, 4); + expect_shared_threads(3, 5); + expect_shared_threads(3, 6); expect_shared_threads(4, 8); expect_shared_threads(5, 9); expect_shared_threads(5, 10); diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp index 82e1aa3b57c..430412fc1c0 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp @@ -137,8 +137,15 @@ SearchContext::getStore() const SearchContext::SearchContext(QueryTermSimple::UP qTerm, const DocumentMetaStore &toBeSearched) : search::attribute::SearchContext(toBeSearched), - _isWord(qTerm->isWord()) + _isWord(qTerm->isWord()), + _docid_limit(toBeSearched.getCommittedDocIdLimit()) { } +uint32_t +SearchContext::get_committed_docid_limit() const noexcept +{ + return _docid_limit; +} + } diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h index ca4b026e2a4..7c88d8f3502 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h @@ -18,6 +18,7 @@ private: bool _isWord; document::GlobalId _gid; + uint32_t _docid_limit; unsigned int approximateHits() const override; int32_t onFind(DocId docId, int32_t elemId, int32_t &weight) const override; @@ -30,6 +31,7 @@ private: public: SearchContext(std::unique_ptr<search::QueryTermSimple> qTerm, const DocumentMetaStore &toBeSearched); + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp index 50ef5039e75..c1802b40deb 100644 --- a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp @@ -28,10 +28,10 @@ namespace { uint32_t derive_shared_threads(const ProtonConfig& cfg, const HwInfo::Cpu& cpu_info) { - uint32_t scaled_cores = (uint32_t)std::ceil(cpu_info.cores() * cfg.feeding.concurrency); + uint32_t scaled_cores = uint32_t(std::ceil(cpu_info.cores() * cfg.feeding.concurrency)); // We need at least 1 guaranteed free worker in order to ensure progress. - return std::max(scaled_cores, (uint32_t)cfg.documentdb.size() + cfg.flush.maxconcurrent + 1); + return std::max(scaled_cores, uint32_t(cfg.flush.maxconcurrent + 1u)); } uint32_t @@ -42,8 +42,8 @@ derive_warmup_threads(const HwInfo::Cpu& cpu_info) { uint32_t derive_field_writer_threads(const ProtonConfig& cfg, const HwInfo::Cpu& cpu_info) { - uint32_t scaled_cores = (size_t)std::ceil(cpu_info.cores() * cfg.feeding.concurrency); - uint32_t field_writer_threads = std::max(scaled_cores, (uint32_t)cfg.indexing.threads); + uint32_t scaled_cores = size_t(std::ceil(cpu_info.cores() * cfg.feeding.concurrency)); + uint32_t field_writer_threads = std::max(scaled_cores, uint32_t(cfg.indexing.threads)); // Originally we used at least 3 threads for writing fields: // - index field inverter // - index field writer diff --git a/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp b/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp index 9e65dfcfc07..e55344aded0 100644 --- a/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp +++ b/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp @@ -240,6 +240,24 @@ TEST_F("original lid range is used by read guard", Fixture) EXPECT_EQUAL(getUndefined<int>(), first_guard->getInt(DocId(10))); } +TEST_F("Original target lid range is used by read guard", Fixture) +{ + reset_with_single_value_reference_mappings<IntegerAttribute, int32_t>( + f, BasicType::INT32, + {}); + EXPECT_EQUAL(11u, f.target_attr->getNumDocs()); + auto first_guard = f.get_imported_attr(); + add_n_docs_with_undefined_values(*f.target_attr, 1); + EXPECT_EQUAL(12u, f.target_attr->getNumDocs()); + auto typed_target_attr = f.template target_attr_as<IntegerAttribute>(); + ASSERT_TRUE(typed_target_attr->update(11, 2345)); + f.target_attr->commit(); + f.map_reference(DocId(8), dummy_gid(11), DocId(11)); + auto second_guard = f.get_imported_attr(); + EXPECT_EQUAL(2345, second_guard->getInt(DocId(8))); + EXPECT_NOT_EQUAL(2345, first_guard->getInt(DocId(8))); +} + struct SingleStringAttrFixture : Fixture { SingleStringAttrFixture() : Fixture() { setup(); diff --git a/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp b/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp index 847a992d241..19327245083 100644 --- a/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp +++ b/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp @@ -429,6 +429,21 @@ TEST_F("original lid range is used by search context", SingleValueFixture) EXPECT_TRUE(second_ctx->matches(DocId(10))); } +TEST_F("Original target lid range is used by search context", SingleValueFixture) +{ + EXPECT_EQUAL(11u, f.target_attr->getNumDocs()); + auto first_ctx = f.create_context(word_term("2345")); + add_n_docs_with_undefined_values(*f.target_attr, 1); + EXPECT_EQUAL(12u, f.target_attr->getNumDocs()); + auto typed_target_attr = f.template target_attr_as<IntegerAttribute>(); + ASSERT_TRUE(typed_target_attr->update(11, 2345)); + f.target_attr->commit(); + f.map_reference(DocId(8), dummy_gid(11), DocId(11)); + auto second_ctx = f.create_context(word_term("2345")); + EXPECT_FALSE(first_ctx->matches(DocId(8))); + EXPECT_TRUE(second_ctx->matches(DocId(8))); +} + // Note: this uses an underlying string attribute, as queryTerm() does not seem to // implemented at all for (single) numeric attributes. Intentional? TEST_F("queryTerm() returns term context was created with", WsetValueFixture) { diff --git a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp index 2804e3f74e4..5ba90d2b077 100644 --- a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp +++ b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp @@ -390,7 +390,7 @@ TEST("testSingleValue") { EXPECT_EQUAL(24u, sizeof(SearchContext)); EXPECT_EQUAL(32u, sizeof(StringSearchHelper)); - EXPECT_EQUAL(80u, sizeof(attribute::SingleStringEnumSearchContext)); + EXPECT_EQUAL(88u, sizeof(attribute::SingleStringEnumSearchContext)); { Config cfg(BasicType::STRING, CollectionType::SINGLE); SingleValueStringAttribute svsa("svsa", cfg); diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index 2c202d9131b..210f32af15e 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -814,7 +814,7 @@ TEST("test_nearest_neighbor_query_node") constexpr uint32_t target_num_hits = 100; constexpr bool allow_approximate = false; constexpr uint32_t explore_additional_hits = 800; - constexpr double raw_score = 0.5; + constexpr double distance = 0.5; builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold); auto build_node = builder.build(); auto stack_dump = StackDumpCreator::create(*build_node); @@ -830,14 +830,14 @@ TEST("test_nearest_neighbor_query_node") EXPECT_EQUAL(id, static_cast<int32_t>(node->uniqueId())); EXPECT_EQUAL(weight, node->weight().percent()); EXPECT_EQUAL(distance_threshold, node->get_distance_threshold()); - EXPECT_FALSE(node->get_raw_score().has_value()); + EXPECT_FALSE(node->get_distance().has_value()); EXPECT_FALSE(node->evaluate()); - node->set_raw_score(raw_score); - EXPECT_TRUE(node->get_raw_score().has_value()); - EXPECT_EQUAL(raw_score, node->get_raw_score().value()); + node->set_distance(distance); + EXPECT_TRUE(node->get_distance().has_value()); + EXPECT_EQUAL(distance, node->get_distance().value()); EXPECT_TRUE(node->evaluate()); node->reset(); - EXPECT_FALSE(node->get_raw_score().has_value()); + EXPECT_FALSE(node->get_distance().has_value()); EXPECT_FALSE(node->evaluate()); } diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index ae283f3f2b2..9b8ad0d26ce 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -18,14 +18,17 @@ using search::attribute::DistanceMetric; template <typename T> TypedCells t(const std::vector<T> &v) { return TypedCells(v); } -void verify_geo_miles(const DistanceFunction *dist_fun, - const std::vector<double> &p1, +void verify_geo_miles(const std::vector<double> &p1, const std::vector<double> &p2, double exp_miles) { + static GeoDistanceFunctionFactory dff; TypedCells t1(p1); TypedCells t2(p2); - double abstract_distance = dist_fun->calc(t1, t2); + auto dist_fun = dff.for_query_vector(t1); + double abstract_distance = dist_fun->calc(t2); + EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance); + EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance); double raw_score = dist_fun->to_rawscore(abstract_distance); double km = ((1.0/raw_score)-1.0); double d_miles = km / 1.609344; @@ -69,6 +72,8 @@ double computeEuclideanChecked(TypedCells a, TypedCells b) { return result; } +namespace { const double sq_root_half = std::sqrt(0.5); } + TEST(DistanceFunctionsTest, euclidean_gives_expected_score) { auto ct = vespalib::eval::CellType::DOUBLE; @@ -79,7 +84,7 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) std::vector<double> p1{1.0, 0.0, 0.0}; std::vector<double> p2{0.0, 1.0, 0.0}; std::vector<double> p3{0.0, 0.0, 1.0}; - std::vector<double> p4{0.5, 0.5, 0.707107}; + std::vector<double> p4{0.5, 0.5, sq_root_half}; std::vector<double> p5{0.0,-1.0, 0.0}; std::vector<double> p6{1.0, 2.0, 2.0}; @@ -179,7 +184,7 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) std::vector<double> p1{1.0, 0.0, 0.0}; std::vector<double> p2{0.0, 1.0, 0.0}; std::vector<double> p3{0.0, 0.0, 1.0}; - std::vector<double> p4{0.5, 0.5, 0.707107}; + std::vector<double> p4{0.5, 0.5, sq_root_half}; std::vector<double> p5{0.0,-1.0, 0.0}; std::vector<double> p6{1.0, 2.0, 2.0}; @@ -207,7 +212,7 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_DOUBLE_EQ(threshold, 0.5); double a34 = computeAngularChecked(t(p3), t(p4)); - EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107)); + EXPECT_FLOAT_EQ(a34, (1.0 - sq_root_half)); EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4)); threshold = angular->convert_threshold(pi/4); EXPECT_FLOAT_EQ(threshold, a34); @@ -257,6 +262,89 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_DOUBLE_EQ(a66, computeAngularChecked(t(iv6), t(iv6))); } +double computePrenormalizedAngularChecked(TypedCells a, TypedCells b) { + static PrenormalizedAngularDistanceFunctionFactory<float> flt_dff; + static PrenormalizedAngularDistanceFunctionFactory<double> dbl_dff; + auto d_n = dbl_dff.for_query_vector(a); + auto d_f = flt_dff.for_query_vector(a); + auto d_r = dbl_dff.for_query_vector(b); + auto d_i = dbl_dff.for_insertion_vector(a); + // normal: + double result = d_n->calc(b); + // insert is exactly same: + EXPECT_EQ(d_i->calc(b), result); + // note: for this distance, reverse is not necessarily equal, + // since we normalize based on length of LHS only + EXPECT_FLOAT_EQ(d_r->calc(a), result); + // float factory: + EXPECT_FLOAT_EQ(d_f->calc(b), result); + double closeness_n = d_n->to_rawscore(result); + double closeness_f = d_f->to_rawscore(result); + double closeness_r = d_r->to_rawscore(result); + double closeness_i = d_i->to_rawscore(result); + EXPECT_DOUBLE_EQ(closeness_n, closeness_f); + EXPECT_DOUBLE_EQ(closeness_n, closeness_r); + EXPECT_DOUBLE_EQ(closeness_n, closeness_i); + EXPECT_GT(closeness_n, 0.0); + EXPECT_LE(closeness_n, 1.0); + return result; +} + +TEST(DistanceFunctionsTest, prenormalized_angular_gives_expected_score) +{ + std::vector<double> p0{0.0, 0.0, 0.0}; + std::vector<double> p1{1.0, 0.0, 0.0}; + std::vector<double> p2{0.0, 1.0, 0.0}; + std::vector<double> p3{0.0, 0.0, 1.0}; + std::vector<double> p4{0.5, 0.5, sq_root_half}; + std::vector<double> p5{0.0,-1.0, 0.0}; + std::vector<double> p6{1.0, 2.0, 2.0}; + std::vector<double> p7{2.0, -1.0, -2.0}; + std::vector<double> p8{3.0, 0.0, 0.0}; + + PrenormalizedAngularDistanceFunctionFactory<double> dff; + auto pnad = dff.for_query_vector(t(p0)); + + double i12 = computePrenormalizedAngularChecked(t(p1), t(p2)); + double i13 = computePrenormalizedAngularChecked(t(p1), t(p3)); + double i23 = computePrenormalizedAngularChecked(t(p2), t(p3)); + EXPECT_DOUBLE_EQ(i12, 1.0); + EXPECT_DOUBLE_EQ(i13, 1.0); + EXPECT_DOUBLE_EQ(i23, 1.0); + + double i14 = computePrenormalizedAngularChecked(t(p1), t(p4)); + double i24 = computePrenormalizedAngularChecked(t(p2), t(p4)); + EXPECT_DOUBLE_EQ(i14, 0.5); + EXPECT_DOUBLE_EQ(i24, 0.5); + double i34 = computePrenormalizedAngularChecked(t(p3), t(p4)); + EXPECT_FLOAT_EQ(i34, 1.0 - sq_root_half); + + double i25 = computePrenormalizedAngularChecked(t(p2), t(p5)); + EXPECT_DOUBLE_EQ(i25, 2.0); + + double i44 = computePrenormalizedAngularChecked(t(p4), t(p4)); + EXPECT_GE(i44, 0.0); + EXPECT_LT(i44, 0.000001); + + double i66 = computePrenormalizedAngularChecked(t(p6), t(p6)); + EXPECT_GE(i66, 0.0); + EXPECT_LT(i66, 0.000001); + + double i67 = computePrenormalizedAngularChecked(t(p6), t(p7)); + EXPECT_DOUBLE_EQ(i67, 13.0); + double i68 = computePrenormalizedAngularChecked(t(p6), t(p8)); + EXPECT_DOUBLE_EQ(i68, 6.0); + double i78 = computePrenormalizedAngularChecked(t(p7), t(p8)); + EXPECT_DOUBLE_EQ(i78, 3.0); + + double threshold = pnad->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = pnad->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = pnad->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); +} + TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) { auto ct = vespalib::eval::CellType::DOUBLE; @@ -267,7 +355,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) std::vector<double> p1{1.0, 0.0, 0.0}; std::vector<double> p2{0.0, 1.0, 0.0}; std::vector<double> p3{0.0, 0.0, 1.0}; - std::vector<double> p4{0.5, 0.5, 0.707107}; + std::vector<double> p4{0.5, 0.5, sq_root_half}; std::vector<double> p5{0.0,-1.0, 0.0}; std::vector<double> p6{1.0, 2.0, 2.0}; @@ -283,7 +371,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) EXPECT_DOUBLE_EQ(i14, 0.5); EXPECT_DOUBLE_EQ(i24, 0.5); double i34 = innerproduct->calc(t(p3), t(p4)); - EXPECT_FLOAT_EQ(i34, 1.0 - 0.707107); + EXPECT_FLOAT_EQ(i34, 1.0 - sq_root_half); double i25 = innerproduct->calc(t(p2), t(p5)); EXPECT_DOUBLE_EQ(i25, 2.0); @@ -292,6 +380,10 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) EXPECT_GE(i44, 0.0); EXPECT_LT(i44, 0.000001); + double i66 = innerproduct->calc(t(p6), t(p6)); + EXPECT_GE(i66, 0.0); + EXPECT_LT(i66, 0.000001); + double threshold = innerproduct->convert_threshold(0.25); EXPECT_DOUBLE_EQ(threshold, 0.25); threshold = innerproduct->convert_threshold(0.5); @@ -302,6 +394,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) TEST(DistanceFunctionsTest, hamming_gives_expected_score) { + static HammingDistanceFunctionFactory<Int8Float> dff; auto ct = vespalib::eval::CellType::DOUBLE; auto hamming = make_distance_function(DistanceMetric::Hamming, ct); @@ -318,6 +411,9 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double h0 = hamming->calc(t(p), t(p)); EXPECT_EQ(h0, 0.0); EXPECT_EQ(hamming->to_rawscore(h0), 1.0); + auto dist_fun = dff.for_query_vector(t(p)); + EXPECT_EQ(dist_fun->calc(t(p)), 0.0); + EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0); } double d12 = hamming->calc(t(points[1]), t(points[2])); EXPECT_EQ(d12, 3.0); @@ -350,13 +446,12 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 }; // expect diff: 1 2 1 1 7 EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0); + auto dist_fun = dff.for_query_vector(TypedCells(bytes_a)); + EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0); } TEST(GeoDegreesTest, gives_expected_score) { - auto ct = vespalib::eval::CellType::DOUBLE; - auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct); - std::vector<double> g1_sfo{37.61, -122.38}; std::vector<double> g2_lhr{51.47, -0.46}; std::vector<double> g3_osl{60.20, 11.08}; @@ -367,7 +462,8 @@ TEST(GeoDegreesTest, gives_expected_score) std::vector<double> g8_lax{33.94, -118.41}; std::vector<double> g9_jfk{40.64, -73.78}; - double g63_a = geodeg->calc(t(g6_trd), t(g3_osl)); + auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd)); + double g63_a = geodeg->calc(t(g3_osl)); double g63_r = geodeg->to_rawscore(g63_a); double g63_km = ((1.0/g63_r)-1.0); EXPECT_GT(g63_km, 350); @@ -377,96 +473,95 @@ TEST(GeoDegreesTest, gives_expected_score) // Great Circle Mapper for airports using // a more accurate formula - we should agree // with < 1.0% deviation - verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0); - verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367); - verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196); - verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604); - verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927); - verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012); - verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417); - verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337); - verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586); - - verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367); - verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0); - verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750); - verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734); - verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994); - verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928); - verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573); - verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456); - verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451); - - verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196); - verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750); - verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0); - verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479); - verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319); - verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226); - verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888); - verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345); - verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687); - - verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604); - verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734); - verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479); - verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0); - verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989); - verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623); - verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414); - verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294); - verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786); - - verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927); - verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994); - verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319); - verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989); - verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0); - verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240); - verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581); - verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260); - verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072); - - verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012); - verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928); - verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226); - verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623); - verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240); - verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0); - verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782); - verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171); - verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611); - - verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417); - verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573); - verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888); - verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414); - verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581); - verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782); - verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0); - verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488); - verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950); - - verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337); - verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456); - verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345); - verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294); - verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260); - verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171); - verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488); - verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0); - verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475); - - verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586); - verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451); - verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687); - verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786); - verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072); - verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611); - verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950); - verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475); - verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0); - + verify_geo_miles(g1_sfo, g1_sfo, 0); + verify_geo_miles(g1_sfo, g2_lhr, 5367); + verify_geo_miles(g1_sfo, g3_osl, 5196); + verify_geo_miles(g1_sfo, g4_gig, 6604); + verify_geo_miles(g1_sfo, g5_hkg, 6927); + verify_geo_miles(g1_sfo, g6_trd, 5012); + verify_geo_miles(g1_sfo, g7_syd, 7417); + verify_geo_miles(g1_sfo, g8_lax, 337); + verify_geo_miles(g1_sfo, g9_jfk, 2586); + + verify_geo_miles(g2_lhr, g1_sfo, 5367); + verify_geo_miles(g2_lhr, g2_lhr, 0); + verify_geo_miles(g2_lhr, g3_osl, 750); + verify_geo_miles(g2_lhr, g4_gig, 5734); + verify_geo_miles(g2_lhr, g5_hkg, 5994); + verify_geo_miles(g2_lhr, g6_trd, 928); + verify_geo_miles(g2_lhr, g7_syd, 10573); + verify_geo_miles(g2_lhr, g8_lax, 5456); + verify_geo_miles(g2_lhr, g9_jfk, 3451); + + verify_geo_miles(g3_osl, g1_sfo, 5196); + verify_geo_miles(g3_osl, g2_lhr, 750); + verify_geo_miles(g3_osl, g3_osl, 0); + verify_geo_miles(g3_osl, g4_gig, 6479); + verify_geo_miles(g3_osl, g5_hkg, 5319); + verify_geo_miles(g3_osl, g6_trd, 226); + verify_geo_miles(g3_osl, g7_syd, 9888); + verify_geo_miles(g3_osl, g8_lax, 5345); + verify_geo_miles(g3_osl, g9_jfk, 3687); + + verify_geo_miles(g4_gig, g1_sfo, 6604); + verify_geo_miles(g4_gig, g2_lhr, 5734); + verify_geo_miles(g4_gig, g3_osl, 6479); + verify_geo_miles(g4_gig, g4_gig, 0); + verify_geo_miles(g4_gig, g5_hkg, 10989); + verify_geo_miles(g4_gig, g6_trd, 6623); + verify_geo_miles(g4_gig, g7_syd, 8414); + verify_geo_miles(g4_gig, g8_lax, 6294); + verify_geo_miles(g4_gig, g9_jfk, 4786); + + verify_geo_miles(g5_hkg, g1_sfo, 6927); + verify_geo_miles(g5_hkg, g2_lhr, 5994); + verify_geo_miles(g5_hkg, g3_osl, 5319); + verify_geo_miles(g5_hkg, g4_gig, 10989); + verify_geo_miles(g5_hkg, g5_hkg, 0); + verify_geo_miles(g5_hkg, g6_trd, 5240); + verify_geo_miles(g5_hkg, g7_syd, 4581); + verify_geo_miles(g5_hkg, g8_lax, 7260); + verify_geo_miles(g5_hkg, g9_jfk, 8072); + + verify_geo_miles(g6_trd, g1_sfo, 5012); + verify_geo_miles(g6_trd, g2_lhr, 928); + verify_geo_miles(g6_trd, g3_osl, 226); + verify_geo_miles(g6_trd, g4_gig, 6623); + verify_geo_miles(g6_trd, g5_hkg, 5240); + verify_geo_miles(g6_trd, g6_trd, 0); + verify_geo_miles(g6_trd, g7_syd, 9782); + verify_geo_miles(g6_trd, g8_lax, 5171); + verify_geo_miles(g6_trd, g9_jfk, 3611); + + verify_geo_miles(g7_syd, g1_sfo, 7417); + verify_geo_miles(g7_syd, g2_lhr, 10573); + verify_geo_miles(g7_syd, g3_osl, 9888); + verify_geo_miles(g7_syd, g4_gig, 8414); + verify_geo_miles(g7_syd, g5_hkg, 4581); + verify_geo_miles(g7_syd, g6_trd, 9782); + verify_geo_miles(g7_syd, g7_syd, 0); + verify_geo_miles(g7_syd, g8_lax, 7488); + verify_geo_miles(g7_syd, g9_jfk, 9950); + + verify_geo_miles(g8_lax, g1_sfo, 337); + verify_geo_miles(g8_lax, g2_lhr, 5456); + verify_geo_miles(g8_lax, g3_osl, 5345); + verify_geo_miles(g8_lax, g4_gig, 6294); + verify_geo_miles(g8_lax, g5_hkg, 7260); + verify_geo_miles(g8_lax, g6_trd, 5171); + verify_geo_miles(g8_lax, g7_syd, 7488); + verify_geo_miles(g8_lax, g8_lax, 0); + verify_geo_miles(g8_lax, g9_jfk, 2475); + + verify_geo_miles(g9_jfk, g1_sfo, 2586); + verify_geo_miles(g9_jfk, g2_lhr, 3451); + verify_geo_miles(g9_jfk, g3_osl, 3687); + verify_geo_miles(g9_jfk, g4_gig, 4786); + verify_geo_miles(g9_jfk, g5_hkg, 8072); + verify_geo_miles(g9_jfk, g6_trd, 3611); + verify_geo_miles(g9_jfk, g7_syd, 9950); + verify_geo_miles(g9_jfk, g8_lax, 2475); + verify_geo_miles(g9_jfk, g9_jfk, 0); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchcommon/attribute/i_search_context.h b/searchlib/src/vespa/searchcommon/attribute/i_search_context.h index 8867d1b87e4..4657d41a4a0 100644 --- a/searchlib/src/vespa/searchcommon/attribute/i_search_context.h +++ b/searchlib/src/vespa/searchcommon/attribute/i_search_context.h @@ -70,6 +70,11 @@ public: bool matches(DocId docId, int32_t &weight) const { return matches(*this, docId, weight); } bool matches(DocId doc) const { return find(doc, 0) >= 0; } + /* + * Committed docid limit on attribute vector when search context was + * created. + */ + virtual uint32_t get_committed_docid_limit() const noexcept = 0; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp index 7a5d82cd9ba..91bdb45ff19 100644 --- a/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp @@ -30,6 +30,12 @@ EmptySearchContext::approximateHits() const return 0u; } +uint32_t +EmptySearchContext::get_committed_docid_limit() const noexcept +{ + return 0u; +} + std::unique_ptr<queryeval::SearchIterator> EmptySearchContext::createIterator(fef::TermFieldMatchData*, bool) { diff --git a/searchlib/src/vespa/searchlib/attribute/empty_search_context.h b/searchlib/src/vespa/searchlib/attribute/empty_search_context.h index ae6f6d76edf..133e540d87f 100644 --- a/searchlib/src/vespa/searchlib/attribute/empty_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/empty_search_context.h @@ -19,6 +19,7 @@ class EmptySearchContext : public SearchContext public: EmptySearchContext(const AttributeVector& attr) noexcept; ~EmptySearchContext(); + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h b/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h index 0342976ffd6..86ffa1c8ab0 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h +++ b/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h @@ -41,6 +41,7 @@ protected: void fetchPostings(const queryeval::ExecuteInfo & execInfo) override; unsigned int approximateHits() const override; + uint32_t get_committed_docid_limit() const noexcept { return _docIdLimit; } }; } diff --git a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp index a1a5e9f7894..b50a3720ff8 100644 --- a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp +++ b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp @@ -17,12 +17,14 @@ ImportedAttributeVectorReadGuard::ImportedAttributeVectorReadGuard(std::shared_p _target_document_meta_store_read_guard(std::move(targetMetaStoreReadGuard)), _imported_attribute(imported_attribute), _targetLids(), + _target_docid_limit(0u), _reference_attribute_guard(imported_attribute.getReferenceAttribute()), _target_attribute_guard(imported_attribute.getTargetAttribute()->makeReadGuard(stableEnumGuard)), _reference_attribute(*imported_attribute.getReferenceAttribute()), _target_attribute(*_target_attribute_guard->attribute()) { _targetLids = _reference_attribute.getTargetLids(); + _target_docid_limit = _target_attribute.getCommittedDocIdLimit(); } ImportedAttributeVectorReadGuard::~ImportedAttributeVectorReadGuard() = default; diff --git a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h index cb48399f688..1297acad9b8 100644 --- a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h +++ b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h @@ -95,6 +95,7 @@ private: std::shared_ptr<MetaStoreReadGuard> _target_document_meta_store_read_guard; const ImportedAttributeVector &_imported_attribute; TargetLids _targetLids; + uint32_t _target_docid_limit; AttributeGuard _reference_attribute_guard; std::unique_ptr<attribute::AttributeReadGuard> _target_attribute_guard; const ReferenceAttribute &_reference_attribute; @@ -103,7 +104,9 @@ protected: uint32_t getTargetLid(uint32_t lid) const { // Check range to avoid reading memory beyond end of mapping array - return lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u; + uint32_t target_lid = lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u; + // Check target range + return target_lid < _target_docid_limit ? target_lid : 0u; } long onSerializeForAscendingSort(DocId doc, void * serTo, long available, diff --git a/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp index 1e8adc3922e..3d308b82b04 100644 --- a/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp @@ -43,6 +43,7 @@ ImportedSearchContext::ImportedSearchContext( _target_attribute(target_attribute), _target_search_context(_target_attribute.createSearchContext(std::move(term), params)), _targetLids(_reference_attribute.getTargetLids()), + _target_docid_limit(_target_search_context->get_committed_docid_limit()), _merger(_reference_attribute.getCommittedDocIdLimit()), _params(params), _zero_hits(false) @@ -327,4 +328,10 @@ const vespalib::string& ImportedSearchContext::attributeName() const { return _imported_attribute.getName(); } +uint32_t +ImportedSearchContext::get_committed_docid_limit() const noexcept +{ + return _targetLids.size(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/imported_search_context.h b/searchlib/src/vespa/searchlib/attribute/imported_search_context.h index d9c09d8c645..d6b6d09e8fc 100644 --- a/searchlib/src/vespa/searchlib/attribute/imported_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/imported_search_context.h @@ -39,6 +39,7 @@ class ImportedSearchContext : public ISearchContext { const IAttributeVector &_target_attribute; std::unique_ptr<ISearchContext> _target_search_context; TargetLids _targetLids; + uint32_t _target_docid_limit; PostingListMerger<int32_t> _merger; SearchContextParams _params; mutable std::atomic<bool> _zero_hits; @@ -47,7 +48,9 @@ class ImportedSearchContext : public ISearchContext { uint32_t getTargetLid(uint32_t lid) const { // Check range to avoid reading memory beyond end of mapping array - return lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u; + uint32_t target_lid = lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u; + // Check target range + return target_lid < _target_docid_limit ? target_lid : 0u; } void makeMergedPostings(bool isFilter); @@ -90,6 +93,7 @@ public: const ISearchContext &target_search_context() const noexcept { return *_target_search_context; } + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h index 5b393d8bdb2..161c6799787 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h @@ -59,6 +59,7 @@ public: std::unique_ptr<queryeval::SearchIterator> createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override; + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp index e7901199e50..15abcf6f0d9 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp @@ -33,4 +33,11 @@ MultiEnumSearchContext<T, BaseSC, M>::createFilterIterator(fef::TermFieldMatchDa : std::make_unique<AttributeIteratorT<MultiEnumSearchContext>>(*this, matchData); } +template <typename T, typename BaseSC, typename M> +uint32_t +MultiEnumSearchContext<T, BaseSC, M>::get_committed_docid_limit() const noexcept +{ + return _mv_mapping_read_view.get_committed_docid_limit(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h index b2c76a120f9..23e56e23af9 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h @@ -54,6 +54,7 @@ public: std::unique_ptr<queryeval::SearchIterator> createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override; + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp index 15b851215f8..7e1fd1aeb5a 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp @@ -33,4 +33,11 @@ MultiNumericSearchContext<T, M>::createFilterIterator(fef::TermFieldMatchData* m : std::make_unique<AttributeIteratorT<MultiNumericSearchContext<T, M>>>(*this, matchData); } +template <typename T, typename M> +uint32_t +MultiNumericSearchContext<T, M>::get_committed_docid_limit() const noexcept +{ + return _mv_mapping_read_view.get_committed_docid_limit(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h index 41138ff0890..609989208c3 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h @@ -33,6 +33,7 @@ public: } vespalib::ConstArrayRef<ElemT> get(uint32_t doc_id) const { return _store->get(_indices[doc_id].load_acquire()); } bool valid() const noexcept { return _store != nullptr; } + uint32_t get_committed_docid_limit() const noexcept { return _indices.size(); } }; } diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp index b701a6fd08f..9343dafe917 100644 --- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp @@ -454,12 +454,14 @@ class ReferenceSearchContext : public attribute::SearchContext { private: const ReferenceAttribute& _ref_attr; GlobalId _term; + uint32_t _docid_limit; public: ReferenceSearchContext(const ReferenceAttribute& ref_attr, const GlobalId& term) : attribute::SearchContext(ref_attr), _ref_attr(ref_attr), - _term(term) + _term(term), + _docid_limit(ref_attr.getCommittedDocIdLimit()) { } bool valid() const override { @@ -480,8 +482,15 @@ public: int32_t weight; return onFind(docId, elementId, weight); } + uint32_t get_committed_docid_limit() const noexcept override; }; +uint32_t +ReferenceSearchContext::get_committed_docid_limit() const noexcept +{ + return _docid_limit; +} + } std::unique_ptr<attribute::SearchContext> diff --git a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h index 83d6c696117..f6a2f94dedb 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h @@ -17,7 +17,9 @@ class SingleEnumSearchContext : public BaseSC { protected: using DocId = ISearchContext::DocId; - const vespalib::datastore::AtomicEntryRef* _enum_indices; + using AtomicEntryRef = vespalib::datastore::AtomicEntryRef; + using EnumIndices = vespalib::ConstArrayRef<AtomicEntryRef>; + EnumIndices _enum_indices; const EnumStoreT<T>& _enum_store; int32_t onFind(DocId docId, int32_t elemId, int32_t & weight) const final { @@ -29,7 +31,7 @@ protected: } public: - SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store); + SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store); int32_t find(DocId docId, int32_t elemId, int32_t & weight) const { if ( elemId != 0) return -1; @@ -46,6 +48,7 @@ public: std::unique_ptr<queryeval::SearchIterator> createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override; + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp index a415c301f9c..6b6cf480d6a 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp @@ -9,7 +9,7 @@ namespace search::attribute { template <typename T, typename BaseSC> -SingleEnumSearchContext<T, BaseSC>::SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store) +SingleEnumSearchContext<T, BaseSC>::SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store) : BaseSC(toBeSearched, std::move(matcher)), _enum_indices(enum_indices), _enum_store(enum_store) @@ -33,4 +33,11 @@ SingleEnumSearchContext<T, BaseSC>::createFilterIterator(fef::TermFieldMatchData : std::make_unique<AttributeIteratorT<SingleEnumSearchContext>>(*this, matchData); } +template <typename T, typename BaseSC> +uint32_t +SingleEnumSearchContext<T, BaseSC>::get_committed_docid_limit() const noexcept +{ + return _enum_indices.size(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h index 86283f59283..fd3f4c03a8a 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h @@ -16,7 +16,9 @@ template <typename T> class SingleNumericEnumSearchContext : public SingleEnumSearchContext<T, NumericSearchContext<NumericRangeMatcher<T>>> { public: - SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store); + using AtomicEntryRef = vespalib::datastore::AtomicEntryRef; + using EnumIndices = vespalib::ConstArrayRef<AtomicEntryRef>; + SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store); }; } diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp index f4e049cb6f1..c0818d4d18a 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp @@ -8,7 +8,7 @@ namespace search::attribute { template <typename T> -SingleNumericEnumSearchContext<T>::SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store) +SingleNumericEnumSearchContext<T>::SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store) : SingleEnumSearchContext<T, NumericSearchContext<NumericRangeMatcher<T>>>(NumericRangeMatcher<T>(*qTerm, true), toBeSearched, enum_indices, enum_store) { } diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h index 5f6925f7f4d..6362c69cdac 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h @@ -3,6 +3,7 @@ #pragma once #include "numeric_search_context.h" +#include <vespa/vespalib/util/arrayref.h> #include <vespa/vespalib/util/atomic.h> namespace search::attribute { @@ -16,7 +17,7 @@ class SingleNumericSearchContext final : public NumericSearchContext<M> { private: using DocId = ISearchContext::DocId; - const T* _data; + vespalib::ConstArrayRef<T> _data; int32_t onFind(DocId docId, int32_t elemId, int32_t& weight) const override { return find(docId, elemId, weight); @@ -27,7 +28,7 @@ private: } public: - SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const T* data); + SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, vespalib::ConstArrayRef<T> data); int32_t find(DocId docId, int32_t elemId, int32_t& weight) const { if ( elemId != 0) return -1; const T v = vespalib::atomic::load_ref_relaxed(_data[docId]); @@ -43,6 +44,7 @@ public: std::unique_ptr<queryeval::SearchIterator> createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override; + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp index 75d3da9de7f..b40b1336e6f 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp @@ -9,7 +9,7 @@ namespace search::attribute { template <typename T, typename M> -SingleNumericSearchContext<T, M>::SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const T* data) +SingleNumericSearchContext<T, M>::SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, vespalib::ConstArrayRef<T> data) : NumericSearchContext<M>(toBeSearched, *qTerm, true), _data(data) { @@ -32,4 +32,11 @@ SingleNumericSearchContext<T, M>::createFilterIterator(fef::TermFieldMatchData* : std::make_unique<AttributeIteratorT<SingleNumericSearchContext<T, M>>>(*this, matchData); } +template <typename T, typename M> +uint32_t +SingleNumericSearchContext<T, M>::get_committed_docid_limit() const noexcept +{ + return _data.size(); +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp index 5eeef7cd61a..074435809cc 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp @@ -6,13 +6,14 @@ namespace search::attribute { -SingleSmallNumericSearchContext::SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift) +SingleSmallNumericSearchContext::SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift, uint32_t docid_limit) : NumericSearchContext<NumericRangeMatcher<T>>(toBeSearched, *qTerm, false), _wordData(word_data), _valueMask(value_mask), _valueShiftShift(value_shift_shift), _valueShiftMask(value_shift_mask), - _wordShift(word_shift) + _wordShift(word_shift), + _docid_limit(docid_limit) { } @@ -32,4 +33,10 @@ SingleSmallNumericSearchContext::createFilterIterator(fef::TermFieldMatchData* m : std::make_unique<AttributeIteratorT<SingleSmallNumericSearchContext>>(*this, matchData); } +uint32_t +SingleSmallNumericSearchContext::get_committed_docid_limit() const noexcept +{ + return _docid_limit; +} + } diff --git a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h index 46ed02b3eca..a42c8b9b29c 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h @@ -22,6 +22,7 @@ private: uint32_t _valueShiftShift; uint32_t _valueShiftMask; uint32_t _wordShift; + uint32_t _docid_limit; int32_t onFind(DocId docId, int32_t elementId, int32_t & weight) const override { return find(docId, elementId, weight); @@ -32,7 +33,7 @@ private: } public: - SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift); + SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift, uint32_t docid_limit); int32_t find(DocId docId, int32_t elemId, int32_t & weight) const { if ( elemId != 0) return -1; @@ -53,6 +54,7 @@ public: std::unique_ptr<queryeval::SearchIterator> createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override; + uint32_t get_committed_docid_limit() const noexcept override; }; } diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp index 70023b27802..2d1748cefa5 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp @@ -5,10 +5,10 @@ namespace search::attribute { -SingleStringEnumHintSearchContext::SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store, uint32_t doc_id_limit, uint64_t num_values) +SingleStringEnumHintSearchContext::SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store, uint64_t num_values) : SingleStringEnumSearchContext(std::move(qTerm), cased, toBeSearched, enum_indices, enum_store), EnumHintSearchContext(enum_store.get_dictionary(), - doc_id_limit, num_values) + enum_indices.size(), num_values) { setup_enum_hint_sc(enum_store, *this); } diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h index f9d44454cd0..f157bf17a71 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h @@ -16,7 +16,7 @@ class SingleStringEnumHintSearchContext : public SingleStringEnumSearchContext, public EnumHintSearchContext { public: - SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store, uint32_t doc_id_limit, uint64_t num_values); + SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store, uint64_t num_values); ~SingleStringEnumHintSearchContext() override; }; diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp index cba1d207501..8d23eaf7af0 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp @@ -6,7 +6,7 @@ namespace search::attribute { -SingleStringEnumSearchContext::SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store) +SingleStringEnumSearchContext::SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store) : SingleEnumSearchContext<const char*, StringSearchContext>(StringMatcher(std::move(qTerm), cased), toBeSearched, enum_indices, enum_store) { } diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h index 6a9ed38b4ea..b8014b1b0e3 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h @@ -14,7 +14,7 @@ namespace search::attribute { class SingleStringEnumSearchContext : public SingleEnumSearchContext<const char*, StringSearchContext> { public: - SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store); + SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store); SingleStringEnumSearchContext(SingleStringEnumSearchContext&&) noexcept; ~SingleStringEnumSearchContext() override; }; diff --git a/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp b/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp index 15fc819300c..87b7049b9b7 100644 --- a/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp @@ -132,6 +132,7 @@ public: void fetchPostings(const queryeval::ExecuteInfo &execInfo) override; std::unique_ptr<queryeval::SearchIterator> createPostingIterator(fef::TermFieldMatchData *matchData, bool strict) override; unsigned int approximateHits() const override; + uint32_t get_committed_docid_limit() const noexcept override; }; BitVectorSearchContext::BitVectorSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const SingleBoolAttribute & attr) @@ -177,6 +178,12 @@ BitVectorSearchContext::approximateHits() const { : 0; } +uint32_t +BitVectorSearchContext::get_committed_docid_limit() const noexcept +{ + return _doc_id_limit; +} + } std::unique_ptr<attribute::SearchContext> diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp index c75ee0aacb5..606c7a92ef5 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp @@ -164,7 +164,7 @@ SingleValueNumericAttribute<B>::getSearch(QueryTermSimple::UP qTerm, { (void) params; QueryTermSimple::RangeResult<T> res = qTerm->getRange<T>(); - const T* data = &_data.acquire_elem_ref(0); + auto data = _data.make_read_view(this->getCommittedDocIdLimit()); if (res.isEqual()) { return std::make_unique<attribute::SingleNumericSearchContext<T, attribute::NumericMatcher<T>>>(std::move(qTerm), *this, data); } else { diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp index b840a0516b2..e459d3d9c9c 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp @@ -160,7 +160,8 @@ SingleValueNumericEnumAttribute<B>::getSearch(QueryTermSimple::UP qTerm, const attribute::SearchContextParams & params) const { (void) params; - return std::make_unique<attribute::SingleNumericEnumSearchContext<T>>(std::move(qTerm), *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore); + auto docid_limit = this->getCommittedDocIdLimit(); + return std::make_unique<attribute::SingleNumericEnumSearchContext<T>>(std::move(qTerm), *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore); } } diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp index e353d03a9e8..a4b9abb084a 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp @@ -143,7 +143,8 @@ SingleValueNumericPostingAttribute<B>::getSearch(QueryTermSimple::UP qTerm, { using BaseSC = attribute::SingleNumericEnumSearchContext<T>; using SC = attribute::NumericPostingSearchContext<BaseSC, SelfType, vespalib::btree::BTreeNoLeafData>; - BaseSC base_sc(std::move(qTerm), *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore); + auto docid_limit = this->getCommittedDocIdLimit(); + BaseSC base_sc(std::move(qTerm), *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore); return std::make_unique<SC>(std::move(base_sc), params, *this); } diff --git a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp index 13bf2f932e8..3c1621ac244 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp @@ -170,7 +170,8 @@ std::unique_ptr<attribute::SearchContext> SingleValueSmallNumericAttribute::getSearch(std::unique_ptr<QueryTermSimple> qTerm, const attribute::SearchContextParams &) const { - return std::make_unique<attribute::SingleSmallNumericSearchContext>(std::move(qTerm), *this, &_wordData.acquire_elem_ref(0), _valueMask, _valueShiftShift, _valueShiftMask, _wordShift); + auto docid_limit = getCommittedDocIdLimit(); + return std::make_unique<attribute::SingleSmallNumericSearchContext>(std::move(qTerm), *this, &_wordData.acquire_elem_ref(0), _valueMask, _valueShiftShift, _valueShiftMask, _wordShift, docid_limit); } void diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp index 69fe6435a03..c3f5c295260 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp @@ -46,7 +46,8 @@ SingleValueStringAttributeT<B>::getSearch(QueryTermSimpleUP qTerm, const attribute::SearchContextParams &) const { bool cased = this->get_match_is_cased(); - return std::make_unique<attribute::SingleStringEnumHintSearchContext>(std::move(qTerm), cased, *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore, this->getCommittedDocIdLimit(), this->getStatus().getNumValues()); + auto docid_limit = this->getCommittedDocIdLimit(); + return std::make_unique<attribute::SingleStringEnumHintSearchContext>(std::move(qTerm), cased, *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore, this->getStatus().getNumValues()); } } diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp index 5b5214f6d3e..60847636baa 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp @@ -145,7 +145,8 @@ SingleValueStringPostingAttributeT<B>::getSearch(QueryTermSimpleUP qTerm, using BaseSC = attribute::SingleStringEnumSearchContext; using SC = attribute::StringPostingSearchContext<BaseSC, SelfType, vespalib::btree::BTreeNoLeafData>; bool cased = this->get_match_is_cased(); - BaseSC base_sc(std::move(qTerm), cased, *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore); + auto docid_limit = this->getCommittedDocIdLimit(); + BaseSC base_sc(std::move(qTerm), cased, *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore); return std::make_unique<SC>(std::move(base_sc), params.useBitVector(), *this); diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h index 46b89fdfeb4..9bef389a278 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h @@ -109,7 +109,7 @@ public: uint32_t getArity() const { return _currArity; } uint32_t getNearDistance() const { return _extraIntArg1; } - uint32_t getTargetNumHits() const { return _extraIntArg1; } + uint32_t getTargetHits() const { return _extraIntArg1; } double getDistanceThreshold() const { return _extraDoubleArg4; } double getScoreThreshold() const { return _extraDoubleArg4; } double getThresholdBoostFactor() const { return _extraDoubleArg5; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp index d1c37cd6dcd..b2d8a0ee4be 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp @@ -1,15 +1,21 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "nearest_neighbor_query_node.h" +#include <cassert> namespace search::streaming { -NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold) - : QueryTerm(std::move(resultBase), term, index, Type::NEAREST_NEIGHBOR), +NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, + const string& query_tensor_name, const string& field_name, + uint32_t target_hits, double distance_threshold, + int32_t unique_id, search::query::Weight weight) + : QueryTerm(std::move(resultBase), query_tensor_name, field_name, Type::NEAREST_NEIGHBOR), + _target_hits(target_hits), _distance_threshold(distance_threshold), - _raw_score() + _distance(), + _calc() { - setUniqueId(id); + setUniqueId(unique_id); setWeight(weight); } @@ -18,13 +24,13 @@ NearestNeighborQueryNode::~NearestNeighborQueryNode() = default; bool NearestNeighborQueryNode::evaluate() const { - return _raw_score.has_value(); + return _distance.has_value(); } void NearestNeighborQueryNode::reset() { - _raw_score.reset(); + _distance.reset(); } NearestNeighborQueryNode* @@ -33,4 +39,14 @@ NearestNeighborQueryNode::as_nearest_neighbor_query_node() noexcept return this; } +std::optional<double> +NearestNeighborQueryNode::get_raw_score() const +{ + if (_distance.has_value()) { + assert(_calc != nullptr); + return _calc->to_raw_score(_distance.value()); + } + return std::nullopt; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h index 0beb130c53d..c66364b0c52 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h +++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h @@ -8,16 +8,34 @@ namespace search::streaming { /* - * Nearest neighbor query node. + * Nearest neighbor query node for streaming search. */ class NearestNeighborQueryNode: public QueryTerm { +public: + class RawScoreCalculator { + public: + virtual ~RawScoreCalculator() = default; + /** + * Convert the given distance to a raw score. + * + * This is used during unpacking, and also signals that the entire document was a match. + */ + virtual double to_raw_score(double distance) = 0; + }; + private: + uint32_t _target_hits; double _distance_threshold; - // When this value is set it also indicates a match - std::optional<double> _raw_score; + // When this value is set it also indicates a match for this query node. + std::optional<double> _distance; + RawScoreCalculator* _calc; + public: - NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold); + NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, + const string& query_tensor_name, const string& field_name, + uint32_t target_hits, double distance_threshold, + int32_t unique_id, search::query::Weight weight); NearestNeighborQueryNode(const NearestNeighborQueryNode &) = delete; NearestNeighborQueryNode & operator = (const NearestNeighborQueryNode &) = delete; NearestNeighborQueryNode(NearestNeighborQueryNode &&) = delete; @@ -27,9 +45,13 @@ public: void reset() override; NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept override; const vespalib::string& get_query_tensor_name() const { return getTermString(); } + uint32_t get_target_hits() const { return _target_hits; } double get_distance_threshold() const { return _distance_threshold; } - void set_raw_score(double value) { _raw_score = value; } - const std::optional<double>& get_raw_score() const noexcept { return _raw_score; } + void set_raw_score_calc(RawScoreCalculator* calc_in) { _calc = calc_in; } + void set_distance(double value) { _distance = value; } + const std::optional<double>& get_distance() const { return _distance; } + // This is used during unpacking, and also signals to the RawScoreCalculator that the entire document was a match. + std::optional<double> get_raw_score() const; }; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 226cb92c894..84344831cbc 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -200,15 +200,17 @@ QueryNode::build_nearest_neighbor_query_node(const QueryNodeResultFactory& facto { vespalib::stringref query_tensor_name = query_rep.getTerm(); vespalib::stringref field_name = query_rep.getIndexName(); - int32_t id = query_rep.getUniqueId(); - search::query::Weight weight = query_rep.GetWeight(); + int32_t unique_id = query_rep.getUniqueId(); + auto weight = query_rep.GetWeight(); + uint32_t target_hits = query_rep.getTargetHits(); double distance_threshold = query_rep.getDistanceThreshold(); return std::make_unique<NearestNeighborQueryNode>(factory.create(), query_tensor_name, field_name, - id, - weight, - distance_threshold); + target_hits, + distance_threshold, + unique_id, + weight); } } diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index 90bd87979c7..a552a650704 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -89,7 +89,7 @@ private: pureTermView = view; } else if (type == ParseItem::ITEM_WEAK_AND) { vespalib::stringref view = queryStack.getIndexName(); - uint32_t targetNumHits = queryStack.getTargetNumHits(); + uint32_t targetNumHits = queryStack.getTargetHits(); builder.addWeakAnd(arity, targetNumHits, view); pureTermView = view; } else if (type == ParseItem::ITEM_EQUIV) { @@ -134,7 +134,7 @@ private: vespalib::stringref view = queryStack.getIndexName(); int32_t id = queryStack.getUniqueId(); Weight weight = queryStack.GetWeight(); - uint32_t targetNumHits = queryStack.getTargetNumHits(); + uint32_t targetNumHits = queryStack.getTargetHits(); double scoreThreshold = queryStack.getScoreThreshold(); double thresholdBoostFactor = queryStack.getThresholdBoostFactor(); auto & wand = builder.addWandTerm(arity, view, id, weight, targetNumHits, scoreThreshold, thresholdBoostFactor); @@ -146,7 +146,7 @@ private: } else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) { vespalib::stringref query_tensor_name = queryStack.getTerm(); vespalib::stringref field_name = queryStack.getIndexName(); - uint32_t target_num_hits = queryStack.getTargetNumHits(); + uint32_t target_num_hits = queryStack.getTargetHits(); int32_t id = queryStack.getUniqueId(); Weight weight = queryStack.GetWeight(); bool allow_approximate = queryStack.getAllowApproximate(); diff --git a/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp index c50c6ec49f5..86f520c8711 100644 --- a/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp @@ -129,8 +129,16 @@ struct FakeContext : attribute::ISearchContext { DoubleRange getAsDoubleTerm() const override { abort(); } const QueryTermUCS4 * queryTerm() const override { abort(); } const vespalib::string &attributeName() const override { return name; } + uint32_t get_committed_docid_limit() const noexcept override; }; +uint32_t +FakeContext::get_committed_docid_limit() const noexcept +{ + auto& documents = result.inspect(); + return documents.empty() ? 0 : (documents.back().docId + 1); +} + } SearchIterator::UP diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt index 1783e0da1dd..2e874ffa4ae 100644 --- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt @@ -30,6 +30,7 @@ vespa_add_library(searchlib_tensor OBJECT large_subspaces_buffer_type.cpp nearest_neighbor_index.cpp nearest_neighbor_index_saver.cpp + prenormalized_angular_distance.cpp serialized_fast_value_attribute.cpp serialized_tensor_ref.cpp small_subspaces_buffer_type.cpp diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp index 85eac76728c..a7ae02bb9f4 100644 --- a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp @@ -61,20 +61,19 @@ private: double _lhs_norm_sq; public: BoundAngularDistance(const vespalib::eval::TypedCells& lhs) - : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()), - _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), + : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), _tmpSpace(lhs.size), _lhs(_tmpSpace.storeLhs(lhs)) { - auto a = &_lhs[0]; + auto a = _lhs.data(); _lhs_norm_sq = _computer.dotProduct(a, a, lhs.size); } double calc(const vespalib::eval::TypedCells& rhs) const override { size_t sz = _lhs.size(); vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); assert(sz == rhs_vector.size()); - auto a = &_lhs[0]; - auto b = &rhs_vector[0]; + auto a = _lhs.data(); + auto b = rhs_vector.data(); double b_norm_sq = _computer.dotProduct(b, b, sz); double squared_norms = _lhs_norm_sq * b_norm_sq; double dot_product = _computer.dotProduct(a, b, sz); diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.h b/searchlib/src/vespa/searchlib/tensor/angular_distance.h index 97f692d05a2..bba83576153 100644 --- a/searchlib/src/vespa/searchlib/tensor/angular_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.h @@ -60,8 +60,8 @@ public: auto rhs_vector = rhs.typify<FloatType>(); size_t sz = lhs_vector.size(); assert(sz == rhs_vector.size()); - auto a = &lhs_vector[0]; - auto b = &rhs_vector[0]; + auto a = lhs_vector.data(); + auto b = rhs_vector.data(); double a_norm_sq = _computer.dotProduct(a, a, sz); double b_norm_sq = _computer.dotProduct(b, b, sz); double squared_norms = a_norm_sq * b_norm_sq; diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h index 5d602a52227..c072d6de8e5 100644 --- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h +++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h @@ -20,20 +20,13 @@ namespace search::tensor { * mutable temporary storage. */ class BoundDistanceFunction : public DistanceConverter { -private: - vespalib::eval::CellType _expect_cell_type; public: using UP = std::unique_ptr<BoundDistanceFunction>; - BoundDistanceFunction(vespalib::eval::CellType expected) : _expect_cell_type(expected) {} + BoundDistanceFunction() = default; virtual ~BoundDistanceFunction() = default; - // input vectors will be converted to this cell type: - vespalib::eval::CellType expected_cell_type() const { - return _expect_cell_type; - } - // calculate internal distance (comparable) virtual double calc(const vespalib::eval::TypedCells& rhs) const = 0; diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp index cca492ef212..c088d498f0f 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp @@ -55,8 +55,7 @@ class SimpleBoundDistanceFunction : public BoundDistanceFunction { public: SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs, const DistanceFunction &df) - : BoundDistanceFunction(lhs.type), - _lhs(lhs), + : _lhs(lhs), _df(df) {} @@ -94,21 +93,35 @@ std::unique_ptr<DistanceFunctionFactory> make_distance_function_factory(search::attribute::DistanceMetric variant, vespalib::eval::CellType cell_type) { - if (variant == DistanceMetric::Angular) { - if (cell_type == CellType::DOUBLE) { - return std::make_unique<AngularDistanceFunctionFactory<double>>(); - } - return std::make_unique<AngularDistanceFunctionFactory<float>>(); - } - if (variant == DistanceMetric::Euclidean) { - switch (cell_type) { - case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>(); - case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>(); - default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>(); - } + switch (variant) { + case DistanceMetric::Angular: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<AngularDistanceFunctionFactory<double>>(); + default: return std::make_unique<AngularDistanceFunctionFactory<float>>(); + } + case DistanceMetric::Euclidean: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>(); + case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>(); + default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>(); + } + case DistanceMetric::InnerProduct: + case DistanceMetric::PrenormalizedAngular: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>(); + default: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>(); + } + case DistanceMetric::GeoDegrees: + return std::make_unique<GeoDistanceFunctionFactory>(); + case DistanceMetric::Hamming: + switch (cell_type) { + case CellType::DOUBLE: return std::make_unique<HammingDistanceFunctionFactory<double>>(); + case CellType::INT8: return std::make_unique<HammingDistanceFunctionFactory<vespalib::eval::Int8Float>>(); + default: return std::make_unique<HammingDistanceFunctionFactory<float>>(); + } } - auto df = make_distance_function(variant, cell_type); - return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df)); + // not reached: + return {}; } } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_functions.h b/searchlib/src/vespa/searchlib/tensor/distance_functions.h index b28cc2bda46..2300dba2db1 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_functions.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_functions.h @@ -8,3 +8,4 @@ #include "geo_degrees_distance.h" #include "hamming_distance.h" #include "inner_product_distance.h" +#include "prenormalized_angular_distance.h" diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp index 92d4e7af406..7995c87d055 100644 --- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp @@ -62,8 +62,7 @@ private: static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); } public: BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs) - : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()), - _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), + : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), _tmpSpace(lhs.size), _lhs_vector(_tmpSpace.storeLhs(lhs)) {} @@ -71,8 +70,8 @@ public: size_t sz = _lhs_vector.size(); vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); assert(sz == rhs_vector.size()); - auto a = &_lhs_vector[0]; - auto b = &rhs_vector[0]; + auto a = _lhs_vector.data(); + auto b = rhs_vector.data(); return _computer.squaredEuclideanDistance(cast(a), cast(b), sz); } double convert_threshold(double threshold) const override { diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp index bcce75da3ab..38ba8205c90 100644 --- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "geo_degrees_distance.h" +#include "temporary_vector_store.h" using vespalib::typify_invoke; using vespalib::eval::TypifyCellType; @@ -27,11 +28,11 @@ struct CalcGeoDegrees { double lat_diff = lat_A - lat_B; double lon_diff = lon_A - lon_B; - + // haversines of differences: double hav_lat = GeoDegreesDistance::hav(lat_diff); double hav_lon = GeoDegreesDistance::hav(lon_diff); - + // haversine of central angle between the two points: double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; return hav_central_angle; @@ -42,9 +43,63 @@ struct CalcGeoDegrees { double GeoDegreesDistance::calc(const vespalib::eval::TypedCells& lhs, - const vespalib::eval::TypedCells& rhs) const + const vespalib::eval::TypedCells& rhs) const { return typify_invoke<2,TypifyCellType,CalcGeoDegrees>(lhs.type, rhs.type, lhs, rhs); } +using vespalib::eval::TypedCells; + +class BoundGeoDistance : public BoundDistanceFunction { +private: + mutable TemporaryVectorStore<double> _tmpSpace; + const vespalib::ConstArrayRef<double> _lh_vector; + static GeoDegreesDistance _g_d_helper; +public: + BoundGeoDistance(const vespalib::eval::TypedCells& lhs) + : _tmpSpace(lhs.size), + _lh_vector(_tmpSpace.storeLhs(lhs)) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + vespalib::ConstArrayRef<double> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(2 == _lh_vector.size()); + assert(2 == rhs_vector.size()); + // convert to radians: + double lat_A = _lh_vector[0] * GeoDegreesDistance::degrees_to_radians; + double lat_B = rhs_vector[0] * GeoDegreesDistance::degrees_to_radians; + double lon_A = _lh_vector[1] * GeoDegreesDistance::degrees_to_radians; + double lon_B = rhs_vector[1] * GeoDegreesDistance::degrees_to_radians; + + double lat_diff = lat_A - lat_B; + double lon_diff = lon_A - lon_B; + + // haversines of differences: + double hav_lat = GeoDegreesDistance::hav(lat_diff); + double hav_lon = GeoDegreesDistance::hav(lon_diff); + + // haversine of central angle between the two points: + double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; + return hav_central_angle; + } + double convert_threshold(double threshold) const override { + return _g_d_helper.convert_threshold(threshold); + } + double to_rawscore(double distance) const override { + return _g_d_helper.to_rawscore(distance); + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + return calc(rhs); + } +}; + +BoundDistanceFunction::UP +GeoDistanceFunctionFactory::for_query_vector(const vespalib::eval::TypedCells& lhs) { + return std::make_unique<BoundGeoDistance>(lhs); +} + +BoundDistanceFunction::UP +GeoDistanceFunctionFactory::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + return std::make_unique<BoundGeoDistance>(lhs); +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h index 46feee19119..4522bc03c9e 100644 --- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include <vespa/eval/eval/typed_cells.h> #include <vespa/vespalib/hwaccelrated/iaccelrated.h> #include <vespa/vespalib/util/typify.h> @@ -50,4 +51,11 @@ public: } }; +class GeoDistanceFunctionFactory : public DistanceFunctionFactory { +public: + GeoDistanceFunctionFactory() : DistanceFunctionFactory(vespalib::eval::CellType::DOUBLE) {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + } diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp index 43596478a6f..f4f6842715f 100644 --- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "hamming_distance.h" +#include "temporary_vector_store.h" #include <vespa/vespalib/util/binary_hamming_distance.h> using vespalib::typify_invoke; @@ -52,4 +53,63 @@ HammingDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs, return calc(lhs, rhs); } +using vespalib::eval::Int8Float; + +template<typename FloatType> +class BoundHammingDistance : public BoundDistanceFunction { +private: + mutable TemporaryVectorStore<FloatType> _tmpSpace; + const vespalib::ConstArrayRef<FloatType> _lhs_vector; +public: + BoundHammingDistance(const vespalib::eval::TypedCells& lhs) + : _tmpSpace(lhs.size), + _lhs_vector(_tmpSpace.storeLhs(lhs)) + {} + double calc(const vespalib::eval::TypedCells& rhs) const override { + size_t sz = _lhs_vector.size(); + vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(sz == rhs_vector.size()); + auto a = _lhs_vector.data(); + auto b = rhs_vector.data(); + if constexpr (std::is_same<Int8Float, FloatType>::value) { + return (double) vespalib::binary_hamming_distance(a, b, sz); + } else { + size_t sum = 0; + for (size_t i = 0; i < sz; ++i) { + sum += (_lhs_vector[i] == rhs_vector[i]) ? 0 : 1; + } + return (double)sum; + } + } + double convert_threshold(double threshold) const override { + return threshold; + } + double to_rawscore(double distance) const override { + double score = 1.0 / (1.0 + distance); + return score; + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + // consider optimizing: + return calc(rhs); + } +}; + +template <typename FloatType> +BoundDistanceFunction::UP +HammingDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundHammingDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template <typename FloatType> +BoundDistanceFunction::UP +HammingDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundHammingDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template class HammingDistanceFunctionFactory<Int8Float>; +template class HammingDistanceFunctionFactory<float>; +template class HammingDistanceFunctionFactory<double>; + } diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h index c64fc5b532d..23c855eb137 100644 --- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h @@ -3,6 +3,7 @@ #pragma once #include "distance_function.h" +#include "distance_function_factory.h" #include <vespa/eval/eval/typed_cells.h> #include <vespa/vespalib/util/typify.h> #include <cmath> @@ -29,4 +30,14 @@ public: double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, double) const override; }; +template <typename FloatType> +class HammingDistanceFunctionFactory : public DistanceFunctionFactory { +public: + HammingDistanceFunctionFactory() + : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>()) + {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + } diff --git a/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h b/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h index 8ba14580885..135bb186fd4 100644 --- a/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h +++ b/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h @@ -54,7 +54,7 @@ public: auto rhs_vector = rhs.typify<FloatType>(); size_t sz = lhs_vector.size(); assert(sz == rhs_vector.size()); - double score = 1.0 - _computer.dotProduct(&lhs_vector[0], &rhs_vector[0], sz); + double score = 1.0 - _computer.dotProduct(lhs_vector.data(), rhs_vector.data(), sz); return std::max(0.0, score); } private: diff --git a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp new file mode 100644 index 00000000000..292edc1259d --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp @@ -0,0 +1,81 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "prenormalized_angular_distance.h" +#include "temporary_vector_store.h" + +using vespalib::typify_invoke; +using vespalib::eval::TypifyCellType; + +namespace search::tensor { + +template<typename FloatType> +class BoundPrenormalizedAngularDistance : public BoundDistanceFunction { +private: + const vespalib::hwaccelrated::IAccelrated & _computer; + mutable TemporaryVectorStore<FloatType> _tmpSpace; + const vespalib::ConstArrayRef<FloatType> _lhs; + double _lhs_norm_sq; +public: + BoundPrenormalizedAngularDistance(const vespalib::eval::TypedCells& lhs) + : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()), + _tmpSpace(lhs.size), + _lhs(_tmpSpace.storeLhs(lhs)) + { + auto a = _lhs.data(); + _lhs_norm_sq = _computer.dotProduct(a, a, lhs.size); + if (_lhs_norm_sq <= 0.0) { + _lhs_norm_sq = 1.0; + } + } + double calc(const vespalib::eval::TypedCells& rhs) const override { + size_t sz = _lhs.size(); + vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs); + assert(sz == rhs_vector.size()); + auto a = _lhs.data(); + auto b = rhs_vector.data(); + double dot_product = _computer.dotProduct(a, b, sz); + double distance = _lhs_norm_sq - dot_product; + return distance; + } + double convert_threshold(double threshold) const override { + double cosine_similarity = 1.0 - threshold; + double dot_product = cosine_similarity * _lhs_norm_sq; + double distance = _lhs_norm_sq - dot_product; + return distance; + } + double to_rawscore(double distance) const override { + double dot_product = _lhs_norm_sq - distance; + double cosine_similarity = dot_product / _lhs_norm_sq; + // should be in in range [-1,1] but roundoff may cause problems: + cosine_similarity = std::min(1.0, cosine_similarity); + cosine_similarity = std::max(-1.0, cosine_similarity); + double cosine_distance = 1.0 - cosine_similarity; // in range [0,2] + double score = 1.0 / (1.0 + cosine_distance); + return score; + } + double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override { + return calc(rhs); + } +}; + +template class BoundPrenormalizedAngularDistance<float>; +template class BoundPrenormalizedAngularDistance<double>; + +template <typename FloatType> +BoundDistanceFunction::UP +PrenormalizedAngularDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundPrenormalizedAngularDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template <typename FloatType> +BoundDistanceFunction::UP +PrenormalizedAngularDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) { + using DFT = BoundPrenormalizedAngularDistance<FloatType>; + return std::make_unique<DFT>(lhs); +} + +template class PrenormalizedAngularDistanceFunctionFactory<float>; +template class PrenormalizedAngularDistanceFunctionFactory<double>; + +} diff --git a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h new file mode 100644 index 00000000000..88953a236e7 --- /dev/null +++ b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h @@ -0,0 +1,27 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "distance_function.h" +#include "bound_distance_function.h" +#include "distance_function_factory.h" +#include <vespa/eval/eval/typed_cells.h> +#include <vespa/vespalib/hwaccelrated/iaccelrated.h> + +namespace search::tensor { + +/** + * Calculates inner-product "distance" between vectors with assumed norm 1. + * Should give same ordering as Angular distance, but is less expensive. + */ +template <typename FloatType> +class PrenormalizedAngularDistanceFunctionFactory : public DistanceFunctionFactory { +public: + PrenormalizedAngularDistanceFunctionFactory() + : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>()) + {} + BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override; + BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override; +}; + +} diff --git a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp index 43c77398be8..b64d477fd4c 100644 --- a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp +++ b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp @@ -31,9 +31,11 @@ struct MockQuery { std::vector<std::unique_ptr<NearestNeighborQueryNode>> nodes; QueryTermList term_list; MockQuery& add(const vespalib::string& query_tensor_name, + uint32_t target_hits, double distance_threshold) { std::unique_ptr<QueryNodeResultBase> base; - auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", 7, search::query::Weight(11), distance_threshold); + auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", + target_hits, distance_threshold, 7, search::query::Weight(100)); nodes.push_back(std::move(node)); term_list.push_back(nodes.back().get()); return *this; @@ -90,34 +92,71 @@ public: query.reset(); searcher.onValue(fv); } + void expect_match(const vespalib::string& spec_expr, double exp_square_distance, const NearestNeighborQueryNode& node) { + match(spec_expr); + expect_match(exp_square_distance, node); + } void expect_match(double exp_square_distance, const NearestNeighborQueryNode& node) { double exp_raw_score = dist_func.to_rawscore(exp_square_distance); EXPECT_TRUE(node.evaluate()); + EXPECT_DOUBLE_EQ(exp_square_distance, node.get_distance().value()); EXPECT_DOUBLE_EQ(exp_raw_score, node.get_raw_score().value()); } + void expect_not_match(const vespalib::string& spec_expr, const NearestNeighborQueryNode& node) { + match(spec_expr); + EXPECT_FALSE(node.evaluate()); + } }; -TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold) +TEST_F(NearestNeighborSearcherTest, distance_heap_keeps_the_best_target_hits) { - query.add("qt1", 3); + query.add("qt1", 2, 100.0); + const auto& node = query.get(0); set_query_tensor("qt1", "tensor(x[2]):[1,3]"); prepare(); - match("tensor(x[2]):[1,5]"); - expect_match((5-3)*(5-3), query.get(0)); + expect_match("tensor(x[2]):[1,7]", (7-3)*(7-3), node); + expect_match("tensor(x[2]):[1,9]", (9-3)*(9-3), node); - match("tensor(x[2]):[1,6]"); - expect_match((6-3)*(6-3), query.get(0)); + // The distance limit is now (9-3)*(9-3) = 36, so this is not good enough. + expect_not_match("tensor(x[2]):[1,10]", node); + + expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node); + + // The distance limit is now (7-3)*(7-3) = 16, so this is not good enough. + expect_not_match("tensor(x[2]):[1,8]", node); + + // This is not considered a document match as get_raw_score() is not called, + // and the distance heap is not updated. + match("tensor(x[2]):[1,4]"); + EXPECT_EQ(1, node.get_distance().value()); + EXPECT_TRUE(node.evaluate()); + + // The distance limit is still (7-3)*(7-3) = 16, so this is in fact good enough. + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); + + // The distance limit is (6-3)*(6-3) = 4, and a similar distance is a match. + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); +} + +TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold) +{ + query.add("qt1", 10, 3.0); + const auto& node = query.get(0); + set_query_tensor("qt1", "tensor(x[2]):[1,3]"); + prepare(); + + expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node); + expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node); - match("tensor(x[2]):[1,7]"); // This is not a match since ((7-3)*(7-3) = 16) is larger than the internal distance threshold of (3*3 = 9). - EXPECT_FALSE(query.get(0).evaluate()); + expect_not_match("tensor(x[2]):[1,7]", node); } TEST_F(NearestNeighborSearcherTest, raw_score_calculated_for_two_query_operators) { - query.add("qt1", 3); - query.add("qt2", 4); + query.add("qt1", 10, 3.0); + query.add("qt2", 10, 4.0); set_query_tensor("qt1", "tensor(x[2]):[1,3]"); set_query_tensor("qt2", "tensor(x[2]):[1,4]"); prepare(); diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 9f3f3d770e4..4d425d9dedd 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -55,6 +55,10 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) _query_wrapper = std::make_unique<QueryWrapper>(*_query); } +class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator { +public: + double to_raw_score(double distance) override { return distance * 2; } +}; TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) { @@ -71,6 +75,8 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) EXPECT_EQ(1u, term_list.size()); auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm()); EXPECT_NE(nullptr, node); + MockRawScoreCalculator calc; + node->set_raw_score_calc(&calc); auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); auto& td = qtd.getTermData(); constexpr TermFieldHandle handle = 27; @@ -82,11 +88,11 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node) EXPECT_EQ(invalid_id, tfmd->getDocId()); RankProcessor::unpack_match_data(1, *md, *_query_wrapper); EXPECT_EQ(invalid_id, tfmd->getDocId()); - constexpr double raw_score = 1.5; - node->set_raw_score(raw_score); + constexpr double distance = 1.5; + node->set_distance(distance); RankProcessor::unpack_match_data(2, *md, *_query_wrapper); EXPECT_EQ(2, tfmd->getDocId()); - EXPECT_EQ(raw_score, tfmd->getRawScore()); + EXPECT_EQ(distance * 2, tfmd->getRawScore()); node->reset(); RankProcessor::unpack_match_data(3, *md, *_query_wrapper); EXPECT_EQ(2, tfmd->getDocId()); diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 01b21edc1ba..3ce137bffe5 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -241,7 +241,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap for (QueryWrapper::Term & term: query.getTermList()) { auto nn_node = term.getTerm()->as_nearest_neighbor_query_node(); if (nn_node != nullptr) { - auto& raw_score = nn_node->get_raw_score(); + auto raw_score = nn_node->get_raw_score(); if (raw_score.has_value()) { auto& qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem()); auto& td = qtd.getTermData(); diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp index f064760e55d..db4ee12438e 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp @@ -48,8 +48,17 @@ NearestNeighborFieldSearcher::NodeAndCalc::NodeAndCalc(search::streaming::Neares std::unique_ptr<search::tensor::DistanceCalculator> calc_in) : node(node_in), calc(std::move(calc_in)), - distance_threshold(calc->function().convert_threshold(node->get_distance_threshold())) + heap(node->get_target_hits()) { + node->set_raw_score_calc(this); + heap.set_distance_threshold(calc->function().convert_threshold(node->get_distance_threshold())); +} + +double +NearestNeighborFieldSearcher::NodeAndCalc::to_raw_score(double distance) +{ + heap.used(distance); + return calc->function().to_rawscore(distance); } NearestNeighborFieldSearcher::NearestNeighborFieldSearcher(FieldIdT fid, @@ -100,7 +109,7 @@ NearestNeighborFieldSearcher::prepare(search::streaming::QueryTermList& qtl, } try { auto calc = DistanceCalculator::make_with_validation(*_attr, *tensor_value); - _calcs.emplace_back(nn_term, std::move(calc)); + _calcs.push_back(std::make_unique<NodeAndCalc>(nn_term, std::move(calc))); } catch (const vespalib::IllegalArgumentException& ex) { vespalib::Issue::report("Could not create DistanceCalculator for NearestNeighborQueryNode(%s, %s): %s", nn_term->index().c_str(), nn_term->get_query_tensor_name().c_str(), ex.what()); @@ -116,10 +125,10 @@ NearestNeighborFieldSearcher::onValue(const document::FieldValue& fv) if (tfv && tfv->getAsTensorPtr()) { _attr->add(*tfv->getAsTensorPtr(), 1); for (auto& elem : _calcs) { - double distance = elem.calc->calc_with_limit(scratch_docid, elem.distance_threshold); - if (distance <= elem.distance_threshold) { - double score = elem.calc->function().to_rawscore(distance); - elem.node->set_raw_score(score); + double distance_limit = elem->heap.distanceLimit(); + double distance = elem->calc->calc_with_limit(scratch_docid, distance_limit); + if (distance <= distance_limit) { + elem->node->set_distance(distance); } } } diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h index ba39b91c677..d5d751cd637 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h +++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h @@ -5,6 +5,8 @@ #include "fieldsearcher.h" #include <vespa/eval/eval/value_type.h> #include <vespa/searchcommon/attribute/distance_metric.h> +#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> +#include <vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h> #include <vespa/searchlib/tensor/distance_calculator.h> #include <vespa/searchlib/tensor/tensor_ext_attribute.h> @@ -14,8 +16,6 @@ namespace search::tensor { class TensorExtAttribute; } -namespace search::streaming { class NearestNeighborQueryNode; } - namespace vsm { /** @@ -26,16 +26,19 @@ namespace vsm { */ class NearestNeighborFieldSearcher : public FieldSearcher { private: - struct NodeAndCalc { + class NodeAndCalc : search::streaming::NearestNeighborQueryNode::RawScoreCalculator { + public: search::streaming::NearestNeighborQueryNode* node; std::unique_ptr<search::tensor::DistanceCalculator> calc; - double distance_threshold; + search::queryeval::NearestNeighborDistanceHeap heap; NodeAndCalc(search::streaming::NearestNeighborQueryNode* node_in, std::unique_ptr<search::tensor::DistanceCalculator> calc_in); + + double to_raw_score(double distance) override; }; search::attribute::DistanceMetric _metric; std::unique_ptr<search::tensor::TensorExtAttribute> _attr; - std::vector<NodeAndCalc> _calcs; + std::vector<std::unique_ptr<NodeAndCalc>> _calcs; public: NearestNeighborFieldSearcher(FieldIdT fid, diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java new file mode 100644 index 00000000000..c2ab22f4921 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java @@ -0,0 +1,14 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api; + +public record DefaultSignedIdentityDocument(String signature, int signingKeyVersion, int documentVersion, + String data, IdentityDocument identityDocument) implements SignedIdentityDocument { + + public DefaultSignedIdentityDocument { + identityDocument = EntityBindingsMapper.fromIdentityDocumentData(data); + } + + public DefaultSignedIdentityDocument(String signature, int signingKeyVersion, int documentVersion, String data) { + this(signature,signingKeyVersion,documentVersion, data, null); + } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java index 2d77d2ceda1..a695e10a29c 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java @@ -6,8 +6,12 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; +import com.yahoo.vespa.athenz.identityprovider.api.bindings.DefaultSignedIdentityDocumentEntity; +import com.yahoo.vespa.athenz.identityprovider.api.bindings.IdentityDocumentEntity; +import com.yahoo.vespa.athenz.identityprovider.api.bindings.LegacySignedIdentityDocumentEntity; import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity; import com.yahoo.vespa.athenz.utils.AthenzIdentities; +import com.yahoo.yolean.Exceptions; import java.io.IOException; import java.io.InputStream; @@ -16,6 +20,8 @@ import java.io.UncheckedIOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; +import java.time.Instant; +import java.util.Base64; import java.util.Optional; import static com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId.fromDottedString; @@ -24,6 +30,7 @@ import static com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId. * Utility class for mapping objects model types and their Jackson binding versions. * * @author bjorncs + * @author mortent */ public class EntityBindingsMapper { @@ -48,39 +55,60 @@ public class EntityBindingsMapper { } public static SignedIdentityDocument toSignedIdentityDocument(SignedIdentityDocumentEntity entity) { - return new SignedIdentityDocument( - entity.signature(), - entity.signingKeyVersion(), - fromDottedString(entity.providerUniqueId()), - new AthenzService(entity.providerService()), - entity.documentVersion(), - entity.configServerHostname(), - entity.instanceHostname(), - entity.createdAt(), - entity.ipAddresses(), - IdentityType.fromId(entity.identityType()), - Optional.ofNullable(entity.clusterType()).map(ClusterType::from).orElse(null), - entity.ztsUrl(), - Optional.ofNullable(entity.serviceIdentity()).map(AthenzIdentities::from).orElse(null), - entity.unknownAttributes()); + if (entity instanceof LegacySignedIdentityDocumentEntity docEntity) { + IdentityDocument doc = new IdentityDocument( + fromDottedString(docEntity.providerUniqueId()), + new AthenzService(docEntity.providerService()), + docEntity.configServerHostname(), + docEntity.instanceHostname(), + docEntity.createdAt(), + docEntity.ipAddresses(), + IdentityType.fromId(docEntity.identityType()), + Optional.ofNullable(docEntity.clusterType()).map(ClusterType::from).orElse(null), + docEntity.ztsUrl(), + Optional.ofNullable(docEntity.serviceIdentity()).map(AthenzIdentities::from).orElse(null), + docEntity.unknownAttributes()); + return new LegacySignedIdentityDocument( + docEntity.signature(), + docEntity.signingKeyVersion(), + entity.documentVersion(), + doc); + } else if (entity instanceof DefaultSignedIdentityDocumentEntity docEntity) { + return new DefaultSignedIdentityDocument(docEntity.signature(), + docEntity.signingKeyVersion(), + docEntity.documentVersion(), + docEntity.data()); + } else { + throw new IllegalArgumentException("Unknown signed identity document type: " + entity.getClass().getName()); + } } public static SignedIdentityDocumentEntity toSignedIdentityDocumentEntity(SignedIdentityDocument model) { - return new SignedIdentityDocumentEntity( - model.signature(), - model.signingKeyVersion(), - model.providerUniqueId().asDottedString(), - model.providerService().getFullName(), - model.documentVersion(), - model.configServerHostname(), - model.instanceHostname(), - model.createdAt(), - model.ipAddresses(), - model.identityType().id(), - Optional.ofNullable(model.clusterType()).map(ClusterType::toConfigValue).orElse(null), - model.ztsUrl(), - Optional.ofNullable(model.serviceIdentity()).map(AthenzIdentity::getFullName).orElse(null), - model.unknownAttributes()); + if (model instanceof LegacySignedIdentityDocument legacyModel) { + IdentityDocument idDoc = legacyModel.identityDocument(); + return new LegacySignedIdentityDocumentEntity( + legacyModel.signature(), + legacyModel.signingKeyVersion(), + idDoc.providerUniqueId().asDottedString(), + idDoc.providerService().getFullName(), + legacyModel.documentVersion(), + idDoc.configServerHostname(), + idDoc.instanceHostname(), + idDoc.createdAt(), + idDoc.ipAddresses(), + idDoc.identityType().id(), + Optional.ofNullable(idDoc.clusterType()).map(ClusterType::toConfigValue).orElse(null), + idDoc.ztsUrl(), + Optional.ofNullable(idDoc.serviceIdentity()).map(AthenzIdentity::getFullName).orElse(null), + idDoc.unknownAttributes()); + } else if (model instanceof DefaultSignedIdentityDocument defaultModel){ + return new DefaultSignedIdentityDocumentEntity(defaultModel.signature(), + defaultModel.signingKeyVersion(), + defaultModel.documentVersion(), + defaultModel.data()); + } else { + throw new IllegalArgumentException("Unsupported model type: " + model.getClass().getName()); + } } public static SignedIdentityDocument readSignedIdentityDocumentFromFile(Path file) { @@ -104,4 +132,40 @@ public class EntityBindingsMapper { } } + public static IdentityDocument fromIdentityDocumentData(String data) { + byte[] decoded = Base64.getDecoder().decode(data); + IdentityDocumentEntity docEntity = Exceptions.uncheck(() -> mapper.readValue(decoded, IdentityDocumentEntity.class)); + return new IdentityDocument( + fromDottedString(docEntity.providerUniqueId()), + new AthenzService(docEntity.providerService()), + docEntity.configServerHostname(), + docEntity.instanceHostname(), + docEntity.createdAt(), + docEntity.ipAddresses(), + IdentityType.fromId(docEntity.identityType()), + Optional.ofNullable(docEntity.clusterType()).map(ClusterType::from).orElse(null), + docEntity.ztsUrl(), + Optional.ofNullable(docEntity.serviceIdentity()).map(AthenzIdentities::from).orElse(null), + docEntity.unknownAttributes()); + } + + public static String toIdentityDocmentData(IdentityDocument identityDocument) { + IdentityDocumentEntity documentEntity = new IdentityDocumentEntity( + identityDocument.providerUniqueId().asDottedString(), + identityDocument.providerService().getFullName(), + identityDocument.configServerHostname(), + identityDocument.instanceHostname(), + identityDocument.createdAt(), + identityDocument.ipAddresses(), + identityDocument.identityType().id(), + Optional.ofNullable(identityDocument.clusterType()).map(ClusterType::toConfigValue).orElse(null), + identityDocument.ztsUrl(), + identityDocument.serviceIdentity().getFullName()); + try { + byte[] bytes = mapper.writeValueAsBytes(documentEntity); + return Base64.getEncoder().encodeToString(bytes); + } catch (JsonProcessingException e) { + throw new RuntimeException("Error during serialization of identity document.", e); + } + } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java new file mode 100644 index 00000000000..577584db185 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java @@ -0,0 +1,54 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api; + +import com.yahoo.vespa.athenz.api.AthenzIdentity; +import com.yahoo.vespa.athenz.api.AthenzService; + +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Represents an unsigned identity document + * @author mortent + */ +public record IdentityDocument(VespaUniqueInstanceId providerUniqueId, AthenzService providerService, String configServerHostname, + String instanceHostname, Instant createdAt, Set<String> ipAddresses, + IdentityType identityType, ClusterType clusterType, String ztsUrl, + AthenzIdentity serviceIdentity, Map<String, Object> unknownAttributes) { + + public IdentityDocument { + ipAddresses = Set.copyOf(ipAddresses); + + Map<String, Object> nonNull = new HashMap<>(); + unknownAttributes.forEach((key, value) -> { + if (value != null) nonNull.put(key, value); + }); + // Map.copyOf() does not allow null values + unknownAttributes = Map.copyOf(nonNull); + } + + public IdentityDocument(VespaUniqueInstanceId providerUniqueId, AthenzService providerService, String configServerHostname, + String instanceHostname, Instant createdAt, Set<String> ipAddresses, + IdentityType identityType, ClusterType clusterType, String ztsUrl, + AthenzIdentity serviceIdentity) { + this(providerUniqueId, providerService, configServerHostname, instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, Map.of()); + } + + + public IdentityDocument withServiceIdentity(AthenzService athenzService) { + return new IdentityDocument( + this.providerUniqueId, + this.providerService, + this.configServerHostname, + this.instanceHostname, + this.createdAt, + this.ipAddresses, + this.identityType, + this.clusterType, + this.ztsUrl, + athenzService, + this.unknownAttributes); + } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java index 5a0f77ec765..a3c2f0264d3 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java @@ -1,12 +1,15 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.identityprovider.api; +import java.util.Optional; +import java.util.OptionalInt; + /** * A client that communicates that fetches an identity document. * * @author bjorncs */ public interface IdentityDocumentClient { - SignedIdentityDocument getNodeIdentityDocument(String host); - SignedIdentityDocument getTenantIdentityDocument(String host); + SignedIdentityDocument getNodeIdentityDocument(String host, int documentVersion); + Optional<SignedIdentityDocument> getTenantIdentityDocument(String host, int documentVersion); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java new file mode 100644 index 00000000000..220bc72a017 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java @@ -0,0 +1,6 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api; + +public record LegacySignedIdentityDocument(String signature, int signingKeyVersion, int documentVersion, + IdentityDocument identityDocument) implements SignedIdentityDocument { +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java index de78d81cd1b..4e3bd8dee91 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java @@ -1,54 +1,20 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.identityprovider.api; -import com.yahoo.vespa.athenz.api.AthenzIdentity; -import com.yahoo.vespa.athenz.api.AthenzService; - -import java.net.URL; -import java.time.Instant; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - /** * A signed identity document. - * The {@link #unknownAttributes()} member provides forward compatibility and ensures any new/unknown fields are kept intact when serialized to JSON. - * * @author bjorncs + * @author mortent */ -public record SignedIdentityDocument(String signature, int signingKeyVersion, VespaUniqueInstanceId providerUniqueId, - AthenzService providerService, int documentVersion, String configServerHostname, - String instanceHostname, Instant createdAt, Set<String> ipAddresses, - IdentityType identityType, ClusterType clusterType, String ztsUrl, - AthenzIdentity serviceIdentity, Map<String, Object> unknownAttributes) { - - public SignedIdentityDocument { - ipAddresses = Set.copyOf(ipAddresses); - - Map<String, Object> nonNull = new HashMap<>(); - unknownAttributes.forEach((key, value) -> { - if (value != null) nonNull.put(key, value); - }); - // Map.copyOf() does not allow null values - unknownAttributes = Map.copyOf(nonNull); - } - - public SignedIdentityDocument(String signature, int signingKeyVersion, VespaUniqueInstanceId providerUniqueId, - AthenzService providerService, int documentVersion, String configServerHostname, - String instanceHostname, Instant createdAt, Set<String> ipAddresses, - IdentityType identityType, ClusterType clusterType, String ztsUrl, AthenzIdentity serviceIdentity) { - this(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname, - instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, Map.of()); - } - - public static final int DEFAULT_DOCUMENT_VERSION = 3; - - public boolean outdated() { return documentVersion < DEFAULT_DOCUMENT_VERSION; } +public interface SignedIdentityDocument { - public SignedIdentityDocument withServiceIdentity(AthenzIdentity identity) { - return new SignedIdentityDocument(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname, instanceHostname, createdAt, - ipAddresses, identityType, clusterType, ztsUrl, identity); - } + int LEGACY_DEFAULT_DOCUMENT_VERSION = 3; + int DEFAULT_DOCUMENT_VERSION = 4; + default boolean outdated() { return documentVersion() < LEGACY_DEFAULT_DOCUMENT_VERSION; } + IdentityDocument identityDocument(); + String signature(); + int signingKeyVersion(); + int documentVersion(); } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java new file mode 100644 index 00000000000..3aaff011415 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java @@ -0,0 +1,12 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api.bindings; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public record DefaultSignedIdentityDocumentEntity( + @JsonProperty("signature") String signature, + @JsonProperty("signing-key-version") int signingKeyVersion, + @JsonProperty("document-version") int documentVersion, + @JsonProperty("data") String data) + implements SignedIdentityDocumentEntity { +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java new file mode 100644 index 00000000000..946eacc67eb --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java @@ -0,0 +1,51 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api.bindings; + +import com.fasterxml.jackson.annotation.JsonAnyGetter; +import com.fasterxml.jackson.annotation.JsonAnySetter; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * @author bjorncs + * @author mortent + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record IdentityDocumentEntity(String providerUniqueId, String providerService, + String configServerHostname, String instanceHostname, Instant createdAt, Set<String> ipAddresses, + String identityType, String clusterType, String ztsUrl, String serviceIdentity, Map<String, Object> unknownAttributes) { + + @JsonCreator + public IdentityDocumentEntity(@JsonProperty("provider-unique-id") String providerUniqueId, + @JsonProperty("provider-service") String providerService, + @JsonProperty("configserver-hostname") String configServerHostname, + @JsonProperty("instance-hostname") String instanceHostname, + @JsonProperty("created-at") Instant createdAt, + @JsonProperty("ip-addresses") Set<String> ipAddresses, + @JsonProperty("identity-type") String identityType, + @JsonProperty("cluster-type") String clusterType, + @JsonProperty("zts-url") String ztsUrl, + @JsonProperty("service-identity") String serviceIdentity) { + this(providerUniqueId, providerService, configServerHostname, + instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, new HashMap<>()); + } + + @JsonProperty("provider-unique-id") @Override public String providerUniqueId() { return providerUniqueId; } + @JsonProperty("provider-service") @Override public String providerService() { return providerService; } + @JsonProperty("configserver-hostname") @Override public String configServerHostname() { return configServerHostname; } + @JsonProperty("instance-hostname") @Override public String instanceHostname() { return instanceHostname; } + @JsonProperty("created-at") @Override public Instant createdAt() { return createdAt; } + @JsonProperty("ip-addresses") @Override public Set<String> ipAddresses() { return ipAddresses; } + @JsonProperty("identity-type") @Override public String identityType() { return identityType; } + @JsonProperty("cluster-type") @Override public String clusterType() { return clusterType; } + @JsonProperty("zts-url") @Override public String ztsUrl() { return ztsUrl; } + @JsonProperty("service-identity") @Override public String serviceIdentity() { return serviceIdentity; } + @JsonAnyGetter @Override public Map<String, Object> unknownAttributes() { return unknownAttributes; } + @JsonAnySetter public void set(String name, Object value) { unknownAttributes.put(name, value); } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java new file mode 100644 index 00000000000..e00ab9978f6 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java @@ -0,0 +1,57 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.api.bindings; + +import com.fasterxml.jackson.annotation.JsonAnyGetter; +import com.fasterxml.jackson.annotation.JsonAnySetter; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.time.Instant; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * @author bjorncs + */ +@JsonInclude(JsonInclude.Include.NON_NULL) +public record LegacySignedIdentityDocumentEntity ( + String signature, int signingKeyVersion, String providerUniqueId, String providerService, int documentVersion, + String configServerHostname, String instanceHostname, Instant createdAt, Set<String> ipAddresses, + String identityType, String clusterType, String ztsUrl, String serviceIdentity, Map<String, Object> unknownAttributes) implements SignedIdentityDocumentEntity { + + @JsonCreator + public LegacySignedIdentityDocumentEntity(@JsonProperty("signature") String signature, + @JsonProperty("signing-key-version") int signingKeyVersion, + @JsonProperty("provider-unique-id") String providerUniqueId, + @JsonProperty("provider-service") String providerService, + @JsonProperty("document-version") int documentVersion, + @JsonProperty("configserver-hostname") String configServerHostname, + @JsonProperty("instance-hostname") String instanceHostname, + @JsonProperty("created-at") Instant createdAt, + @JsonProperty("ip-addresses") Set<String> ipAddresses, + @JsonProperty("identity-type") String identityType, + @JsonProperty("cluster-type") String clusterType, + @JsonProperty("zts-url") String ztsUrl, + @JsonProperty("service-identity") String serviceIdentity) { + this(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname, + instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, new HashMap<>()); + } + + @JsonProperty("signature") @Override public String signature() { return signature; } + @JsonProperty("signing-key-version") @Override public int signingKeyVersion() { return signingKeyVersion; } + @JsonProperty("provider-unique-id") @Override public String providerUniqueId() { return providerUniqueId; } + @JsonProperty("provider-service") @Override public String providerService() { return providerService; } + @JsonProperty("document-version") @Override public int documentVersion() { return documentVersion; } + @JsonProperty("configserver-hostname") @Override public String configServerHostname() { return configServerHostname; } + @JsonProperty("instance-hostname") @Override public String instanceHostname() { return instanceHostname; } + @JsonProperty("created-at") @Override public Instant createdAt() { return createdAt; } + @JsonProperty("ip-addresses") @Override public Set<String> ipAddresses() { return ipAddresses; } + @JsonProperty("identity-type") @Override public String identityType() { return identityType; } + @JsonProperty("cluster-type") @Override public String clusterType() { return clusterType; } + @JsonProperty("zts-url") @Override public String ztsUrl() { return ztsUrl; } + @JsonProperty("service-identity") @Override public String serviceIdentity() { return serviceIdentity; } + @JsonAnyGetter @Override public Map<String, Object> unknownAttributes() { return unknownAttributes; } + @JsonAnySetter public void set(String name, Object value) { unknownAttributes.put(name, value); } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java index fc0dff3b97b..174c76f7fa9 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java @@ -1,57 +1,77 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.athenz.identityprovider.api.bindings; -import com.fasterxml.jackson.annotation.JsonAnyGetter; -import com.fasterxml.jackson.annotation.JsonAnySetter; -import com.fasterxml.jackson.annotation.JsonCreator; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonProperty; - -import java.time.Instant; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * @author bjorncs - */ -@JsonInclude(JsonInclude.Include.NON_NULL) -public record SignedIdentityDocumentEntity( - String signature, int signingKeyVersion, String providerUniqueId, String providerService, int documentVersion, - String configServerHostname, String instanceHostname, Instant createdAt, Set<String> ipAddresses, - String identityType, String clusterType, String ztsUrl, String serviceIdentity, Map<String, Object> unknownAttributes) { - - @JsonCreator - public SignedIdentityDocumentEntity(@JsonProperty("signature") String signature, - @JsonProperty("signing-key-version") int signingKeyVersion, - @JsonProperty("provider-unique-id") String providerUniqueId, - @JsonProperty("provider-service") String providerService, - @JsonProperty("document-version") int documentVersion, - @JsonProperty("configserver-hostname") String configServerHostname, - @JsonProperty("instance-hostname") String instanceHostname, - @JsonProperty("created-at") Instant createdAt, - @JsonProperty("ip-addresses") Set<String> ipAddresses, - @JsonProperty("identity-type") String identityType, - @JsonProperty("cluster-type") String clusterType, - @JsonProperty("zts-url") String ztsUrl, - @JsonProperty("service-identity") String serviceIdentity) { - this(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname, - instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, new HashMap<>()); - } - @JsonProperty("signature") @Override public String signature() { return signature; } - @JsonProperty("signing-key-version") @Override public int signingKeyVersion() { return signingKeyVersion; } - @JsonProperty("provider-unique-id") @Override public String providerUniqueId() { return providerUniqueId; } - @JsonProperty("provider-service") @Override public String providerService() { return providerService; } - @JsonProperty("document-version") @Override public int documentVersion() { return documentVersion; } - @JsonProperty("configserver-hostname") @Override public String configServerHostname() { return configServerHostname; } - @JsonProperty("instance-hostname") @Override public String instanceHostname() { return instanceHostname; } - @JsonProperty("created-at") @Override public Instant createdAt() { return createdAt; } - @JsonProperty("ip-addresses") @Override public Set<String> ipAddresses() { return ipAddresses; } - @JsonProperty("identity-type") @Override public String identityType() { return identityType; } - @JsonProperty("cluster-type") @Override public String clusterType() { return clusterType; } - @JsonProperty("zts-url") @Override public String ztsUrl() { return ztsUrl; } - @JsonProperty("service-identity") @Override public String serviceIdentity() { return serviceIdentity; } - @JsonAnyGetter @Override public Map<String, Object> unknownAttributes() { return unknownAttributes; } - @JsonAnySetter public void set(String name, Object value) { unknownAttributes.put(name, value); } +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonTypeInfo; +import com.fasterxml.jackson.databind.DatabindContext; +import com.fasterxml.jackson.databind.JavaType; +import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver; +import com.fasterxml.jackson.databind.jsontype.TypeIdResolver; +import com.fasterxml.jackson.databind.type.TypeFactory; +import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; + +import java.io.IOException; +import java.util.Objects; + +@JsonIgnoreProperties(ignoreUnknown = true) +@JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, property = "document-version", visible = true) +@JsonTypeIdResolver(SignedIdentityDocumentEntityTypeResolver.class) +public interface SignedIdentityDocumentEntity { + int documentVersion(); } + +class SignedIdentityDocumentEntityTypeResolver implements TypeIdResolver { + JavaType javaType; + + @Override + public void init(JavaType javaType) { + this.javaType = javaType; + } + + @Override + public String idFromValue(Object o) { + return idFromValueAndType(o, o.getClass()); + } + + @Override + public String idFromValueAndType(Object o, Class<?> aClass) { + if (Objects.isNull(o)) { + throw new IllegalArgumentException("Cannot serialize null oject"); + } else { + if (o instanceof SignedIdentityDocumentEntity s) { + return Integer.toString(s.documentVersion()); + } else { + throw new IllegalArgumentException("Cannot serialize class: " + o.getClass()); + } + } + } + + @Override + public String idFromBaseType() { + return idFromValueAndType(null, javaType.getRawClass()); + } + + @Override + public JavaType typeFromId(DatabindContext databindContext, String s) throws IOException { + try { + int version = Integer.parseInt(s); + Class<? extends SignedIdentityDocumentEntity> cls = version <= SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION + ? LegacySignedIdentityDocumentEntity.class + : DefaultSignedIdentityDocumentEntity.class; + return TypeFactory.defaultInstance().constructSpecializedType(javaType,cls); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Unable to deserialize document with version: \"%s\"".formatted(s)); + } + } + + @Override + public String getDescForKnownTypeIds() { + return "Type resolver for SignedIdentityDocumentEntity"; + } + + @Override + public JsonTypeInfo.Id getMechanism() { + return JsonTypeInfo.Id.CUSTOM; + } +}
\ No newline at end of file diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java index cc9d3b2be65..d26386702d5 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java @@ -11,6 +11,7 @@ import com.yahoo.vespa.athenz.client.zts.InstanceIdentity; import com.yahoo.vespa.athenz.client.zts.ZtsClient; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier; @@ -74,7 +75,9 @@ class AthenzCredentialsService { } KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); IdentityDocumentClient identityDocumentClient = createIdentityDocumentClient(); - SignedIdentityDocument document = identityDocumentClient.getTenantIdentityDocument(hostname); + // Use legacy version for now. + SignedIdentityDocument signedDocument = identityDocumentClient.getTenantIdentityDocument(hostname, SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION).orElseThrow(); + IdentityDocument document = signedDocument.identityDocument(); Pkcs10Csr csr = csrGenerator.generateInstanceCsr( tenantIdentity, document.providerUniqueId(), @@ -87,16 +90,17 @@ class AthenzCredentialsService { ztsClient.registerInstance( configserverIdentity, tenantIdentity, - EntityBindingsMapper.toAttestationData(document), + EntityBindingsMapper.toAttestationData(signedDocument), csr); X509Certificate certificate = instanceIdentity.certificate(); - writeCredentialsToDisk(keyPair.getPrivate(), certificate, document); - return new AthenzCredentials(certificate, keyPair, document); + writeCredentialsToDisk(keyPair.getPrivate(), certificate, signedDocument); + return new AthenzCredentials(certificate, keyPair, signedDocument); } } - AthenzCredentials updateCredentials(SignedIdentityDocument document, SSLContext sslContext) { + AthenzCredentials updateCredentials(SignedIdentityDocument signedDocument, SSLContext sslContext) { KeyPair newKeyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA); + IdentityDocument document = signedDocument.identityDocument(); Pkcs10Csr csr = csrGenerator.generateInstanceCsr( tenantIdentity, document.providerUniqueId(), @@ -112,8 +116,8 @@ class AthenzCredentialsService { document.providerUniqueId().asDottedString(), csr); X509Certificate certificate = instanceIdentity.certificate(); - writeCredentialsToDisk(newKeyPair.getPrivate(), certificate, document); - return new AthenzCredentials(certificate, newKeyPair, document); + writeCredentialsToDisk(newKeyPair.getPrivate(), certificate, signedDocument); + return new AthenzCredentials(certificate, newKeyPair, signedDocument); } } diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java index b9f9f3862c2..77aaf17419d 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java @@ -11,7 +11,9 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; import com.yahoo.jdisc.Metric; import com.yahoo.metrics.ContainerMetrics; +import com.yahoo.security.AutoReloadingX509KeyManager; import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyUtils; import com.yahoo.security.MutableX509KeyManager; import com.yahoo.security.Pkcs10Csr; import com.yahoo.security.SslContextBuilder; @@ -24,9 +26,9 @@ import com.yahoo.vespa.athenz.api.ZToken; import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; import com.yahoo.vespa.athenz.client.zts.ZtsClient; import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; -import com.yahoo.vespa.athenz.identity.SiaIdentityProvider; +import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; +import com.yahoo.vespa.athenz.tls.AthenzX509CertificateUtils; import com.yahoo.vespa.athenz.utils.SiaUtils; -import com.yahoo.vespa.defaults.Defaults; import javax.net.ssl.SSLContext; import javax.net.ssl.X509ExtendedKeyManager; @@ -38,7 +40,6 @@ import java.security.cert.X509Certificate; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -65,7 +66,6 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen // TODO Make some of these values configurable through config. Match requested expiration of register/update requests. // TODO These should match the requested expiration - static final Duration UPDATE_PERIOD = Duration.ofDays(1); static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90); private final static Duration ROLE_SSL_CONTEXT_EXPIRY = Duration.ofHours(2); // TODO CMS expects 10min or less token ttl. Use 10min default until we have configurable expiry @@ -73,20 +73,17 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen // TODO Make path to trust store paths config private static final Path CLIENT_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/yahoo_certificate_bundle.pem"); - private static final Path ATHENZ_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/athenz_certificate_bundle.pem"); public static final String CERTIFICATE_EXPIRY_METRIC_NAME = ContainerMetrics.ATHENZ_TENANT_CERT_EXPIRY_SECONDS.baseName(); - private volatile AthenzCredentials credentials; private final Metric metric; private final Path trustStore; - private final AthenzCredentialsService athenzCredentialsService; private final ScheduledExecutorService scheduler; private final Clock clock; private final AthenzService identity; private final URI ztsEndpoint; - private final MutableX509KeyManager identityKeyManager = new MutableX509KeyManager(); + private final AutoReloadingX509KeyManager autoReloadingX509KeyManager; private final SSLContext identitySslContext; private final LoadingCache<AthenzRole, X509Certificate> roleSslCertCache; private final Map<AthenzRole, MutableX509KeyManager> roleKeyManagerCache; @@ -98,40 +95,32 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen @Inject public AthenzIdentityProviderImpl(IdentityConfig config, Metric metric) { - this(config, - metric, - CLIENT_TRUST_STORE, - new AthenzCredentialsService(config, - createNodeIdentityProvider(config), - Defaults.getDefaults().vespaHostname(), - Clock.systemUTC()), - new ScheduledThreadPoolExecutor(1), - Clock.systemUTC()); + this(config, metric, CLIENT_TRUST_STORE, new ScheduledThreadPoolExecutor(1), Clock.systemUTC(), createAutoReloadingX509KeyManager(config)); } // Test only AthenzIdentityProviderImpl(IdentityConfig config, Metric metric, Path trustStore, - AthenzCredentialsService athenzCredentialsService, ScheduledExecutorService scheduler, - Clock clock) { + Clock clock, + AutoReloadingX509KeyManager autoReloadingX509KeyManager) { this.metric = metric; this.trustStore = trustStore; - this.athenzCredentialsService = athenzCredentialsService; this.scheduler = scheduler; this.clock = clock; this.identity = new AthenzService(config.domain(), config.service()); this.ztsEndpoint = URI.create(config.ztsUrl()); - roleSslCertCache = crateAutoReloadableCache(ROLE_SSL_CONTEXT_EXPIRY, this::requestRoleCertificate, this.scheduler); - roleKeyManagerCache = new HashMap<>(); - roleSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken); - domainSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken); - domainSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken); - roleSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken); + this.roleSslCertCache = crateAutoReloadableCache(ROLE_SSL_CONTEXT_EXPIRY, this::requestRoleCertificate, this.scheduler); + this.roleKeyManagerCache = new HashMap<>(); + this.roleSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken); + this.domainSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken); + this.domainSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken); + this.roleSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken); this.csrGenerator = new CsrGenerator(config.athenzDnsSuffix(), config.configserverIdentityName()); - this.identitySslContext = createIdentitySslContext(identityKeyManager, trustStore); - registerInstance(); + this.autoReloadingX509KeyManager = autoReloadingX509KeyManager; + this.identitySslContext = createIdentitySslContext(autoReloadingX509KeyManager, trustStore); + this.scheduler.scheduleAtFixedRate(this::reportMetrics, 0, 5, TimeUnit.MINUTES); } private static <KEY, VALUE> LoadingCache<KEY, VALUE> createCache(Duration expiry, Function<KEY, VALUE> cacheLoader) { @@ -165,16 +154,6 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen .build(); } - private void registerInstance() { - try { - updateIdentityCredentials(this.athenzCredentialsService.registerInstance()); - this.scheduler.scheduleAtFixedRate(this::refreshCertificate, UPDATE_PERIOD.toMinutes(), UPDATE_PERIOD.toMinutes(), TimeUnit.MINUTES); - this.scheduler.scheduleAtFixedRate(this::reportMetrics, 0, 5, TimeUnit.MINUTES); - } catch (Throwable t) { - throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", t); - } - } - @Override public AthenzService identity() { return identity; @@ -197,13 +176,13 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen @Override public X509CertificateWithKey getIdentityCertificateWithKey() { - AthenzCredentials copy = this.credentials; - return new X509CertificateWithKey(copy.getCertificate(), copy.getKeyPair().getPrivate()); + var copy = this.autoReloadingX509KeyManager.getCurrentCertificateWithKey(); + return new X509CertificateWithKey(copy.certificate(), copy.privateKey()); } - @Override public Path certificatePath() { return athenzCredentialsService.certificatePath(); } + @Override public Path certificatePath() { return SiaUtils.getCertificateFile(identity); } - @Override public Path privateKeyPath() { return athenzCredentialsService.privateKeyPath(); } + @Override public Path privateKeyPath() { return SiaUtils.getPrivateKeyFile(identity); } @Override public SSLContext getRoleSslContext(String domain, String role) { @@ -262,7 +241,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen @Override public PrivateKey getPrivateKey() { - return credentials.getKeyPair().getPrivate(); + return autoReloadingX509KeyManager.getPrivateKey(AutoReloadingX509KeyManager.CERTIFICATE_ALIAS); } @Override @@ -272,7 +251,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen @Override public List<X509Certificate> getIdentityCertificate() { - return Collections.singletonList(credentials.getCertificate()); + return List.of(autoReloadingX509KeyManager.getCertificateChain(AutoReloadingX509KeyManager.CERTIFICATE_ALIAS)); } @Override @@ -288,19 +267,15 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen } } - private void updateIdentityCredentials(AthenzCredentials credentials) { - this.credentials = credentials; - this.identityKeyManager.updateKeystore( - KeyStoreBuilder.withType(PKCS12) - .withKeyEntry("default", credentials.getKeyPair().getPrivate(), credentials.getCertificate()) - .build(), - new char[0]); - } - private X509Certificate requestRoleCertificate(AthenzRole role) { - var doc = credentials.getIdentityDocument(); + var credentials = autoReloadingX509KeyManager.getCurrentCertificateWithKey(); + var athenzUniqueInstanceId = VespaUniqueInstanceId.fromDottedString( + AthenzX509CertificateUtils.getInstanceId(credentials.certificate()) + .orElseThrow() + ); + var keyPair = KeyUtils.toKeyPair(credentials.privateKey()); Pkcs10Csr csr = csrGenerator.generateRoleCsr( - identity, role, doc.providerUniqueId(), doc.clusterType(), credentials.getKeyPair()); + identity, role, athenzUniqueInstanceId, null, keyPair); try (ZtsClient client = createZtsClient()) { X509Certificate roleCertificate = client.getRoleCertificate(role, csr); updateRoleKeyManager(role, roleCertificate); @@ -313,7 +288,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen MutableX509KeyManager keyManager = roleKeyManagerCache.computeIfAbsent(role, r -> new MutableX509KeyManager()); keyManager.updateKeystore( KeyStoreBuilder.withType(PKCS12) - .withKeyEntry("default", credentials.getKeyPair().getPrivate(), certificate) + .withKeyEntry("default", autoReloadingX509KeyManager.getCurrentCertificateWithKey().privateKey(), certificate) .build(), new char[0]); } @@ -346,6 +321,11 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen return new DefaultZtsClient.Builder(ztsEndpoint).withSslContext(getIdentitySslContext()).build(); } + private static AutoReloadingX509KeyManager createAutoReloadingX509KeyManager(IdentityConfig config) { + var tenantIdentity = new AthenzService(config.domain(), config.service()); + return AutoReloadingX509KeyManager.fromPemFiles(SiaUtils.getPrivateKeyFile(tenantIdentity), SiaUtils.getCertificateFile(tenantIdentity)); + } + @Override public void deconstruct() { try { @@ -356,32 +336,13 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen } } - private static SiaIdentityProvider createNodeIdentityProvider(IdentityConfig config) { - return new SiaIdentityProvider( - new AthenzService(config.nodeIdentityName()), SiaUtils.DEFAULT_SIA_DIRECTORY, CLIENT_TRUST_STORE); - } - - private boolean isExpired(AthenzCredentials credentials) { - return clock.instant().isAfter(getExpirationTime(credentials)); - } - - private static Instant getExpirationTime(AthenzCredentials credentials) { - return credentials.getCertificate().getNotAfter().toInstant(); - } - - void refreshCertificate() { - try { - updateIdentityCredentials(isExpired(credentials) - ? athenzCredentialsService.registerInstance() - : athenzCredentialsService.updateCredentials(credentials.getIdentityDocument(), identitySslContext)); - } catch (Throwable t) { - log.log(Level.WARNING, "Failed to update credentials: " + t.getMessage(), t); - } + private static Instant getExpirationTime(X509Certificate certificate) { + return certificate.getNotAfter().toInstant(); } void reportMetrics() { try { - Instant expirationTime = getExpirationTime(credentials); + Instant expirationTime = getExpirationTime(autoReloadingX509KeyManager.getCurrentCertificateWithKey().certificate()); Duration remainingLifetime = Duration.between(clock.instant(), expirationTime); metric.set(CERTIFICATE_EXPIRY_METRIC_NAME, remainingLifetime.getSeconds(), null); } catch (Throwable t) { diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java new file mode 100644 index 00000000000..66dad931815 --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java @@ -0,0 +1,38 @@ +package com.yahoo.vespa.athenz.identityprovider.client; + +import com.yahoo.container.core.identity.IdentityConfig; +import com.yahoo.container.di.componentgraph.Provider; +import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; +import com.yahoo.jdisc.Metric; + +import javax.inject.Inject; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; + +/** + * @author olaa + */ +public class AthenzIdentityProviderProvider implements Provider<AthenzIdentityProvider> { + + private final Path NODE_ADMIN_MANAGED_IDENTITY_DOCUMENT = Paths.get("/var/lib/sia/vespa-tenant-identity-document.json"); + private final AthenzIdentityProvider athenzIdentityProvider; + + @Inject + public AthenzIdentityProviderProvider(IdentityConfig config, Metric metric) { + if (Files.exists(NODE_ADMIN_MANAGED_IDENTITY_DOCUMENT)) + athenzIdentityProvider = new AthenzIdentityProviderImpl(config, metric); + else + athenzIdentityProvider = new LegacyAthenzIdentityProviderImpl(config, metric); + } + + @Override + public void deconstruct() { + athenzIdentityProvider.deconstruct(); + } + + @Override + public AthenzIdentityProvider get() { + return athenzIdentityProvider; + } +} diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java index 5b884e3dfb3..f95a3335c24 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.net.URI; import java.time.Duration; +import java.util.Optional; import java.util.function.Supplier; /** @@ -56,16 +57,16 @@ public class DefaultIdentityDocumentClient implements IdentityDocumentClient { } @Override - public SignedIdentityDocument getNodeIdentityDocument(String host) { - return getIdentityDocument(host, "node"); + public SignedIdentityDocument getNodeIdentityDocument(String host, int documentVersion) { + return getIdentityDocument(host, "node", documentVersion).orElseThrow(); } @Override - public SignedIdentityDocument getTenantIdentityDocument(String host) { - return getIdentityDocument(host, "tenant"); + public Optional<SignedIdentityDocument> getTenantIdentityDocument(String host, int documentVersion) { + return getIdentityDocument(host, "tenant", documentVersion); } - private SignedIdentityDocument getIdentityDocument(String host, String type) { + private Optional<SignedIdentityDocument> getIdentityDocument(String host, String type, int documentVersion) { try (CloseableHttpClient client = createHttpClient(sslContextSupplier.get(), hostnameVerifier)) { URI uri = configserverUri @@ -76,13 +77,16 @@ public class DefaultIdentityDocumentClient implements IdentityDocumentClient { .setUri(uri) .addHeader("Connection", "close") .addHeader("Accept", "application/json") + .addParameter("documentVersion", Integer.toString(documentVersion)) .build(); try (CloseableHttpResponse response = client.execute(request)) { String responseContent = EntityUtils.toString(response.getEntity()); int statusCode = response.getStatusLine().getStatusCode(); if (statusCode >= 200 && statusCode <= 299) { SignedIdentityDocumentEntity entity = objectMapper.readValue(responseContent, SignedIdentityDocumentEntity.class); - return EntityBindingsMapper.toSignedIdentityDocument(entity); + return Optional.of(EntityBindingsMapper.toSignedIdentityDocument(entity)); + } else if (statusCode == 404) { + return Optional.empty(); } else { throw new RuntimeException( String.format( diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java index 019f73fc6bf..11b30585933 100644 --- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java @@ -4,7 +4,10 @@ package com.yahoo.vespa.athenz.identityprovider.client; import com.yahoo.security.SignatureUtils; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; +import com.yahoo.vespa.athenz.identityprovider.api.DefaultSignedIdentityDocument; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; +import com.yahoo.vespa.athenz.identityprovider.api.LegacySignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; @@ -19,7 +22,7 @@ import java.util.Base64; import java.util.Set; import java.util.TreeSet; -import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION; +import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION; import static java.nio.charset.StandardCharsets.UTF_8; /** @@ -29,8 +32,25 @@ import static java.nio.charset.StandardCharsets.UTF_8; */ public class IdentityDocumentSigner { + public String generateSignature(String identityDocumentData, PrivateKey privateKey) { + try { + Signature signer = SignatureUtils.createSigner(privateKey); + signer.initSign(privateKey); + signer.update(identityDocumentData.getBytes(UTF_8)); + byte[] signature = signer.sign(); + return Base64.getEncoder().encodeToString(signature); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } + + public String generateLegacySignature(IdentityDocument doc, PrivateKey privateKey) { + return generateSignature(doc.providerUniqueId(), doc.providerService(), doc.configServerHostname(), + doc.instanceHostname(), doc.createdAt(), doc.ipAddresses(), doc.identityType(), privateKey, doc.serviceIdentity()); + } + // Cluster type is ignored due to old Vespa versions not forwarding unknown fields in signed identity document - public String generateSignature(VespaUniqueInstanceId providerUniqueId, + private String generateSignature(VespaUniqueInstanceId providerUniqueId, AthenzService providerService, String configServerHostname, String instanceHostname, @@ -54,14 +74,32 @@ public class IdentityDocumentSigner { } public boolean hasValidSignature(SignedIdentityDocument doc, PublicKey publicKey) { + if (doc instanceof LegacySignedIdentityDocument signedDoc) { + return validateLegacySignature(signedDoc, publicKey); + } else if (doc instanceof DefaultSignedIdentityDocument signedDoc) { + try { + Signature signer = SignatureUtils.createVerifier(publicKey); + signer.initVerify(publicKey); + signer.update(signedDoc.data().getBytes(UTF_8)); + return signer.verify(Base64.getDecoder().decode(doc.signature())); + } catch (GeneralSecurityException e) { + throw new RuntimeException(e); + } + } else { + throw new IllegalArgumentException("Unknown identity document type: " + doc.getClass().getName()); + } + } + + private boolean validateLegacySignature(SignedIdentityDocument doc, PublicKey publicKey) { try { + IdentityDocument iddoc = doc.identityDocument(); Signature signer = SignatureUtils.createVerifier(publicKey); signer.initVerify(publicKey); writeToSigner( - signer, doc.providerUniqueId(), doc.providerService(), doc.configServerHostname(), - doc.instanceHostname(), doc.createdAt(), doc.ipAddresses(), doc.identityType()); - if (doc.documentVersion() >= DEFAULT_DOCUMENT_VERSION) { - writeToSigner(signer, doc.serviceIdentity()); + signer, iddoc.providerUniqueId(), iddoc.providerService(), iddoc.configServerHostname(), + iddoc.instanceHostname(), iddoc.createdAt(), iddoc.ipAddresses(), iddoc.identityType()); + if (doc.documentVersion() >= LEGACY_DEFAULT_DOCUMENT_VERSION) { + writeToSigner(signer, iddoc.serviceIdentity()); } return signer.verify(Base64.getDecoder().decode(doc.signature())); } catch (GeneralSecurityException e) { diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java new file mode 100644 index 00000000000..d699564a4ee --- /dev/null +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java @@ -0,0 +1,392 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.client; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.core.identity.IdentityConfig; +import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; +import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; +import com.yahoo.jdisc.Metric; +import com.yahoo.metrics.ContainerMetrics; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.MutableX509KeyManager; +import com.yahoo.security.Pkcs10Csr; +import com.yahoo.security.SslContextBuilder; +import com.yahoo.security.X509CertificateWithKey; +import com.yahoo.vespa.athenz.api.AthenzAccessToken; +import com.yahoo.vespa.athenz.api.AthenzDomain; +import com.yahoo.vespa.athenz.api.AthenzRole; +import com.yahoo.vespa.athenz.api.AthenzService; +import com.yahoo.vespa.athenz.api.ZToken; +import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient; +import com.yahoo.vespa.athenz.client.zts.ZtsClient; +import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider; +import com.yahoo.vespa.athenz.identity.SiaIdentityProvider; +import com.yahoo.vespa.athenz.utils.SiaUtils; +import com.yahoo.vespa.defaults.Defaults; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.X509ExtendedKeyManager; +import java.net.URI; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; + +import static com.yahoo.security.KeyStoreType.PKCS12; + +/** + * A {@link AthenzIdentityProvider} / {@link ServiceIdentityProvider} component that provides the tenant identity. + * + * @author mortent + * @author bjorncs + */ +// This class should probably not implement ServiceIdentityProvider, +// as that interface is intended for providing the node's identity, not the tenant's application identity. +public final class LegacyAthenzIdentityProviderImpl extends AbstractComponent implements AthenzIdentityProvider, ServiceIdentityProvider { + + private static final Logger log = Logger.getLogger(LegacyAthenzIdentityProviderImpl.class.getName()); + + // TODO Make some of these values configurable through config. Match requested expiration of register/update requests. + // TODO These should match the requested expiration + static final Duration UPDATE_PERIOD = Duration.ofDays(1); + static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90); + private final static Duration ROLE_SSL_CONTEXT_EXPIRY = Duration.ofHours(2); + // TODO CMS expects 10min or less token ttl. Use 10min default until we have configurable expiry + private final static Duration ROLE_TOKEN_EXPIRY = Duration.ofMinutes(10); + + // TODO Make path to trust store paths config + private static final Path CLIENT_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/yahoo_certificate_bundle.pem"); + private static final Path ATHENZ_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/athenz_certificate_bundle.pem"); + + public static final String CERTIFICATE_EXPIRY_METRIC_NAME = ContainerMetrics.ATHENZ_TENANT_CERT_EXPIRY_SECONDS.baseName(); + + private volatile AthenzCredentials credentials; + private final Metric metric; + private final Path trustStore; + private final AthenzCredentialsService athenzCredentialsService; + private final ScheduledExecutorService scheduler; + private final Clock clock; + private final AthenzService identity; + private final URI ztsEndpoint; + + private final MutableX509KeyManager identityKeyManager = new MutableX509KeyManager(); + private final SSLContext identitySslContext; + private final LoadingCache<AthenzRole, X509Certificate> roleSslCertCache; + private final Map<AthenzRole, MutableX509KeyManager> roleKeyManagerCache; + private final LoadingCache<AthenzRole, ZToken> roleSpecificRoleTokenCache; + private final LoadingCache<AthenzDomain, ZToken> domainSpecificRoleTokenCache; + private final LoadingCache<AthenzDomain, AthenzAccessToken> domainSpecificAccessTokenCache; + private final LoadingCache<List<AthenzRole>, AthenzAccessToken> roleSpecificAccessTokenCache; + private final CsrGenerator csrGenerator; + + @Inject + public LegacyAthenzIdentityProviderImpl(IdentityConfig config, Metric metric) { + this(config, + metric, + CLIENT_TRUST_STORE, + new AthenzCredentialsService(config, + createNodeIdentityProvider(config), + Defaults.getDefaults().vespaHostname(), + Clock.systemUTC()), + new ScheduledThreadPoolExecutor(1), + Clock.systemUTC()); + } + + // Test only + LegacyAthenzIdentityProviderImpl(IdentityConfig config, + Metric metric, + Path trustStore, + AthenzCredentialsService athenzCredentialsService, + ScheduledExecutorService scheduler, + Clock clock) { + this.metric = metric; + this.trustStore = trustStore; + this.athenzCredentialsService = athenzCredentialsService; + this.scheduler = scheduler; + this.clock = clock; + this.identity = new AthenzService(config.domain(), config.service()); + this.ztsEndpoint = URI.create(config.ztsUrl()); + roleSslCertCache = crateAutoReloadableCache(ROLE_SSL_CONTEXT_EXPIRY, this::requestRoleCertificate, this.scheduler); + roleKeyManagerCache = new HashMap<>(); + roleSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken); + domainSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken); + domainSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken); + roleSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken); + this.csrGenerator = new CsrGenerator(config.athenzDnsSuffix(), config.configserverIdentityName()); + this.identitySslContext = createIdentitySslContext(identityKeyManager, trustStore); + registerInstance(); + } + + private static <KEY, VALUE> LoadingCache<KEY, VALUE> createCache(Duration expiry, Function<KEY, VALUE> cacheLoader) { + return CacheBuilder.newBuilder() + .refreshAfterWrite(expiry.dividedBy(2).toMinutes(), TimeUnit.MINUTES) + .expireAfterWrite(expiry.toMinutes(), TimeUnit.MINUTES) + .build(new CacheLoader<KEY, VALUE>() { + @Override + public VALUE load(KEY key) { + return cacheLoader.apply(key); + } + }); + } + + private static <KEY, VALUE> LoadingCache<KEY, VALUE> crateAutoReloadableCache(Duration expiry, Function<KEY, VALUE> cacheLoader, ScheduledExecutorService scheduler) { + LoadingCache<KEY, VALUE> cache = createCache(expiry, cacheLoader); + + // The cache above will reload it's contents if and only if a request for the key is made. Scheduling + // a cache reloader to reload all keys in this cache. + scheduler.scheduleAtFixedRate(() -> { cache.asMap().keySet().forEach(cache::getUnchecked);}, + expiry.dividedBy(4).toMinutes(), + expiry.dividedBy(4).toMinutes(), + TimeUnit.MINUTES); + return cache; + } + + private static SSLContext createIdentitySslContext(X509ExtendedKeyManager keyManager, Path trustStore) { + return new SslContextBuilder() + .withKeyManager(keyManager) + .withTrustStore(trustStore) + .build(); + } + + private void registerInstance() { + try { + updateIdentityCredentials(this.athenzCredentialsService.registerInstance()); + this.scheduler.scheduleAtFixedRate(this::refreshCertificate, UPDATE_PERIOD.toMinutes(), UPDATE_PERIOD.toMinutes(), TimeUnit.MINUTES); + this.scheduler.scheduleAtFixedRate(this::reportMetrics, 0, 5, TimeUnit.MINUTES); + } catch (Throwable t) { + throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", t); + } + } + + @Override + public AthenzService identity() { + return identity; + } + + @Override + public String domain() { + return identity.getDomain().getName(); + } + + @Override + public String service() { + return identity.getName(); + } + + @Override + public SSLContext getIdentitySslContext() { + return identitySslContext; + } + + @Override + public X509CertificateWithKey getIdentityCertificateWithKey() { + AthenzCredentials copy = this.credentials; + return new X509CertificateWithKey(copy.getCertificate(), copy.getKeyPair().getPrivate()); + } + + @Override public Path certificatePath() { return athenzCredentialsService.certificatePath(); } + + @Override public Path privateKeyPath() { return athenzCredentialsService.privateKeyPath(); } + + @Override + public SSLContext getRoleSslContext(String domain, String role) { + try { + AthenzRole athenzRole = new AthenzRole(new AthenzDomain(domain), role); + // Make sure to request a certificate which triggers creating a new key manager for this role + X509Certificate x509Certificate = getRoleCertificate(athenzRole); + MutableX509KeyManager keyManager = roleKeyManagerCache.get(athenzRole); + return new SslContextBuilder() + .withKeyManager(keyManager) + .withTrustStore(trustStore) + .build(); + } catch (Exception e) { + throw new AthenzIdentityProviderException("Could not retrieve role certificate: " + e.getMessage(), e); + } + } + + @Override + public String getRoleToken(String domain) { + try { + return domainSpecificRoleTokenCache.get(new AthenzDomain(domain)).getRawToken(); + } catch (Exception e) { + throw new AthenzIdentityProviderException("Could not retrieve role token: " + e.getMessage(), e); + } + } + + @Override + public String getRoleToken(String domain, String role) { + try { + return roleSpecificRoleTokenCache.get(new AthenzRole(domain, role)).getRawToken(); + } catch (Exception e) { + throw new AthenzIdentityProviderException("Could not retrieve role token: " + e.getMessage(), e); + } + } + + @Override + public String getAccessToken(String domain) { + try { + return domainSpecificAccessTokenCache.get(new AthenzDomain(domain)).value(); + } catch (Exception e) { + throw new AthenzIdentityProviderException("Could not retrieve access token: " + e.getMessage(), e); + } + } + + @Override + public String getAccessToken(String domain, List<String> roles) { + try { + List<AthenzRole> roleList = roles.stream() + .map(roleName -> new AthenzRole(domain, roleName)) + .toList(); + return roleSpecificAccessTokenCache.get(roleList).value(); + } catch (Exception e) { + throw new AthenzIdentityProviderException("Could not retrieve access token: " + e.getMessage(), e); + } + } + + @Override + public PrivateKey getPrivateKey() { + return credentials.getKeyPair().getPrivate(); + } + + @Override + public Path trustStorePath() { + return trustStore; + } + + @Override + public List<X509Certificate> getIdentityCertificate() { + return Collections.singletonList(credentials.getCertificate()); + } + + @Override + public X509Certificate getRoleCertificate(String domain, String role) { + return getRoleCertificate(new AthenzRole(new AthenzDomain(domain), role)); + } + + private X509Certificate getRoleCertificate(AthenzRole athenzRole) { + try { + return roleSslCertCache.get(athenzRole); + } catch (Exception e) { + throw new AthenzIdentityProviderException("Could not retrieve role certificate: " + e.getMessage(), e); + } + } + + private void updateIdentityCredentials(AthenzCredentials credentials) { + this.credentials = credentials; + this.identityKeyManager.updateKeystore( + KeyStoreBuilder.withType(PKCS12) + .withKeyEntry("default", credentials.getKeyPair().getPrivate(), credentials.getCertificate()) + .build(), + new char[0]); + } + + private X509Certificate requestRoleCertificate(AthenzRole role) { + var doc = credentials.getIdentityDocument().identityDocument(); + Pkcs10Csr csr = csrGenerator.generateRoleCsr( + identity, role, doc.providerUniqueId(), doc.clusterType(), credentials.getKeyPair()); + try (ZtsClient client = createZtsClient()) { + X509Certificate roleCertificate = client.getRoleCertificate(role, csr); + updateRoleKeyManager(role, roleCertificate); + log.info(String.format("Requester role certificate for role %s, expires: %s", role.toResourceNameString(), roleCertificate.getNotAfter().toInstant().toString())); + return roleCertificate; + } + } + + private void updateRoleKeyManager(AthenzRole role, X509Certificate certificate) { + MutableX509KeyManager keyManager = roleKeyManagerCache.computeIfAbsent(role, r -> new MutableX509KeyManager()); + keyManager.updateKeystore( + KeyStoreBuilder.withType(PKCS12) + .withKeyEntry("default", credentials.getKeyPair().getPrivate(), certificate) + .build(), + new char[0]); + } + + private ZToken createRoleToken(AthenzRole athenzRole) { + try (ZtsClient client = createZtsClient()) { + return client.getRoleToken(athenzRole, ROLE_TOKEN_EXPIRY); + } + } + + private ZToken createRoleToken(AthenzDomain domain) { + try (ZtsClient client = createZtsClient()) { + return client.getRoleToken(domain, ROLE_TOKEN_EXPIRY); + } + } + + private AthenzAccessToken createAccessToken(AthenzDomain domain) { + try (ZtsClient client = createZtsClient()) { + return client.getAccessToken(domain); + } + } + + private AthenzAccessToken createAccessToken(List<AthenzRole> roles) { + try (ZtsClient client = createZtsClient()) { + return client.getAccessToken(roles); + } + } + + private DefaultZtsClient createZtsClient() { + return new DefaultZtsClient.Builder(ztsEndpoint).withSslContext(getIdentitySslContext()).build(); + } + + @Override + public void deconstruct() { + try { + scheduler.shutdownNow(); + scheduler.awaitTermination(AWAIT_TERMINTATION_TIMEOUT.getSeconds(), TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private static SiaIdentityProvider createNodeIdentityProvider(IdentityConfig config) { + return new SiaIdentityProvider( + new AthenzService(config.nodeIdentityName()), SiaUtils.DEFAULT_SIA_DIRECTORY, CLIENT_TRUST_STORE); + } + + private boolean isExpired(AthenzCredentials credentials) { + return clock.instant().isAfter(getExpirationTime(credentials)); + } + + private static Instant getExpirationTime(AthenzCredentials credentials) { + return credentials.getCertificate().getNotAfter().toInstant(); + } + + void refreshCertificate() { + try { + updateIdentityCredentials(isExpired(credentials) + ? athenzCredentialsService.registerInstance() + : athenzCredentialsService.updateCredentials(credentials.getIdentityDocument(), identitySslContext)); + } catch (Throwable t) { + log.log(Level.WARNING, "Failed to update credentials: " + t.getMessage(), t); + } + } + + void reportMetrics() { + try { + Instant expirationTime = getExpirationTime(credentials); + Duration remainingLifetime = Duration.between(clock.instant(), expirationTime); + metric.set(CERTIFICATE_EXPIRY_METRIC_NAME, remainingLifetime.getSeconds(), null); + } catch (Throwable t) { + log.log(Level.WARNING, "Failed to update metrics: " + t.getMessage(), t); + } + } +} + diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java index 2a68f6fd231..513fb4cdbd3 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java @@ -5,8 +5,11 @@ package com.yahoo.vespa.athenz.identityprovider.api; import org.junit.jupiter.api.Test; import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertTrue; /** @@ -15,7 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; class EntityBindingsMapperTest { @Test - public void persists_unknown_json_members() throws IOException { + public void legacy_persists_unknown_json_members() throws IOException { var originalJson = """ { @@ -36,7 +39,8 @@ class EntityBindingsMapperTest { } """; var entity = EntityBindingsMapper.fromString(originalJson); - assertEquals(2, entity.unknownAttributes().size(), entity.unknownAttributes().toString()); + assertInstanceOf(LegacySignedIdentityDocument.class, entity); + assertEquals(2, entity.identityDocument().unknownAttributes().size(), entity.identityDocument().unknownAttributes().toString()); var json = EntityBindingsMapper.toAttestationData(entity); var expectedMemberInJson = "member-in-unknown-object"; @@ -45,4 +49,39 @@ class EntityBindingsMapperTest { assertEquals(EntityBindingsMapper.mapper.readTree(originalJson), EntityBindingsMapper.mapper.readTree(json)); } + @Test + public void reads_unknown_json_members() throws IOException { + var iddoc = """ + { + "provider-unique-id": "0.cluster.instance.app.tenant.us-west-1.test.node", + "provider-service": "domain.service", + "configserver-hostname": "cfg", + "instance-hostname": "host", + "created-at": 12345.0, + "ip-addresses": [], + "identity-type": "node", + "cluster-type": "admin", + "zts-url": "https://zts.url/", + "unknown-string": "string-value", + "unknown-object": { "member-in-unknown-object": 123 } + } + """; + var originalJson = + """ + { + "signature": "sig", + "signing-key-version": 0, + "document-version": 4, + "data": "%s" + } + """.formatted(Base64.getEncoder().encodeToString(iddoc.getBytes(StandardCharsets.UTF_8))); + var entity = EntityBindingsMapper.fromString(originalJson); + assertEquals(2, entity.identityDocument().unknownAttributes().size(), entity.identityDocument().unknownAttributes().toString()); + var json = EntityBindingsMapper.toAttestationData(entity); + + // For the new iddoc format the identity document should be unchanged during serialization/deserialization, + // i.e the signed identity document should be unchanged + assertEquals(EntityBindingsMapper.mapper.readTree(originalJson), EntityBindingsMapper.mapper.readTree(json)); + } + }
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java index c9d2ea581bb..108da9e0136 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java @@ -2,8 +2,8 @@ package com.yahoo.vespa.athenz.identityprovider.client; import com.yahoo.container.core.identity.IdentityConfig; -import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; import com.yahoo.jdisc.Metric; +import com.yahoo.security.AutoReloadingX509KeyManager; import com.yahoo.security.KeyAlgorithm; import com.yahoo.security.KeyStoreBuilder; import com.yahoo.security.KeyStoreType; @@ -13,13 +13,13 @@ import com.yahoo.security.Pkcs10Csr; import com.yahoo.security.Pkcs10CsrBuilder; import com.yahoo.security.SignatureAlgorithm; import com.yahoo.security.X509CertificateBuilder; +import com.yahoo.security.X509CertificateWithKey; import com.yahoo.test.ManualClock; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.io.TempDir; import javax.security.auth.x500.X500Principal; - import java.io.File; import java.io.IOException; import java.math.BigInteger; @@ -33,17 +33,12 @@ import java.util.Date; import java.util.concurrent.ScheduledExecutorService; import java.util.function.Supplier; -import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -/** - * @author mortent - * @author bjorncs - */ public class AthenzIdentityProviderImplTest { @TempDir @@ -85,58 +80,25 @@ public class AthenzIdentityProviderImplTest { } @Test - void component_creation_fails_when_credentials_not_found() { - assertThrows(AthenzIdentityProviderException.class, () -> { - AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class); - when(credentialService.registerInstance()) - .thenThrow(new RuntimeException("athenz unavailable")); - - new AthenzIdentityProviderImpl(IDENTITY_CONFIG, mock(Metric.class), trustStoreFile, credentialService, mock(ScheduledExecutorService.class), new ManualClock(Instant.EPOCH)); - }); - } - - @Test - void metrics_updated_on_refresh() { + void certificate_expiry_metric_is_reported() { ManualClock clock = new ManualClock(Instant.EPOCH); Metric metric = mock(Metric.class); - - AthenzCredentialsService athenzCredentialsService = mock(AthenzCredentialsService.class); - + AutoReloadingX509KeyManager keyManager = mock(AutoReloadingX509KeyManager.class); KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); X509Certificate certificate = getCertificate(keyPair, getExpirationSupplier(clock)); + when(keyManager.getCurrentCertificateWithKey()).thenReturn(new X509CertificateWithKey(certificate, keyPair.getPrivate())); - when(athenzCredentialsService.registerInstance()) - .thenReturn(new AthenzCredentials(certificate, keyPair, null)); - - when(athenzCredentialsService.updateCredentials(any(), any())) - .thenThrow(new RuntimeException("#1")) - .thenThrow(new RuntimeException("#2")) - .thenReturn(new AthenzCredentials(certificate, keyPair, null)); - - AthenzIdentityProviderImpl identityProvider = - new AthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, trustStoreFile, athenzCredentialsService, mock(ScheduledExecutorService.class), clock); - + AthenzIdentityProviderImpl identityProvider = new AthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, trustStoreFile, mock(ScheduledExecutorService.class), clock, keyManager); identityProvider.reportMetrics(); verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any()); - // Advance 1 day, refresh fails, cert is 1 day old clock.advance(Duration.ofDays(1)); - identityProvider.refreshCertificate(); identityProvider.reportMetrics(); verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(1)).getSeconds()), any()); - // Advance 1 more day, refresh fails, cert is 2 days old clock.advance(Duration.ofDays(1)); - identityProvider.refreshCertificate(); identityProvider.reportMetrics(); verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(2)).getSeconds()), any()); - - // Advance 1 more day, refresh succeds, cert is new - clock.advance(Duration.ofDays(1)); - identityProvider.refreshCertificate(); - identityProvider.reportMetrics(); - verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any()); - } private Supplier<Date> getExpirationSupplier(ManualClock clock) { diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java index ff85cb79f02..acb0905700f 100644 --- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java @@ -6,10 +6,13 @@ import com.yahoo.security.KeyUtils; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzService; import com.yahoo.vespa.athenz.identityprovider.api.ClusterType; +import com.yahoo.vespa.athenz.identityprovider.api.DefaultSignedIdentityDocument; +import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper; +import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.IdentityType; +import com.yahoo.vespa.athenz.identityprovider.api.LegacySignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument; import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId; -import com.yahoo.vespa.athenz.utils.AthenzIdentities; import org.junit.jupiter.api.Test; import java.security.KeyPair; @@ -18,6 +21,7 @@ import java.util.Arrays; import java.util.HashSet; import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.TENANT; +import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION; import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -42,32 +46,53 @@ public class IdentityDocumentSignerTest { private static final AthenzIdentity serviceIdentity = new AthenzService("vespa", "node"); @Test - void generates_and_validates_signature() { + void legacy_generates_and_validates_signature() { IdentityDocumentSigner signer = new IdentityDocumentSigner(); + IdentityDocument identityDocument = new IdentityDocument( + id, providerService, configserverHostname, + instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity); String signature = - signer.generateSignature(id, providerService, configserverHostname, instanceHostname, createdAt, - ipAddresses, identityType, keyPair.getPrivate(), serviceIdentity); + signer.generateLegacySignature(identityDocument, keyPair.getPrivate()); - SignedIdentityDocument signedIdentityDocument = new SignedIdentityDocument( - signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname, - instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity); + SignedIdentityDocument signedIdentityDocument = new LegacySignedIdentityDocument( + signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, identityDocument); assertTrue(signer.hasValidSignature(signedIdentityDocument, keyPair.getPublic())); } @Test - void ignores_cluster_type_and_zts_url() { + void generates_and_validates_signature() { IdentityDocumentSigner signer = new IdentityDocumentSigner(); + IdentityDocument identityDocument = new IdentityDocument( + id, providerService, configserverHostname, + instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity); + String data = EntityBindingsMapper.toIdentityDocmentData(identityDocument); String signature = - signer.generateSignature(id, providerService, configserverHostname, instanceHostname, createdAt, - ipAddresses, identityType, keyPair.getPrivate(), serviceIdentity); + signer.generateSignature(data, keyPair.getPrivate()); - var docWithoutIgnoredFields = new SignedIdentityDocument( - signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname, - instanceHostname, createdAt, ipAddresses, identityType, null, null, serviceIdentity); - var docWithIgnoredFields = new SignedIdentityDocument( - signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname, + SignedIdentityDocument signedIdentityDocument = new DefaultSignedIdentityDocument( + signature, KEY_VERSION, DEFAULT_DOCUMENT_VERSION, data); + + assertTrue(signer.hasValidSignature(signedIdentityDocument, keyPair.getPublic())); + } + + @Test + void legacy_ignores_cluster_type_and_zts_url() { + IdentityDocumentSigner signer = new IdentityDocumentSigner(); + IdentityDocument identityDocument = new IdentityDocument( + id, providerService, configserverHostname, instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity); + IdentityDocument withoutIgnoredFields = new IdentityDocument( + id, providerService, configserverHostname, + instanceHostname, createdAt, ipAddresses, identityType, null, null, serviceIdentity); + + String signature = + signer.generateLegacySignature(identityDocument, keyPair.getPrivate()); + + var docWithoutIgnoredFields = new LegacySignedIdentityDocument( + signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, withoutIgnoredFields); + var docWithIgnoredFields = new LegacySignedIdentityDocument( + signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, identityDocument); assertTrue(signer.hasValidSignature(docWithoutIgnoredFields, keyPair.getPublic())); assertEquals(docWithIgnoredFields.signature(), docWithoutIgnoredFields.signature()); @@ -76,16 +101,15 @@ public class IdentityDocumentSignerTest { @Test void validates_signature_for_new_and_old_versions() { IdentityDocumentSigner signer = new IdentityDocumentSigner(); + IdentityDocument identityDocument = new IdentityDocument( + id, providerService, configserverHostname, + instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity); String signature = - signer.generateSignature(id, providerService, configserverHostname, instanceHostname, createdAt, - ipAddresses, identityType, keyPair.getPrivate(), serviceIdentity); + signer.generateLegacySignature(identityDocument, keyPair.getPrivate()); - SignedIdentityDocument signedIdentityDocument = new SignedIdentityDocument( - signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname, - instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity); + SignedIdentityDocument signedIdentityDocument = new LegacySignedIdentityDocument( + signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, identityDocument); assertTrue(signer.hasValidSignature(signedIdentityDocument, keyPair.getPublic())); - } - }
\ No newline at end of file diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java new file mode 100644 index 00000000000..75dc42cd4a6 --- /dev/null +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java @@ -0,0 +1,160 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.athenz.identityprovider.client; + +import com.yahoo.container.core.identity.IdentityConfig; +import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; +import com.yahoo.jdisc.Metric; +import com.yahoo.security.KeyAlgorithm; +import com.yahoo.security.KeyStoreBuilder; +import com.yahoo.security.KeyStoreType; +import com.yahoo.security.KeyStoreUtils; +import com.yahoo.security.KeyUtils; +import com.yahoo.security.Pkcs10Csr; +import com.yahoo.security.Pkcs10CsrBuilder; +import com.yahoo.security.SignatureAlgorithm; +import com.yahoo.security.X509CertificateBuilder; +import com.yahoo.test.ManualClock; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import javax.security.auth.x500.X500Principal; + +import java.io.File; +import java.io.IOException; +import java.math.BigInteger; +import java.nio.file.Path; +import java.security.KeyPair; +import java.security.cert.X509Certificate; +import java.time.Duration; +import java.time.Instant; +import java.time.temporal.ChronoUnit; +import java.util.Date; +import java.util.concurrent.ScheduledExecutorService; +import java.util.function.Supplier; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * @author mortent + * @author bjorncs + */ +public class LegacyAthenzIdentityProviderImplTest { + + @TempDir + public File tempDir; + + public static final Duration certificateValidity = Duration.ofDays(30); + + private static final IdentityConfig IDENTITY_CONFIG = + new IdentityConfig(new IdentityConfig.Builder() + .service("tenantService") + .domain("tenantDomain") + .nodeIdentityName("vespa.tenant") + .configserverIdentityName("vespa.configserver") + .loadBalancerAddress("cfg") + .ztsUrl("https:localhost:4443/zts/v1") + .athenzDnsSuffix("dev-us-north-1.vespa.cloud")); + + private final KeyPair caKeypair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + private Path trustStoreFile; + private X509Certificate caCertificate; + + @BeforeEach + public void createTrustStoreFile() throws IOException { + caCertificate = X509CertificateBuilder + .fromKeypair( + caKeypair, + new X500Principal("CN=mydummyca"), + Instant.EPOCH, + Instant.EPOCH.plus(10000, ChronoUnit.DAYS), + SignatureAlgorithm.SHA256_WITH_ECDSA, + BigInteger.ONE) + .build(); + trustStoreFile = File.createTempFile("junit", null, tempDir).toPath(); + KeyStoreUtils.writeKeyStoreToFile( + KeyStoreBuilder.withType(KeyStoreType.JKS) + .withKeyEntry("default", caKeypair.getPrivate(), caCertificate) + .build(), + trustStoreFile); + } + + @Test + void component_creation_fails_when_credentials_not_found() { + assertThrows(AthenzIdentityProviderException.class, () -> { + AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class); + when(credentialService.registerInstance()) + .thenThrow(new RuntimeException("athenz unavailable")); + + new LegacyAthenzIdentityProviderImpl(IDENTITY_CONFIG, mock(Metric.class), trustStoreFile, credentialService, mock(ScheduledExecutorService.class), new ManualClock(Instant.EPOCH)); + }); + } + + @Test + void metrics_updated_on_refresh() { + ManualClock clock = new ManualClock(Instant.EPOCH); + Metric metric = mock(Metric.class); + + AthenzCredentialsService athenzCredentialsService = mock(AthenzCredentialsService.class); + + KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); + X509Certificate certificate = getCertificate(keyPair, getExpirationSupplier(clock)); + + when(athenzCredentialsService.registerInstance()) + .thenReturn(new AthenzCredentials(certificate, keyPair, null)); + + when(athenzCredentialsService.updateCredentials(any(), any())) + .thenThrow(new RuntimeException("#1")) + .thenThrow(new RuntimeException("#2")) + .thenReturn(new AthenzCredentials(certificate, keyPair, null)); + + LegacyAthenzIdentityProviderImpl identityProvider = + new LegacyAthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, trustStoreFile, athenzCredentialsService, mock(ScheduledExecutorService.class), clock); + + identityProvider.reportMetrics(); + verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any()); + + // Advance 1 day, refresh fails, cert is 1 day old + clock.advance(Duration.ofDays(1)); + identityProvider.refreshCertificate(); + identityProvider.reportMetrics(); + verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(1)).getSeconds()), any()); + + // Advance 1 more day, refresh fails, cert is 2 days old + clock.advance(Duration.ofDays(1)); + identityProvider.refreshCertificate(); + identityProvider.reportMetrics(); + verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(2)).getSeconds()), any()); + + // Advance 1 more day, refresh succeds, cert is new + clock.advance(Duration.ofDays(1)); + identityProvider.refreshCertificate(); + identityProvider.reportMetrics(); + verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any()); + + } + + private Supplier<Date> getExpirationSupplier(ManualClock clock) { + return () -> new Date(clock.instant().plus(certificateValidity).toEpochMilli()); + } + + private X509Certificate getCertificate(KeyPair keyPair, Supplier<Date> expiry) { + Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(new X500Principal("CN=dummy"), keyPair, SignatureAlgorithm.SHA256_WITH_ECDSA) + .build(); + return X509CertificateBuilder + .fromCsr(csr, + caCertificate.getSubjectX500Principal(), + Instant.EPOCH, + expiry.get().toInstant(), + caKeypair.getPrivate(), + SignatureAlgorithm.SHA256_WITH_ECDSA, + BigInteger.ONE) + .build(); + } + +} diff --git a/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java b/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java index 3192bb4f225..8e7bf59cd0f 100644 --- a/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java +++ b/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java @@ -86,7 +86,8 @@ class ApacheCluster implements Cluster { SimpleHttpRequest request = new SimpleHttpRequest(wrapped.method(), wrapped.path()); request.setScheme(endpoint.url.getScheme()); request.setAuthority(new URIAuthority(endpoint.url.getHost(), portOf(endpoint.url))); - request.setConfig(requestConfig); + long timeoutMillis = wrapped.timeout() == null ? 190_000 : wrapped.timeout().toMillis() * 11 / 10 + 1_000; + request.setConfig(RequestConfig.copy(requestConfig).setResponseTimeout(Timeout.ofMilliseconds(timeoutMillis)).build()); defaultHeaders.forEach(request::setHeader); wrapped.headers().forEach((name, value) -> request.setHeader(name, value.get())); if (wrapped.body() != null) { @@ -104,11 +105,11 @@ class ApacheCluster implements Cluster { @Override public void failed(Exception ex) { vessel.completeExceptionally(ex); } @Override public void cancelled() { vessel.cancel(false); } }); - long timeoutMillis = wrapped.timeout() == null ? 200_000 : wrapped.timeout().toMillis() * 11 / 10 + 1_000; - Future<?> cancellation = timeoutExecutor.schedule(() -> { - future.cancel(true); - vessel.cancel(true); - }, timeoutMillis, TimeUnit.MILLISECONDS); + // We've seen some requests time out, even with a response timeout, + // so we schedule this to be absolutely sure we don't hang (for ever). + Future<?> cancellation = timeoutExecutor.schedule(() -> { future.cancel(true); vessel.cancel(true); }, + timeoutMillis + 10_000, + TimeUnit.MILLISECONDS); vessel.whenComplete((__, ___) -> cancellation.cancel(true)); } catch (Throwable thrown) { @@ -196,8 +197,7 @@ class ApacheCluster implements Cluster { private static RequestConfig createRequestConfig(FeedClientBuilderImpl b) { RequestConfig.Builder builder = RequestConfig.custom() .setConnectTimeout(Timeout.ofSeconds(10)) - .setConnectionRequestTimeout(Timeout.DISABLED) - .setResponseTimeout(Timeout.ofSeconds(190)); + .setConnectionRequestTimeout(Timeout.DISABLED); if (b.proxy != null) builder.setProxy(new HttpHost(b.proxy.getScheme(), b.proxy.getHost(), b.proxy.getPort())); return builder.build(); } |