aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java4
-rw-r--r--container-disc/abi-spec.json3
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java3
-rw-r--r--container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java2
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java15
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java85
-rw-r--r--node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java15
-rw-r--r--searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp7
-rw-r--r--searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp9
-rw-r--r--searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h2
-rw-r--r--searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp8
-rw-r--r--searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp18
-rw-r--r--searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp15
-rw-r--r--searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp2
-rw-r--r--searchlib/src/tests/query/streaming_query_test.cpp12
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp299
-rw-r--r--searchlib/src/vespa/searchcommon/attribute/i_search_context.h5
-rw-r--r--searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/attribute/empty_search_context.h1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h5
-rw-r--r--searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/imported_search_context.h6
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h1
-rw-r--r--searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp11
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp9
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h6
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp9
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp11
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp2
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp3
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp3
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp3
-rw-r--r--searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp3
-rw-r--r--searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h2
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp28
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h34
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h6
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp8
-rw-r--r--searchlib/src/vespa/searchlib/tensor/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/angular_distance.cpp9
-rw-r--r--searchlib/src/vespa/searchlib/tensor/angular_distance.h4
-rw-r--r--searchlib/src/vespa/searchlib/tensor/bound_distance_function.h9
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp45
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_functions.h1
-rw-r--r--searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp7
-rw-r--r--searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp61
-rw-r--r--searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h8
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp60
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hamming_distance.h11
-rw-r--r--searchlib/src/vespa/searchlib/tensor/inner_product_distance.h2
-rw-r--r--searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp81
-rw-r--r--searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h27
-rw-r--r--streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp61
-rw-r--r--streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp12
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp2
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp21
-rw-r--r--streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h13
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java14
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java124
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java54
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java7
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java6
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java52
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java12
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java51
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java57
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java124
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java18
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java117
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java38
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java16
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java50
-rw-r--r--vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java392
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java43
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java50
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java68
-rw-r--r--vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java160
-rw-r--r--vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java16
97 files changed, 2075 insertions, 569 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java
index 2b5dadf5512..c791fea3a56 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java
@@ -17,7 +17,7 @@ import java.net.URI;
* @author mortent
*/
public class IdentityProvider extends SimpleComponent implements IdentityConfig.Producer {
- public static final String CLASS = "com.yahoo.vespa.athenz.identityprovider.client.AthenzIdentityProviderImpl";
+ public static final String CLASS = "com.yahoo.vespa.athenz.identityprovider.client.AthenzIdentityProviderProvider";
public static final String BUNDLE = "vespa-athenz";
private final AthenzDomain domain;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
index 36d34b99223..f74c218a906 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.container.xml;
+import com.yahoo.component.ComponentId;
import com.yahoo.config.provision.ClusterInfo;
import com.yahoo.config.provision.IntRange;
import com.yahoo.component.ComponentSpecification;
@@ -1169,6 +1170,9 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> {
ztsUrl,
zoneDnsSuffix,
zone);
+
+ // Replace AthenzIdentityProviderProvider
+ cluster.removeComponent(ComponentId.fromString("com.yahoo.container.jdisc.AthenzIdentityProviderProvider"));
cluster.addComponent(identityProvider);
cluster.getContainers().forEach(container -> {
diff --git a/container-disc/abi-spec.json b/container-disc/abi-spec.json
index dd681e4124f..75246a77e03 100644
--- a/container-disc/abi-spec.json
+++ b/container-disc/abi-spec.json
@@ -19,7 +19,8 @@
"public abstract java.util.List getIdentityCertificate()",
"public abstract java.security.cert.X509Certificate getRoleCertificate(java.lang.String, java.lang.String)",
"public abstract java.security.PrivateKey getPrivateKey()",
- "public abstract java.nio.file.Path trustStorePath()"
+ "public abstract java.nio.file.Path trustStorePath()",
+ "public abstract void deconstruct()"
],
"fields" : [ ]
},
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java
index f04e2291ee8..9d2e06ed9da 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/AthenzIdentityProviderProvider.java
@@ -89,6 +89,9 @@ public class AthenzIdentityProviderProvider implements Provider<AthenzIdentityPr
public Path trustStorePath() {
throw new UnsupportedOperationException(message);
}
+
+ @Override
+ public void deconstruct() {}
}
}
diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
index af5133eceac..46803988b20 100644
--- a/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
+++ b/container-disc/src/main/java/com/yahoo/container/jdisc/athenz/AthenzIdentityProvider.java
@@ -24,5 +24,5 @@ public interface AthenzIdentityProvider {
X509Certificate getRoleCertificate(String domain, String role);
PrivateKey getPrivateKey();
Path trustStorePath();
-
+ void deconstruct();
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java
index 216da0d9045..c83ddbd8045 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ControllerMaintenance.java
@@ -199,7 +199,7 @@ public class ControllerMaintenance extends AbstractComponent {
public SuccessFactorBaseline(SystemName system) {
Objects.requireNonNull(system);
this.defaultSuccessFactorBaseline = 1.0;
- this.deploymentMetricsMaintainerBaseline = 0.95;
+ this.deploymentMetricsMaintainerBaseline = 0.90;
this.trafficFractionUpdater = system.isCd() ? 0.5 : 0.65;
}
}
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index 0a5cd1aacac..d3eef8d24a9 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -48,13 +48,6 @@ public class Flags {
private static volatile TreeMap<FlagId, FlagDefinition> flags = new TreeMap<>();
- public static final UnboundBooleanFlag RECONFIGURE_ALB_TARGETS = defineFeatureFlag(
- "reconfigure-alb-targets", false,
- List.of("bjormel"), "2023-03-24", "2023-04-30",
- "Reconfigure ALB targets",
- "Takes effect on next config server container start",
- ZONE_ID);
-
public static final UnboundBooleanFlag DROP_CACHES = defineFeatureFlag(
"drop-caches", false,
List.of("hakonhall", "baldersheim"), "2023-03-06", "2023-06-05",
@@ -390,7 +383,7 @@ public class Flags {
List.of("olaa"), "2023-04-12", "2023-06-12",
"Whether AthenzCredentialsMaintainer in node-admin should create tenant service identity certificate",
"Takes effect on next tick",
- ZONE_ID, HOSTNAME
+ ZONE_ID, HOSTNAME, VESPA_VERSION, APPLICATION_ID
);
public static final UnboundBooleanFlag ENABLE_CROWDSTRIKE = defineFeatureFlag(
@@ -410,6 +403,12 @@ public class Flags {
"Takes effect at redeployment",
ZONE_ID);
+ public static final UnboundBooleanFlag NEW_IDDOC_LAYOUT = defineFeatureFlag(
+ "new_iddoc_layout", false, List.of("tokle", "bjorncs", "olaa"), "2023-04-24", "2023-05-31",
+ "Whether to use new identity document lauoyt",
+ "Takes effect on node reboot",
+ HOSTNAME);
+
/** WARNING: public for testing: All flags should be defined in {@link Flags}. */
public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners,
String createdAt, String expiresAt, String description,
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
index 3fb9c73367d..3ab1fdf211b 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/maintenance/identity/AthenzCredentialsMaintainer.java
@@ -1,6 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.node.admin.maintenance.identity;
+import com.yahoo.component.Version;
+import com.yahoo.config.provision.ApplicationId;
import com.yahoo.security.KeyAlgorithm;
import com.yahoo.security.KeyUtils;
import com.yahoo.security.Pkcs10Csr;
@@ -13,6 +15,7 @@ import com.yahoo.vespa.athenz.client.zts.ZtsClient;
import com.yahoo.vespa.athenz.client.zts.ZtsClientException;
import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.client.CsrGenerator;
@@ -76,6 +79,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
private final ServiceIdentityProvider hostIdentityProvider;
private final IdentityDocumentClient identityDocumentClient;
private final BooleanFlag tenantServiceIdentityFlag;
+ private final BooleanFlag useNewIdentityDocumentLayout;
// Used as an optimization to ensure ZTS is not DDoS'ed on continuously failing refresh attempts
private final Map<ContainerName, Instant> lastRefreshAttempt = new ConcurrentHashMap<>();
@@ -97,13 +101,20 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
new AthenzIdentityVerifier(Set.of(configServerInfo.getConfigServerIdentity())));
this.clock = clock;
this.tenantServiceIdentityFlag = Flags.NODE_ADMIN_TENANT_SERVICE_REGISTRY.bindTo(flagSource);
+ this.useNewIdentityDocumentLayout = Flags.NEW_IDDOC_LAYOUT.bindTo(flagSource);
}
public boolean converge(NodeAgentContext context) {
var modified = false;
modified |= maintain(context, NODE);
+
+ if (context.zone().getSystemName().isPublic())
+ return modified;
+
if (shouldWriteTenantServiceIdentity(context))
modified |= maintain(context, TENANT);
+ else
+ modified |= deleteTenantCredentials(context);
return modified;
}
@@ -114,7 +125,10 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
context.log(logger, Level.FINE, "Checking certificate");
ContainerPath siaDirectory = context.paths().of(CONTAINER_SIA_DIRECTORY, context.users().vespa());
ContainerPath identityDocumentFile = siaDirectory.resolve(identityType.getIdentityDocument());
- AthenzIdentity athenzIdentity = getAthenzIdentity(context, identityType, identityDocumentFile);
+ Optional<AthenzIdentity> optionalAthenzIdentity = getAthenzIdentity(context, identityType, identityDocumentFile);
+ if (optionalAthenzIdentity.isEmpty())
+ return false;
+ AthenzIdentity athenzIdentity = optionalAthenzIdentity.get();
ContainerPath privateKeyFile = (ContainerPath) SiaUtils.getPrivateKeyFile(siaDirectory, athenzIdentity);
ContainerPath certificateFile = (ContainerPath) SiaUtils.getCertificateFile(siaDirectory, athenzIdentity);
if (!Files.exists(privateKeyFile) || !Files.exists(certificateFile) || !Files.exists(identityDocumentFile)) {
@@ -130,7 +144,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
Instant now = clock.instant();
Instant expiry = certificate.getNotAfter().toInstant();
var doc = EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile);
- if (doc.outdated()) {
+ if (refreshIdentityDocument(doc, context)) {
context.log(logger, "Identity document is outdated (version=%d)", doc.documentVersion());
registerIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, identityType, athenzIdentity);
return true;
@@ -150,7 +164,7 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
return false;
} else {
lastRefreshAttempt.put(context.containerName(), now);
- refreshIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, doc, identityType, athenzIdentity);
+ refreshIdentity(context, privateKeyFile, certificateFile, identityDocumentFile, doc.identityDocument(), identityType, athenzIdentity);
return true;
}
}
@@ -161,6 +175,11 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
}
}
+ private boolean refreshIdentityDocument(SignedIdentityDocument signedIdentityDocument, NodeAgentContext context) {
+ int expectedVersion = documentVersion(context);
+ return signedIdentityDocument.outdated() || signedIdentityDocument.documentVersion() != expectedVersion;
+ }
+
public void clearCredentials(NodeAgentContext context) {
FileFinder.files(context.paths().of(CONTAINER_SIA_DIRECTORY))
.deleteRecursively(context);
@@ -187,6 +206,23 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
return "node-certificate";
}
+ private boolean deleteTenantCredentials(NodeAgentContext context) {
+ var siaDirectory = context.paths().of(CONTAINER_SIA_DIRECTORY, context.users().vespa());
+ var identityDocumentFile = siaDirectory.resolve(TENANT.getIdentityDocument());
+ if (!Files.exists(identityDocumentFile)) return false;
+ return getAthenzIdentity(context, TENANT, identityDocumentFile).map(athenzIdentity -> {
+ var privateKeyFile = (ContainerPath) SiaUtils.getPrivateKeyFile(siaDirectory, athenzIdentity);
+ var certificateFile = (ContainerPath) SiaUtils.getCertificateFile(siaDirectory, athenzIdentity);
+ try {
+ return Files.deleteIfExists(identityDocumentFile) ||
+ Files.deleteIfExists(privateKeyFile) ||
+ Files.deleteIfExists(certificateFile);
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }).orElse(false);
+ }
+
private boolean shouldRefreshCredentials(Duration age) {
return age.compareTo(REFRESH_PERIOD) >= 0;
}
@@ -200,7 +236,8 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
private void registerIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile, ContainerPath identityDocumentFile, IdentityType identityType, AthenzIdentity identity) {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
- SignedIdentityDocument doc = signedIdentityDocument(context, identityType);
+ SignedIdentityDocument signedDoc = signedIdentityDocument(context, identityType);
+ IdentityDocument doc = signedDoc.identityDocument();
CsrGenerator csrGenerator = new CsrGenerator(certificateDnsSuffix, doc.providerService().getFullName());
Pkcs10Csr csr = csrGenerator.generateInstanceCsr(
identity, doc.providerUniqueId(), doc.ipAddresses(), doc.clusterType(), keyPair);
@@ -212,9 +249,9 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
ztsClient.registerInstance(
doc.providerService(),
identity,
- EntityBindingsMapper.toAttestationData(doc),
+ EntityBindingsMapper.toAttestationData(signedDoc),
csr);
- EntityBindingsMapper.writeSignedIdentityDocumentToFile(identityDocumentFile, doc);
+ EntityBindingsMapper.writeSignedIdentityDocumentToFile(identityDocumentFile, signedDoc);
writePrivateKeyAndCertificate(privateKeyFile, keyPair.getPrivate(), certificateFile, instanceIdentity.certificate());
context.log(logger, "Instance successfully registered and credentials written to file");
}
@@ -223,14 +260,14 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
/**
* Return zts url from identity document, fallback to ztsEndpoint
*/
- private URI ztsEndpoint(SignedIdentityDocument doc) {
+ private URI ztsEndpoint(IdentityDocument doc) {
return Optional.ofNullable(doc.ztsUrl())
.filter(s -> !s.isBlank())
.map(URI::create)
.orElse(ztsEndpoint);
}
private void refreshIdentity(NodeAgentContext context, ContainerPath privateKeyFile, ContainerPath certificateFile,
- ContainerPath identityDocumentFile, SignedIdentityDocument doc, IdentityType identityType, AthenzIdentity identity) {
+ ContainerPath identityDocumentFile, IdentityDocument doc, IdentityType identityType, AthenzIdentity identity) {
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
CsrGenerator csrGenerator = new CsrGenerator(certificateDnsSuffix, doc.providerService().getFullName());
Pkcs10Csr csr = csrGenerator.generateInstanceCsr(
@@ -291,32 +328,48 @@ public class AthenzCredentialsMaintainer implements CredentialsMaintainer {
private SignedIdentityDocument signedIdentityDocument(NodeAgentContext context, IdentityType identityType) {
return switch (identityType) {
- case NODE -> identityDocumentClient.getNodeIdentityDocument(context.hostname().value());
- case TENANT -> identityDocumentClient.getTenantIdentityDocument(context.hostname().value());
+ case NODE -> identityDocumentClient.getNodeIdentityDocument(context.hostname().value(), documentVersion(context));
+ case TENANT -> identityDocumentClient.getTenantIdentityDocument(context.hostname().value(), documentVersion(context)).get();
};
}
- private AthenzIdentity getAthenzIdentity(NodeAgentContext context, IdentityType identityType, ContainerPath identityDocumentFile) {
+ private Optional<AthenzIdentity> getAthenzIdentity(NodeAgentContext context, IdentityType identityType, ContainerPath identityDocumentFile) {
return switch (identityType) {
- case NODE -> context.identity();
+ case NODE -> Optional.of(context.identity());
case TENANT -> getTenantIdentity(context, identityDocumentFile);
};
}
- private AthenzIdentity getTenantIdentity(NodeAgentContext context, ContainerPath identityDocumentFile) {
+ private Optional<AthenzIdentity> getTenantIdentity(NodeAgentContext context, ContainerPath identityDocumentFile) {
if (Files.exists(identityDocumentFile)) {
- return EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile).serviceIdentity();
+ return Optional.of(EntityBindingsMapper.readSignedIdentityDocumentFromFile(identityDocumentFile).identityDocument().serviceIdentity());
} else {
- return identityDocumentClient.getTenantIdentityDocument(context.hostname().value()).serviceIdentity();
+ return identityDocumentClient.getTenantIdentityDocument(context.hostname().value(), documentVersion(context))
+ .map(doc -> doc.identityDocument().serviceIdentity());
}
}
private boolean shouldWriteTenantServiceIdentity(NodeAgentContext context) {
+ var version = context.node().currentVespaVersion()
+ .orElse(context.node().wantedVespaVersion().orElse(Version.emptyVersion));
+ var appId = context.node().owner().orElse(ApplicationId.defaultId());
return tenantServiceIdentityFlag
- .with(FetchVector.Dimension.HOSTNAME, context.hostname().value())
+ .with(FetchVector.Dimension.VESPA_VERSION, version.toFullString())
+ .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm())
.value();
}
+ /*
+ Get the document version to ask for
+ */
+ private int documentVersion(NodeAgentContext context) {
+ return useNewIdentityDocumentLayout
+ .with(FetchVector.Dimension.HOSTNAME, context.hostname().value())
+ .value()
+ ? SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION
+ : SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION;
+ }
+
enum IdentityType {
NODE("vespa-node-identity-document.json"),
TENANT("vespa-tenant-identity-document.json");
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
index 7c84afc8397..025a04a15d6 100644
--- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
+++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java
@@ -506,14 +506,6 @@ public class NodeAgentImpl implements NodeAgent {
storageMaintainer.cleanDiskIfFull(context);
storageMaintainer.handleCoreDumpsForContainer(context, container, false);
- // TODO: this is a workaround for restarting wireguard as early as possible after host-admin has been down.
- var runOrdinaryWireguardTasks = true;
- if (container.isPresent() && container.get().state().isRunning()) {
- Optional<Container> finalContainer = container;
- wireguardTasks.forEach(task -> task.converge(context, finalContainer.get().id()));
- runOrdinaryWireguardTasks = false;
- }
-
if (downloadImageIfNeeded(context, container)) {
context.log(logger, "Waiting for image to download " + context.node().wantedDockerImage().get().asString());
return;
@@ -525,16 +517,13 @@ public class NodeAgentImpl implements NodeAgent {
containerState = STARTING;
container = Optional.of(startContainer(context));
containerState = UNKNOWN;
- runOrdinaryWireguardTasks = true;
} else {
container = Optional.of(updateContainerIfNeeded(context, container.get()));
}
aclMaintainer.ifPresent(maintainer -> maintainer.converge(context));
- if (runOrdinaryWireguardTasks) {
- Optional<Container> finalContainer = container;
- wireguardTasks.forEach(task -> task.converge(context, finalContainer.get().id()));
- }
+ final Optional<Container> finalContainer = container;
+ wireguardTasks.forEach(task -> task.converge(context, finalContainer.get().id()));
startServicesIfNeeded(context);
resumeNodeIfNeeded(context);
if (healthChecker.isPresent()) {
diff --git a/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp b/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp
index 2027ad56768..fe7303692ba 100644
--- a/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp
+++ b/searchcore/src/tests/proton/server/shared_threading_service/shared_threading_service_test.cpp
@@ -20,7 +20,6 @@ ProtonConfig
make_proton_config(double concurrency, uint32_t indexing_threads = 1)
{
ProtonConfigBuilder builder;
- // This setup requires a minimum of 4 shared threads.
builder.documentdb.push_back(ProtonConfig::Documentdb());
builder.documentdb.push_back(ProtonConfig::Documentdb());
builder.flush.maxconcurrent = 1;
@@ -48,8 +47,10 @@ expect_field_writer_threads(uint32_t exp_threads, uint32_t cpu_cores, uint32_t i
TEST(SharedThreadingServiceConfigTest, shared_threads_are_derived_from_cpu_cores_and_feeding_concurrency)
{
- expect_shared_threads(4, 1);
- expect_shared_threads(4, 6);
+ expect_shared_threads(2, 1);
+ expect_shared_threads(2, 4);
+ expect_shared_threads(3, 5);
+ expect_shared_threads(3, 6);
expect_shared_threads(4, 8);
expect_shared_threads(5, 9);
expect_shared_threads(5, 10);
diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp
index 82e1aa3b57c..430412fc1c0 100644
--- a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp
+++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.cpp
@@ -137,8 +137,15 @@ SearchContext::getStore() const
SearchContext::SearchContext(QueryTermSimple::UP qTerm, const DocumentMetaStore &toBeSearched)
: search::attribute::SearchContext(toBeSearched),
- _isWord(qTerm->isWord())
+ _isWord(qTerm->isWord()),
+ _docid_limit(toBeSearched.getCommittedDocIdLimit())
{
}
+uint32_t
+SearchContext::get_committed_docid_limit() const noexcept
+{
+ return _docid_limit;
+}
+
}
diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h
index ca4b026e2a4..7c88d8f3502 100644
--- a/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h
+++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/search_context.h
@@ -18,6 +18,7 @@ private:
bool _isWord;
document::GlobalId _gid;
+ uint32_t _docid_limit;
unsigned int approximateHits() const override;
int32_t onFind(DocId docId, int32_t elemId, int32_t &weight) const override;
@@ -30,6 +31,7 @@ private:
public:
SearchContext(std::unique_ptr<search::QueryTermSimple> qTerm, const DocumentMetaStore &toBeSearched);
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp
index 50ef5039e75..c1802b40deb 100644
--- a/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp
+++ b/searchcore/src/vespa/searchcore/proton/server/shared_threading_service_config.cpp
@@ -28,10 +28,10 @@ namespace {
uint32_t
derive_shared_threads(const ProtonConfig& cfg, const HwInfo::Cpu& cpu_info)
{
- uint32_t scaled_cores = (uint32_t)std::ceil(cpu_info.cores() * cfg.feeding.concurrency);
+ uint32_t scaled_cores = uint32_t(std::ceil(cpu_info.cores() * cfg.feeding.concurrency));
// We need at least 1 guaranteed free worker in order to ensure progress.
- return std::max(scaled_cores, (uint32_t)cfg.documentdb.size() + cfg.flush.maxconcurrent + 1);
+ return std::max(scaled_cores, uint32_t(cfg.flush.maxconcurrent + 1u));
}
uint32_t
@@ -42,8 +42,8 @@ derive_warmup_threads(const HwInfo::Cpu& cpu_info) {
uint32_t
derive_field_writer_threads(const ProtonConfig& cfg, const HwInfo::Cpu& cpu_info)
{
- uint32_t scaled_cores = (size_t)std::ceil(cpu_info.cores() * cfg.feeding.concurrency);
- uint32_t field_writer_threads = std::max(scaled_cores, (uint32_t)cfg.indexing.threads);
+ uint32_t scaled_cores = size_t(std::ceil(cpu_info.cores() * cfg.feeding.concurrency));
+ uint32_t field_writer_threads = std::max(scaled_cores, uint32_t(cfg.indexing.threads));
// Originally we used at least 3 threads for writing fields:
// - index field inverter
// - index field writer
diff --git a/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp b/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp
index 9e65dfcfc07..e55344aded0 100644
--- a/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp
+++ b/searchlib/src/tests/attribute/imported_attribute_vector/imported_attribute_vector_test.cpp
@@ -240,6 +240,24 @@ TEST_F("original lid range is used by read guard", Fixture)
EXPECT_EQUAL(getUndefined<int>(), first_guard->getInt(DocId(10)));
}
+TEST_F("Original target lid range is used by read guard", Fixture)
+{
+ reset_with_single_value_reference_mappings<IntegerAttribute, int32_t>(
+ f, BasicType::INT32,
+ {});
+ EXPECT_EQUAL(11u, f.target_attr->getNumDocs());
+ auto first_guard = f.get_imported_attr();
+ add_n_docs_with_undefined_values(*f.target_attr, 1);
+ EXPECT_EQUAL(12u, f.target_attr->getNumDocs());
+ auto typed_target_attr = f.template target_attr_as<IntegerAttribute>();
+ ASSERT_TRUE(typed_target_attr->update(11, 2345));
+ f.target_attr->commit();
+ f.map_reference(DocId(8), dummy_gid(11), DocId(11));
+ auto second_guard = f.get_imported_attr();
+ EXPECT_EQUAL(2345, second_guard->getInt(DocId(8)));
+ EXPECT_NOT_EQUAL(2345, first_guard->getInt(DocId(8)));
+}
+
struct SingleStringAttrFixture : Fixture {
SingleStringAttrFixture() : Fixture() {
setup();
diff --git a/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp b/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp
index 847a992d241..19327245083 100644
--- a/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp
+++ b/searchlib/src/tests/attribute/imported_search_context/imported_search_context_test.cpp
@@ -429,6 +429,21 @@ TEST_F("original lid range is used by search context", SingleValueFixture)
EXPECT_TRUE(second_ctx->matches(DocId(10)));
}
+TEST_F("Original target lid range is used by search context", SingleValueFixture)
+{
+ EXPECT_EQUAL(11u, f.target_attr->getNumDocs());
+ auto first_ctx = f.create_context(word_term("2345"));
+ add_n_docs_with_undefined_values(*f.target_attr, 1);
+ EXPECT_EQUAL(12u, f.target_attr->getNumDocs());
+ auto typed_target_attr = f.template target_attr_as<IntegerAttribute>();
+ ASSERT_TRUE(typed_target_attr->update(11, 2345));
+ f.target_attr->commit();
+ f.map_reference(DocId(8), dummy_gid(11), DocId(11));
+ auto second_ctx = f.create_context(word_term("2345"));
+ EXPECT_FALSE(first_ctx->matches(DocId(8)));
+ EXPECT_TRUE(second_ctx->matches(DocId(8)));
+}
+
// Note: this uses an underlying string attribute, as queryTerm() does not seem to
// implemented at all for (single) numeric attributes. Intentional?
TEST_F("queryTerm() returns term context was created with", WsetValueFixture) {
diff --git a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp
index 2804e3f74e4..5ba90d2b077 100644
--- a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp
+++ b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp
@@ -390,7 +390,7 @@ TEST("testSingleValue")
{
EXPECT_EQUAL(24u, sizeof(SearchContext));
EXPECT_EQUAL(32u, sizeof(StringSearchHelper));
- EXPECT_EQUAL(80u, sizeof(attribute::SingleStringEnumSearchContext));
+ EXPECT_EQUAL(88u, sizeof(attribute::SingleStringEnumSearchContext));
{
Config cfg(BasicType::STRING, CollectionType::SINGLE);
SingleValueStringAttribute svsa("svsa", cfg);
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp
index 2c202d9131b..210f32af15e 100644
--- a/searchlib/src/tests/query/streaming_query_test.cpp
+++ b/searchlib/src/tests/query/streaming_query_test.cpp
@@ -814,7 +814,7 @@ TEST("test_nearest_neighbor_query_node")
constexpr uint32_t target_num_hits = 100;
constexpr bool allow_approximate = false;
constexpr uint32_t explore_additional_hits = 800;
- constexpr double raw_score = 0.5;
+ constexpr double distance = 0.5;
builder.add_nearest_neighbor_term("qtensor", "field", id, Weight(weight), target_num_hits, allow_approximate, explore_additional_hits, distance_threshold);
auto build_node = builder.build();
auto stack_dump = StackDumpCreator::create(*build_node);
@@ -830,14 +830,14 @@ TEST("test_nearest_neighbor_query_node")
EXPECT_EQUAL(id, static_cast<int32_t>(node->uniqueId()));
EXPECT_EQUAL(weight, node->weight().percent());
EXPECT_EQUAL(distance_threshold, node->get_distance_threshold());
- EXPECT_FALSE(node->get_raw_score().has_value());
+ EXPECT_FALSE(node->get_distance().has_value());
EXPECT_FALSE(node->evaluate());
- node->set_raw_score(raw_score);
- EXPECT_TRUE(node->get_raw_score().has_value());
- EXPECT_EQUAL(raw_score, node->get_raw_score().value());
+ node->set_distance(distance);
+ EXPECT_TRUE(node->get_distance().has_value());
+ EXPECT_EQUAL(distance, node->get_distance().value());
EXPECT_TRUE(node->evaluate());
node->reset();
- EXPECT_FALSE(node->get_raw_score().has_value());
+ EXPECT_FALSE(node->get_distance().has_value());
EXPECT_FALSE(node->evaluate());
}
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index ae283f3f2b2..9b8ad0d26ce 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -18,14 +18,17 @@ using search::attribute::DistanceMetric;
template <typename T>
TypedCells t(const std::vector<T> &v) { return TypedCells(v); }
-void verify_geo_miles(const DistanceFunction *dist_fun,
- const std::vector<double> &p1,
+void verify_geo_miles(const std::vector<double> &p1,
const std::vector<double> &p2,
double exp_miles)
{
+ static GeoDistanceFunctionFactory dff;
TypedCells t1(p1);
TypedCells t2(p2);
- double abstract_distance = dist_fun->calc(t1, t2);
+ auto dist_fun = dff.for_query_vector(t1);
+ double abstract_distance = dist_fun->calc(t2);
+ EXPECT_EQ(dff.for_insertion_vector(t1)->calc(t2), abstract_distance);
+ EXPECT_FLOAT_EQ(dff.for_query_vector(t2)->calc(t1), abstract_distance);
double raw_score = dist_fun->to_rawscore(abstract_distance);
double km = ((1.0/raw_score)-1.0);
double d_miles = km / 1.609344;
@@ -69,6 +72,8 @@ double computeEuclideanChecked(TypedCells a, TypedCells b) {
return result;
}
+namespace { const double sq_root_half = std::sqrt(0.5); }
+
TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
{
auto ct = vespalib::eval::CellType::DOUBLE;
@@ -79,7 +84,7 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score)
std::vector<double> p1{1.0, 0.0, 0.0};
std::vector<double> p2{0.0, 1.0, 0.0};
std::vector<double> p3{0.0, 0.0, 1.0};
- std::vector<double> p4{0.5, 0.5, 0.707107};
+ std::vector<double> p4{0.5, 0.5, sq_root_half};
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
@@ -179,7 +184,7 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
std::vector<double> p1{1.0, 0.0, 0.0};
std::vector<double> p2{0.0, 1.0, 0.0};
std::vector<double> p3{0.0, 0.0, 1.0};
- std::vector<double> p4{0.5, 0.5, 0.707107};
+ std::vector<double> p4{0.5, 0.5, sq_root_half};
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
@@ -207,7 +212,7 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
EXPECT_DOUBLE_EQ(threshold, 0.5);
double a34 = computeAngularChecked(t(p3), t(p4));
- EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107));
+ EXPECT_FLOAT_EQ(a34, (1.0 - sq_root_half));
EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4));
threshold = angular->convert_threshold(pi/4);
EXPECT_FLOAT_EQ(threshold, a34);
@@ -257,6 +262,89 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score)
EXPECT_DOUBLE_EQ(a66, computeAngularChecked(t(iv6), t(iv6)));
}
+double computePrenormalizedAngularChecked(TypedCells a, TypedCells b) {
+ static PrenormalizedAngularDistanceFunctionFactory<float> flt_dff;
+ static PrenormalizedAngularDistanceFunctionFactory<double> dbl_dff;
+ auto d_n = dbl_dff.for_query_vector(a);
+ auto d_f = flt_dff.for_query_vector(a);
+ auto d_r = dbl_dff.for_query_vector(b);
+ auto d_i = dbl_dff.for_insertion_vector(a);
+ // normal:
+ double result = d_n->calc(b);
+ // insert is exactly same:
+ EXPECT_EQ(d_i->calc(b), result);
+ // note: for this distance, reverse is not necessarily equal,
+ // since we normalize based on length of LHS only
+ EXPECT_FLOAT_EQ(d_r->calc(a), result);
+ // float factory:
+ EXPECT_FLOAT_EQ(d_f->calc(b), result);
+ double closeness_n = d_n->to_rawscore(result);
+ double closeness_f = d_f->to_rawscore(result);
+ double closeness_r = d_r->to_rawscore(result);
+ double closeness_i = d_i->to_rawscore(result);
+ EXPECT_DOUBLE_EQ(closeness_n, closeness_f);
+ EXPECT_DOUBLE_EQ(closeness_n, closeness_r);
+ EXPECT_DOUBLE_EQ(closeness_n, closeness_i);
+ EXPECT_GT(closeness_n, 0.0);
+ EXPECT_LE(closeness_n, 1.0);
+ return result;
+}
+
+TEST(DistanceFunctionsTest, prenormalized_angular_gives_expected_score)
+{
+ std::vector<double> p0{0.0, 0.0, 0.0};
+ std::vector<double> p1{1.0, 0.0, 0.0};
+ std::vector<double> p2{0.0, 1.0, 0.0};
+ std::vector<double> p3{0.0, 0.0, 1.0};
+ std::vector<double> p4{0.5, 0.5, sq_root_half};
+ std::vector<double> p5{0.0,-1.0, 0.0};
+ std::vector<double> p6{1.0, 2.0, 2.0};
+ std::vector<double> p7{2.0, -1.0, -2.0};
+ std::vector<double> p8{3.0, 0.0, 0.0};
+
+ PrenormalizedAngularDistanceFunctionFactory<double> dff;
+ auto pnad = dff.for_query_vector(t(p0));
+
+ double i12 = computePrenormalizedAngularChecked(t(p1), t(p2));
+ double i13 = computePrenormalizedAngularChecked(t(p1), t(p3));
+ double i23 = computePrenormalizedAngularChecked(t(p2), t(p3));
+ EXPECT_DOUBLE_EQ(i12, 1.0);
+ EXPECT_DOUBLE_EQ(i13, 1.0);
+ EXPECT_DOUBLE_EQ(i23, 1.0);
+
+ double i14 = computePrenormalizedAngularChecked(t(p1), t(p4));
+ double i24 = computePrenormalizedAngularChecked(t(p2), t(p4));
+ EXPECT_DOUBLE_EQ(i14, 0.5);
+ EXPECT_DOUBLE_EQ(i24, 0.5);
+ double i34 = computePrenormalizedAngularChecked(t(p3), t(p4));
+ EXPECT_FLOAT_EQ(i34, 1.0 - sq_root_half);
+
+ double i25 = computePrenormalizedAngularChecked(t(p2), t(p5));
+ EXPECT_DOUBLE_EQ(i25, 2.0);
+
+ double i44 = computePrenormalizedAngularChecked(t(p4), t(p4));
+ EXPECT_GE(i44, 0.0);
+ EXPECT_LT(i44, 0.000001);
+
+ double i66 = computePrenormalizedAngularChecked(t(p6), t(p6));
+ EXPECT_GE(i66, 0.0);
+ EXPECT_LT(i66, 0.000001);
+
+ double i67 = computePrenormalizedAngularChecked(t(p6), t(p7));
+ EXPECT_DOUBLE_EQ(i67, 13.0);
+ double i68 = computePrenormalizedAngularChecked(t(p6), t(p8));
+ EXPECT_DOUBLE_EQ(i68, 6.0);
+ double i78 = computePrenormalizedAngularChecked(t(p7), t(p8));
+ EXPECT_DOUBLE_EQ(i78, 3.0);
+
+ double threshold = pnad->convert_threshold(0.25);
+ EXPECT_DOUBLE_EQ(threshold, 0.25);
+ threshold = pnad->convert_threshold(0.5);
+ EXPECT_DOUBLE_EQ(threshold, 0.5);
+ threshold = pnad->convert_threshold(1.0);
+ EXPECT_DOUBLE_EQ(threshold, 1.0);
+}
+
TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
{
auto ct = vespalib::eval::CellType::DOUBLE;
@@ -267,7 +355,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
std::vector<double> p1{1.0, 0.0, 0.0};
std::vector<double> p2{0.0, 1.0, 0.0};
std::vector<double> p3{0.0, 0.0, 1.0};
- std::vector<double> p4{0.5, 0.5, 0.707107};
+ std::vector<double> p4{0.5, 0.5, sq_root_half};
std::vector<double> p5{0.0,-1.0, 0.0};
std::vector<double> p6{1.0, 2.0, 2.0};
@@ -283,7 +371,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
EXPECT_DOUBLE_EQ(i14, 0.5);
EXPECT_DOUBLE_EQ(i24, 0.5);
double i34 = innerproduct->calc(t(p3), t(p4));
- EXPECT_FLOAT_EQ(i34, 1.0 - 0.707107);
+ EXPECT_FLOAT_EQ(i34, 1.0 - sq_root_half);
double i25 = innerproduct->calc(t(p2), t(p5));
EXPECT_DOUBLE_EQ(i25, 2.0);
@@ -292,6 +380,10 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
EXPECT_GE(i44, 0.0);
EXPECT_LT(i44, 0.000001);
+ double i66 = innerproduct->calc(t(p6), t(p6));
+ EXPECT_GE(i66, 0.0);
+ EXPECT_LT(i66, 0.000001);
+
double threshold = innerproduct->convert_threshold(0.25);
EXPECT_DOUBLE_EQ(threshold, 0.25);
threshold = innerproduct->convert_threshold(0.5);
@@ -302,6 +394,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score)
TEST(DistanceFunctionsTest, hamming_gives_expected_score)
{
+ static HammingDistanceFunctionFactory<Int8Float> dff;
auto ct = vespalib::eval::CellType::DOUBLE;
auto hamming = make_distance_function(DistanceMetric::Hamming, ct);
@@ -318,6 +411,9 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
double h0 = hamming->calc(t(p), t(p));
EXPECT_EQ(h0, 0.0);
EXPECT_EQ(hamming->to_rawscore(h0), 1.0);
+ auto dist_fun = dff.for_query_vector(t(p));
+ EXPECT_EQ(dist_fun->calc(t(p)), 0.0);
+ EXPECT_EQ(dist_fun->to_rawscore(h0), 1.0);
}
double d12 = hamming->calc(t(points[1]), t(points[2]));
EXPECT_EQ(d12, 3.0);
@@ -350,13 +446,12 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 };
// expect diff: 1 2 1 1 7
EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0);
+ auto dist_fun = dff.for_query_vector(TypedCells(bytes_a));
+ EXPECT_EQ(dist_fun->calc(TypedCells(bytes_b)), 12.0);
}
TEST(GeoDegreesTest, gives_expected_score)
{
- auto ct = vespalib::eval::CellType::DOUBLE;
- auto geodeg = make_distance_function(DistanceMetric::GeoDegrees, ct);
-
std::vector<double> g1_sfo{37.61, -122.38};
std::vector<double> g2_lhr{51.47, -0.46};
std::vector<double> g3_osl{60.20, 11.08};
@@ -367,7 +462,8 @@ TEST(GeoDegreesTest, gives_expected_score)
std::vector<double> g8_lax{33.94, -118.41};
std::vector<double> g9_jfk{40.64, -73.78};
- double g63_a = geodeg->calc(t(g6_trd), t(g3_osl));
+ auto geodeg = GeoDistanceFunctionFactory().for_query_vector(t(g6_trd));
+ double g63_a = geodeg->calc(t(g3_osl));
double g63_r = geodeg->to_rawscore(g63_a);
double g63_km = ((1.0/g63_r)-1.0);
EXPECT_GT(g63_km, 350);
@@ -377,96 +473,95 @@ TEST(GeoDegreesTest, gives_expected_score)
// Great Circle Mapper for airports using
// a more accurate formula - we should agree
// with < 1.0% deviation
- verify_geo_miles(geodeg.get(), g1_sfo, g1_sfo, 0);
- verify_geo_miles(geodeg.get(), g1_sfo, g2_lhr, 5367);
- verify_geo_miles(geodeg.get(), g1_sfo, g3_osl, 5196);
- verify_geo_miles(geodeg.get(), g1_sfo, g4_gig, 6604);
- verify_geo_miles(geodeg.get(), g1_sfo, g5_hkg, 6927);
- verify_geo_miles(geodeg.get(), g1_sfo, g6_trd, 5012);
- verify_geo_miles(geodeg.get(), g1_sfo, g7_syd, 7417);
- verify_geo_miles(geodeg.get(), g1_sfo, g8_lax, 337);
- verify_geo_miles(geodeg.get(), g1_sfo, g9_jfk, 2586);
-
- verify_geo_miles(geodeg.get(), g2_lhr, g1_sfo, 5367);
- verify_geo_miles(geodeg.get(), g2_lhr, g2_lhr, 0);
- verify_geo_miles(geodeg.get(), g2_lhr, g3_osl, 750);
- verify_geo_miles(geodeg.get(), g2_lhr, g4_gig, 5734);
- verify_geo_miles(geodeg.get(), g2_lhr, g5_hkg, 5994);
- verify_geo_miles(geodeg.get(), g2_lhr, g6_trd, 928);
- verify_geo_miles(geodeg.get(), g2_lhr, g7_syd, 10573);
- verify_geo_miles(geodeg.get(), g2_lhr, g8_lax, 5456);
- verify_geo_miles(geodeg.get(), g2_lhr, g9_jfk, 3451);
-
- verify_geo_miles(geodeg.get(), g3_osl, g1_sfo, 5196);
- verify_geo_miles(geodeg.get(), g3_osl, g2_lhr, 750);
- verify_geo_miles(geodeg.get(), g3_osl, g3_osl, 0);
- verify_geo_miles(geodeg.get(), g3_osl, g4_gig, 6479);
- verify_geo_miles(geodeg.get(), g3_osl, g5_hkg, 5319);
- verify_geo_miles(geodeg.get(), g3_osl, g6_trd, 226);
- verify_geo_miles(geodeg.get(), g3_osl, g7_syd, 9888);
- verify_geo_miles(geodeg.get(), g3_osl, g8_lax, 5345);
- verify_geo_miles(geodeg.get(), g3_osl, g9_jfk, 3687);
-
- verify_geo_miles(geodeg.get(), g4_gig, g1_sfo, 6604);
- verify_geo_miles(geodeg.get(), g4_gig, g2_lhr, 5734);
- verify_geo_miles(geodeg.get(), g4_gig, g3_osl, 6479);
- verify_geo_miles(geodeg.get(), g4_gig, g4_gig, 0);
- verify_geo_miles(geodeg.get(), g4_gig, g5_hkg, 10989);
- verify_geo_miles(geodeg.get(), g4_gig, g6_trd, 6623);
- verify_geo_miles(geodeg.get(), g4_gig, g7_syd, 8414);
- verify_geo_miles(geodeg.get(), g4_gig, g8_lax, 6294);
- verify_geo_miles(geodeg.get(), g4_gig, g9_jfk, 4786);
-
- verify_geo_miles(geodeg.get(), g5_hkg, g1_sfo, 6927);
- verify_geo_miles(geodeg.get(), g5_hkg, g2_lhr, 5994);
- verify_geo_miles(geodeg.get(), g5_hkg, g3_osl, 5319);
- verify_geo_miles(geodeg.get(), g5_hkg, g4_gig, 10989);
- verify_geo_miles(geodeg.get(), g5_hkg, g5_hkg, 0);
- verify_geo_miles(geodeg.get(), g5_hkg, g6_trd, 5240);
- verify_geo_miles(geodeg.get(), g5_hkg, g7_syd, 4581);
- verify_geo_miles(geodeg.get(), g5_hkg, g8_lax, 7260);
- verify_geo_miles(geodeg.get(), g5_hkg, g9_jfk, 8072);
-
- verify_geo_miles(geodeg.get(), g6_trd, g1_sfo, 5012);
- verify_geo_miles(geodeg.get(), g6_trd, g2_lhr, 928);
- verify_geo_miles(geodeg.get(), g6_trd, g3_osl, 226);
- verify_geo_miles(geodeg.get(), g6_trd, g4_gig, 6623);
- verify_geo_miles(geodeg.get(), g6_trd, g5_hkg, 5240);
- verify_geo_miles(geodeg.get(), g6_trd, g6_trd, 0);
- verify_geo_miles(geodeg.get(), g6_trd, g7_syd, 9782);
- verify_geo_miles(geodeg.get(), g6_trd, g8_lax, 5171);
- verify_geo_miles(geodeg.get(), g6_trd, g9_jfk, 3611);
-
- verify_geo_miles(geodeg.get(), g7_syd, g1_sfo, 7417);
- verify_geo_miles(geodeg.get(), g7_syd, g2_lhr, 10573);
- verify_geo_miles(geodeg.get(), g7_syd, g3_osl, 9888);
- verify_geo_miles(geodeg.get(), g7_syd, g4_gig, 8414);
- verify_geo_miles(geodeg.get(), g7_syd, g5_hkg, 4581);
- verify_geo_miles(geodeg.get(), g7_syd, g6_trd, 9782);
- verify_geo_miles(geodeg.get(), g7_syd, g7_syd, 0);
- verify_geo_miles(geodeg.get(), g7_syd, g8_lax, 7488);
- verify_geo_miles(geodeg.get(), g7_syd, g9_jfk, 9950);
-
- verify_geo_miles(geodeg.get(), g8_lax, g1_sfo, 337);
- verify_geo_miles(geodeg.get(), g8_lax, g2_lhr, 5456);
- verify_geo_miles(geodeg.get(), g8_lax, g3_osl, 5345);
- verify_geo_miles(geodeg.get(), g8_lax, g4_gig, 6294);
- verify_geo_miles(geodeg.get(), g8_lax, g5_hkg, 7260);
- verify_geo_miles(geodeg.get(), g8_lax, g6_trd, 5171);
- verify_geo_miles(geodeg.get(), g8_lax, g7_syd, 7488);
- verify_geo_miles(geodeg.get(), g8_lax, g8_lax, 0);
- verify_geo_miles(geodeg.get(), g8_lax, g9_jfk, 2475);
-
- verify_geo_miles(geodeg.get(), g9_jfk, g1_sfo, 2586);
- verify_geo_miles(geodeg.get(), g9_jfk, g2_lhr, 3451);
- verify_geo_miles(geodeg.get(), g9_jfk, g3_osl, 3687);
- verify_geo_miles(geodeg.get(), g9_jfk, g4_gig, 4786);
- verify_geo_miles(geodeg.get(), g9_jfk, g5_hkg, 8072);
- verify_geo_miles(geodeg.get(), g9_jfk, g6_trd, 3611);
- verify_geo_miles(geodeg.get(), g9_jfk, g7_syd, 9950);
- verify_geo_miles(geodeg.get(), g9_jfk, g8_lax, 2475);
- verify_geo_miles(geodeg.get(), g9_jfk, g9_jfk, 0);
-
+ verify_geo_miles(g1_sfo, g1_sfo, 0);
+ verify_geo_miles(g1_sfo, g2_lhr, 5367);
+ verify_geo_miles(g1_sfo, g3_osl, 5196);
+ verify_geo_miles(g1_sfo, g4_gig, 6604);
+ verify_geo_miles(g1_sfo, g5_hkg, 6927);
+ verify_geo_miles(g1_sfo, g6_trd, 5012);
+ verify_geo_miles(g1_sfo, g7_syd, 7417);
+ verify_geo_miles(g1_sfo, g8_lax, 337);
+ verify_geo_miles(g1_sfo, g9_jfk, 2586);
+
+ verify_geo_miles(g2_lhr, g1_sfo, 5367);
+ verify_geo_miles(g2_lhr, g2_lhr, 0);
+ verify_geo_miles(g2_lhr, g3_osl, 750);
+ verify_geo_miles(g2_lhr, g4_gig, 5734);
+ verify_geo_miles(g2_lhr, g5_hkg, 5994);
+ verify_geo_miles(g2_lhr, g6_trd, 928);
+ verify_geo_miles(g2_lhr, g7_syd, 10573);
+ verify_geo_miles(g2_lhr, g8_lax, 5456);
+ verify_geo_miles(g2_lhr, g9_jfk, 3451);
+
+ verify_geo_miles(g3_osl, g1_sfo, 5196);
+ verify_geo_miles(g3_osl, g2_lhr, 750);
+ verify_geo_miles(g3_osl, g3_osl, 0);
+ verify_geo_miles(g3_osl, g4_gig, 6479);
+ verify_geo_miles(g3_osl, g5_hkg, 5319);
+ verify_geo_miles(g3_osl, g6_trd, 226);
+ verify_geo_miles(g3_osl, g7_syd, 9888);
+ verify_geo_miles(g3_osl, g8_lax, 5345);
+ verify_geo_miles(g3_osl, g9_jfk, 3687);
+
+ verify_geo_miles(g4_gig, g1_sfo, 6604);
+ verify_geo_miles(g4_gig, g2_lhr, 5734);
+ verify_geo_miles(g4_gig, g3_osl, 6479);
+ verify_geo_miles(g4_gig, g4_gig, 0);
+ verify_geo_miles(g4_gig, g5_hkg, 10989);
+ verify_geo_miles(g4_gig, g6_trd, 6623);
+ verify_geo_miles(g4_gig, g7_syd, 8414);
+ verify_geo_miles(g4_gig, g8_lax, 6294);
+ verify_geo_miles(g4_gig, g9_jfk, 4786);
+
+ verify_geo_miles(g5_hkg, g1_sfo, 6927);
+ verify_geo_miles(g5_hkg, g2_lhr, 5994);
+ verify_geo_miles(g5_hkg, g3_osl, 5319);
+ verify_geo_miles(g5_hkg, g4_gig, 10989);
+ verify_geo_miles(g5_hkg, g5_hkg, 0);
+ verify_geo_miles(g5_hkg, g6_trd, 5240);
+ verify_geo_miles(g5_hkg, g7_syd, 4581);
+ verify_geo_miles(g5_hkg, g8_lax, 7260);
+ verify_geo_miles(g5_hkg, g9_jfk, 8072);
+
+ verify_geo_miles(g6_trd, g1_sfo, 5012);
+ verify_geo_miles(g6_trd, g2_lhr, 928);
+ verify_geo_miles(g6_trd, g3_osl, 226);
+ verify_geo_miles(g6_trd, g4_gig, 6623);
+ verify_geo_miles(g6_trd, g5_hkg, 5240);
+ verify_geo_miles(g6_trd, g6_trd, 0);
+ verify_geo_miles(g6_trd, g7_syd, 9782);
+ verify_geo_miles(g6_trd, g8_lax, 5171);
+ verify_geo_miles(g6_trd, g9_jfk, 3611);
+
+ verify_geo_miles(g7_syd, g1_sfo, 7417);
+ verify_geo_miles(g7_syd, g2_lhr, 10573);
+ verify_geo_miles(g7_syd, g3_osl, 9888);
+ verify_geo_miles(g7_syd, g4_gig, 8414);
+ verify_geo_miles(g7_syd, g5_hkg, 4581);
+ verify_geo_miles(g7_syd, g6_trd, 9782);
+ verify_geo_miles(g7_syd, g7_syd, 0);
+ verify_geo_miles(g7_syd, g8_lax, 7488);
+ verify_geo_miles(g7_syd, g9_jfk, 9950);
+
+ verify_geo_miles(g8_lax, g1_sfo, 337);
+ verify_geo_miles(g8_lax, g2_lhr, 5456);
+ verify_geo_miles(g8_lax, g3_osl, 5345);
+ verify_geo_miles(g8_lax, g4_gig, 6294);
+ verify_geo_miles(g8_lax, g5_hkg, 7260);
+ verify_geo_miles(g8_lax, g6_trd, 5171);
+ verify_geo_miles(g8_lax, g7_syd, 7488);
+ verify_geo_miles(g8_lax, g8_lax, 0);
+ verify_geo_miles(g8_lax, g9_jfk, 2475);
+
+ verify_geo_miles(g9_jfk, g1_sfo, 2586);
+ verify_geo_miles(g9_jfk, g2_lhr, 3451);
+ verify_geo_miles(g9_jfk, g3_osl, 3687);
+ verify_geo_miles(g9_jfk, g4_gig, 4786);
+ verify_geo_miles(g9_jfk, g5_hkg, 8072);
+ verify_geo_miles(g9_jfk, g6_trd, 3611);
+ verify_geo_miles(g9_jfk, g7_syd, 9950);
+ verify_geo_miles(g9_jfk, g8_lax, 2475);
+ verify_geo_miles(g9_jfk, g9_jfk, 0);
}
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchcommon/attribute/i_search_context.h b/searchlib/src/vespa/searchcommon/attribute/i_search_context.h
index 8867d1b87e4..4657d41a4a0 100644
--- a/searchlib/src/vespa/searchcommon/attribute/i_search_context.h
+++ b/searchlib/src/vespa/searchcommon/attribute/i_search_context.h
@@ -70,6 +70,11 @@ public:
bool matches(DocId docId, int32_t &weight) const { return matches(*this, docId, weight); }
bool matches(DocId doc) const { return find(doc, 0) >= 0; }
+ /*
+ * Committed docid limit on attribute vector when search context was
+ * created.
+ */
+ virtual uint32_t get_committed_docid_limit() const noexcept = 0;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp
index 7a5d82cd9ba..91bdb45ff19 100644
--- a/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/empty_search_context.cpp
@@ -30,6 +30,12 @@ EmptySearchContext::approximateHits() const
return 0u;
}
+uint32_t
+EmptySearchContext::get_committed_docid_limit() const noexcept
+{
+ return 0u;
+}
+
std::unique_ptr<queryeval::SearchIterator>
EmptySearchContext::createIterator(fef::TermFieldMatchData*, bool)
{
diff --git a/searchlib/src/vespa/searchlib/attribute/empty_search_context.h b/searchlib/src/vespa/searchlib/attribute/empty_search_context.h
index ae6f6d76edf..133e540d87f 100644
--- a/searchlib/src/vespa/searchlib/attribute/empty_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/empty_search_context.h
@@ -19,6 +19,7 @@ class EmptySearchContext : public SearchContext
public:
EmptySearchContext(const AttributeVector& attr) noexcept;
~EmptySearchContext();
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h b/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h
index 0342976ffd6..86ffa1c8ab0 100644
--- a/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h
+++ b/searchlib/src/vespa/searchlib/attribute/enumhintsearchcontext.h
@@ -41,6 +41,7 @@ protected:
void fetchPostings(const queryeval::ExecuteInfo & execInfo) override;
unsigned int approximateHits() const override;
+ uint32_t get_committed_docid_limit() const noexcept { return _docIdLimit; }
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp
index a1a5e9f7894..b50a3720ff8 100644
--- a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.cpp
@@ -17,12 +17,14 @@ ImportedAttributeVectorReadGuard::ImportedAttributeVectorReadGuard(std::shared_p
_target_document_meta_store_read_guard(std::move(targetMetaStoreReadGuard)),
_imported_attribute(imported_attribute),
_targetLids(),
+ _target_docid_limit(0u),
_reference_attribute_guard(imported_attribute.getReferenceAttribute()),
_target_attribute_guard(imported_attribute.getTargetAttribute()->makeReadGuard(stableEnumGuard)),
_reference_attribute(*imported_attribute.getReferenceAttribute()),
_target_attribute(*_target_attribute_guard->attribute())
{
_targetLids = _reference_attribute.getTargetLids();
+ _target_docid_limit = _target_attribute.getCommittedDocIdLimit();
}
ImportedAttributeVectorReadGuard::~ImportedAttributeVectorReadGuard() = default;
diff --git a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h
index cb48399f688..1297acad9b8 100644
--- a/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h
+++ b/searchlib/src/vespa/searchlib/attribute/imported_attribute_vector_read_guard.h
@@ -95,6 +95,7 @@ private:
std::shared_ptr<MetaStoreReadGuard> _target_document_meta_store_read_guard;
const ImportedAttributeVector &_imported_attribute;
TargetLids _targetLids;
+ uint32_t _target_docid_limit;
AttributeGuard _reference_attribute_guard;
std::unique_ptr<attribute::AttributeReadGuard> _target_attribute_guard;
const ReferenceAttribute &_reference_attribute;
@@ -103,7 +104,9 @@ protected:
uint32_t getTargetLid(uint32_t lid) const {
// Check range to avoid reading memory beyond end of mapping array
- return lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u;
+ uint32_t target_lid = lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u;
+ // Check target range
+ return target_lid < _target_docid_limit ? target_lid : 0u;
}
long onSerializeForAscendingSort(DocId doc, void * serTo, long available,
diff --git a/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp
index 1e8adc3922e..3d308b82b04 100644
--- a/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/imported_search_context.cpp
@@ -43,6 +43,7 @@ ImportedSearchContext::ImportedSearchContext(
_target_attribute(target_attribute),
_target_search_context(_target_attribute.createSearchContext(std::move(term), params)),
_targetLids(_reference_attribute.getTargetLids()),
+ _target_docid_limit(_target_search_context->get_committed_docid_limit()),
_merger(_reference_attribute.getCommittedDocIdLimit()),
_params(params),
_zero_hits(false)
@@ -327,4 +328,10 @@ const vespalib::string& ImportedSearchContext::attributeName() const {
return _imported_attribute.getName();
}
+uint32_t
+ImportedSearchContext::get_committed_docid_limit() const noexcept
+{
+ return _targetLids.size();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/imported_search_context.h b/searchlib/src/vespa/searchlib/attribute/imported_search_context.h
index d9c09d8c645..d6b6d09e8fc 100644
--- a/searchlib/src/vespa/searchlib/attribute/imported_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/imported_search_context.h
@@ -39,6 +39,7 @@ class ImportedSearchContext : public ISearchContext {
const IAttributeVector &_target_attribute;
std::unique_ptr<ISearchContext> _target_search_context;
TargetLids _targetLids;
+ uint32_t _target_docid_limit;
PostingListMerger<int32_t> _merger;
SearchContextParams _params;
mutable std::atomic<bool> _zero_hits;
@@ -47,7 +48,9 @@ class ImportedSearchContext : public ISearchContext {
uint32_t getTargetLid(uint32_t lid) const {
// Check range to avoid reading memory beyond end of mapping array
- return lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u;
+ uint32_t target_lid = lid < _targetLids.size() ? _targetLids[lid].load_acquire() : 0u;
+ // Check target range
+ return target_lid < _target_docid_limit ? target_lid : 0u;
}
void makeMergedPostings(bool isFilter);
@@ -90,6 +93,7 @@ public:
const ISearchContext &target_search_context() const noexcept {
return *_target_search_context;
}
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h
index 5b393d8bdb2..161c6799787 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.h
@@ -59,6 +59,7 @@ public:
std::unique_ptr<queryeval::SearchIterator>
createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override;
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp
index e7901199e50..15abcf6f0d9 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multi_enum_search_context.hpp
@@ -33,4 +33,11 @@ MultiEnumSearchContext<T, BaseSC, M>::createFilterIterator(fef::TermFieldMatchDa
: std::make_unique<AttributeIteratorT<MultiEnumSearchContext>>(*this, matchData);
}
+template <typename T, typename BaseSC, typename M>
+uint32_t
+MultiEnumSearchContext<T, BaseSC, M>::get_committed_docid_limit() const noexcept
+{
+ return _mv_mapping_read_view.get_committed_docid_limit();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h
index b2c76a120f9..23e56e23af9 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.h
@@ -54,6 +54,7 @@ public:
std::unique_ptr<queryeval::SearchIterator>
createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override;
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp
index 15b851215f8..7e1fd1aeb5a 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multi_numeric_search_context.hpp
@@ -33,4 +33,11 @@ MultiNumericSearchContext<T, M>::createFilterIterator(fef::TermFieldMatchData* m
: std::make_unique<AttributeIteratorT<MultiNumericSearchContext<T, M>>>(*this, matchData);
}
+template <typename T, typename M>
+uint32_t
+MultiNumericSearchContext<T, M>::get_committed_docid_limit() const noexcept
+{
+ return _mv_mapping_read_view.get_committed_docid_limit();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h
index 41138ff0890..609989208c3 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h
+++ b/searchlib/src/vespa/searchlib/attribute/multi_value_mapping_read_view.h
@@ -33,6 +33,7 @@ public:
}
vespalib::ConstArrayRef<ElemT> get(uint32_t doc_id) const { return _store->get(_indices[doc_id].load_acquire()); }
bool valid() const noexcept { return _store != nullptr; }
+ uint32_t get_committed_docid_limit() const noexcept { return _indices.size(); }
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp
index b701a6fd08f..9343dafe917 100644
--- a/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/reference_attribute.cpp
@@ -454,12 +454,14 @@ class ReferenceSearchContext : public attribute::SearchContext {
private:
const ReferenceAttribute& _ref_attr;
GlobalId _term;
+ uint32_t _docid_limit;
public:
ReferenceSearchContext(const ReferenceAttribute& ref_attr, const GlobalId& term)
: attribute::SearchContext(ref_attr),
_ref_attr(ref_attr),
- _term(term)
+ _term(term),
+ _docid_limit(ref_attr.getCommittedDocIdLimit())
{
}
bool valid() const override {
@@ -480,8 +482,15 @@ public:
int32_t weight;
return onFind(docId, elementId, weight);
}
+ uint32_t get_committed_docid_limit() const noexcept override;
};
+uint32_t
+ReferenceSearchContext::get_committed_docid_limit() const noexcept
+{
+ return _docid_limit;
+}
+
}
std::unique_ptr<attribute::SearchContext>
diff --git a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h
index 83d6c696117..f6a2f94dedb 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.h
@@ -17,7 +17,9 @@ class SingleEnumSearchContext : public BaseSC
{
protected:
using DocId = ISearchContext::DocId;
- const vespalib::datastore::AtomicEntryRef* _enum_indices;
+ using AtomicEntryRef = vespalib::datastore::AtomicEntryRef;
+ using EnumIndices = vespalib::ConstArrayRef<AtomicEntryRef>;
+ EnumIndices _enum_indices;
const EnumStoreT<T>& _enum_store;
int32_t onFind(DocId docId, int32_t elemId, int32_t & weight) const final {
@@ -29,7 +31,7 @@ protected:
}
public:
- SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store);
+ SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store);
int32_t find(DocId docId, int32_t elemId, int32_t & weight) const {
if ( elemId != 0) return -1;
@@ -46,6 +48,7 @@ public:
std::unique_ptr<queryeval::SearchIterator>
createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override;
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp
index a415c301f9c..6b6cf480d6a 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/single_enum_search_context.hpp
@@ -9,7 +9,7 @@
namespace search::attribute {
template <typename T, typename BaseSC>
-SingleEnumSearchContext<T, BaseSC>::SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store)
+SingleEnumSearchContext<T, BaseSC>::SingleEnumSearchContext(typename BaseSC::MatcherType&& matcher, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store)
: BaseSC(toBeSearched, std::move(matcher)),
_enum_indices(enum_indices),
_enum_store(enum_store)
@@ -33,4 +33,11 @@ SingleEnumSearchContext<T, BaseSC>::createFilterIterator(fef::TermFieldMatchData
: std::make_unique<AttributeIteratorT<SingleEnumSearchContext>>(*this, matchData);
}
+template <typename T, typename BaseSC>
+uint32_t
+SingleEnumSearchContext<T, BaseSC>::get_committed_docid_limit() const noexcept
+{
+ return _enum_indices.size();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h
index 86283f59283..fd3f4c03a8a 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.h
@@ -16,7 +16,9 @@ template <typename T>
class SingleNumericEnumSearchContext : public SingleEnumSearchContext<T, NumericSearchContext<NumericRangeMatcher<T>>>
{
public:
- SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store);
+ using AtomicEntryRef = vespalib::datastore::AtomicEntryRef;
+ using EnumIndices = vespalib::ConstArrayRef<AtomicEntryRef>;
+ SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store);
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp
index f4e049cb6f1..c0818d4d18a 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_enum_search_context.hpp
@@ -8,7 +8,7 @@
namespace search::attribute {
template <typename T>
-SingleNumericEnumSearchContext<T>::SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<T>& enum_store)
+SingleNumericEnumSearchContext<T>::SingleNumericEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<T>& enum_store)
: SingleEnumSearchContext<T, NumericSearchContext<NumericRangeMatcher<T>>>(NumericRangeMatcher<T>(*qTerm, true), toBeSearched, enum_indices, enum_store)
{
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h
index 5f6925f7f4d..6362c69cdac 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.h
@@ -3,6 +3,7 @@
#pragma once
#include "numeric_search_context.h"
+#include <vespa/vespalib/util/arrayref.h>
#include <vespa/vespalib/util/atomic.h>
namespace search::attribute {
@@ -16,7 +17,7 @@ class SingleNumericSearchContext final : public NumericSearchContext<M>
{
private:
using DocId = ISearchContext::DocId;
- const T* _data;
+ vespalib::ConstArrayRef<T> _data;
int32_t onFind(DocId docId, int32_t elemId, int32_t& weight) const override {
return find(docId, elemId, weight);
@@ -27,7 +28,7 @@ private:
}
public:
- SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const T* data);
+ SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, vespalib::ConstArrayRef<T> data);
int32_t find(DocId docId, int32_t elemId, int32_t& weight) const {
if ( elemId != 0) return -1;
const T v = vespalib::atomic::load_ref_relaxed(_data[docId]);
@@ -43,6 +44,7 @@ public:
std::unique_ptr<queryeval::SearchIterator>
createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override;
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp
index 75d3da9de7f..b40b1336e6f 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/single_numeric_search_context.hpp
@@ -9,7 +9,7 @@
namespace search::attribute {
template <typename T, typename M>
-SingleNumericSearchContext<T, M>::SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const T* data)
+SingleNumericSearchContext<T, M>::SingleNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, vespalib::ConstArrayRef<T> data)
: NumericSearchContext<M>(toBeSearched, *qTerm, true),
_data(data)
{
@@ -32,4 +32,11 @@ SingleNumericSearchContext<T, M>::createFilterIterator(fef::TermFieldMatchData*
: std::make_unique<AttributeIteratorT<SingleNumericSearchContext<T, M>>>(*this, matchData);
}
+template <typename T, typename M>
+uint32_t
+SingleNumericSearchContext<T, M>::get_committed_docid_limit() const noexcept
+{
+ return _data.size();
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp
index 5eeef7cd61a..074435809cc 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.cpp
@@ -6,13 +6,14 @@
namespace search::attribute {
-SingleSmallNumericSearchContext::SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift)
+SingleSmallNumericSearchContext::SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift, uint32_t docid_limit)
: NumericSearchContext<NumericRangeMatcher<T>>(toBeSearched, *qTerm, false),
_wordData(word_data),
_valueMask(value_mask),
_valueShiftShift(value_shift_shift),
_valueShiftMask(value_shift_mask),
- _wordShift(word_shift)
+ _wordShift(word_shift),
+ _docid_limit(docid_limit)
{
}
@@ -32,4 +33,10 @@ SingleSmallNumericSearchContext::createFilterIterator(fef::TermFieldMatchData* m
: std::make_unique<AttributeIteratorT<SingleSmallNumericSearchContext>>(*this, matchData);
}
+uint32_t
+SingleSmallNumericSearchContext::get_committed_docid_limit() const noexcept
+{
+ return _docid_limit;
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h
index 46ed02b3eca..a42c8b9b29c 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/single_small_numeric_search_context.h
@@ -22,6 +22,7 @@ private:
uint32_t _valueShiftShift;
uint32_t _valueShiftMask;
uint32_t _wordShift;
+ uint32_t _docid_limit;
int32_t onFind(DocId docId, int32_t elementId, int32_t & weight) const override {
return find(docId, elementId, weight);
@@ -32,7 +33,7 @@ private:
}
public:
- SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift);
+ SingleSmallNumericSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const AttributeVector& toBeSearched, const Word* word_data, Word value_mask, uint32_t value_shift_shift, uint32_t value_shift_mask, uint32_t word_shift, uint32_t docid_limit);
int32_t find(DocId docId, int32_t elemId, int32_t & weight) const {
if ( elemId != 0) return -1;
@@ -53,6 +54,7 @@ public:
std::unique_ptr<queryeval::SearchIterator>
createFilterIterator(fef::TermFieldMatchData* matchData, bool strict) override;
+ uint32_t get_committed_docid_limit() const noexcept override;
};
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp
index 70023b27802..2d1748cefa5 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp
@@ -5,10 +5,10 @@
namespace search::attribute {
-SingleStringEnumHintSearchContext::SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store, uint32_t doc_id_limit, uint64_t num_values)
+SingleStringEnumHintSearchContext::SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store, uint64_t num_values)
: SingleStringEnumSearchContext(std::move(qTerm), cased, toBeSearched, enum_indices, enum_store),
EnumHintSearchContext(enum_store.get_dictionary(),
- doc_id_limit, num_values)
+ enum_indices.size(), num_values)
{
setup_enum_hint_sc(enum_store, *this);
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h
index f9d44454cd0..f157bf17a71 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h
@@ -16,7 +16,7 @@ class SingleStringEnumHintSearchContext : public SingleStringEnumSearchContext,
public EnumHintSearchContext
{
public:
- SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store, uint32_t doc_id_limit, uint64_t num_values);
+ SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store, uint64_t num_values);
~SingleStringEnumHintSearchContext() override;
};
diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp
index cba1d207501..8d23eaf7af0 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp
@@ -6,7 +6,7 @@
namespace search::attribute {
-SingleStringEnumSearchContext::SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store)
+SingleStringEnumSearchContext::SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store)
: SingleEnumSearchContext<const char*, StringSearchContext>(StringMatcher(std::move(qTerm), cased), toBeSearched, enum_indices, enum_store)
{
}
diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h
index 6a9ed38b4ea..b8014b1b0e3 100644
--- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h
+++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h
@@ -14,7 +14,7 @@ namespace search::attribute {
class SingleStringEnumSearchContext : public SingleEnumSearchContext<const char*, StringSearchContext>
{
public:
- SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, const vespalib::datastore::AtomicEntryRef* enum_indices, const EnumStoreT<const char*>& enum_store);
+ SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store);
SingleStringEnumSearchContext(SingleStringEnumSearchContext&&) noexcept;
~SingleStringEnumSearchContext() override;
};
diff --git a/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp b/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp
index 15fc819300c..87b7049b9b7 100644
--- a/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/singleboolattribute.cpp
@@ -132,6 +132,7 @@ public:
void fetchPostings(const queryeval::ExecuteInfo &execInfo) override;
std::unique_ptr<queryeval::SearchIterator> createPostingIterator(fef::TermFieldMatchData *matchData, bool strict) override;
unsigned int approximateHits() const override;
+ uint32_t get_committed_docid_limit() const noexcept override;
};
BitVectorSearchContext::BitVectorSearchContext(std::unique_ptr<QueryTermSimple> qTerm, const SingleBoolAttribute & attr)
@@ -177,6 +178,12 @@ BitVectorSearchContext::approximateHits() const {
: 0;
}
+uint32_t
+BitVectorSearchContext::get_committed_docid_limit() const noexcept
+{
+ return _doc_id_limit;
+}
+
}
std::unique_ptr<attribute::SearchContext>
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp
index c75ee0aacb5..606c7a92ef5 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericattribute.hpp
@@ -164,7 +164,7 @@ SingleValueNumericAttribute<B>::getSearch(QueryTermSimple::UP qTerm,
{
(void) params;
QueryTermSimple::RangeResult<T> res = qTerm->getRange<T>();
- const T* data = &_data.acquire_elem_ref(0);
+ auto data = _data.make_read_view(this->getCommittedDocIdLimit());
if (res.isEqual()) {
return std::make_unique<attribute::SingleNumericSearchContext<T, attribute::NumericMatcher<T>>>(std::move(qTerm), *this, data);
} else {
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp
index b840a0516b2..e459d3d9c9c 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericenumattribute.hpp
@@ -160,7 +160,8 @@ SingleValueNumericEnumAttribute<B>::getSearch(QueryTermSimple::UP qTerm,
const attribute::SearchContextParams & params) const
{
(void) params;
- return std::make_unique<attribute::SingleNumericEnumSearchContext<T>>(std::move(qTerm), *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore);
+ auto docid_limit = this->getCommittedDocIdLimit();
+ return std::make_unique<attribute::SingleNumericEnumSearchContext<T>>(std::move(qTerm), *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore);
}
}
diff --git a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp
index e353d03a9e8..a4b9abb084a 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlenumericpostattribute.hpp
@@ -143,7 +143,8 @@ SingleValueNumericPostingAttribute<B>::getSearch(QueryTermSimple::UP qTerm,
{
using BaseSC = attribute::SingleNumericEnumSearchContext<T>;
using SC = attribute::NumericPostingSearchContext<BaseSC, SelfType, vespalib::btree::BTreeNoLeafData>;
- BaseSC base_sc(std::move(qTerm), *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore);
+ auto docid_limit = this->getCommittedDocIdLimit();
+ BaseSC base_sc(std::move(qTerm), *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore);
return std::make_unique<SC>(std::move(base_sc), params, *this);
}
diff --git a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp
index 13bf2f932e8..3c1621ac244 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlesmallnumericattribute.cpp
@@ -170,7 +170,8 @@ std::unique_ptr<attribute::SearchContext>
SingleValueSmallNumericAttribute::getSearch(std::unique_ptr<QueryTermSimple> qTerm,
const attribute::SearchContextParams &) const
{
- return std::make_unique<attribute::SingleSmallNumericSearchContext>(std::move(qTerm), *this, &_wordData.acquire_elem_ref(0), _valueMask, _valueShiftShift, _valueShiftMask, _wordShift);
+ auto docid_limit = getCommittedDocIdLimit();
+ return std::make_unique<attribute::SingleSmallNumericSearchContext>(std::move(qTerm), *this, &_wordData.acquire_elem_ref(0), _valueMask, _valueShiftShift, _valueShiftMask, _wordShift, docid_limit);
}
void
diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp
index 69fe6435a03..c3f5c295260 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp
@@ -46,7 +46,8 @@ SingleValueStringAttributeT<B>::getSearch(QueryTermSimpleUP qTerm,
const attribute::SearchContextParams &) const
{
bool cased = this->get_match_is_cased();
- return std::make_unique<attribute::SingleStringEnumHintSearchContext>(std::move(qTerm), cased, *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore, this->getCommittedDocIdLimit(), this->getStatus().getNumValues());
+ auto docid_limit = this->getCommittedDocIdLimit();
+ return std::make_unique<attribute::SingleStringEnumHintSearchContext>(std::move(qTerm), cased, *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore, this->getStatus().getNumValues());
}
}
diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp
index 5b5214f6d3e..60847636baa 100644
--- a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp
@@ -145,7 +145,8 @@ SingleValueStringPostingAttributeT<B>::getSearch(QueryTermSimpleUP qTerm,
using BaseSC = attribute::SingleStringEnumSearchContext;
using SC = attribute::StringPostingSearchContext<BaseSC, SelfType, vespalib::btree::BTreeNoLeafData>;
bool cased = this->get_match_is_cased();
- BaseSC base_sc(std::move(qTerm), cased, *this, &this->_enumIndices.acquire_elem_ref(0), this->_enumStore);
+ auto docid_limit = this->getCommittedDocIdLimit();
+ BaseSC base_sc(std::move(qTerm), cased, *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore);
return std::make_unique<SC>(std::move(base_sc),
params.useBitVector(),
*this);
diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
index 46b89fdfeb4..9bef389a278 100644
--- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
+++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h
@@ -109,7 +109,7 @@ public:
uint32_t getArity() const { return _currArity; }
uint32_t getNearDistance() const { return _extraIntArg1; }
- uint32_t getTargetNumHits() const { return _extraIntArg1; }
+ uint32_t getTargetHits() const { return _extraIntArg1; }
double getDistanceThreshold() const { return _extraDoubleArg4; }
double getScoreThreshold() const { return _extraDoubleArg4; }
double getThresholdBoostFactor() const { return _extraDoubleArg5; }
diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
index d1c37cd6dcd..b2d8a0ee4be 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.cpp
@@ -1,15 +1,21 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "nearest_neighbor_query_node.h"
+#include <cassert>
namespace search::streaming {
-NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold)
- : QueryTerm(std::move(resultBase), term, index, Type::NEAREST_NEIGHBOR),
+NearestNeighborQueryNode::NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase,
+ const string& query_tensor_name, const string& field_name,
+ uint32_t target_hits, double distance_threshold,
+ int32_t unique_id, search::query::Weight weight)
+ : QueryTerm(std::move(resultBase), query_tensor_name, field_name, Type::NEAREST_NEIGHBOR),
+ _target_hits(target_hits),
_distance_threshold(distance_threshold),
- _raw_score()
+ _distance(),
+ _calc()
{
- setUniqueId(id);
+ setUniqueId(unique_id);
setWeight(weight);
}
@@ -18,13 +24,13 @@ NearestNeighborQueryNode::~NearestNeighborQueryNode() = default;
bool
NearestNeighborQueryNode::evaluate() const
{
- return _raw_score.has_value();
+ return _distance.has_value();
}
void
NearestNeighborQueryNode::reset()
{
- _raw_score.reset();
+ _distance.reset();
}
NearestNeighborQueryNode*
@@ -33,4 +39,14 @@ NearestNeighborQueryNode::as_nearest_neighbor_query_node() noexcept
return this;
}
+std::optional<double>
+NearestNeighborQueryNode::get_raw_score() const
+{
+ if (_distance.has_value()) {
+ assert(_calc != nullptr);
+ return _calc->to_raw_score(_distance.value());
+ }
+ return std::nullopt;
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
index 0beb130c53d..c66364b0c52 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/nearest_neighbor_query_node.h
@@ -8,16 +8,34 @@
namespace search::streaming {
/*
- * Nearest neighbor query node.
+ * Nearest neighbor query node for streaming search.
*/
class NearestNeighborQueryNode: public QueryTerm {
+public:
+ class RawScoreCalculator {
+ public:
+ virtual ~RawScoreCalculator() = default;
+ /**
+ * Convert the given distance to a raw score.
+ *
+ * This is used during unpacking, and also signals that the entire document was a match.
+ */
+ virtual double to_raw_score(double distance) = 0;
+ };
+
private:
+ uint32_t _target_hits;
double _distance_threshold;
- // When this value is set it also indicates a match
- std::optional<double> _raw_score;
+ // When this value is set it also indicates a match for this query node.
+ std::optional<double> _distance;
+ RawScoreCalculator* _calc;
+
public:
- NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase, const string& term, const string& index, int32_t id, search::query::Weight weight, double distance_threshold);
+ NearestNeighborQueryNode(std::unique_ptr<QueryNodeResultBase> resultBase,
+ const string& query_tensor_name, const string& field_name,
+ uint32_t target_hits, double distance_threshold,
+ int32_t unique_id, search::query::Weight weight);
NearestNeighborQueryNode(const NearestNeighborQueryNode &) = delete;
NearestNeighborQueryNode & operator = (const NearestNeighborQueryNode &) = delete;
NearestNeighborQueryNode(NearestNeighborQueryNode &&) = delete;
@@ -27,9 +45,13 @@ public:
void reset() override;
NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept override;
const vespalib::string& get_query_tensor_name() const { return getTermString(); }
+ uint32_t get_target_hits() const { return _target_hits; }
double get_distance_threshold() const { return _distance_threshold; }
- void set_raw_score(double value) { _raw_score = value; }
- const std::optional<double>& get_raw_score() const noexcept { return _raw_score; }
+ void set_raw_score_calc(RawScoreCalculator* calc_in) { _calc = calc_in; }
+ void set_distance(double value) { _distance = value; }
+ const std::optional<double>& get_distance() const { return _distance; }
+ // This is used during unpacking, and also signals to the RawScoreCalculator that the entire document was a match.
+ std::optional<double> get_raw_score() const;
};
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index 226cb92c894..84344831cbc 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -200,15 +200,17 @@ QueryNode::build_nearest_neighbor_query_node(const QueryNodeResultFactory& facto
{
vespalib::stringref query_tensor_name = query_rep.getTerm();
vespalib::stringref field_name = query_rep.getIndexName();
- int32_t id = query_rep.getUniqueId();
- search::query::Weight weight = query_rep.GetWeight();
+ int32_t unique_id = query_rep.getUniqueId();
+ auto weight = query_rep.GetWeight();
+ uint32_t target_hits = query_rep.getTargetHits();
double distance_threshold = query_rep.getDistanceThreshold();
return std::make_unique<NearestNeighborQueryNode>(factory.create(),
query_tensor_name,
field_name,
- id,
- weight,
- distance_threshold);
+ target_hits,
+ distance_threshold,
+ unique_id,
+ weight);
}
}
diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
index 90bd87979c7..a552a650704 100644
--- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
+++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h
@@ -89,7 +89,7 @@ private:
pureTermView = view;
} else if (type == ParseItem::ITEM_WEAK_AND) {
vespalib::stringref view = queryStack.getIndexName();
- uint32_t targetNumHits = queryStack.getTargetNumHits();
+ uint32_t targetNumHits = queryStack.getTargetHits();
builder.addWeakAnd(arity, targetNumHits, view);
pureTermView = view;
} else if (type == ParseItem::ITEM_EQUIV) {
@@ -134,7 +134,7 @@ private:
vespalib::stringref view = queryStack.getIndexName();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
- uint32_t targetNumHits = queryStack.getTargetNumHits();
+ uint32_t targetNumHits = queryStack.getTargetHits();
double scoreThreshold = queryStack.getScoreThreshold();
double thresholdBoostFactor = queryStack.getThresholdBoostFactor();
auto & wand = builder.addWandTerm(arity, view, id, weight, targetNumHits, scoreThreshold, thresholdBoostFactor);
@@ -146,7 +146,7 @@ private:
} else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) {
vespalib::stringref query_tensor_name = queryStack.getTerm();
vespalib::stringref field_name = queryStack.getIndexName();
- uint32_t target_num_hits = queryStack.getTargetNumHits();
+ uint32_t target_num_hits = queryStack.getTargetHits();
int32_t id = queryStack.getUniqueId();
Weight weight = queryStack.GetWeight();
bool allow_approximate = queryStack.getAllowApproximate();
diff --git a/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp
index c50c6ec49f5..86f520c8711 100644
--- a/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/leaf_blueprints.cpp
@@ -129,8 +129,16 @@ struct FakeContext : attribute::ISearchContext {
DoubleRange getAsDoubleTerm() const override { abort(); }
const QueryTermUCS4 * queryTerm() const override { abort(); }
const vespalib::string &attributeName() const override { return name; }
+ uint32_t get_committed_docid_limit() const noexcept override;
};
+uint32_t
+FakeContext::get_committed_docid_limit() const noexcept
+{
+ auto& documents = result.inspect();
+ return documents.empty() ? 0 : (documents.back().docId + 1);
+}
+
}
SearchIterator::UP
diff --git a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
index 1783e0da1dd..2e874ffa4ae 100644
--- a/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/tensor/CMakeLists.txt
@@ -30,6 +30,7 @@ vespa_add_library(searchlib_tensor OBJECT
large_subspaces_buffer_type.cpp
nearest_neighbor_index.cpp
nearest_neighbor_index_saver.cpp
+ prenormalized_angular_distance.cpp
serialized_fast_value_attribute.cpp
serialized_tensor_ref.cpp
small_subspaces_buffer_type.cpp
diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
index 85eac76728c..a7ae02bb9f4 100644
--- a/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.cpp
@@ -61,20 +61,19 @@ private:
double _lhs_norm_sq;
public:
BoundAngularDistance(const vespalib::eval::TypedCells& lhs)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
_tmpSpace(lhs.size),
_lhs(_tmpSpace.storeLhs(lhs))
{
- auto a = &_lhs[0];
+ auto a = _lhs.data();
_lhs_norm_sq = _computer.dotProduct(a, a, lhs.size);
}
double calc(const vespalib::eval::TypedCells& rhs) const override {
size_t sz = _lhs.size();
vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
assert(sz == rhs_vector.size());
- auto a = &_lhs[0];
- auto b = &rhs_vector[0];
+ auto a = _lhs.data();
+ auto b = rhs_vector.data();
double b_norm_sq = _computer.dotProduct(b, b, sz);
double squared_norms = _lhs_norm_sq * b_norm_sq;
double dot_product = _computer.dotProduct(a, b, sz);
diff --git a/searchlib/src/vespa/searchlib/tensor/angular_distance.h b/searchlib/src/vespa/searchlib/tensor/angular_distance.h
index 97f692d05a2..bba83576153 100644
--- a/searchlib/src/vespa/searchlib/tensor/angular_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/angular_distance.h
@@ -60,8 +60,8 @@ public:
auto rhs_vector = rhs.typify<FloatType>();
size_t sz = lhs_vector.size();
assert(sz == rhs_vector.size());
- auto a = &lhs_vector[0];
- auto b = &rhs_vector[0];
+ auto a = lhs_vector.data();
+ auto b = rhs_vector.data();
double a_norm_sq = _computer.dotProduct(a, a, sz);
double b_norm_sq = _computer.dotProduct(b, b, sz);
double squared_norms = a_norm_sq * b_norm_sq;
diff --git a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
index 5d602a52227..c072d6de8e5 100644
--- a/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
+++ b/searchlib/src/vespa/searchlib/tensor/bound_distance_function.h
@@ -20,20 +20,13 @@ namespace search::tensor {
* mutable temporary storage.
*/
class BoundDistanceFunction : public DistanceConverter {
-private:
- vespalib::eval::CellType _expect_cell_type;
public:
using UP = std::unique_ptr<BoundDistanceFunction>;
- BoundDistanceFunction(vespalib::eval::CellType expected) : _expect_cell_type(expected) {}
+ BoundDistanceFunction() = default;
virtual ~BoundDistanceFunction() = default;
- // input vectors will be converted to this cell type:
- vespalib::eval::CellType expected_cell_type() const {
- return _expect_cell_type;
- }
-
// calculate internal distance (comparable)
virtual double calc(const vespalib::eval::TypedCells& rhs) const = 0;
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
index cca492ef212..c088d498f0f 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_function_factory.cpp
@@ -55,8 +55,7 @@ class SimpleBoundDistanceFunction : public BoundDistanceFunction {
public:
SimpleBoundDistanceFunction(const vespalib::eval::TypedCells& lhs,
const DistanceFunction &df)
- : BoundDistanceFunction(lhs.type),
- _lhs(lhs),
+ : _lhs(lhs),
_df(df)
{}
@@ -94,21 +93,35 @@ std::unique_ptr<DistanceFunctionFactory>
make_distance_function_factory(search::attribute::DistanceMetric variant,
vespalib::eval::CellType cell_type)
{
- if (variant == DistanceMetric::Angular) {
- if (cell_type == CellType::DOUBLE) {
- return std::make_unique<AngularDistanceFunctionFactory<double>>();
- }
- return std::make_unique<AngularDistanceFunctionFactory<float>>();
- }
- if (variant == DistanceMetric::Euclidean) {
- switch (cell_type) {
- case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
- case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
- default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
- }
+ switch (variant) {
+ case DistanceMetric::Angular:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<AngularDistanceFunctionFactory<double>>();
+ default: return std::make_unique<AngularDistanceFunctionFactory<float>>();
+ }
+ case DistanceMetric::Euclidean:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<EuclideanDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<EuclideanDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<EuclideanDistanceFunctionFactory<float>>();
+ }
+ case DistanceMetric::InnerProduct:
+ case DistanceMetric::PrenormalizedAngular:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<double>>();
+ default: return std::make_unique<PrenormalizedAngularDistanceFunctionFactory<float>>();
+ }
+ case DistanceMetric::GeoDegrees:
+ return std::make_unique<GeoDistanceFunctionFactory>();
+ case DistanceMetric::Hamming:
+ switch (cell_type) {
+ case CellType::DOUBLE: return std::make_unique<HammingDistanceFunctionFactory<double>>();
+ case CellType::INT8: return std::make_unique<HammingDistanceFunctionFactory<vespalib::eval::Int8Float>>();
+ default: return std::make_unique<HammingDistanceFunctionFactory<float>>();
+ }
}
- auto df = make_distance_function(variant, cell_type);
- return std::make_unique<SimpleDistanceFunctionFactory>(std::move(df));
+ // not reached:
+ return {};
}
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_functions.h b/searchlib/src/vespa/searchlib/tensor/distance_functions.h
index b28cc2bda46..2300dba2db1 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_functions.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_functions.h
@@ -8,3 +8,4 @@
#include "geo_degrees_distance.h"
#include "hamming_distance.h"
#include "inner_product_distance.h"
+#include "prenormalized_angular_distance.h"
diff --git a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
index 92d4e7af406..7995c87d055 100644
--- a/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/euclidean_distance.cpp
@@ -62,8 +62,7 @@ private:
static const int8_t *cast(const Int8Float * p) { return reinterpret_cast<const int8_t *>(p); }
public:
BoundEuclideanDistance(const vespalib::eval::TypedCells& lhs)
- : BoundDistanceFunction(vespalib::eval::get_cell_type<FloatType>()),
- _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
_tmpSpace(lhs.size),
_lhs_vector(_tmpSpace.storeLhs(lhs))
{}
@@ -71,8 +70,8 @@ public:
size_t sz = _lhs_vector.size();
vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
assert(sz == rhs_vector.size());
- auto a = &_lhs_vector[0];
- auto b = &rhs_vector[0];
+ auto a = _lhs_vector.data();
+ auto b = rhs_vector.data();
return _computer.squaredEuclideanDistance(cast(a), cast(b), sz);
}
double convert_threshold(double threshold) const override {
diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
index bcce75da3ab..38ba8205c90 100644
--- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "geo_degrees_distance.h"
+#include "temporary_vector_store.h"
using vespalib::typify_invoke;
using vespalib::eval::TypifyCellType;
@@ -27,11 +28,11 @@ struct CalcGeoDegrees {
double lat_diff = lat_A - lat_B;
double lon_diff = lon_A - lon_B;
-
+
// haversines of differences:
double hav_lat = GeoDegreesDistance::hav(lat_diff);
double hav_lon = GeoDegreesDistance::hav(lon_diff);
-
+
// haversine of central angle between the two points:
double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon;
return hav_central_angle;
@@ -42,9 +43,63 @@ struct CalcGeoDegrees {
double
GeoDegreesDistance::calc(const vespalib::eval::TypedCells& lhs,
- const vespalib::eval::TypedCells& rhs) const
+ const vespalib::eval::TypedCells& rhs) const
{
return typify_invoke<2,TypifyCellType,CalcGeoDegrees>(lhs.type, rhs.type, lhs, rhs);
}
+using vespalib::eval::TypedCells;
+
+class BoundGeoDistance : public BoundDistanceFunction {
+private:
+ mutable TemporaryVectorStore<double> _tmpSpace;
+ const vespalib::ConstArrayRef<double> _lh_vector;
+ static GeoDegreesDistance _g_d_helper;
+public:
+ BoundGeoDistance(const vespalib::eval::TypedCells& lhs)
+ : _tmpSpace(lhs.size),
+ _lh_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ vespalib::ConstArrayRef<double> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(2 == _lh_vector.size());
+ assert(2 == rhs_vector.size());
+ // convert to radians:
+ double lat_A = _lh_vector[0] * GeoDegreesDistance::degrees_to_radians;
+ double lat_B = rhs_vector[0] * GeoDegreesDistance::degrees_to_radians;
+ double lon_A = _lh_vector[1] * GeoDegreesDistance::degrees_to_radians;
+ double lon_B = rhs_vector[1] * GeoDegreesDistance::degrees_to_radians;
+
+ double lat_diff = lat_A - lat_B;
+ double lon_diff = lon_A - lon_B;
+
+ // haversines of differences:
+ double hav_lat = GeoDegreesDistance::hav(lat_diff);
+ double hav_lon = GeoDegreesDistance::hav(lon_diff);
+
+ // haversine of central angle between the two points:
+ double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon;
+ return hav_central_angle;
+ }
+ double convert_threshold(double threshold) const override {
+ return _g_d_helper.convert_threshold(threshold);
+ }
+ double to_rawscore(double distance) const override {
+ return _g_d_helper.to_rawscore(distance);
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ return calc(rhs);
+ }
+};
+
+BoundDistanceFunction::UP
+GeoDistanceFunctionFactory::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundGeoDistance>(lhs);
+}
+
+BoundDistanceFunction::UP
+GeoDistanceFunctionFactory::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ return std::make_unique<BoundGeoDistance>(lhs);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
index 46feee19119..4522bc03c9e 100644
--- a/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/geo_degrees_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
#include <vespa/vespalib/util/typify.h>
@@ -50,4 +51,11 @@ public:
}
};
+class GeoDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ GeoDistanceFunctionFactory() : DistanceFunctionFactory(vespalib::eval::CellType::DOUBLE) {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp
index 43596478a6f..f4f6842715f 100644
--- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "hamming_distance.h"
+#include "temporary_vector_store.h"
#include <vespa/vespalib/util/binary_hamming_distance.h>
using vespalib::typify_invoke;
@@ -52,4 +53,63 @@ HammingDistance::calc_with_limit(const vespalib::eval::TypedCells& lhs,
return calc(lhs, rhs);
}
+using vespalib::eval::Int8Float;
+
+template<typename FloatType>
+class BoundHammingDistance : public BoundDistanceFunction {
+private:
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs_vector;
+public:
+ BoundHammingDistance(const vespalib::eval::TypedCells& lhs)
+ : _tmpSpace(lhs.size),
+ _lhs_vector(_tmpSpace.storeLhs(lhs))
+ {}
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ size_t sz = _lhs_vector.size();
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(sz == rhs_vector.size());
+ auto a = _lhs_vector.data();
+ auto b = rhs_vector.data();
+ if constexpr (std::is_same<Int8Float, FloatType>::value) {
+ return (double) vespalib::binary_hamming_distance(a, b, sz);
+ } else {
+ size_t sum = 0;
+ for (size_t i = 0; i < sz; ++i) {
+ sum += (_lhs_vector[i] == rhs_vector[i]) ? 0 : 1;
+ }
+ return (double)sum;
+ }
+ }
+ double convert_threshold(double threshold) const override {
+ return threshold;
+ }
+ double to_rawscore(double distance) const override {
+ double score = 1.0 / (1.0 + distance);
+ return score;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ // consider optimizing:
+ return calc(rhs);
+ }
+};
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+HammingDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundHammingDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+HammingDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundHammingDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template class HammingDistanceFunctionFactory<Int8Float>;
+template class HammingDistanceFunctionFactory<float>;
+template class HammingDistanceFunctionFactory<double>;
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h
index c64fc5b532d..23c855eb137 100644
--- a/searchlib/src/vespa/searchlib/tensor/hamming_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/hamming_distance.h
@@ -3,6 +3,7 @@
#pragma once
#include "distance_function.h"
+#include "distance_function_factory.h"
#include <vespa/eval/eval/typed_cells.h>
#include <vespa/vespalib/util/typify.h>
#include <cmath>
@@ -29,4 +30,14 @@ public:
double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, double) const override;
};
+template <typename FloatType>
+class HammingDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ HammingDistanceFunctionFactory()
+ : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>())
+ {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h b/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h
index 8ba14580885..135bb186fd4 100644
--- a/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h
+++ b/searchlib/src/vespa/searchlib/tensor/inner_product_distance.h
@@ -54,7 +54,7 @@ public:
auto rhs_vector = rhs.typify<FloatType>();
size_t sz = lhs_vector.size();
assert(sz == rhs_vector.size());
- double score = 1.0 - _computer.dotProduct(&lhs_vector[0], &rhs_vector[0], sz);
+ double score = 1.0 - _computer.dotProduct(lhs_vector.data(), rhs_vector.data(), sz);
return std::max(0.0, score);
}
private:
diff --git a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp
new file mode 100644
index 00000000000..292edc1259d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.cpp
@@ -0,0 +1,81 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "prenormalized_angular_distance.h"
+#include "temporary_vector_store.h"
+
+using vespalib::typify_invoke;
+using vespalib::eval::TypifyCellType;
+
+namespace search::tensor {
+
+template<typename FloatType>
+class BoundPrenormalizedAngularDistance : public BoundDistanceFunction {
+private:
+ const vespalib::hwaccelrated::IAccelrated & _computer;
+ mutable TemporaryVectorStore<FloatType> _tmpSpace;
+ const vespalib::ConstArrayRef<FloatType> _lhs;
+ double _lhs_norm_sq;
+public:
+ BoundPrenormalizedAngularDistance(const vespalib::eval::TypedCells& lhs)
+ : _computer(vespalib::hwaccelrated::IAccelrated::getAccelerator()),
+ _tmpSpace(lhs.size),
+ _lhs(_tmpSpace.storeLhs(lhs))
+ {
+ auto a = _lhs.data();
+ _lhs_norm_sq = _computer.dotProduct(a, a, lhs.size);
+ if (_lhs_norm_sq <= 0.0) {
+ _lhs_norm_sq = 1.0;
+ }
+ }
+ double calc(const vespalib::eval::TypedCells& rhs) const override {
+ size_t sz = _lhs.size();
+ vespalib::ConstArrayRef<FloatType> rhs_vector = _tmpSpace.convertRhs(rhs);
+ assert(sz == rhs_vector.size());
+ auto a = _lhs.data();
+ auto b = rhs_vector.data();
+ double dot_product = _computer.dotProduct(a, b, sz);
+ double distance = _lhs_norm_sq - dot_product;
+ return distance;
+ }
+ double convert_threshold(double threshold) const override {
+ double cosine_similarity = 1.0 - threshold;
+ double dot_product = cosine_similarity * _lhs_norm_sq;
+ double distance = _lhs_norm_sq - dot_product;
+ return distance;
+ }
+ double to_rawscore(double distance) const override {
+ double dot_product = _lhs_norm_sq - distance;
+ double cosine_similarity = dot_product / _lhs_norm_sq;
+ // should be in in range [-1,1] but roundoff may cause problems:
+ cosine_similarity = std::min(1.0, cosine_similarity);
+ cosine_similarity = std::max(-1.0, cosine_similarity);
+ double cosine_distance = 1.0 - cosine_similarity; // in range [0,2]
+ double score = 1.0 / (1.0 + cosine_distance);
+ return score;
+ }
+ double calc_with_limit(const vespalib::eval::TypedCells& rhs, double) const override {
+ return calc(rhs);
+ }
+};
+
+template class BoundPrenormalizedAngularDistance<float>;
+template class BoundPrenormalizedAngularDistance<double>;
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+PrenormalizedAngularDistanceFunctionFactory<FloatType>::for_query_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundPrenormalizedAngularDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template <typename FloatType>
+BoundDistanceFunction::UP
+PrenormalizedAngularDistanceFunctionFactory<FloatType>::for_insertion_vector(const vespalib::eval::TypedCells& lhs) {
+ using DFT = BoundPrenormalizedAngularDistance<FloatType>;
+ return std::make_unique<DFT>(lhs);
+}
+
+template class PrenormalizedAngularDistanceFunctionFactory<float>;
+template class PrenormalizedAngularDistanceFunctionFactory<double>;
+
+}
diff --git a/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h
new file mode 100644
index 00000000000..88953a236e7
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/tensor/prenormalized_angular_distance.h
@@ -0,0 +1,27 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "distance_function.h"
+#include "bound_distance_function.h"
+#include "distance_function_factory.h"
+#include <vespa/eval/eval/typed_cells.h>
+#include <vespa/vespalib/hwaccelrated/iaccelrated.h>
+
+namespace search::tensor {
+
+/**
+ * Calculates inner-product "distance" between vectors with assumed norm 1.
+ * Should give same ordering as Angular distance, but is less expensive.
+ */
+template <typename FloatType>
+class PrenormalizedAngularDistanceFunctionFactory : public DistanceFunctionFactory {
+public:
+ PrenormalizedAngularDistanceFunctionFactory()
+ : DistanceFunctionFactory(vespalib::eval::get_cell_type<FloatType>())
+ {}
+ BoundDistanceFunction::UP for_query_vector(const vespalib::eval::TypedCells& lhs) override;
+ BoundDistanceFunction::UP for_insertion_vector(const vespalib::eval::TypedCells& lhs) override;
+};
+
+}
diff --git a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp
index 43c77398be8..b64d477fd4c 100644
--- a/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp
+++ b/streamingvisitors/src/tests/nearest_neighbor_field_searcher/nearest_neighbor_field_searcher_test.cpp
@@ -31,9 +31,11 @@ struct MockQuery {
std::vector<std::unique_ptr<NearestNeighborQueryNode>> nodes;
QueryTermList term_list;
MockQuery& add(const vespalib::string& query_tensor_name,
+ uint32_t target_hits,
double distance_threshold) {
std::unique_ptr<QueryNodeResultBase> base;
- auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field", 7, search::query::Weight(11), distance_threshold);
+ auto node = std::make_unique<NearestNeighborQueryNode>(std::move(base), query_tensor_name, "my_tensor_field",
+ target_hits, distance_threshold, 7, search::query::Weight(100));
nodes.push_back(std::move(node));
term_list.push_back(nodes.back().get());
return *this;
@@ -90,34 +92,71 @@ public:
query.reset();
searcher.onValue(fv);
}
+ void expect_match(const vespalib::string& spec_expr, double exp_square_distance, const NearestNeighborQueryNode& node) {
+ match(spec_expr);
+ expect_match(exp_square_distance, node);
+ }
void expect_match(double exp_square_distance, const NearestNeighborQueryNode& node) {
double exp_raw_score = dist_func.to_rawscore(exp_square_distance);
EXPECT_TRUE(node.evaluate());
+ EXPECT_DOUBLE_EQ(exp_square_distance, node.get_distance().value());
EXPECT_DOUBLE_EQ(exp_raw_score, node.get_raw_score().value());
}
+ void expect_not_match(const vespalib::string& spec_expr, const NearestNeighborQueryNode& node) {
+ match(spec_expr);
+ EXPECT_FALSE(node.evaluate());
+ }
};
-TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold)
+TEST_F(NearestNeighborSearcherTest, distance_heap_keeps_the_best_target_hits)
{
- query.add("qt1", 3);
+ query.add("qt1", 2, 100.0);
+ const auto& node = query.get(0);
set_query_tensor("qt1", "tensor(x[2]):[1,3]");
prepare();
- match("tensor(x[2]):[1,5]");
- expect_match((5-3)*(5-3), query.get(0));
+ expect_match("tensor(x[2]):[1,7]", (7-3)*(7-3), node);
+ expect_match("tensor(x[2]):[1,9]", (9-3)*(9-3), node);
- match("tensor(x[2]):[1,6]");
- expect_match((6-3)*(6-3), query.get(0));
+ // The distance limit is now (9-3)*(9-3) = 36, so this is not good enough.
+ expect_not_match("tensor(x[2]):[1,10]", node);
+
+ expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node);
+
+ // The distance limit is now (7-3)*(7-3) = 16, so this is not good enough.
+ expect_not_match("tensor(x[2]):[1,8]", node);
+
+ // This is not considered a document match as get_raw_score() is not called,
+ // and the distance heap is not updated.
+ match("tensor(x[2]):[1,4]");
+ EXPECT_EQ(1, node.get_distance().value());
+ EXPECT_TRUE(node.evaluate());
+
+ // The distance limit is still (7-3)*(7-3) = 16, so this is in fact good enough.
+ expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node);
+
+ // The distance limit is (6-3)*(6-3) = 4, and a similar distance is a match.
+ expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node);
+}
+
+TEST_F(NearestNeighborSearcherTest, raw_score_calculated_with_distance_threshold)
+{
+ query.add("qt1", 10, 3.0);
+ const auto& node = query.get(0);
+ set_query_tensor("qt1", "tensor(x[2]):[1,3]");
+ prepare();
+
+ expect_match("tensor(x[2]):[1,5]", (5-3)*(5-3), node);
+ expect_match("tensor(x[2]):[1,6]", (6-3)*(6-3), node);
- match("tensor(x[2]):[1,7]");
// This is not a match since ((7-3)*(7-3) = 16) is larger than the internal distance threshold of (3*3 = 9).
- EXPECT_FALSE(query.get(0).evaluate());
+ expect_not_match("tensor(x[2]):[1,7]", node);
}
TEST_F(NearestNeighborSearcherTest, raw_score_calculated_for_two_query_operators)
{
- query.add("qt1", 3);
- query.add("qt2", 4);
+ query.add("qt1", 10, 3.0);
+ query.add("qt2", 10, 4.0);
set_query_tensor("qt1", "tensor(x[2]):[1,3]");
set_query_tensor("qt2", "tensor(x[2]):[1,4]");
prepare();
diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
index 9f3f3d770e4..4d425d9dedd 100644
--- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
+++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp
@@ -55,6 +55,10 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder)
_query_wrapper = std::make_unique<QueryWrapper>(*_query);
}
+class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator {
+public:
+ double to_raw_score(double distance) override { return distance * 2; }
+};
TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
{
@@ -71,6 +75,8 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
EXPECT_EQ(1u, term_list.size());
auto node = dynamic_cast<NearestNeighborQueryNode*>(term_list.front().getTerm());
EXPECT_NE(nullptr, node);
+ MockRawScoreCalculator calc;
+ node->set_raw_score_calc(&calc);
auto& qtd = static_cast<QueryTermData &>(node->getQueryItem());
auto& td = qtd.getTermData();
constexpr TermFieldHandle handle = 27;
@@ -82,11 +88,11 @@ TEST_F(RankProcessorTest, unpack_match_data_for_nearest_neighbor_query_node)
EXPECT_EQ(invalid_id, tfmd->getDocId());
RankProcessor::unpack_match_data(1, *md, *_query_wrapper);
EXPECT_EQ(invalid_id, tfmd->getDocId());
- constexpr double raw_score = 1.5;
- node->set_raw_score(raw_score);
+ constexpr double distance = 1.5;
+ node->set_distance(distance);
RankProcessor::unpack_match_data(2, *md, *_query_wrapper);
EXPECT_EQ(2, tfmd->getDocId());
- EXPECT_EQ(raw_score, tfmd->getRawScore());
+ EXPECT_EQ(distance * 2, tfmd->getRawScore());
node->reset();
RankProcessor::unpack_match_data(3, *md, *_query_wrapper);
EXPECT_EQ(2, tfmd->getDocId());
diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
index 01b21edc1ba..3ce137bffe5 100644
--- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp
@@ -241,7 +241,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap
for (QueryWrapper::Term & term: query.getTermList()) {
auto nn_node = term.getTerm()->as_nearest_neighbor_query_node();
if (nn_node != nullptr) {
- auto& raw_score = nn_node->get_raw_score();
+ auto raw_score = nn_node->get_raw_score();
if (raw_score.has_value()) {
auto& qtd = static_cast<QueryTermData &>(term.getTerm()->getQueryItem());
auto& td = qtd.getTermData();
diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
index f064760e55d..db4ee12438e 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
+++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp
@@ -48,8 +48,17 @@ NearestNeighborFieldSearcher::NodeAndCalc::NodeAndCalc(search::streaming::Neares
std::unique_ptr<search::tensor::DistanceCalculator> calc_in)
: node(node_in),
calc(std::move(calc_in)),
- distance_threshold(calc->function().convert_threshold(node->get_distance_threshold()))
+ heap(node->get_target_hits())
{
+ node->set_raw_score_calc(this);
+ heap.set_distance_threshold(calc->function().convert_threshold(node->get_distance_threshold()));
+}
+
+double
+NearestNeighborFieldSearcher::NodeAndCalc::to_raw_score(double distance)
+{
+ heap.used(distance);
+ return calc->function().to_rawscore(distance);
}
NearestNeighborFieldSearcher::NearestNeighborFieldSearcher(FieldIdT fid,
@@ -100,7 +109,7 @@ NearestNeighborFieldSearcher::prepare(search::streaming::QueryTermList& qtl,
}
try {
auto calc = DistanceCalculator::make_with_validation(*_attr, *tensor_value);
- _calcs.emplace_back(nn_term, std::move(calc));
+ _calcs.push_back(std::make_unique<NodeAndCalc>(nn_term, std::move(calc)));
} catch (const vespalib::IllegalArgumentException& ex) {
vespalib::Issue::report("Could not create DistanceCalculator for NearestNeighborQueryNode(%s, %s): %s",
nn_term->index().c_str(), nn_term->get_query_tensor_name().c_str(), ex.what());
@@ -116,10 +125,10 @@ NearestNeighborFieldSearcher::onValue(const document::FieldValue& fv)
if (tfv && tfv->getAsTensorPtr()) {
_attr->add(*tfv->getAsTensorPtr(), 1);
for (auto& elem : _calcs) {
- double distance = elem.calc->calc_with_limit(scratch_docid, elem.distance_threshold);
- if (distance <= elem.distance_threshold) {
- double score = elem.calc->function().to_rawscore(distance);
- elem.node->set_raw_score(score);
+ double distance_limit = elem->heap.distanceLimit();
+ double distance = elem->calc->calc_with_limit(scratch_docid, distance_limit);
+ if (distance <= distance_limit) {
+ elem->node->set_distance(distance);
}
}
}
diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h
index ba39b91c677..d5d751cd637 100644
--- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h
+++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.h
@@ -5,6 +5,8 @@
#include "fieldsearcher.h"
#include <vespa/eval/eval/value_type.h>
#include <vespa/searchcommon/attribute/distance_metric.h>
+#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
+#include <vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h>
#include <vespa/searchlib/tensor/distance_calculator.h>
#include <vespa/searchlib/tensor/tensor_ext_attribute.h>
@@ -14,8 +16,6 @@ namespace search::tensor {
class TensorExtAttribute;
}
-namespace search::streaming { class NearestNeighborQueryNode; }
-
namespace vsm {
/**
@@ -26,16 +26,19 @@ namespace vsm {
*/
class NearestNeighborFieldSearcher : public FieldSearcher {
private:
- struct NodeAndCalc {
+ class NodeAndCalc : search::streaming::NearestNeighborQueryNode::RawScoreCalculator {
+ public:
search::streaming::NearestNeighborQueryNode* node;
std::unique_ptr<search::tensor::DistanceCalculator> calc;
- double distance_threshold;
+ search::queryeval::NearestNeighborDistanceHeap heap;
NodeAndCalc(search::streaming::NearestNeighborQueryNode* node_in,
std::unique_ptr<search::tensor::DistanceCalculator> calc_in);
+
+ double to_raw_score(double distance) override;
};
search::attribute::DistanceMetric _metric;
std::unique_ptr<search::tensor::TensorExtAttribute> _attr;
- std::vector<NodeAndCalc> _calcs;
+ std::vector<std::unique_ptr<NodeAndCalc>> _calcs;
public:
NearestNeighborFieldSearcher(FieldIdT fid,
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java
new file mode 100644
index 00000000000..c2ab22f4921
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/DefaultSignedIdentityDocument.java
@@ -0,0 +1,14 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api;
+
+public record DefaultSignedIdentityDocument(String signature, int signingKeyVersion, int documentVersion,
+ String data, IdentityDocument identityDocument) implements SignedIdentityDocument {
+
+ public DefaultSignedIdentityDocument {
+ identityDocument = EntityBindingsMapper.fromIdentityDocumentData(data);
+ }
+
+ public DefaultSignedIdentityDocument(String signature, int signingKeyVersion, int documentVersion, String data) {
+ this(signature,signingKeyVersion,documentVersion, data, null);
+ }
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java
index 2d77d2ceda1..a695e10a29c 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapper.java
@@ -6,8 +6,12 @@ import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.datatype.jsr310.JavaTimeModule;
import com.yahoo.vespa.athenz.api.AthenzIdentity;
import com.yahoo.vespa.athenz.api.AthenzService;
+import com.yahoo.vespa.athenz.identityprovider.api.bindings.DefaultSignedIdentityDocumentEntity;
+import com.yahoo.vespa.athenz.identityprovider.api.bindings.IdentityDocumentEntity;
+import com.yahoo.vespa.athenz.identityprovider.api.bindings.LegacySignedIdentityDocumentEntity;
import com.yahoo.vespa.athenz.identityprovider.api.bindings.SignedIdentityDocumentEntity;
import com.yahoo.vespa.athenz.utils.AthenzIdentities;
+import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.InputStream;
@@ -16,6 +20,8 @@ import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
+import java.time.Instant;
+import java.util.Base64;
import java.util.Optional;
import static com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId.fromDottedString;
@@ -24,6 +30,7 @@ import static com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId.
* Utility class for mapping objects model types and their Jackson binding versions.
*
* @author bjorncs
+ * @author mortent
*/
public class EntityBindingsMapper {
@@ -48,39 +55,60 @@ public class EntityBindingsMapper {
}
public static SignedIdentityDocument toSignedIdentityDocument(SignedIdentityDocumentEntity entity) {
- return new SignedIdentityDocument(
- entity.signature(),
- entity.signingKeyVersion(),
- fromDottedString(entity.providerUniqueId()),
- new AthenzService(entity.providerService()),
- entity.documentVersion(),
- entity.configServerHostname(),
- entity.instanceHostname(),
- entity.createdAt(),
- entity.ipAddresses(),
- IdentityType.fromId(entity.identityType()),
- Optional.ofNullable(entity.clusterType()).map(ClusterType::from).orElse(null),
- entity.ztsUrl(),
- Optional.ofNullable(entity.serviceIdentity()).map(AthenzIdentities::from).orElse(null),
- entity.unknownAttributes());
+ if (entity instanceof LegacySignedIdentityDocumentEntity docEntity) {
+ IdentityDocument doc = new IdentityDocument(
+ fromDottedString(docEntity.providerUniqueId()),
+ new AthenzService(docEntity.providerService()),
+ docEntity.configServerHostname(),
+ docEntity.instanceHostname(),
+ docEntity.createdAt(),
+ docEntity.ipAddresses(),
+ IdentityType.fromId(docEntity.identityType()),
+ Optional.ofNullable(docEntity.clusterType()).map(ClusterType::from).orElse(null),
+ docEntity.ztsUrl(),
+ Optional.ofNullable(docEntity.serviceIdentity()).map(AthenzIdentities::from).orElse(null),
+ docEntity.unknownAttributes());
+ return new LegacySignedIdentityDocument(
+ docEntity.signature(),
+ docEntity.signingKeyVersion(),
+ entity.documentVersion(),
+ doc);
+ } else if (entity instanceof DefaultSignedIdentityDocumentEntity docEntity) {
+ return new DefaultSignedIdentityDocument(docEntity.signature(),
+ docEntity.signingKeyVersion(),
+ docEntity.documentVersion(),
+ docEntity.data());
+ } else {
+ throw new IllegalArgumentException("Unknown signed identity document type: " + entity.getClass().getName());
+ }
}
public static SignedIdentityDocumentEntity toSignedIdentityDocumentEntity(SignedIdentityDocument model) {
- return new SignedIdentityDocumentEntity(
- model.signature(),
- model.signingKeyVersion(),
- model.providerUniqueId().asDottedString(),
- model.providerService().getFullName(),
- model.documentVersion(),
- model.configServerHostname(),
- model.instanceHostname(),
- model.createdAt(),
- model.ipAddresses(),
- model.identityType().id(),
- Optional.ofNullable(model.clusterType()).map(ClusterType::toConfigValue).orElse(null),
- model.ztsUrl(),
- Optional.ofNullable(model.serviceIdentity()).map(AthenzIdentity::getFullName).orElse(null),
- model.unknownAttributes());
+ if (model instanceof LegacySignedIdentityDocument legacyModel) {
+ IdentityDocument idDoc = legacyModel.identityDocument();
+ return new LegacySignedIdentityDocumentEntity(
+ legacyModel.signature(),
+ legacyModel.signingKeyVersion(),
+ idDoc.providerUniqueId().asDottedString(),
+ idDoc.providerService().getFullName(),
+ legacyModel.documentVersion(),
+ idDoc.configServerHostname(),
+ idDoc.instanceHostname(),
+ idDoc.createdAt(),
+ idDoc.ipAddresses(),
+ idDoc.identityType().id(),
+ Optional.ofNullable(idDoc.clusterType()).map(ClusterType::toConfigValue).orElse(null),
+ idDoc.ztsUrl(),
+ Optional.ofNullable(idDoc.serviceIdentity()).map(AthenzIdentity::getFullName).orElse(null),
+ idDoc.unknownAttributes());
+ } else if (model instanceof DefaultSignedIdentityDocument defaultModel){
+ return new DefaultSignedIdentityDocumentEntity(defaultModel.signature(),
+ defaultModel.signingKeyVersion(),
+ defaultModel.documentVersion(),
+ defaultModel.data());
+ } else {
+ throw new IllegalArgumentException("Unsupported model type: " + model.getClass().getName());
+ }
}
public static SignedIdentityDocument readSignedIdentityDocumentFromFile(Path file) {
@@ -104,4 +132,40 @@ public class EntityBindingsMapper {
}
}
+ public static IdentityDocument fromIdentityDocumentData(String data) {
+ byte[] decoded = Base64.getDecoder().decode(data);
+ IdentityDocumentEntity docEntity = Exceptions.uncheck(() -> mapper.readValue(decoded, IdentityDocumentEntity.class));
+ return new IdentityDocument(
+ fromDottedString(docEntity.providerUniqueId()),
+ new AthenzService(docEntity.providerService()),
+ docEntity.configServerHostname(),
+ docEntity.instanceHostname(),
+ docEntity.createdAt(),
+ docEntity.ipAddresses(),
+ IdentityType.fromId(docEntity.identityType()),
+ Optional.ofNullable(docEntity.clusterType()).map(ClusterType::from).orElse(null),
+ docEntity.ztsUrl(),
+ Optional.ofNullable(docEntity.serviceIdentity()).map(AthenzIdentities::from).orElse(null),
+ docEntity.unknownAttributes());
+ }
+
+ public static String toIdentityDocmentData(IdentityDocument identityDocument) {
+ IdentityDocumentEntity documentEntity = new IdentityDocumentEntity(
+ identityDocument.providerUniqueId().asDottedString(),
+ identityDocument.providerService().getFullName(),
+ identityDocument.configServerHostname(),
+ identityDocument.instanceHostname(),
+ identityDocument.createdAt(),
+ identityDocument.ipAddresses(),
+ identityDocument.identityType().id(),
+ Optional.ofNullable(identityDocument.clusterType()).map(ClusterType::toConfigValue).orElse(null),
+ identityDocument.ztsUrl(),
+ identityDocument.serviceIdentity().getFullName());
+ try {
+ byte[] bytes = mapper.writeValueAsBytes(documentEntity);
+ return Base64.getEncoder().encodeToString(bytes);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Error during serialization of identity document.", e);
+ }
+ }
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java
new file mode 100644
index 00000000000..577584db185
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocument.java
@@ -0,0 +1,54 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api;
+
+import com.yahoo.vespa.athenz.api.AthenzIdentity;
+import com.yahoo.vespa.athenz.api.AthenzService;
+
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Represents an unsigned identity document
+ * @author mortent
+ */
+public record IdentityDocument(VespaUniqueInstanceId providerUniqueId, AthenzService providerService, String configServerHostname,
+ String instanceHostname, Instant createdAt, Set<String> ipAddresses,
+ IdentityType identityType, ClusterType clusterType, String ztsUrl,
+ AthenzIdentity serviceIdentity, Map<String, Object> unknownAttributes) {
+
+ public IdentityDocument {
+ ipAddresses = Set.copyOf(ipAddresses);
+
+ Map<String, Object> nonNull = new HashMap<>();
+ unknownAttributes.forEach((key, value) -> {
+ if (value != null) nonNull.put(key, value);
+ });
+ // Map.copyOf() does not allow null values
+ unknownAttributes = Map.copyOf(nonNull);
+ }
+
+ public IdentityDocument(VespaUniqueInstanceId providerUniqueId, AthenzService providerService, String configServerHostname,
+ String instanceHostname, Instant createdAt, Set<String> ipAddresses,
+ IdentityType identityType, ClusterType clusterType, String ztsUrl,
+ AthenzIdentity serviceIdentity) {
+ this(providerUniqueId, providerService, configServerHostname, instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, Map.of());
+ }
+
+
+ public IdentityDocument withServiceIdentity(AthenzService athenzService) {
+ return new IdentityDocument(
+ this.providerUniqueId,
+ this.providerService,
+ this.configServerHostname,
+ this.instanceHostname,
+ this.createdAt,
+ this.ipAddresses,
+ this.identityType,
+ this.clusterType,
+ this.ztsUrl,
+ athenzService,
+ this.unknownAttributes);
+ }
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java
index 5a0f77ec765..a3c2f0264d3 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/IdentityDocumentClient.java
@@ -1,12 +1,15 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.athenz.identityprovider.api;
+import java.util.Optional;
+import java.util.OptionalInt;
+
/**
* A client that communicates that fetches an identity document.
*
* @author bjorncs
*/
public interface IdentityDocumentClient {
- SignedIdentityDocument getNodeIdentityDocument(String host);
- SignedIdentityDocument getTenantIdentityDocument(String host);
+ SignedIdentityDocument getNodeIdentityDocument(String host, int documentVersion);
+ Optional<SignedIdentityDocument> getTenantIdentityDocument(String host, int documentVersion);
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java
new file mode 100644
index 00000000000..220bc72a017
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/LegacySignedIdentityDocument.java
@@ -0,0 +1,6 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api;
+
+public record LegacySignedIdentityDocument(String signature, int signingKeyVersion, int documentVersion,
+ IdentityDocument identityDocument) implements SignedIdentityDocument {
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java
index de78d81cd1b..4e3bd8dee91 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/SignedIdentityDocument.java
@@ -1,54 +1,20 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.athenz.identityprovider.api;
-import com.yahoo.vespa.athenz.api.AthenzIdentity;
-import com.yahoo.vespa.athenz.api.AthenzService;
-
-import java.net.URL;
-import java.time.Instant;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
/**
* A signed identity document.
- * The {@link #unknownAttributes()} member provides forward compatibility and ensures any new/unknown fields are kept intact when serialized to JSON.
- *
* @author bjorncs
+ * @author mortent
*/
-public record SignedIdentityDocument(String signature, int signingKeyVersion, VespaUniqueInstanceId providerUniqueId,
- AthenzService providerService, int documentVersion, String configServerHostname,
- String instanceHostname, Instant createdAt, Set<String> ipAddresses,
- IdentityType identityType, ClusterType clusterType, String ztsUrl,
- AthenzIdentity serviceIdentity, Map<String, Object> unknownAttributes) {
-
- public SignedIdentityDocument {
- ipAddresses = Set.copyOf(ipAddresses);
-
- Map<String, Object> nonNull = new HashMap<>();
- unknownAttributes.forEach((key, value) -> {
- if (value != null) nonNull.put(key, value);
- });
- // Map.copyOf() does not allow null values
- unknownAttributes = Map.copyOf(nonNull);
- }
-
- public SignedIdentityDocument(String signature, int signingKeyVersion, VespaUniqueInstanceId providerUniqueId,
- AthenzService providerService, int documentVersion, String configServerHostname,
- String instanceHostname, Instant createdAt, Set<String> ipAddresses,
- IdentityType identityType, ClusterType clusterType, String ztsUrl, AthenzIdentity serviceIdentity) {
- this(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname,
- instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, Map.of());
- }
-
- public static final int DEFAULT_DOCUMENT_VERSION = 3;
-
- public boolean outdated() { return documentVersion < DEFAULT_DOCUMENT_VERSION; }
+public interface SignedIdentityDocument {
- public SignedIdentityDocument withServiceIdentity(AthenzIdentity identity) {
- return new SignedIdentityDocument(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname, instanceHostname, createdAt,
- ipAddresses, identityType, clusterType, ztsUrl, identity);
- }
+ int LEGACY_DEFAULT_DOCUMENT_VERSION = 3;
+ int DEFAULT_DOCUMENT_VERSION = 4;
+ default boolean outdated() { return documentVersion() < LEGACY_DEFAULT_DOCUMENT_VERSION; }
+ IdentityDocument identityDocument();
+ String signature();
+ int signingKeyVersion();
+ int documentVersion();
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java
new file mode 100644
index 00000000000..3aaff011415
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/DefaultSignedIdentityDocumentEntity.java
@@ -0,0 +1,12 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api.bindings;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+public record DefaultSignedIdentityDocumentEntity(
+ @JsonProperty("signature") String signature,
+ @JsonProperty("signing-key-version") int signingKeyVersion,
+ @JsonProperty("document-version") int documentVersion,
+ @JsonProperty("data") String data)
+ implements SignedIdentityDocumentEntity {
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java
new file mode 100644
index 00000000000..946eacc67eb
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/IdentityDocumentEntity.java
@@ -0,0 +1,51 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api.bindings;
+
+import com.fasterxml.jackson.annotation.JsonAnyGetter;
+import com.fasterxml.jackson.annotation.JsonAnySetter;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * @author bjorncs
+ * @author mortent
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public record IdentityDocumentEntity(String providerUniqueId, String providerService,
+ String configServerHostname, String instanceHostname, Instant createdAt, Set<String> ipAddresses,
+ String identityType, String clusterType, String ztsUrl, String serviceIdentity, Map<String, Object> unknownAttributes) {
+
+ @JsonCreator
+ public IdentityDocumentEntity(@JsonProperty("provider-unique-id") String providerUniqueId,
+ @JsonProperty("provider-service") String providerService,
+ @JsonProperty("configserver-hostname") String configServerHostname,
+ @JsonProperty("instance-hostname") String instanceHostname,
+ @JsonProperty("created-at") Instant createdAt,
+ @JsonProperty("ip-addresses") Set<String> ipAddresses,
+ @JsonProperty("identity-type") String identityType,
+ @JsonProperty("cluster-type") String clusterType,
+ @JsonProperty("zts-url") String ztsUrl,
+ @JsonProperty("service-identity") String serviceIdentity) {
+ this(providerUniqueId, providerService, configServerHostname,
+ instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, new HashMap<>());
+ }
+
+ @JsonProperty("provider-unique-id") @Override public String providerUniqueId() { return providerUniqueId; }
+ @JsonProperty("provider-service") @Override public String providerService() { return providerService; }
+ @JsonProperty("configserver-hostname") @Override public String configServerHostname() { return configServerHostname; }
+ @JsonProperty("instance-hostname") @Override public String instanceHostname() { return instanceHostname; }
+ @JsonProperty("created-at") @Override public Instant createdAt() { return createdAt; }
+ @JsonProperty("ip-addresses") @Override public Set<String> ipAddresses() { return ipAddresses; }
+ @JsonProperty("identity-type") @Override public String identityType() { return identityType; }
+ @JsonProperty("cluster-type") @Override public String clusterType() { return clusterType; }
+ @JsonProperty("zts-url") @Override public String ztsUrl() { return ztsUrl; }
+ @JsonProperty("service-identity") @Override public String serviceIdentity() { return serviceIdentity; }
+ @JsonAnyGetter @Override public Map<String, Object> unknownAttributes() { return unknownAttributes; }
+ @JsonAnySetter public void set(String name, Object value) { unknownAttributes.put(name, value); }
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java
new file mode 100644
index 00000000000..e00ab9978f6
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/LegacySignedIdentityDocumentEntity.java
@@ -0,0 +1,57 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.api.bindings;
+
+import com.fasterxml.jackson.annotation.JsonAnyGetter;
+import com.fasterxml.jackson.annotation.JsonAnySetter;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.time.Instant;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * @author bjorncs
+ */
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public record LegacySignedIdentityDocumentEntity (
+ String signature, int signingKeyVersion, String providerUniqueId, String providerService, int documentVersion,
+ String configServerHostname, String instanceHostname, Instant createdAt, Set<String> ipAddresses,
+ String identityType, String clusterType, String ztsUrl, String serviceIdentity, Map<String, Object> unknownAttributes) implements SignedIdentityDocumentEntity {
+
+ @JsonCreator
+ public LegacySignedIdentityDocumentEntity(@JsonProperty("signature") String signature,
+ @JsonProperty("signing-key-version") int signingKeyVersion,
+ @JsonProperty("provider-unique-id") String providerUniqueId,
+ @JsonProperty("provider-service") String providerService,
+ @JsonProperty("document-version") int documentVersion,
+ @JsonProperty("configserver-hostname") String configServerHostname,
+ @JsonProperty("instance-hostname") String instanceHostname,
+ @JsonProperty("created-at") Instant createdAt,
+ @JsonProperty("ip-addresses") Set<String> ipAddresses,
+ @JsonProperty("identity-type") String identityType,
+ @JsonProperty("cluster-type") String clusterType,
+ @JsonProperty("zts-url") String ztsUrl,
+ @JsonProperty("service-identity") String serviceIdentity) {
+ this(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname,
+ instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, new HashMap<>());
+ }
+
+ @JsonProperty("signature") @Override public String signature() { return signature; }
+ @JsonProperty("signing-key-version") @Override public int signingKeyVersion() { return signingKeyVersion; }
+ @JsonProperty("provider-unique-id") @Override public String providerUniqueId() { return providerUniqueId; }
+ @JsonProperty("provider-service") @Override public String providerService() { return providerService; }
+ @JsonProperty("document-version") @Override public int documentVersion() { return documentVersion; }
+ @JsonProperty("configserver-hostname") @Override public String configServerHostname() { return configServerHostname; }
+ @JsonProperty("instance-hostname") @Override public String instanceHostname() { return instanceHostname; }
+ @JsonProperty("created-at") @Override public Instant createdAt() { return createdAt; }
+ @JsonProperty("ip-addresses") @Override public Set<String> ipAddresses() { return ipAddresses; }
+ @JsonProperty("identity-type") @Override public String identityType() { return identityType; }
+ @JsonProperty("cluster-type") @Override public String clusterType() { return clusterType; }
+ @JsonProperty("zts-url") @Override public String ztsUrl() { return ztsUrl; }
+ @JsonProperty("service-identity") @Override public String serviceIdentity() { return serviceIdentity; }
+ @JsonAnyGetter @Override public Map<String, Object> unknownAttributes() { return unknownAttributes; }
+ @JsonAnySetter public void set(String name, Object value) { unknownAttributes.put(name, value); }
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java
index fc0dff3b97b..174c76f7fa9 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/api/bindings/SignedIdentityDocumentEntity.java
@@ -1,57 +1,77 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.athenz.identityprovider.api.bindings;
-import com.fasterxml.jackson.annotation.JsonAnyGetter;
-import com.fasterxml.jackson.annotation.JsonAnySetter;
-import com.fasterxml.jackson.annotation.JsonCreator;
-import com.fasterxml.jackson.annotation.JsonInclude;
-import com.fasterxml.jackson.annotation.JsonProperty;
-
-import java.time.Instant;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * @author bjorncs
- */
-@JsonInclude(JsonInclude.Include.NON_NULL)
-public record SignedIdentityDocumentEntity(
- String signature, int signingKeyVersion, String providerUniqueId, String providerService, int documentVersion,
- String configServerHostname, String instanceHostname, Instant createdAt, Set<String> ipAddresses,
- String identityType, String clusterType, String ztsUrl, String serviceIdentity, Map<String, Object> unknownAttributes) {
-
- @JsonCreator
- public SignedIdentityDocumentEntity(@JsonProperty("signature") String signature,
- @JsonProperty("signing-key-version") int signingKeyVersion,
- @JsonProperty("provider-unique-id") String providerUniqueId,
- @JsonProperty("provider-service") String providerService,
- @JsonProperty("document-version") int documentVersion,
- @JsonProperty("configserver-hostname") String configServerHostname,
- @JsonProperty("instance-hostname") String instanceHostname,
- @JsonProperty("created-at") Instant createdAt,
- @JsonProperty("ip-addresses") Set<String> ipAddresses,
- @JsonProperty("identity-type") String identityType,
- @JsonProperty("cluster-type") String clusterType,
- @JsonProperty("zts-url") String ztsUrl,
- @JsonProperty("service-identity") String serviceIdentity) {
- this(signature, signingKeyVersion, providerUniqueId, providerService, documentVersion, configServerHostname,
- instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity, new HashMap<>());
- }
- @JsonProperty("signature") @Override public String signature() { return signature; }
- @JsonProperty("signing-key-version") @Override public int signingKeyVersion() { return signingKeyVersion; }
- @JsonProperty("provider-unique-id") @Override public String providerUniqueId() { return providerUniqueId; }
- @JsonProperty("provider-service") @Override public String providerService() { return providerService; }
- @JsonProperty("document-version") @Override public int documentVersion() { return documentVersion; }
- @JsonProperty("configserver-hostname") @Override public String configServerHostname() { return configServerHostname; }
- @JsonProperty("instance-hostname") @Override public String instanceHostname() { return instanceHostname; }
- @JsonProperty("created-at") @Override public Instant createdAt() { return createdAt; }
- @JsonProperty("ip-addresses") @Override public Set<String> ipAddresses() { return ipAddresses; }
- @JsonProperty("identity-type") @Override public String identityType() { return identityType; }
- @JsonProperty("cluster-type") @Override public String clusterType() { return clusterType; }
- @JsonProperty("zts-url") @Override public String ztsUrl() { return ztsUrl; }
- @JsonProperty("service-identity") @Override public String serviceIdentity() { return serviceIdentity; }
- @JsonAnyGetter @Override public Map<String, Object> unknownAttributes() { return unknownAttributes; }
- @JsonAnySetter public void set(String name, Object value) { unknownAttributes.put(name, value); }
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import com.fasterxml.jackson.databind.DatabindContext;
+import com.fasterxml.jackson.databind.JavaType;
+import com.fasterxml.jackson.databind.annotation.JsonTypeIdResolver;
+import com.fasterxml.jackson.databind.jsontype.TypeIdResolver;
+import com.fasterxml.jackson.databind.type.TypeFactory;
+import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
+
+import java.io.IOException;
+import java.util.Objects;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+@JsonTypeInfo(use = JsonTypeInfo.Id.CUSTOM, property = "document-version", visible = true)
+@JsonTypeIdResolver(SignedIdentityDocumentEntityTypeResolver.class)
+public interface SignedIdentityDocumentEntity {
+ int documentVersion();
}
+
+class SignedIdentityDocumentEntityTypeResolver implements TypeIdResolver {
+ JavaType javaType;
+
+ @Override
+ public void init(JavaType javaType) {
+ this.javaType = javaType;
+ }
+
+ @Override
+ public String idFromValue(Object o) {
+ return idFromValueAndType(o, o.getClass());
+ }
+
+ @Override
+ public String idFromValueAndType(Object o, Class<?> aClass) {
+ if (Objects.isNull(o)) {
+ throw new IllegalArgumentException("Cannot serialize null oject");
+ } else {
+ if (o instanceof SignedIdentityDocumentEntity s) {
+ return Integer.toString(s.documentVersion());
+ } else {
+ throw new IllegalArgumentException("Cannot serialize class: " + o.getClass());
+ }
+ }
+ }
+
+ @Override
+ public String idFromBaseType() {
+ return idFromValueAndType(null, javaType.getRawClass());
+ }
+
+ @Override
+ public JavaType typeFromId(DatabindContext databindContext, String s) throws IOException {
+ try {
+ int version = Integer.parseInt(s);
+ Class<? extends SignedIdentityDocumentEntity> cls = version <= SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION
+ ? LegacySignedIdentityDocumentEntity.class
+ : DefaultSignedIdentityDocumentEntity.class;
+ return TypeFactory.defaultInstance().constructSpecializedType(javaType,cls);
+ } catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Unable to deserialize document with version: \"%s\"".formatted(s));
+ }
+ }
+
+ @Override
+ public String getDescForKnownTypeIds() {
+ return "Type resolver for SignedIdentityDocumentEntity";
+ }
+
+ @Override
+ public JsonTypeInfo.Id getMechanism() {
+ return JsonTypeInfo.Id.CUSTOM;
+ }
+} \ No newline at end of file
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java
index cc9d3b2be65..d26386702d5 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzCredentialsService.java
@@ -11,6 +11,7 @@ import com.yahoo.vespa.athenz.client.zts.InstanceIdentity;
import com.yahoo.vespa.athenz.client.zts.ZtsClient;
import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocumentClient;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.tls.AthenzIdentityVerifier;
@@ -74,7 +75,9 @@ class AthenzCredentialsService {
}
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
IdentityDocumentClient identityDocumentClient = createIdentityDocumentClient();
- SignedIdentityDocument document = identityDocumentClient.getTenantIdentityDocument(hostname);
+ // Use legacy version for now.
+ SignedIdentityDocument signedDocument = identityDocumentClient.getTenantIdentityDocument(hostname, SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION).orElseThrow();
+ IdentityDocument document = signedDocument.identityDocument();
Pkcs10Csr csr = csrGenerator.generateInstanceCsr(
tenantIdentity,
document.providerUniqueId(),
@@ -87,16 +90,17 @@ class AthenzCredentialsService {
ztsClient.registerInstance(
configserverIdentity,
tenantIdentity,
- EntityBindingsMapper.toAttestationData(document),
+ EntityBindingsMapper.toAttestationData(signedDocument),
csr);
X509Certificate certificate = instanceIdentity.certificate();
- writeCredentialsToDisk(keyPair.getPrivate(), certificate, document);
- return new AthenzCredentials(certificate, keyPair, document);
+ writeCredentialsToDisk(keyPair.getPrivate(), certificate, signedDocument);
+ return new AthenzCredentials(certificate, keyPair, signedDocument);
}
}
- AthenzCredentials updateCredentials(SignedIdentityDocument document, SSLContext sslContext) {
+ AthenzCredentials updateCredentials(SignedIdentityDocument signedDocument, SSLContext sslContext) {
KeyPair newKeyPair = KeyUtils.generateKeypair(KeyAlgorithm.RSA);
+ IdentityDocument document = signedDocument.identityDocument();
Pkcs10Csr csr = csrGenerator.generateInstanceCsr(
tenantIdentity,
document.providerUniqueId(),
@@ -112,8 +116,8 @@ class AthenzCredentialsService {
document.providerUniqueId().asDottedString(),
csr);
X509Certificate certificate = instanceIdentity.certificate();
- writeCredentialsToDisk(newKeyPair.getPrivate(), certificate, document);
- return new AthenzCredentials(certificate, newKeyPair, document);
+ writeCredentialsToDisk(newKeyPair.getPrivate(), certificate, signedDocument);
+ return new AthenzCredentials(certificate, newKeyPair, signedDocument);
}
}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java
index b9f9f3862c2..77aaf17419d 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImpl.java
@@ -11,7 +11,9 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
import com.yahoo.jdisc.Metric;
import com.yahoo.metrics.ContainerMetrics;
+import com.yahoo.security.AutoReloadingX509KeyManager;
import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyUtils;
import com.yahoo.security.MutableX509KeyManager;
import com.yahoo.security.Pkcs10Csr;
import com.yahoo.security.SslContextBuilder;
@@ -24,9 +26,9 @@ import com.yahoo.vespa.athenz.api.ZToken;
import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient;
import com.yahoo.vespa.athenz.client.zts.ZtsClient;
import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
-import com.yahoo.vespa.athenz.identity.SiaIdentityProvider;
+import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
+import com.yahoo.vespa.athenz.tls.AthenzX509CertificateUtils;
import com.yahoo.vespa.athenz.utils.SiaUtils;
-import com.yahoo.vespa.defaults.Defaults;
import javax.net.ssl.SSLContext;
import javax.net.ssl.X509ExtendedKeyManager;
@@ -38,7 +40,6 @@ import java.security.cert.X509Certificate;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
-import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -65,7 +66,6 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
// TODO Make some of these values configurable through config. Match requested expiration of register/update requests.
// TODO These should match the requested expiration
- static final Duration UPDATE_PERIOD = Duration.ofDays(1);
static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90);
private final static Duration ROLE_SSL_CONTEXT_EXPIRY = Duration.ofHours(2);
// TODO CMS expects 10min or less token ttl. Use 10min default until we have configurable expiry
@@ -73,20 +73,17 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
// TODO Make path to trust store paths config
private static final Path CLIENT_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/yahoo_certificate_bundle.pem");
- private static final Path ATHENZ_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/athenz_certificate_bundle.pem");
public static final String CERTIFICATE_EXPIRY_METRIC_NAME = ContainerMetrics.ATHENZ_TENANT_CERT_EXPIRY_SECONDS.baseName();
- private volatile AthenzCredentials credentials;
private final Metric metric;
private final Path trustStore;
- private final AthenzCredentialsService athenzCredentialsService;
private final ScheduledExecutorService scheduler;
private final Clock clock;
private final AthenzService identity;
private final URI ztsEndpoint;
- private final MutableX509KeyManager identityKeyManager = new MutableX509KeyManager();
+ private final AutoReloadingX509KeyManager autoReloadingX509KeyManager;
private final SSLContext identitySslContext;
private final LoadingCache<AthenzRole, X509Certificate> roleSslCertCache;
private final Map<AthenzRole, MutableX509KeyManager> roleKeyManagerCache;
@@ -98,40 +95,32 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
@Inject
public AthenzIdentityProviderImpl(IdentityConfig config, Metric metric) {
- this(config,
- metric,
- CLIENT_TRUST_STORE,
- new AthenzCredentialsService(config,
- createNodeIdentityProvider(config),
- Defaults.getDefaults().vespaHostname(),
- Clock.systemUTC()),
- new ScheduledThreadPoolExecutor(1),
- Clock.systemUTC());
+ this(config, metric, CLIENT_TRUST_STORE, new ScheduledThreadPoolExecutor(1), Clock.systemUTC(), createAutoReloadingX509KeyManager(config));
}
// Test only
AthenzIdentityProviderImpl(IdentityConfig config,
Metric metric,
Path trustStore,
- AthenzCredentialsService athenzCredentialsService,
ScheduledExecutorService scheduler,
- Clock clock) {
+ Clock clock,
+ AutoReloadingX509KeyManager autoReloadingX509KeyManager) {
this.metric = metric;
this.trustStore = trustStore;
- this.athenzCredentialsService = athenzCredentialsService;
this.scheduler = scheduler;
this.clock = clock;
this.identity = new AthenzService(config.domain(), config.service());
this.ztsEndpoint = URI.create(config.ztsUrl());
- roleSslCertCache = crateAutoReloadableCache(ROLE_SSL_CONTEXT_EXPIRY, this::requestRoleCertificate, this.scheduler);
- roleKeyManagerCache = new HashMap<>();
- roleSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken);
- domainSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken);
- domainSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken);
- roleSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken);
+ this.roleSslCertCache = crateAutoReloadableCache(ROLE_SSL_CONTEXT_EXPIRY, this::requestRoleCertificate, this.scheduler);
+ this.roleKeyManagerCache = new HashMap<>();
+ this.roleSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken);
+ this.domainSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken);
+ this.domainSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken);
+ this.roleSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken);
this.csrGenerator = new CsrGenerator(config.athenzDnsSuffix(), config.configserverIdentityName());
- this.identitySslContext = createIdentitySslContext(identityKeyManager, trustStore);
- registerInstance();
+ this.autoReloadingX509KeyManager = autoReloadingX509KeyManager;
+ this.identitySslContext = createIdentitySslContext(autoReloadingX509KeyManager, trustStore);
+ this.scheduler.scheduleAtFixedRate(this::reportMetrics, 0, 5, TimeUnit.MINUTES);
}
private static <KEY, VALUE> LoadingCache<KEY, VALUE> createCache(Duration expiry, Function<KEY, VALUE> cacheLoader) {
@@ -165,16 +154,6 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
.build();
}
- private void registerInstance() {
- try {
- updateIdentityCredentials(this.athenzCredentialsService.registerInstance());
- this.scheduler.scheduleAtFixedRate(this::refreshCertificate, UPDATE_PERIOD.toMinutes(), UPDATE_PERIOD.toMinutes(), TimeUnit.MINUTES);
- this.scheduler.scheduleAtFixedRate(this::reportMetrics, 0, 5, TimeUnit.MINUTES);
- } catch (Throwable t) {
- throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", t);
- }
- }
-
@Override
public AthenzService identity() {
return identity;
@@ -197,13 +176,13 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
@Override
public X509CertificateWithKey getIdentityCertificateWithKey() {
- AthenzCredentials copy = this.credentials;
- return new X509CertificateWithKey(copy.getCertificate(), copy.getKeyPair().getPrivate());
+ var copy = this.autoReloadingX509KeyManager.getCurrentCertificateWithKey();
+ return new X509CertificateWithKey(copy.certificate(), copy.privateKey());
}
- @Override public Path certificatePath() { return athenzCredentialsService.certificatePath(); }
+ @Override public Path certificatePath() { return SiaUtils.getCertificateFile(identity); }
- @Override public Path privateKeyPath() { return athenzCredentialsService.privateKeyPath(); }
+ @Override public Path privateKeyPath() { return SiaUtils.getPrivateKeyFile(identity); }
@Override
public SSLContext getRoleSslContext(String domain, String role) {
@@ -262,7 +241,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
@Override
public PrivateKey getPrivateKey() {
- return credentials.getKeyPair().getPrivate();
+ return autoReloadingX509KeyManager.getPrivateKey(AutoReloadingX509KeyManager.CERTIFICATE_ALIAS);
}
@Override
@@ -272,7 +251,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
@Override
public List<X509Certificate> getIdentityCertificate() {
- return Collections.singletonList(credentials.getCertificate());
+ return List.of(autoReloadingX509KeyManager.getCertificateChain(AutoReloadingX509KeyManager.CERTIFICATE_ALIAS));
}
@Override
@@ -288,19 +267,15 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
}
}
- private void updateIdentityCredentials(AthenzCredentials credentials) {
- this.credentials = credentials;
- this.identityKeyManager.updateKeystore(
- KeyStoreBuilder.withType(PKCS12)
- .withKeyEntry("default", credentials.getKeyPair().getPrivate(), credentials.getCertificate())
- .build(),
- new char[0]);
- }
-
private X509Certificate requestRoleCertificate(AthenzRole role) {
- var doc = credentials.getIdentityDocument();
+ var credentials = autoReloadingX509KeyManager.getCurrentCertificateWithKey();
+ var athenzUniqueInstanceId = VespaUniqueInstanceId.fromDottedString(
+ AthenzX509CertificateUtils.getInstanceId(credentials.certificate())
+ .orElseThrow()
+ );
+ var keyPair = KeyUtils.toKeyPair(credentials.privateKey());
Pkcs10Csr csr = csrGenerator.generateRoleCsr(
- identity, role, doc.providerUniqueId(), doc.clusterType(), credentials.getKeyPair());
+ identity, role, athenzUniqueInstanceId, null, keyPair);
try (ZtsClient client = createZtsClient()) {
X509Certificate roleCertificate = client.getRoleCertificate(role, csr);
updateRoleKeyManager(role, roleCertificate);
@@ -313,7 +288,7 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
MutableX509KeyManager keyManager = roleKeyManagerCache.computeIfAbsent(role, r -> new MutableX509KeyManager());
keyManager.updateKeystore(
KeyStoreBuilder.withType(PKCS12)
- .withKeyEntry("default", credentials.getKeyPair().getPrivate(), certificate)
+ .withKeyEntry("default", autoReloadingX509KeyManager.getCurrentCertificateWithKey().privateKey(), certificate)
.build(),
new char[0]);
}
@@ -346,6 +321,11 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
return new DefaultZtsClient.Builder(ztsEndpoint).withSslContext(getIdentitySslContext()).build();
}
+ private static AutoReloadingX509KeyManager createAutoReloadingX509KeyManager(IdentityConfig config) {
+ var tenantIdentity = new AthenzService(config.domain(), config.service());
+ return AutoReloadingX509KeyManager.fromPemFiles(SiaUtils.getPrivateKeyFile(tenantIdentity), SiaUtils.getCertificateFile(tenantIdentity));
+ }
+
@Override
public void deconstruct() {
try {
@@ -356,32 +336,13 @@ public final class AthenzIdentityProviderImpl extends AbstractComponent implemen
}
}
- private static SiaIdentityProvider createNodeIdentityProvider(IdentityConfig config) {
- return new SiaIdentityProvider(
- new AthenzService(config.nodeIdentityName()), SiaUtils.DEFAULT_SIA_DIRECTORY, CLIENT_TRUST_STORE);
- }
-
- private boolean isExpired(AthenzCredentials credentials) {
- return clock.instant().isAfter(getExpirationTime(credentials));
- }
-
- private static Instant getExpirationTime(AthenzCredentials credentials) {
- return credentials.getCertificate().getNotAfter().toInstant();
- }
-
- void refreshCertificate() {
- try {
- updateIdentityCredentials(isExpired(credentials)
- ? athenzCredentialsService.registerInstance()
- : athenzCredentialsService.updateCredentials(credentials.getIdentityDocument(), identitySslContext));
- } catch (Throwable t) {
- log.log(Level.WARNING, "Failed to update credentials: " + t.getMessage(), t);
- }
+ private static Instant getExpirationTime(X509Certificate certificate) {
+ return certificate.getNotAfter().toInstant();
}
void reportMetrics() {
try {
- Instant expirationTime = getExpirationTime(credentials);
+ Instant expirationTime = getExpirationTime(autoReloadingX509KeyManager.getCurrentCertificateWithKey().certificate());
Duration remainingLifetime = Duration.between(clock.instant(), expirationTime);
metric.set(CERTIFICATE_EXPIRY_METRIC_NAME, remainingLifetime.getSeconds(), null);
} catch (Throwable t) {
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java
new file mode 100644
index 00000000000..66dad931815
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderProvider.java
@@ -0,0 +1,38 @@
+package com.yahoo.vespa.athenz.identityprovider.client;
+
+import com.yahoo.container.core.identity.IdentityConfig;
+import com.yahoo.container.di.componentgraph.Provider;
+import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
+import com.yahoo.jdisc.Metric;
+
+import javax.inject.Inject;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+
+/**
+ * @author olaa
+ */
+public class AthenzIdentityProviderProvider implements Provider<AthenzIdentityProvider> {
+
+ private final Path NODE_ADMIN_MANAGED_IDENTITY_DOCUMENT = Paths.get("/var/lib/sia/vespa-tenant-identity-document.json");
+ private final AthenzIdentityProvider athenzIdentityProvider;
+
+ @Inject
+ public AthenzIdentityProviderProvider(IdentityConfig config, Metric metric) {
+ if (Files.exists(NODE_ADMIN_MANAGED_IDENTITY_DOCUMENT))
+ athenzIdentityProvider = new AthenzIdentityProviderImpl(config, metric);
+ else
+ athenzIdentityProvider = new LegacyAthenzIdentityProviderImpl(config, metric);
+ }
+
+ @Override
+ public void deconstruct() {
+ athenzIdentityProvider.deconstruct();
+ }
+
+ @Override
+ public AthenzIdentityProvider get() {
+ return athenzIdentityProvider;
+ }
+}
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java
index 5b884e3dfb3..f95a3335c24 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/DefaultIdentityDocumentClient.java
@@ -23,6 +23,7 @@ import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.time.Duration;
+import java.util.Optional;
import java.util.function.Supplier;
/**
@@ -56,16 +57,16 @@ public class DefaultIdentityDocumentClient implements IdentityDocumentClient {
}
@Override
- public SignedIdentityDocument getNodeIdentityDocument(String host) {
- return getIdentityDocument(host, "node");
+ public SignedIdentityDocument getNodeIdentityDocument(String host, int documentVersion) {
+ return getIdentityDocument(host, "node", documentVersion).orElseThrow();
}
@Override
- public SignedIdentityDocument getTenantIdentityDocument(String host) {
- return getIdentityDocument(host, "tenant");
+ public Optional<SignedIdentityDocument> getTenantIdentityDocument(String host, int documentVersion) {
+ return getIdentityDocument(host, "tenant", documentVersion);
}
- private SignedIdentityDocument getIdentityDocument(String host, String type) {
+ private Optional<SignedIdentityDocument> getIdentityDocument(String host, String type, int documentVersion) {
try (CloseableHttpClient client = createHttpClient(sslContextSupplier.get(), hostnameVerifier)) {
URI uri = configserverUri
@@ -76,13 +77,16 @@ public class DefaultIdentityDocumentClient implements IdentityDocumentClient {
.setUri(uri)
.addHeader("Connection", "close")
.addHeader("Accept", "application/json")
+ .addParameter("documentVersion", Integer.toString(documentVersion))
.build();
try (CloseableHttpResponse response = client.execute(request)) {
String responseContent = EntityUtils.toString(response.getEntity());
int statusCode = response.getStatusLine().getStatusCode();
if (statusCode >= 200 && statusCode <= 299) {
SignedIdentityDocumentEntity entity = objectMapper.readValue(responseContent, SignedIdentityDocumentEntity.class);
- return EntityBindingsMapper.toSignedIdentityDocument(entity);
+ return Optional.of(EntityBindingsMapper.toSignedIdentityDocument(entity));
+ } else if (statusCode == 404) {
+ return Optional.empty();
} else {
throw new RuntimeException(
String.format(
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java
index 019f73fc6bf..11b30585933 100644
--- a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSigner.java
@@ -4,7 +4,10 @@ package com.yahoo.vespa.athenz.identityprovider.client;
import com.yahoo.security.SignatureUtils;
import com.yahoo.vespa.athenz.api.AthenzIdentity;
import com.yahoo.vespa.athenz.api.AthenzService;
+import com.yahoo.vespa.athenz.identityprovider.api.DefaultSignedIdentityDocument;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityType;
+import com.yahoo.vespa.athenz.identityprovider.api.LegacySignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
@@ -19,7 +22,7 @@ import java.util.Base64;
import java.util.Set;
import java.util.TreeSet;
-import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION;
+import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
@@ -29,8 +32,25 @@ import static java.nio.charset.StandardCharsets.UTF_8;
*/
public class IdentityDocumentSigner {
+ public String generateSignature(String identityDocumentData, PrivateKey privateKey) {
+ try {
+ Signature signer = SignatureUtils.createSigner(privateKey);
+ signer.initSign(privateKey);
+ signer.update(identityDocumentData.getBytes(UTF_8));
+ byte[] signature = signer.sign();
+ return Base64.getEncoder().encodeToString(signature);
+ } catch (GeneralSecurityException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public String generateLegacySignature(IdentityDocument doc, PrivateKey privateKey) {
+ return generateSignature(doc.providerUniqueId(), doc.providerService(), doc.configServerHostname(),
+ doc.instanceHostname(), doc.createdAt(), doc.ipAddresses(), doc.identityType(), privateKey, doc.serviceIdentity());
+ }
+
// Cluster type is ignored due to old Vespa versions not forwarding unknown fields in signed identity document
- public String generateSignature(VespaUniqueInstanceId providerUniqueId,
+ private String generateSignature(VespaUniqueInstanceId providerUniqueId,
AthenzService providerService,
String configServerHostname,
String instanceHostname,
@@ -54,14 +74,32 @@ public class IdentityDocumentSigner {
}
public boolean hasValidSignature(SignedIdentityDocument doc, PublicKey publicKey) {
+ if (doc instanceof LegacySignedIdentityDocument signedDoc) {
+ return validateLegacySignature(signedDoc, publicKey);
+ } else if (doc instanceof DefaultSignedIdentityDocument signedDoc) {
+ try {
+ Signature signer = SignatureUtils.createVerifier(publicKey);
+ signer.initVerify(publicKey);
+ signer.update(signedDoc.data().getBytes(UTF_8));
+ return signer.verify(Base64.getDecoder().decode(doc.signature()));
+ } catch (GeneralSecurityException e) {
+ throw new RuntimeException(e);
+ }
+ } else {
+ throw new IllegalArgumentException("Unknown identity document type: " + doc.getClass().getName());
+ }
+ }
+
+ private boolean validateLegacySignature(SignedIdentityDocument doc, PublicKey publicKey) {
try {
+ IdentityDocument iddoc = doc.identityDocument();
Signature signer = SignatureUtils.createVerifier(publicKey);
signer.initVerify(publicKey);
writeToSigner(
- signer, doc.providerUniqueId(), doc.providerService(), doc.configServerHostname(),
- doc.instanceHostname(), doc.createdAt(), doc.ipAddresses(), doc.identityType());
- if (doc.documentVersion() >= DEFAULT_DOCUMENT_VERSION) {
- writeToSigner(signer, doc.serviceIdentity());
+ signer, iddoc.providerUniqueId(), iddoc.providerService(), iddoc.configServerHostname(),
+ iddoc.instanceHostname(), iddoc.createdAt(), iddoc.ipAddresses(), iddoc.identityType());
+ if (doc.documentVersion() >= LEGACY_DEFAULT_DOCUMENT_VERSION) {
+ writeToSigner(signer, iddoc.serviceIdentity());
}
return signer.verify(Base64.getDecoder().decode(doc.signature()));
} catch (GeneralSecurityException e) {
diff --git a/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java
new file mode 100644
index 00000000000..d699564a4ee
--- /dev/null
+++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImpl.java
@@ -0,0 +1,392 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.client;
+
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.CacheLoader;
+import com.google.common.cache.LoadingCache;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.container.core.identity.IdentityConfig;
+import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider;
+import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
+import com.yahoo.jdisc.Metric;
+import com.yahoo.metrics.ContainerMetrics;
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.MutableX509KeyManager;
+import com.yahoo.security.Pkcs10Csr;
+import com.yahoo.security.SslContextBuilder;
+import com.yahoo.security.X509CertificateWithKey;
+import com.yahoo.vespa.athenz.api.AthenzAccessToken;
+import com.yahoo.vespa.athenz.api.AthenzDomain;
+import com.yahoo.vespa.athenz.api.AthenzRole;
+import com.yahoo.vespa.athenz.api.AthenzService;
+import com.yahoo.vespa.athenz.api.ZToken;
+import com.yahoo.vespa.athenz.client.zts.DefaultZtsClient;
+import com.yahoo.vespa.athenz.client.zts.ZtsClient;
+import com.yahoo.vespa.athenz.identity.ServiceIdentityProvider;
+import com.yahoo.vespa.athenz.identity.SiaIdentityProvider;
+import com.yahoo.vespa.athenz.utils.SiaUtils;
+import com.yahoo.vespa.defaults.Defaults;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.X509ExtendedKeyManager;
+import java.net.URI;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.security.PrivateKey;
+import java.security.cert.X509Certificate;
+import java.time.Clock;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.Function;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+
+import static com.yahoo.security.KeyStoreType.PKCS12;
+
+/**
+ * A {@link AthenzIdentityProvider} / {@link ServiceIdentityProvider} component that provides the tenant identity.
+ *
+ * @author mortent
+ * @author bjorncs
+ */
+// This class should probably not implement ServiceIdentityProvider,
+// as that interface is intended for providing the node's identity, not the tenant's application identity.
+public final class LegacyAthenzIdentityProviderImpl extends AbstractComponent implements AthenzIdentityProvider, ServiceIdentityProvider {
+
+ private static final Logger log = Logger.getLogger(LegacyAthenzIdentityProviderImpl.class.getName());
+
+ // TODO Make some of these values configurable through config. Match requested expiration of register/update requests.
+ // TODO These should match the requested expiration
+ static final Duration UPDATE_PERIOD = Duration.ofDays(1);
+ static final Duration AWAIT_TERMINTATION_TIMEOUT = Duration.ofSeconds(90);
+ private final static Duration ROLE_SSL_CONTEXT_EXPIRY = Duration.ofHours(2);
+ // TODO CMS expects 10min or less token ttl. Use 10min default until we have configurable expiry
+ private final static Duration ROLE_TOKEN_EXPIRY = Duration.ofMinutes(10);
+
+ // TODO Make path to trust store paths config
+ private static final Path CLIENT_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/yahoo_certificate_bundle.pem");
+ private static final Path ATHENZ_TRUST_STORE = Paths.get("/opt/yahoo/share/ssl/certs/athenz_certificate_bundle.pem");
+
+ public static final String CERTIFICATE_EXPIRY_METRIC_NAME = ContainerMetrics.ATHENZ_TENANT_CERT_EXPIRY_SECONDS.baseName();
+
+ private volatile AthenzCredentials credentials;
+ private final Metric metric;
+ private final Path trustStore;
+ private final AthenzCredentialsService athenzCredentialsService;
+ private final ScheduledExecutorService scheduler;
+ private final Clock clock;
+ private final AthenzService identity;
+ private final URI ztsEndpoint;
+
+ private final MutableX509KeyManager identityKeyManager = new MutableX509KeyManager();
+ private final SSLContext identitySslContext;
+ private final LoadingCache<AthenzRole, X509Certificate> roleSslCertCache;
+ private final Map<AthenzRole, MutableX509KeyManager> roleKeyManagerCache;
+ private final LoadingCache<AthenzRole, ZToken> roleSpecificRoleTokenCache;
+ private final LoadingCache<AthenzDomain, ZToken> domainSpecificRoleTokenCache;
+ private final LoadingCache<AthenzDomain, AthenzAccessToken> domainSpecificAccessTokenCache;
+ private final LoadingCache<List<AthenzRole>, AthenzAccessToken> roleSpecificAccessTokenCache;
+ private final CsrGenerator csrGenerator;
+
+ @Inject
+ public LegacyAthenzIdentityProviderImpl(IdentityConfig config, Metric metric) {
+ this(config,
+ metric,
+ CLIENT_TRUST_STORE,
+ new AthenzCredentialsService(config,
+ createNodeIdentityProvider(config),
+ Defaults.getDefaults().vespaHostname(),
+ Clock.systemUTC()),
+ new ScheduledThreadPoolExecutor(1),
+ Clock.systemUTC());
+ }
+
+ // Test only
+ LegacyAthenzIdentityProviderImpl(IdentityConfig config,
+ Metric metric,
+ Path trustStore,
+ AthenzCredentialsService athenzCredentialsService,
+ ScheduledExecutorService scheduler,
+ Clock clock) {
+ this.metric = metric;
+ this.trustStore = trustStore;
+ this.athenzCredentialsService = athenzCredentialsService;
+ this.scheduler = scheduler;
+ this.clock = clock;
+ this.identity = new AthenzService(config.domain(), config.service());
+ this.ztsEndpoint = URI.create(config.ztsUrl());
+ roleSslCertCache = crateAutoReloadableCache(ROLE_SSL_CONTEXT_EXPIRY, this::requestRoleCertificate, this.scheduler);
+ roleKeyManagerCache = new HashMap<>();
+ roleSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken);
+ domainSpecificRoleTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createRoleToken);
+ domainSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken);
+ roleSpecificAccessTokenCache = createCache(ROLE_TOKEN_EXPIRY, this::createAccessToken);
+ this.csrGenerator = new CsrGenerator(config.athenzDnsSuffix(), config.configserverIdentityName());
+ this.identitySslContext = createIdentitySslContext(identityKeyManager, trustStore);
+ registerInstance();
+ }
+
+ private static <KEY, VALUE> LoadingCache<KEY, VALUE> createCache(Duration expiry, Function<KEY, VALUE> cacheLoader) {
+ return CacheBuilder.newBuilder()
+ .refreshAfterWrite(expiry.dividedBy(2).toMinutes(), TimeUnit.MINUTES)
+ .expireAfterWrite(expiry.toMinutes(), TimeUnit.MINUTES)
+ .build(new CacheLoader<KEY, VALUE>() {
+ @Override
+ public VALUE load(KEY key) {
+ return cacheLoader.apply(key);
+ }
+ });
+ }
+
+ private static <KEY, VALUE> LoadingCache<KEY, VALUE> crateAutoReloadableCache(Duration expiry, Function<KEY, VALUE> cacheLoader, ScheduledExecutorService scheduler) {
+ LoadingCache<KEY, VALUE> cache = createCache(expiry, cacheLoader);
+
+ // The cache above will reload it's contents if and only if a request for the key is made. Scheduling
+ // a cache reloader to reload all keys in this cache.
+ scheduler.scheduleAtFixedRate(() -> { cache.asMap().keySet().forEach(cache::getUnchecked);},
+ expiry.dividedBy(4).toMinutes(),
+ expiry.dividedBy(4).toMinutes(),
+ TimeUnit.MINUTES);
+ return cache;
+ }
+
+ private static SSLContext createIdentitySslContext(X509ExtendedKeyManager keyManager, Path trustStore) {
+ return new SslContextBuilder()
+ .withKeyManager(keyManager)
+ .withTrustStore(trustStore)
+ .build();
+ }
+
+ private void registerInstance() {
+ try {
+ updateIdentityCredentials(this.athenzCredentialsService.registerInstance());
+ this.scheduler.scheduleAtFixedRate(this::refreshCertificate, UPDATE_PERIOD.toMinutes(), UPDATE_PERIOD.toMinutes(), TimeUnit.MINUTES);
+ this.scheduler.scheduleAtFixedRate(this::reportMetrics, 0, 5, TimeUnit.MINUTES);
+ } catch (Throwable t) {
+ throw new AthenzIdentityProviderException("Could not retrieve Athenz credentials", t);
+ }
+ }
+
+ @Override
+ public AthenzService identity() {
+ return identity;
+ }
+
+ @Override
+ public String domain() {
+ return identity.getDomain().getName();
+ }
+
+ @Override
+ public String service() {
+ return identity.getName();
+ }
+
+ @Override
+ public SSLContext getIdentitySslContext() {
+ return identitySslContext;
+ }
+
+ @Override
+ public X509CertificateWithKey getIdentityCertificateWithKey() {
+ AthenzCredentials copy = this.credentials;
+ return new X509CertificateWithKey(copy.getCertificate(), copy.getKeyPair().getPrivate());
+ }
+
+ @Override public Path certificatePath() { return athenzCredentialsService.certificatePath(); }
+
+ @Override public Path privateKeyPath() { return athenzCredentialsService.privateKeyPath(); }
+
+ @Override
+ public SSLContext getRoleSslContext(String domain, String role) {
+ try {
+ AthenzRole athenzRole = new AthenzRole(new AthenzDomain(domain), role);
+ // Make sure to request a certificate which triggers creating a new key manager for this role
+ X509Certificate x509Certificate = getRoleCertificate(athenzRole);
+ MutableX509KeyManager keyManager = roleKeyManagerCache.get(athenzRole);
+ return new SslContextBuilder()
+ .withKeyManager(keyManager)
+ .withTrustStore(trustStore)
+ .build();
+ } catch (Exception e) {
+ throw new AthenzIdentityProviderException("Could not retrieve role certificate: " + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public String getRoleToken(String domain) {
+ try {
+ return domainSpecificRoleTokenCache.get(new AthenzDomain(domain)).getRawToken();
+ } catch (Exception e) {
+ throw new AthenzIdentityProviderException("Could not retrieve role token: " + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public String getRoleToken(String domain, String role) {
+ try {
+ return roleSpecificRoleTokenCache.get(new AthenzRole(domain, role)).getRawToken();
+ } catch (Exception e) {
+ throw new AthenzIdentityProviderException("Could not retrieve role token: " + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public String getAccessToken(String domain) {
+ try {
+ return domainSpecificAccessTokenCache.get(new AthenzDomain(domain)).value();
+ } catch (Exception e) {
+ throw new AthenzIdentityProviderException("Could not retrieve access token: " + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public String getAccessToken(String domain, List<String> roles) {
+ try {
+ List<AthenzRole> roleList = roles.stream()
+ .map(roleName -> new AthenzRole(domain, roleName))
+ .toList();
+ return roleSpecificAccessTokenCache.get(roleList).value();
+ } catch (Exception e) {
+ throw new AthenzIdentityProviderException("Could not retrieve access token: " + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public PrivateKey getPrivateKey() {
+ return credentials.getKeyPair().getPrivate();
+ }
+
+ @Override
+ public Path trustStorePath() {
+ return trustStore;
+ }
+
+ @Override
+ public List<X509Certificate> getIdentityCertificate() {
+ return Collections.singletonList(credentials.getCertificate());
+ }
+
+ @Override
+ public X509Certificate getRoleCertificate(String domain, String role) {
+ return getRoleCertificate(new AthenzRole(new AthenzDomain(domain), role));
+ }
+
+ private X509Certificate getRoleCertificate(AthenzRole athenzRole) {
+ try {
+ return roleSslCertCache.get(athenzRole);
+ } catch (Exception e) {
+ throw new AthenzIdentityProviderException("Could not retrieve role certificate: " + e.getMessage(), e);
+ }
+ }
+
+ private void updateIdentityCredentials(AthenzCredentials credentials) {
+ this.credentials = credentials;
+ this.identityKeyManager.updateKeystore(
+ KeyStoreBuilder.withType(PKCS12)
+ .withKeyEntry("default", credentials.getKeyPair().getPrivate(), credentials.getCertificate())
+ .build(),
+ new char[0]);
+ }
+
+ private X509Certificate requestRoleCertificate(AthenzRole role) {
+ var doc = credentials.getIdentityDocument().identityDocument();
+ Pkcs10Csr csr = csrGenerator.generateRoleCsr(
+ identity, role, doc.providerUniqueId(), doc.clusterType(), credentials.getKeyPair());
+ try (ZtsClient client = createZtsClient()) {
+ X509Certificate roleCertificate = client.getRoleCertificate(role, csr);
+ updateRoleKeyManager(role, roleCertificate);
+ log.info(String.format("Requester role certificate for role %s, expires: %s", role.toResourceNameString(), roleCertificate.getNotAfter().toInstant().toString()));
+ return roleCertificate;
+ }
+ }
+
+ private void updateRoleKeyManager(AthenzRole role, X509Certificate certificate) {
+ MutableX509KeyManager keyManager = roleKeyManagerCache.computeIfAbsent(role, r -> new MutableX509KeyManager());
+ keyManager.updateKeystore(
+ KeyStoreBuilder.withType(PKCS12)
+ .withKeyEntry("default", credentials.getKeyPair().getPrivate(), certificate)
+ .build(),
+ new char[0]);
+ }
+
+ private ZToken createRoleToken(AthenzRole athenzRole) {
+ try (ZtsClient client = createZtsClient()) {
+ return client.getRoleToken(athenzRole, ROLE_TOKEN_EXPIRY);
+ }
+ }
+
+ private ZToken createRoleToken(AthenzDomain domain) {
+ try (ZtsClient client = createZtsClient()) {
+ return client.getRoleToken(domain, ROLE_TOKEN_EXPIRY);
+ }
+ }
+
+ private AthenzAccessToken createAccessToken(AthenzDomain domain) {
+ try (ZtsClient client = createZtsClient()) {
+ return client.getAccessToken(domain);
+ }
+ }
+
+ private AthenzAccessToken createAccessToken(List<AthenzRole> roles) {
+ try (ZtsClient client = createZtsClient()) {
+ return client.getAccessToken(roles);
+ }
+ }
+
+ private DefaultZtsClient createZtsClient() {
+ return new DefaultZtsClient.Builder(ztsEndpoint).withSslContext(getIdentitySslContext()).build();
+ }
+
+ @Override
+ public void deconstruct() {
+ try {
+ scheduler.shutdownNow();
+ scheduler.awaitTermination(AWAIT_TERMINTATION_TIMEOUT.getSeconds(), TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private static SiaIdentityProvider createNodeIdentityProvider(IdentityConfig config) {
+ return new SiaIdentityProvider(
+ new AthenzService(config.nodeIdentityName()), SiaUtils.DEFAULT_SIA_DIRECTORY, CLIENT_TRUST_STORE);
+ }
+
+ private boolean isExpired(AthenzCredentials credentials) {
+ return clock.instant().isAfter(getExpirationTime(credentials));
+ }
+
+ private static Instant getExpirationTime(AthenzCredentials credentials) {
+ return credentials.getCertificate().getNotAfter().toInstant();
+ }
+
+ void refreshCertificate() {
+ try {
+ updateIdentityCredentials(isExpired(credentials)
+ ? athenzCredentialsService.registerInstance()
+ : athenzCredentialsService.updateCredentials(credentials.getIdentityDocument(), identitySslContext));
+ } catch (Throwable t) {
+ log.log(Level.WARNING, "Failed to update credentials: " + t.getMessage(), t);
+ }
+ }
+
+ void reportMetrics() {
+ try {
+ Instant expirationTime = getExpirationTime(credentials);
+ Duration remainingLifetime = Duration.between(clock.instant(), expirationTime);
+ metric.set(CERTIFICATE_EXPIRY_METRIC_NAME, remainingLifetime.getSeconds(), null);
+ } catch (Throwable t) {
+ log.log(Level.WARNING, "Failed to update metrics: " + t.getMessage(), t);
+ }
+ }
+}
+
diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java
index 2a68f6fd231..513fb4cdbd3 100644
--- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/api/EntityBindingsMapperTest.java
@@ -5,8 +5,11 @@ package com.yahoo.vespa.athenz.identityprovider.api;
import org.junit.jupiter.api.Test;
import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertTrue;
/**
@@ -15,7 +18,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
class EntityBindingsMapperTest {
@Test
- public void persists_unknown_json_members() throws IOException {
+ public void legacy_persists_unknown_json_members() throws IOException {
var originalJson =
"""
{
@@ -36,7 +39,8 @@ class EntityBindingsMapperTest {
}
""";
var entity = EntityBindingsMapper.fromString(originalJson);
- assertEquals(2, entity.unknownAttributes().size(), entity.unknownAttributes().toString());
+ assertInstanceOf(LegacySignedIdentityDocument.class, entity);
+ assertEquals(2, entity.identityDocument().unknownAttributes().size(), entity.identityDocument().unknownAttributes().toString());
var json = EntityBindingsMapper.toAttestationData(entity);
var expectedMemberInJson = "member-in-unknown-object";
@@ -45,4 +49,39 @@ class EntityBindingsMapperTest {
assertEquals(EntityBindingsMapper.mapper.readTree(originalJson), EntityBindingsMapper.mapper.readTree(json));
}
+ @Test
+ public void reads_unknown_json_members() throws IOException {
+ var iddoc = """
+ {
+ "provider-unique-id": "0.cluster.instance.app.tenant.us-west-1.test.node",
+ "provider-service": "domain.service",
+ "configserver-hostname": "cfg",
+ "instance-hostname": "host",
+ "created-at": 12345.0,
+ "ip-addresses": [],
+ "identity-type": "node",
+ "cluster-type": "admin",
+ "zts-url": "https://zts.url/",
+ "unknown-string": "string-value",
+ "unknown-object": { "member-in-unknown-object": 123 }
+ }
+ """;
+ var originalJson =
+ """
+ {
+ "signature": "sig",
+ "signing-key-version": 0,
+ "document-version": 4,
+ "data": "%s"
+ }
+ """.formatted(Base64.getEncoder().encodeToString(iddoc.getBytes(StandardCharsets.UTF_8)));
+ var entity = EntityBindingsMapper.fromString(originalJson);
+ assertEquals(2, entity.identityDocument().unknownAttributes().size(), entity.identityDocument().unknownAttributes().toString());
+ var json = EntityBindingsMapper.toAttestationData(entity);
+
+ // For the new iddoc format the identity document should be unchanged during serialization/deserialization,
+ // i.e the signed identity document should be unchanged
+ assertEquals(EntityBindingsMapper.mapper.readTree(originalJson), EntityBindingsMapper.mapper.readTree(json));
+ }
+
} \ No newline at end of file
diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java
index c9d2ea581bb..108da9e0136 100644
--- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/AthenzIdentityProviderImplTest.java
@@ -2,8 +2,8 @@
package com.yahoo.vespa.athenz.identityprovider.client;
import com.yahoo.container.core.identity.IdentityConfig;
-import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
import com.yahoo.jdisc.Metric;
+import com.yahoo.security.AutoReloadingX509KeyManager;
import com.yahoo.security.KeyAlgorithm;
import com.yahoo.security.KeyStoreBuilder;
import com.yahoo.security.KeyStoreType;
@@ -13,13 +13,13 @@ import com.yahoo.security.Pkcs10Csr;
import com.yahoo.security.Pkcs10CsrBuilder;
import com.yahoo.security.SignatureAlgorithm;
import com.yahoo.security.X509CertificateBuilder;
+import com.yahoo.security.X509CertificateWithKey;
import com.yahoo.test.ManualClock;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
import javax.security.auth.x500.X500Principal;
-
import java.io.File;
import java.io.IOException;
import java.math.BigInteger;
@@ -33,17 +33,12 @@ import java.util.Date;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Supplier;
-import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
-/**
- * @author mortent
- * @author bjorncs
- */
public class AthenzIdentityProviderImplTest {
@TempDir
@@ -85,58 +80,25 @@ public class AthenzIdentityProviderImplTest {
}
@Test
- void component_creation_fails_when_credentials_not_found() {
- assertThrows(AthenzIdentityProviderException.class, () -> {
- AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class);
- when(credentialService.registerInstance())
- .thenThrow(new RuntimeException("athenz unavailable"));
-
- new AthenzIdentityProviderImpl(IDENTITY_CONFIG, mock(Metric.class), trustStoreFile, credentialService, mock(ScheduledExecutorService.class), new ManualClock(Instant.EPOCH));
- });
- }
-
- @Test
- void metrics_updated_on_refresh() {
+ void certificate_expiry_metric_is_reported() {
ManualClock clock = new ManualClock(Instant.EPOCH);
Metric metric = mock(Metric.class);
-
- AthenzCredentialsService athenzCredentialsService = mock(AthenzCredentialsService.class);
-
+ AutoReloadingX509KeyManager keyManager = mock(AutoReloadingX509KeyManager.class);
KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
X509Certificate certificate = getCertificate(keyPair, getExpirationSupplier(clock));
+ when(keyManager.getCurrentCertificateWithKey()).thenReturn(new X509CertificateWithKey(certificate, keyPair.getPrivate()));
- when(athenzCredentialsService.registerInstance())
- .thenReturn(new AthenzCredentials(certificate, keyPair, null));
-
- when(athenzCredentialsService.updateCredentials(any(), any()))
- .thenThrow(new RuntimeException("#1"))
- .thenThrow(new RuntimeException("#2"))
- .thenReturn(new AthenzCredentials(certificate, keyPair, null));
-
- AthenzIdentityProviderImpl identityProvider =
- new AthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, trustStoreFile, athenzCredentialsService, mock(ScheduledExecutorService.class), clock);
-
+ AthenzIdentityProviderImpl identityProvider = new AthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, trustStoreFile, mock(ScheduledExecutorService.class), clock, keyManager);
identityProvider.reportMetrics();
verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any());
- // Advance 1 day, refresh fails, cert is 1 day old
clock.advance(Duration.ofDays(1));
- identityProvider.refreshCertificate();
identityProvider.reportMetrics();
verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(1)).getSeconds()), any());
- // Advance 1 more day, refresh fails, cert is 2 days old
clock.advance(Duration.ofDays(1));
- identityProvider.refreshCertificate();
identityProvider.reportMetrics();
verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(2)).getSeconds()), any());
-
- // Advance 1 more day, refresh succeds, cert is new
- clock.advance(Duration.ofDays(1));
- identityProvider.refreshCertificate();
- identityProvider.reportMetrics();
- verify(metric).set(eq(AthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any());
-
}
private Supplier<Date> getExpirationSupplier(ManualClock clock) {
diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java
index ff85cb79f02..acb0905700f 100644
--- a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/IdentityDocumentSignerTest.java
@@ -6,10 +6,13 @@ import com.yahoo.security.KeyUtils;
import com.yahoo.vespa.athenz.api.AthenzIdentity;
import com.yahoo.vespa.athenz.api.AthenzService;
import com.yahoo.vespa.athenz.identityprovider.api.ClusterType;
+import com.yahoo.vespa.athenz.identityprovider.api.DefaultSignedIdentityDocument;
+import com.yahoo.vespa.athenz.identityprovider.api.EntityBindingsMapper;
+import com.yahoo.vespa.athenz.identityprovider.api.IdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.IdentityType;
+import com.yahoo.vespa.athenz.identityprovider.api.LegacySignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument;
import com.yahoo.vespa.athenz.identityprovider.api.VespaUniqueInstanceId;
-import com.yahoo.vespa.athenz.utils.AthenzIdentities;
import org.junit.jupiter.api.Test;
import java.security.KeyPair;
@@ -18,6 +21,7 @@ import java.util.Arrays;
import java.util.HashSet;
import static com.yahoo.vespa.athenz.identityprovider.api.IdentityType.TENANT;
+import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.LEGACY_DEFAULT_DOCUMENT_VERSION;
import static com.yahoo.vespa.athenz.identityprovider.api.SignedIdentityDocument.DEFAULT_DOCUMENT_VERSION;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -42,32 +46,53 @@ public class IdentityDocumentSignerTest {
private static final AthenzIdentity serviceIdentity = new AthenzService("vespa", "node");
@Test
- void generates_and_validates_signature() {
+ void legacy_generates_and_validates_signature() {
IdentityDocumentSigner signer = new IdentityDocumentSigner();
+ IdentityDocument identityDocument = new IdentityDocument(
+ id, providerService, configserverHostname,
+ instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity);
String signature =
- signer.generateSignature(id, providerService, configserverHostname, instanceHostname, createdAt,
- ipAddresses, identityType, keyPair.getPrivate(), serviceIdentity);
+ signer.generateLegacySignature(identityDocument, keyPair.getPrivate());
- SignedIdentityDocument signedIdentityDocument = new SignedIdentityDocument(
- signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname,
- instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity);
+ SignedIdentityDocument signedIdentityDocument = new LegacySignedIdentityDocument(
+ signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, identityDocument);
assertTrue(signer.hasValidSignature(signedIdentityDocument, keyPair.getPublic()));
}
@Test
- void ignores_cluster_type_and_zts_url() {
+ void generates_and_validates_signature() {
IdentityDocumentSigner signer = new IdentityDocumentSigner();
+ IdentityDocument identityDocument = new IdentityDocument(
+ id, providerService, configserverHostname,
+ instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity);
+ String data = EntityBindingsMapper.toIdentityDocmentData(identityDocument);
String signature =
- signer.generateSignature(id, providerService, configserverHostname, instanceHostname, createdAt,
- ipAddresses, identityType, keyPair.getPrivate(), serviceIdentity);
+ signer.generateSignature(data, keyPair.getPrivate());
- var docWithoutIgnoredFields = new SignedIdentityDocument(
- signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname,
- instanceHostname, createdAt, ipAddresses, identityType, null, null, serviceIdentity);
- var docWithIgnoredFields = new SignedIdentityDocument(
- signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname,
+ SignedIdentityDocument signedIdentityDocument = new DefaultSignedIdentityDocument(
+ signature, KEY_VERSION, DEFAULT_DOCUMENT_VERSION, data);
+
+ assertTrue(signer.hasValidSignature(signedIdentityDocument, keyPair.getPublic()));
+ }
+
+ @Test
+ void legacy_ignores_cluster_type_and_zts_url() {
+ IdentityDocumentSigner signer = new IdentityDocumentSigner();
+ IdentityDocument identityDocument = new IdentityDocument(
+ id, providerService, configserverHostname,
instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity);
+ IdentityDocument withoutIgnoredFields = new IdentityDocument(
+ id, providerService, configserverHostname,
+ instanceHostname, createdAt, ipAddresses, identityType, null, null, serviceIdentity);
+
+ String signature =
+ signer.generateLegacySignature(identityDocument, keyPair.getPrivate());
+
+ var docWithoutIgnoredFields = new LegacySignedIdentityDocument(
+ signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, withoutIgnoredFields);
+ var docWithIgnoredFields = new LegacySignedIdentityDocument(
+ signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, identityDocument);
assertTrue(signer.hasValidSignature(docWithoutIgnoredFields, keyPair.getPublic()));
assertEquals(docWithIgnoredFields.signature(), docWithoutIgnoredFields.signature());
@@ -76,16 +101,15 @@ public class IdentityDocumentSignerTest {
@Test
void validates_signature_for_new_and_old_versions() {
IdentityDocumentSigner signer = new IdentityDocumentSigner();
+ IdentityDocument identityDocument = new IdentityDocument(
+ id, providerService, configserverHostname,
+ instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity);
String signature =
- signer.generateSignature(id, providerService, configserverHostname, instanceHostname, createdAt,
- ipAddresses, identityType, keyPair.getPrivate(), serviceIdentity);
+ signer.generateLegacySignature(identityDocument, keyPair.getPrivate());
- SignedIdentityDocument signedIdentityDocument = new SignedIdentityDocument(
- signature, KEY_VERSION, id, providerService, DEFAULT_DOCUMENT_VERSION, configserverHostname,
- instanceHostname, createdAt, ipAddresses, identityType, clusterType, ztsUrl, serviceIdentity);
+ SignedIdentityDocument signedIdentityDocument = new LegacySignedIdentityDocument(
+ signature, KEY_VERSION, LEGACY_DEFAULT_DOCUMENT_VERSION, identityDocument);
assertTrue(signer.hasValidSignature(signedIdentityDocument, keyPair.getPublic()));
-
}
-
} \ No newline at end of file
diff --git a/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java
new file mode 100644
index 00000000000..75dc42cd4a6
--- /dev/null
+++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/client/LegacyAthenzIdentityProviderImplTest.java
@@ -0,0 +1,160 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.athenz.identityprovider.client;
+
+import com.yahoo.container.core.identity.IdentityConfig;
+import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException;
+import com.yahoo.jdisc.Metric;
+import com.yahoo.security.KeyAlgorithm;
+import com.yahoo.security.KeyStoreBuilder;
+import com.yahoo.security.KeyStoreType;
+import com.yahoo.security.KeyStoreUtils;
+import com.yahoo.security.KeyUtils;
+import com.yahoo.security.Pkcs10Csr;
+import com.yahoo.security.Pkcs10CsrBuilder;
+import com.yahoo.security.SignatureAlgorithm;
+import com.yahoo.security.X509CertificateBuilder;
+import com.yahoo.test.ManualClock;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import javax.security.auth.x500.X500Principal;
+
+import java.io.File;
+import java.io.IOException;
+import java.math.BigInteger;
+import java.nio.file.Path;
+import java.security.KeyPair;
+import java.security.cert.X509Certificate;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Date;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.function.Supplier;
+
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author mortent
+ * @author bjorncs
+ */
+public class LegacyAthenzIdentityProviderImplTest {
+
+ @TempDir
+ public File tempDir;
+
+ public static final Duration certificateValidity = Duration.ofDays(30);
+
+ private static final IdentityConfig IDENTITY_CONFIG =
+ new IdentityConfig(new IdentityConfig.Builder()
+ .service("tenantService")
+ .domain("tenantDomain")
+ .nodeIdentityName("vespa.tenant")
+ .configserverIdentityName("vespa.configserver")
+ .loadBalancerAddress("cfg")
+ .ztsUrl("https:localhost:4443/zts/v1")
+ .athenzDnsSuffix("dev-us-north-1.vespa.cloud"));
+
+ private final KeyPair caKeypair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
+ private Path trustStoreFile;
+ private X509Certificate caCertificate;
+
+ @BeforeEach
+ public void createTrustStoreFile() throws IOException {
+ caCertificate = X509CertificateBuilder
+ .fromKeypair(
+ caKeypair,
+ new X500Principal("CN=mydummyca"),
+ Instant.EPOCH,
+ Instant.EPOCH.plus(10000, ChronoUnit.DAYS),
+ SignatureAlgorithm.SHA256_WITH_ECDSA,
+ BigInteger.ONE)
+ .build();
+ trustStoreFile = File.createTempFile("junit", null, tempDir).toPath();
+ KeyStoreUtils.writeKeyStoreToFile(
+ KeyStoreBuilder.withType(KeyStoreType.JKS)
+ .withKeyEntry("default", caKeypair.getPrivate(), caCertificate)
+ .build(),
+ trustStoreFile);
+ }
+
+ @Test
+ void component_creation_fails_when_credentials_not_found() {
+ assertThrows(AthenzIdentityProviderException.class, () -> {
+ AthenzCredentialsService credentialService = mock(AthenzCredentialsService.class);
+ when(credentialService.registerInstance())
+ .thenThrow(new RuntimeException("athenz unavailable"));
+
+ new LegacyAthenzIdentityProviderImpl(IDENTITY_CONFIG, mock(Metric.class), trustStoreFile, credentialService, mock(ScheduledExecutorService.class), new ManualClock(Instant.EPOCH));
+ });
+ }
+
+ @Test
+ void metrics_updated_on_refresh() {
+ ManualClock clock = new ManualClock(Instant.EPOCH);
+ Metric metric = mock(Metric.class);
+
+ AthenzCredentialsService athenzCredentialsService = mock(AthenzCredentialsService.class);
+
+ KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC);
+ X509Certificate certificate = getCertificate(keyPair, getExpirationSupplier(clock));
+
+ when(athenzCredentialsService.registerInstance())
+ .thenReturn(new AthenzCredentials(certificate, keyPair, null));
+
+ when(athenzCredentialsService.updateCredentials(any(), any()))
+ .thenThrow(new RuntimeException("#1"))
+ .thenThrow(new RuntimeException("#2"))
+ .thenReturn(new AthenzCredentials(certificate, keyPair, null));
+
+ LegacyAthenzIdentityProviderImpl identityProvider =
+ new LegacyAthenzIdentityProviderImpl(IDENTITY_CONFIG, metric, trustStoreFile, athenzCredentialsService, mock(ScheduledExecutorService.class), clock);
+
+ identityProvider.reportMetrics();
+ verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any());
+
+ // Advance 1 day, refresh fails, cert is 1 day old
+ clock.advance(Duration.ofDays(1));
+ identityProvider.refreshCertificate();
+ identityProvider.reportMetrics();
+ verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(1)).getSeconds()), any());
+
+ // Advance 1 more day, refresh fails, cert is 2 days old
+ clock.advance(Duration.ofDays(1));
+ identityProvider.refreshCertificate();
+ identityProvider.reportMetrics();
+ verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.minus(Duration.ofDays(2)).getSeconds()), any());
+
+ // Advance 1 more day, refresh succeds, cert is new
+ clock.advance(Duration.ofDays(1));
+ identityProvider.refreshCertificate();
+ identityProvider.reportMetrics();
+ verify(metric).set(eq(LegacyAthenzIdentityProviderImpl.CERTIFICATE_EXPIRY_METRIC_NAME), eq(certificateValidity.getSeconds()), any());
+
+ }
+
+ private Supplier<Date> getExpirationSupplier(ManualClock clock) {
+ return () -> new Date(clock.instant().plus(certificateValidity).toEpochMilli());
+ }
+
+ private X509Certificate getCertificate(KeyPair keyPair, Supplier<Date> expiry) {
+ Pkcs10Csr csr = Pkcs10CsrBuilder.fromKeypair(new X500Principal("CN=dummy"), keyPair, SignatureAlgorithm.SHA256_WITH_ECDSA)
+ .build();
+ return X509CertificateBuilder
+ .fromCsr(csr,
+ caCertificate.getSubjectX500Principal(),
+ Instant.EPOCH,
+ expiry.get().toInstant(),
+ caKeypair.getPrivate(),
+ SignatureAlgorithm.SHA256_WITH_ECDSA,
+ BigInteger.ONE)
+ .build();
+ }
+
+}
diff --git a/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java b/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java
index 3192bb4f225..8e7bf59cd0f 100644
--- a/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java
+++ b/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/ApacheCluster.java
@@ -86,7 +86,8 @@ class ApacheCluster implements Cluster {
SimpleHttpRequest request = new SimpleHttpRequest(wrapped.method(), wrapped.path());
request.setScheme(endpoint.url.getScheme());
request.setAuthority(new URIAuthority(endpoint.url.getHost(), portOf(endpoint.url)));
- request.setConfig(requestConfig);
+ long timeoutMillis = wrapped.timeout() == null ? 190_000 : wrapped.timeout().toMillis() * 11 / 10 + 1_000;
+ request.setConfig(RequestConfig.copy(requestConfig).setResponseTimeout(Timeout.ofMilliseconds(timeoutMillis)).build());
defaultHeaders.forEach(request::setHeader);
wrapped.headers().forEach((name, value) -> request.setHeader(name, value.get()));
if (wrapped.body() != null) {
@@ -104,11 +105,11 @@ class ApacheCluster implements Cluster {
@Override public void failed(Exception ex) { vessel.completeExceptionally(ex); }
@Override public void cancelled() { vessel.cancel(false); }
});
- long timeoutMillis = wrapped.timeout() == null ? 200_000 : wrapped.timeout().toMillis() * 11 / 10 + 1_000;
- Future<?> cancellation = timeoutExecutor.schedule(() -> {
- future.cancel(true);
- vessel.cancel(true);
- }, timeoutMillis, TimeUnit.MILLISECONDS);
+ // We've seen some requests time out, even with a response timeout,
+ // so we schedule this to be absolutely sure we don't hang (for ever).
+ Future<?> cancellation = timeoutExecutor.schedule(() -> { future.cancel(true); vessel.cancel(true); },
+ timeoutMillis + 10_000,
+ TimeUnit.MILLISECONDS);
vessel.whenComplete((__, ___) -> cancellation.cancel(true));
}
catch (Throwable thrown) {
@@ -196,8 +197,7 @@ class ApacheCluster implements Cluster {
private static RequestConfig createRequestConfig(FeedClientBuilderImpl b) {
RequestConfig.Builder builder = RequestConfig.custom()
.setConnectTimeout(Timeout.ofSeconds(10))
- .setConnectionRequestTimeout(Timeout.DISABLED)
- .setResponseTimeout(Timeout.ofSeconds(190));
+ .setConnectionRequestTimeout(Timeout.DISABLED);
if (b.proxy != null) builder.setProxy(new HttpHost(b.proxy.getScheme(), b.proxy.getHost(), b.proxy.getPort()));
return builder.build();
}