diff options
82 files changed, 1333 insertions, 1352 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index e9eb15592a2..b55b9322463 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -125,7 +125,6 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"baldersheim"}, removeAfter="7.last") default boolean skipCommunicationManagerThread() { return true; } @ModelFeatureFlag(owners = {"baldersheim"}, removeAfter="7.last") default boolean skipMbusRequestThread() { return true; } @ModelFeatureFlag(owners = {"baldersheim"}, removeAfter="7.last") default boolean skipMbusReplyThread() { return true; } - @ModelFeatureFlag(owners = {"arnej"}, removeAfter="7.last") default boolean useQrserverServiceName() { return true; } @ModelFeatureFlag(owners = {"arnej"}, removeAfter="7.last") default boolean avoidRenamingSummaryFeatures() { return false; } } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index 2b55b1f1d10..66a23c79fbb 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -88,6 +88,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private boolean allowMoreThanOneContentGroupDown = false; private boolean enableConditionalPutRemoveWriteRepair = false; private List<DataplaneToken> dataplaneTokens; + private boolean enableDataplaneProxy; @Override public ModelContext.FeatureFlags featureFlags() { return this; } @Override public boolean multitenant() { return multitenant; } @@ -148,6 +149,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public boolean allowMoreThanOneContentGroupDown(ClusterSpec.Id id) { return allowMoreThanOneContentGroupDown; } @Override public boolean enableConditionalPutRemoveWriteRepair() { return enableConditionalPutRemoveWriteRepair; } @Override public List<DataplaneToken> dataplaneTokens() { return dataplaneTokens; } + @Override public boolean enableDataplaneProxy() { return enableDataplaneProxy; } public TestProperties sharedStringRepoNoReclaim(boolean sharedStringRepoNoReclaim) { this.sharedStringRepoNoReclaim = sharedStringRepoNoReclaim; @@ -393,6 +395,11 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea return this; } + public TestProperties setEnableDataplaneProxy(boolean enable) { + this.enableDataplaneProxy = enable; + return this; + } + public static class Spec implements ConfigServerSpec { private final String hostName; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomHandlerBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomHandlerBuilder.java index ed53a1d2267..9b5a1429cb7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomHandlerBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomHandlerBuilder.java @@ -15,7 +15,7 @@ import com.yahoo.vespa.model.container.component.UserBindingPattern; import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder; import org.w3c.dom.Element; -import java.util.OptionalInt; +import java.util.Collection; import java.util.Set; import static com.yahoo.vespa.model.container.ApplicationContainerCluster.METRICS_V2_HANDLER_BINDING_1; @@ -38,12 +38,9 @@ public class DomHandlerBuilder extends VespaDomBuilder.DomConfigProducerBuilderB VIP_HANDLER_BINDING); private final ApplicationContainerCluster cluster; - private final OptionalInt portBindingOverride; + private final Set<Integer> portBindingOverride; - public DomHandlerBuilder(ApplicationContainerCluster cluster) { - this(cluster, OptionalInt.empty()); - } - public DomHandlerBuilder(ApplicationContainerCluster cluster, OptionalInt portBindingOverride) { + public DomHandlerBuilder(ApplicationContainerCluster cluster, Set<Integer> portBindingOverride) { this.cluster = cluster; this.portBindingOverride = portBindingOverride; } @@ -51,23 +48,24 @@ public class DomHandlerBuilder extends VespaDomBuilder.DomConfigProducerBuilderB @Override protected Handler doBuild(DeployState deployState, TreeConfigProducer<AnyConfigProducer> parent, Element handlerElement) { Handler handler = createHandler(handlerElement); - OptionalInt port = portBindingOverride.isPresent() && deployState.isHosted() && deployState.featureFlags().useRestrictedDataPlaneBindings() - ? portBindingOverride - : OptionalInt.empty(); + var ports = deployState.isHosted() && deployState.featureFlags().useRestrictedDataPlaneBindings() + ? portBindingOverride : Set.<Integer>of(); - for (Element binding : XML.getChildren(handlerElement, "binding")) - addServerBinding(handler, userBindingPattern(XML.getValue(binding), port), deployState.getDeployLogger()); + for (Element xmlBinding : XML.getChildren(handlerElement, "binding")) + for (var binding : userBindingPattern(XML.getValue(xmlBinding), ports)) + addServerBinding(handler, binding, deployState.getDeployLogger()); DomComponentBuilder.addChildren(deployState, parent, handlerElement, handler); return handler; } - private static UserBindingPattern userBindingPattern(String path, OptionalInt port) { + private static Collection<UserBindingPattern> userBindingPattern(String path, Set<Integer> portBindingOverride) { UserBindingPattern bindingPattern = UserBindingPattern.fromPattern(path); - return port.isPresent() - ? bindingPattern.withPort(port.getAsInt()) - : bindingPattern; + if (portBindingOverride.isEmpty()) return Set.of(bindingPattern); + return portBindingOverride.stream() + .map(bindingPattern::withPort) + .toList(); } Handler createHandler(Element handlerElement) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java b/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java index 8163c268d09..a5a567b18f8 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/clients/ContainerDocumentApi.java @@ -14,7 +14,8 @@ import com.yahoo.vespa.model.container.component.UserBindingPattern; import java.nio.file.Path; import java.util.Collection; import java.util.Collections; -import java.util.OptionalInt; +import java.util.List; +import java.util.Set; /** * @author Einar M R Rosenvinge @@ -28,7 +29,7 @@ public class ContainerDocumentApi { private final boolean ignoreUndefinedFields; - public ContainerDocumentApi(ContainerCluster<?> cluster, HandlerOptions handlerOptions, boolean ignoreUndefinedFields, OptionalInt portOverride) { + public ContainerDocumentApi(ContainerCluster<?> cluster, HandlerOptions handlerOptions, boolean ignoreUndefinedFields, Set<Integer> portOverride) { this.ignoreUndefinedFields = ignoreUndefinedFields; addRestApiHandler(cluster, handlerOptions, portOverride); addFeedHandler(cluster, handlerOptions, portOverride); @@ -39,7 +40,7 @@ public class ContainerDocumentApi { c.addPlatformBundle(VESPACLIENT_CONTAINER_BUNDLE); } - private static void addFeedHandler(ContainerCluster<?> cluster, HandlerOptions handlerOptions, OptionalInt portOverride) { + private static void addFeedHandler(ContainerCluster<?> cluster, HandlerOptions handlerOptions, Set<Integer> portOverride) { String bindingSuffix = ContainerCluster.RESERVED_URI_PREFIX + "/feedapi"; var executor = new Threadpool("feedapi-handler", handlerOptions.feedApiThreadpoolOptions); var handler = newVespaClientHandler("com.yahoo.vespa.http.server.FeedHandler", @@ -48,7 +49,7 @@ public class ContainerDocumentApi { } - private static void addRestApiHandler(ContainerCluster<?> cluster, HandlerOptions handlerOptions, OptionalInt portOverride) { + private static void addRestApiHandler(ContainerCluster<?> cluster, HandlerOptions handlerOptions, Set<Integer> portOverride) { var handler = newVespaClientHandler("com.yahoo.document.restapi.resource.DocumentV1ApiHandler", DOCUMENT_V1_PREFIX + "/*", handlerOptions, null, portOverride); cluster.addComponent(handler); @@ -65,34 +66,34 @@ public class ContainerDocumentApi { String bindingSuffix, HandlerOptions handlerOptions, Threadpool executor, - OptionalInt portOverride) { + Set<Integer> portOverride) { Handler handler = createHandler(componentId, executor); if (handlerOptions.bindings.isEmpty()) { - handler.addServerBindings( - bindingPattern(bindingSuffix, portOverride), - bindingPattern(bindingSuffix + '/', portOverride)); + handler.addServerBindings(bindingPattern(bindingSuffix, portOverride)); + handler.addServerBindings(bindingPattern(bindingSuffix + '/', portOverride)); } else { for (String rootBinding : handlerOptions.bindings) { String pathWithoutLeadingSlash = bindingSuffix.substring(1); - handler.addServerBindings( - userBindingPattern(rootBinding + pathWithoutLeadingSlash, portOverride), - userBindingPattern(rootBinding + pathWithoutLeadingSlash + '/', portOverride)); + handler.addServerBindings(userBindingPattern(rootBinding + pathWithoutLeadingSlash, portOverride)); + handler.addServerBindings(userBindingPattern(rootBinding + pathWithoutLeadingSlash + '/', portOverride)); } } return handler; } - private static BindingPattern bindingPattern(String path, OptionalInt port) { - return port.isPresent() - ? SystemBindingPattern.fromHttpPortAndPath(Integer.toString(port.getAsInt()), path) - : SystemBindingPattern.fromHttpPath(path); + private static List<BindingPattern> bindingPattern(String path, Set<Integer> ports) { + if (ports.isEmpty()) return List.of(SystemBindingPattern.fromHttpPath(path)); + return ports.stream() + .map(p -> (BindingPattern)SystemBindingPattern.fromHttpPortAndPath(p, path)) + .toList(); } - private static UserBindingPattern userBindingPattern(String path, OptionalInt port) { + private static List<BindingPattern> userBindingPattern(String path, Set<Integer> ports) { UserBindingPattern bindingPattern = UserBindingPattern.fromPattern(path); - return port.isPresent() - ? bindingPattern.withPort(port.getAsInt()) - : bindingPattern; + if (ports.isEmpty()) return List.of(bindingPattern); + return ports.stream() + .map(p -> (BindingPattern)bindingPattern.withPort(p)) + .toList(); } private static Handler createHandler(String className, Threadpool executor) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/DataplaneProxy.java b/config-model/src/main/java/com/yahoo/vespa/model/container/DataplaneProxy.java index fe7d9581e46..13aa65909bd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/DataplaneProxy.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/DataplaneProxy.java @@ -7,20 +7,23 @@ import com.yahoo.vespa.model.container.component.SimpleComponent; public class DataplaneProxy extends SimpleComponent implements DataplaneProxyConfig.Producer { - private final Integer port; + private final int mtlsPort; + private final int tokenPort; private final String serverCertificate; private final String serverKey; - public DataplaneProxy(Integer port, String serverCertificate, String serverKey) { + public DataplaneProxy(int mtlsPort, int tokenPort, String serverCertificate, String serverKey) { super(DataplaneProxyConfigurator.class.getName()); - this.port = port; + this.mtlsPort = mtlsPort; + this.tokenPort = tokenPort; this.serverCertificate = serverCertificate; this.serverKey = serverKey; } @Override public void getConfig(DataplaneProxyConfig.Builder builder) { - builder.port(port); + builder.mtlsPort(mtlsPort); + builder.tokenPort(tokenPort); builder.serverCertificate(serverCertificate); builder.serverKey(serverKey); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java index 9f2bfe9251b..31031aa5bf2 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/Handler.java @@ -7,6 +7,7 @@ import com.yahoo.vespa.model.container.ContainerThreadpool; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.LinkedHashSet; import java.util.List; @@ -51,6 +52,8 @@ public class Handler extends Component<Component<?, ?>, ComponentModel> { serverBindings.addAll(Arrays.asList(bindings)); } + public void addServerBindings(Collection<BindingPattern> bps) { serverBindings.addAll(bps); } + public void removeServerBinding(BindingPattern binding) { serverBindings.remove(binding); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SystemBindingPattern.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SystemBindingPattern.java index 606557670a5..0fb3ec389e0 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SystemBindingPattern.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SystemBindingPattern.java @@ -15,6 +15,7 @@ public class SystemBindingPattern extends BindingPattern { public static SystemBindingPattern fromPattern(String binding) { return new SystemBindingPattern(binding);} public static SystemBindingPattern fromHttpPortAndPath(String port, String path) { return new SystemBindingPattern("http", "*", port, path); } public static SystemBindingPattern fromHttpPortAndPath(int port, String path) { return new SystemBindingPattern("http", "*", Integer.toString(port), path); } + public SystemBindingPattern withPort(int port) { return new SystemBindingPattern(scheme(), host(), Integer.toString(port), path()); } @Override public String toString() { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java index 697cfc95039..4929c09d561 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ConnectorFactory.java @@ -8,6 +8,7 @@ import com.yahoo.vespa.model.container.component.SimpleComponent; import com.yahoo.vespa.model.container.http.ssl.DefaultSslProvider; import com.yahoo.vespa.model.container.http.ssl.SslProvider; +import java.util.List; import java.util.Optional; /** @@ -40,6 +41,9 @@ public class ConnectorFactory extends SimpleComponent implements ConnectorConfig public void getConfig(ConnectorConfig.Builder connectorBuilder) { connectorBuilder.listenPort(listenPort); connectorBuilder.name(name); + connectorBuilder.accessLog(new ConnectorConfig.AccessLog.Builder() + .remoteAddressHeaders(List.of("x-forwarded-for")) + .remotePortHeaders(List.of("X-Forwarded-Port"))); sslProviderComponent.amendConnectorConfig(connectorBuilder); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/JettyHttpServer.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/JettyHttpServer.java index 6a2d9685a33..0388230fa6a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/JettyHttpServer.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/JettyHttpServer.java @@ -63,17 +63,8 @@ public class JettyHttpServer extends SimpleComponent implements ServerConfig.Pro .searchHandlerPaths(List.of("/search")) ); if (isHostedVespa) { - // Proxy-protocol v1/v2 is used in hosted Vespa for remote address/port - builder.accessLog(new ServerConfig.AccessLog.Builder() - .remoteAddressHeaders(List.of()) - .remotePortHeaders(List.of())); - // Enable connection log hosted Vespa builder.connectionLog(new ServerConfig.ConnectionLog.Builder().enabled(true)); - } else { - builder.accessLog(new ServerConfig.AccessLog.Builder() - .remoteAddressHeaders(List.of("x-forwarded-for")) - .remotePortHeaders(List.of("X-Forwarded-Port"))); } configureJettyThreadpool(builder); builder.stopTimeout(300); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CloudSslProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CloudSslProvider.java index 5fa893e9599..ab163719aac 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CloudSslProvider.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/CloudSslProvider.java @@ -2,8 +2,6 @@ package com.yahoo.vespa.model.container.http.ssl; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ssl.impl.CloudSslContextProvider; -import com.yahoo.jdisc.http.ssl.impl.ConfiguredSslContextFactoryProvider; import java.util.Optional; @@ -16,18 +14,15 @@ import static com.yahoo.jdisc.http.ConnectorConfig.Ssl.ClientAuth; * @author andreer */ public class CloudSslProvider extends SslProvider { - public static final String COMPONENT_ID_PREFIX = "configured-ssl-provider@"; - public static final String MTLSONLY_COMPONENT_CLASS = ConfiguredSslContextFactoryProvider.class.getName(); - public static final String TOKEN_COMPONENT_CLASS = CloudSslContextProvider.class.getName(); - private final String privateKey; private final String certificate; private final String caCertificatePath; private final String caCertificate; private final ClientAuth.Enum clientAuthentication; - public CloudSslProvider(String servername, String privateKey, String certificate, String caCertificatePath, String caCertificate, ClientAuth.Enum clientAuthentication, boolean enableTokenSupport) { - super(COMPONENT_ID_PREFIX, servername, componentClass(enableTokenSupport), null); + public CloudSslProvider(String servername, String privateKey, String certificate, String caCertificatePath, + String caCertificate, ClientAuth.Enum clientAuthentication, boolean enableTokenSupport) { + super("cloud-ssl-provider@", servername, componentClass(enableTokenSupport), null); this.privateKey = privateKey; this.certificate = certificate; this.caCertificatePath = caCertificatePath; @@ -36,7 +31,9 @@ public class CloudSslProvider extends SslProvider { } private static String componentClass(boolean enableTokenSupport) { - return enableTokenSupport ? TOKEN_COMPONENT_CLASS : MTLSONLY_COMPONENT_CLASS; + return enableTokenSupport + ? "com.yahoo.jdisc.http.ssl.impl.CloudTokenSslContextProvider" + : "com.yahoo.jdisc.http.ssl.impl.ConfiguredSslContextFactoryProvider"; } @Override diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/HostedSslConnectorFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/HostedSslConnectorFactory.java index 5bf348e5bb5..cebe08288f6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/HostedSslConnectorFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/http/ssl/HostedSslConnectorFactory.java @@ -3,11 +3,11 @@ package com.yahoo.vespa.model.container.http.ssl; import com.yahoo.config.model.api.EndpointCertificateSecrets; import com.yahoo.jdisc.http.ConnectorConfig; -import com.yahoo.jdisc.http.ConnectorConfig.Ssl.ClientAuth; import com.yahoo.security.tls.TlsContext; import com.yahoo.vespa.model.container.http.ConnectorFactory; import java.time.Duration; +import java.util.ArrayList; import java.util.Collection; import java.util.List; @@ -18,96 +18,90 @@ import java.util.List; */ public class HostedSslConnectorFactory extends ConnectorFactory { - private static final List<String> INSECURE_WHITELISTED_PATHS = List.of("/status.html"); - private static final String DEFAULT_HOSTED_TRUSTSTORE = "/opt/yahoo/share/ssl/certs/athenz_certificate_bundle.pem"; - - private final boolean enforceClientAuth; - private final boolean enforceHandshakeClientAuth; - private final Collection<String> tlsCiphersOverride; - private final boolean enableProxyProtocolMixedMode; + private final SslClientAuth clientAuth; + private final List<String> tlsCiphersOverride; + private final boolean proxyProtocolEnabled; + private final boolean proxyProtocolMixedMode; private final Duration endpointConnectionTtl; + private final List<String> remoteAddressHeaders; + private final List<String> remotePortHeaders; - /** - * Create connector factory that uses a certificate provided by the config-model / configserver and default hosted Vespa truststore. - */ - public static HostedSslConnectorFactory withProvidedCertificate( - String serverName, EndpointCertificateSecrets endpointCertificateSecrets, boolean enforceHandshakeClientAuth, - Collection<String> tlsCiphersOverride, boolean enableProxyProtocolMixedMode, int port, - Duration endpointConnectionTtl, boolean enableTokenSupport) { - CloudSslProvider sslProvider = createConfiguredDirectSslProvider( - serverName, endpointCertificateSecrets, DEFAULT_HOSTED_TRUSTSTORE, /*tlsCaCertificates*/null, enforceHandshakeClientAuth, enableTokenSupport); - return new HostedSslConnectorFactory(sslProvider, false, enforceHandshakeClientAuth, tlsCiphersOverride, - enableProxyProtocolMixedMode, port, endpointConnectionTtl); - } - - /** - * Create connector factory that uses a certificate provided by the config-model / configserver and a truststore configured by the application. - */ - public static HostedSslConnectorFactory withProvidedCertificateAndTruststore( - String serverName, EndpointCertificateSecrets endpointCertificateSecrets, String tlsCaCertificates, - Collection<String> tlsCiphersOverride, boolean enableProxyProtocolMixedMode, int port, - Duration endpointConnectionTtl, boolean enableTokenSupport) { - CloudSslProvider sslProvider = createConfiguredDirectSslProvider( - serverName, endpointCertificateSecrets, /*tlsCaCertificatesPath*/null, tlsCaCertificates, false, enableTokenSupport); - return new HostedSslConnectorFactory(sslProvider, true, false, tlsCiphersOverride, enableProxyProtocolMixedMode, - port, endpointConnectionTtl); - } + public static Builder builder(String name, int listenPort) { return new Builder(name, listenPort); } - /** - * Create connector factory that uses the default certificate and truststore provided by Vespa (through Vespa-global TLS configuration). - */ - public static HostedSslConnectorFactory withDefaultCertificateAndTruststore(String serverName, Collection<String> tlsCiphersOverride, - boolean enableProxyProtocolMixedMode, int port, - Duration endpointConnectionTtl) { - return new HostedSslConnectorFactory(new DefaultSslProvider(serverName), true, false, tlsCiphersOverride, - enableProxyProtocolMixedMode, port, endpointConnectionTtl); + private HostedSslConnectorFactory(Builder builder) { + super(new ConnectorFactory.Builder("tls"+builder.port, builder.port).sslProvider(createSslProvider(builder))); + this.clientAuth = builder.clientAuth; + this.tlsCiphersOverride = List.copyOf(builder.tlsCiphersOverride); + this.proxyProtocolEnabled = builder.proxyProtocolEnabled; + this.proxyProtocolMixedMode = builder.proxyProtocolMixedMode; + this.endpointConnectionTtl = builder.endpointConnectionTtl; + this.remoteAddressHeaders = List.copyOf(builder.remoteAddressHeaders); + this.remotePortHeaders = List.copyOf(builder.remotePortHeaders); } - private HostedSslConnectorFactory(SslProvider sslProvider, boolean enforceClientAuth, - boolean enforceHandshakeClientAuth, Collection<String> tlsCiphersOverride, - boolean enableProxyProtocolMixedMode, int port, Duration endpointConnectionTtl) { - super(new Builder("tls"+port, port).sslProvider(sslProvider)); - this.enforceClientAuth = enforceClientAuth; - this.enforceHandshakeClientAuth = enforceHandshakeClientAuth; - this.tlsCiphersOverride = tlsCiphersOverride; - this.enableProxyProtocolMixedMode = enableProxyProtocolMixedMode; - this.endpointConnectionTtl = endpointConnectionTtl; - } - - private static CloudSslProvider createConfiguredDirectSslProvider( - String serverName, EndpointCertificateSecrets endpointCertificateSecrets, String tlsCaCertificatesPath, String tlsCaCertificates, boolean enforceHandshakeClientAuth, boolean enableTokenSupport) { - var clientAuthentication = enforceHandshakeClientAuth ? ClientAuth.Enum.NEED_AUTH : ClientAuth.Enum.WANT_AUTH; + private static SslProvider createSslProvider(Builder builder) { + if (builder.endpointCertificate == null) return new DefaultSslProvider(builder.name); + var sslClientAuth = builder.clientAuth == SslClientAuth.NEED + ? ConnectorConfig.Ssl.ClientAuth.Enum.NEED_AUTH : ConnectorConfig.Ssl.ClientAuth.Enum.WANT_AUTH; return new CloudSslProvider( - serverName, - endpointCertificateSecrets.key(), - endpointCertificateSecrets.certificate(), - tlsCaCertificatesPath, - tlsCaCertificates, - clientAuthentication, - enableTokenSupport); + builder.name, builder.endpointCertificate.key(), builder.endpointCertificate.certificate(), + builder.tlsCaCertificatesPath, builder.tlsCaCertificatesPem, sslClientAuth, builder.tokenEndpoint); } @Override public void getConfig(ConnectorConfig.Builder connectorBuilder) { super.getConfig(connectorBuilder); - if (! enforceHandshakeClientAuth) { - connectorBuilder - .tlsClientAuthEnforcer(new ConnectorConfig.TlsClientAuthEnforcer.Builder() - .pathWhitelist(INSECURE_WHITELISTED_PATHS) - .enable(enforceClientAuth)); + if (clientAuth == SslClientAuth.WANT_WITH_ENFORCER) { + connectorBuilder.tlsClientAuthEnforcer( + new ConnectorConfig.TlsClientAuthEnforcer.Builder() + .pathWhitelist(List.of("/status.html")).enable(true)); } // Disables TLSv1.3 as it causes some browsers to prompt user for client certificate (when connector has 'want' auth) connectorBuilder.ssl.enabledProtocols(List.of("TLSv1.2")); - if (!tlsCiphersOverride.isEmpty()) { connectorBuilder.ssl.enabledCipherSuites(tlsCiphersOverride.stream().sorted().toList()); } else { connectorBuilder.ssl.enabledCipherSuites(TlsContext.ALLOWED_CIPHER_SUITES.stream().sorted().toList()); } - connectorBuilder - .proxyProtocol(new ConnectorConfig.ProxyProtocol.Builder().enabled(true).mixedMode(enableProxyProtocolMixedMode)) + .proxyProtocol(new ConnectorConfig.ProxyProtocol.Builder() + .enabled(proxyProtocolEnabled).mixedMode(proxyProtocolMixedMode)) .idleTimeout(Duration.ofSeconds(30).toSeconds()) - .maxConnectionLife(endpointConnectionTtl != null ? endpointConnectionTtl.toSeconds() : 0); + .maxConnectionLife(endpointConnectionTtl != null ? endpointConnectionTtl.toSeconds() : 0) + .accessLog(new ConnectorConfig.AccessLog.Builder() + .remoteAddressHeaders(remoteAddressHeaders) + .remotePortHeaders(remotePortHeaders)); + + } + + public enum SslClientAuth { WANT, NEED, WANT_WITH_ENFORCER } + public static class Builder { + final String name; + final int port; + final List<String> remoteAddressHeaders = new ArrayList<>(); + final List<String> remotePortHeaders = new ArrayList<>(); + SslClientAuth clientAuth; + List<String> tlsCiphersOverride = List.of(); + boolean proxyProtocolEnabled; + boolean proxyProtocolMixedMode; + Duration endpointConnectionTtl; + EndpointCertificateSecrets endpointCertificate; + String tlsCaCertificatesPem; + String tlsCaCertificatesPath; + boolean tokenEndpoint; + + private Builder(String name, int port) { this.name = name; this.port = port; } + public Builder clientAuth(SslClientAuth auth) { clientAuth = auth; return this; } + public Builder endpointConnectionTtl(Duration ttl) { endpointConnectionTtl = ttl; return this; } + public Builder tlsCiphersOverride(Collection<String> ciphers) { tlsCiphersOverride = List.copyOf(ciphers); return this; } + public Builder proxyProtocol(boolean enabled, boolean mixedMode) { proxyProtocolEnabled = enabled; proxyProtocolMixedMode = mixedMode; return this; } + public Builder endpointCertificate(EndpointCertificateSecrets cert) { this.endpointCertificate = cert; return this; } + public Builder tlsCaCertificatesPath(String path) { this.tlsCaCertificatesPath = path; return this; } + public Builder tlsCaCertificatesPem(String pem) { this.tlsCaCertificatesPem = pem; return this; } + public Builder tokenEndpoint(boolean enable) { this.tokenEndpoint = enable; return this; } + public Builder remoteAddressHeader(String header) { this.remoteAddressHeaders.add(header); return this; } + public Builder remotePortHeader(String header) { this.remotePortHeaders.add(header); return this; } + + public HostedSslConnectorFactory build() { return new HostedSslConnectorFactory(this); } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java index 330e1f96dc7..b05466d54ab 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/processing/ProcessingChains.java @@ -6,6 +6,8 @@ import com.yahoo.vespa.model.container.component.BindingPattern; import com.yahoo.vespa.model.container.component.SystemBindingPattern; import com.yahoo.vespa.model.container.component.chain.Chains; +import java.util.List; + /** * Root config producer for processing * @@ -13,7 +15,7 @@ import com.yahoo.vespa.model.container.component.chain.Chains; */ public class ProcessingChains extends Chains<ProcessingChain> { - public static final BindingPattern[] defaultBindings = new BindingPattern[]{SystemBindingPattern.fromHttpPath("/processing/*")}; + public static final List<BindingPattern> defaultBindings = List.of(SystemBindingPattern.fromHttpPath("/processing/*")); public ProcessingChains(TreeConfigProducer<? super Chains> parent, String subId) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java index efa5ee01506..2d3d76e9d0e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilter.java @@ -5,7 +5,6 @@ import com.yahoo.component.ComponentSpecification; import com.yahoo.component.chain.dependencies.Dependencies; import com.yahoo.component.chain.model.ChainedComponentModel; import com.yahoo.config.model.deploy.DeployState; -import com.yahoo.config.provision.DataplaneToken; import com.yahoo.container.bundle.BundleInstantiationSpecification; import com.yahoo.jdisc.http.filter.security.cloud.config.CloudDataPlaneFilterConfig; import com.yahoo.security.X509CertificateUtils; @@ -13,7 +12,6 @@ import com.yahoo.vespa.model.container.ApplicationContainerCluster; import com.yahoo.vespa.model.container.http.Client; import com.yahoo.vespa.model.container.http.Filter; -import java.time.Instant; import java.util.Collection; import java.util.List; @@ -24,15 +22,11 @@ class CloudDataPlaneFilter extends Filter implements CloudDataPlaneFilterConfig. private final Collection<Client> clients; private final boolean clientsLegacyMode; - private final String tokenContext; CloudDataPlaneFilter(ApplicationContainerCluster cluster, DeployState state) { super(model()); this.clients = List.copyOf(cluster.getClients()); this.clientsLegacyMode = cluster.clientsLegacyMode(); - // Token domain must be identical to the domain used for generating the tokens - this.tokenContext = "Vespa Cloud tenant data plane:%s" - .formatted(state.getProperties().applicationId().tenant().value()); } private static ChainedComponentModel model() { @@ -48,24 +42,15 @@ class CloudDataPlaneFilter extends Filter implements CloudDataPlaneFilterConfig. builder.legacyMode(true); } else { var clientsCfg = clients.stream() + .filter(c -> !c.certificates().isEmpty()) .map(x -> new CloudDataPlaneFilterConfig.Clients.Builder() .id(x.id()) .certificates(x.certificates().stream().map(X509CertificateUtils::toPem).toList()) - .tokens(tokensConfig(x.tokens())) .permissions(x.permissions())) .toList(); - builder.clients(clientsCfg).legacyMode(false).tokenContext(tokenContext); + builder.clients(clientsCfg).legacyMode(false); } } - private static List<CloudDataPlaneFilterConfig.Clients.Tokens.Builder> tokensConfig(Collection<DataplaneToken> tokens) { - return tokens.stream() - .map(token -> new CloudDataPlaneFilterConfig.Clients.Tokens.Builder() - .id(token.tokenId()) - .fingerprints(token.versions().stream().map(DataplaneToken.Version::fingerprint).toList()) - .checkAccessHashes(token.versions().stream().map(DataplaneToken.Version::checkAccessHash).toList()) - .expirations(token.versions().stream().map(v -> v.expiration().map(Instant::toString).orElse("<none>")).toList())) - .toList(); - } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java new file mode 100644 index 00000000000..a6f6d0a36ba --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilter.java @@ -0,0 +1,62 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.xml; + +import com.yahoo.component.ComponentSpecification; +import com.yahoo.component.chain.dependencies.Dependencies; +import com.yahoo.component.chain.model.ChainedComponentModel; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.provision.DataplaneToken; +import com.yahoo.container.bundle.BundleInstantiationSpecification; +import com.yahoo.jdisc.http.filter.security.cloud.config.CloudTokenDataPlaneFilterConfig; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; +import com.yahoo.vespa.model.container.http.Client; +import com.yahoo.vespa.model.container.http.Filter; + +import java.time.Instant; +import java.util.Collection; +import java.util.List; + +class CloudTokenDataPlaneFilter extends Filter implements CloudTokenDataPlaneFilterConfig.Producer { + private final Collection<Client> clients; + private final String tokenContext; + + CloudTokenDataPlaneFilter(ApplicationContainerCluster cluster, DeployState state) { + super(model()); + this.clients = List.copyOf(cluster.getClients()); + // Token domain must be identical to the domain used for generating the tokens + this.tokenContext = "Vespa Cloud tenant data plane:%s" + .formatted(state.getProperties().applicationId().tenant().value()); + } + + private static ChainedComponentModel model() { + return new ChainedComponentModel( + new BundleInstantiationSpecification( + new ComponentSpecification("com.yahoo.jdisc.http.filter.security.cloud.CloudTokenDataPlaneFilter"), + null, + new ComponentSpecification("jdisc-security-filters")), + Dependencies.emptyDependencies()); + } + + @Override + public void getConfig(CloudTokenDataPlaneFilterConfig.Builder builder) { + var clientsCfg = clients.stream() + .filter(c -> !c.tokens().isEmpty()) + .map(x -> new CloudTokenDataPlaneFilterConfig.Clients.Builder() + .id(x.id()) + .tokens(tokensConfig(x.tokens())) + .permissions(x.permissions())) + .toList(); + builder.clients(clientsCfg).tokenContext(tokenContext); + } + + private static List<CloudTokenDataPlaneFilterConfig.Clients.Tokens.Builder> tokensConfig(Collection<DataplaneToken> tokens) { + return tokens.stream() + .map(token -> new CloudTokenDataPlaneFilterConfig.Clients.Tokens.Builder() + .id(token.tokenId()) + .fingerprints(token.versions().stream().map(DataplaneToken.Version::fingerprint).toList()) + .checkAccessHashes(token.versions().stream().map(DataplaneToken.Version::checkAccessHash).toList()) + .expirations(token.versions().stream().map(v -> v.expiration().map(Instant::toString).orElse("<none>")).toList())) + .toList(); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java index 00feb0a1c76..1036a615bb5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ContainerModelBuilder.java @@ -16,7 +16,6 @@ import com.yahoo.config.model.ConfigModelContext; import com.yahoo.config.model.api.ApplicationClusterEndpoint; import com.yahoo.config.model.api.ConfigServerSpec; import com.yahoo.config.model.api.ContainerEndpoint; -import com.yahoo.config.model.api.EndpointCertificateSecrets; import com.yahoo.config.model.api.TenantSecretStore; import com.yahoo.config.model.application.provider.IncludeDirs; import com.yahoo.config.model.builder.xml.ConfigModelBuilder; @@ -95,6 +94,7 @@ import com.yahoo.vespa.model.container.http.Http; import com.yahoo.vespa.model.container.http.HttpFilterChain; import com.yahoo.vespa.model.container.http.JettyHttpServer; import com.yahoo.vespa.model.container.http.ssl.HostedSslConnectorFactory; +import com.yahoo.vespa.model.container.http.ssl.HostedSslConnectorFactory.SslClientAuth; import com.yahoo.vespa.model.container.http.xml.HttpBuilder; import com.yahoo.vespa.model.container.processing.ProcessingChains; import com.yahoo.vespa.model.container.search.ContainerSearch; @@ -109,7 +109,6 @@ import java.io.IOException; import java.io.Reader; import java.net.URI; import java.security.cert.X509Certificate; -import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -140,9 +139,6 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { // Default path to vip status file for container in Hosted Vespa. static final String HOSTED_VESPA_STATUS_FILE = Defaults.getDefaults().underVespaHome("var/vespa/load-balancer/status.html"); - // Data plane port for hosted Vespa - public static final int HOSTED_VESPA_DATAPLANE_PORT = 4443; - //Path to vip status file for container in Hosted Vespa. Only used if set, else use HOSTED_VESPA_STATUS_FILE private static final String HOSTED_VESPA_STATUS_FILE_SETTING = "VESPA_LB_STATUS_FILE"; @@ -462,15 +458,16 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { addHostedImplicitHttpIfNotPresent(deployState, cluster); addHostedImplicitAccessControlIfNotPresent(deployState, cluster); addDefaultConnectorHostedFilterBinding(cluster); - addAdditionalHostedConnector(deployState, cluster); + addCloudMtlsConnector(deployState, cluster); addCloudDataPlaneFilter(deployState, cluster); + addCloudTokenSupport(deployState, cluster); } } private static void addCloudDataPlaneFilter(DeployState deployState, ApplicationContainerCluster cluster) { if (!deployState.isHosted() || !deployState.zone().system().isPublic()) return; - var dataplanePort = getDataplanePort(deployState); + var dataplanePort = getMtlsDataplanePort(deployState); // Setup secure filter chain var secureChain = new HttpFilterChain("cloud-data-plane-secure", HttpFilterChain.Type.SYSTEM); secureChain.addInnerComponent(new CloudDataPlaneFilter(cluster, deployState)); @@ -600,61 +597,83 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { .ifPresent(accessControl -> accessControl.configureDefaultHostedConnector(cluster.getHttp())); ; } - private void addAdditionalHostedConnector(DeployState deployState, ApplicationContainerCluster cluster) { + private void addCloudMtlsConnector(DeployState state, ApplicationContainerCluster cluster) { JettyHttpServer server = cluster.getHttp().getHttpServer().get(); String serverName = server.getComponentId().getName(); // If the deployment contains certificate/private key reference, setup TLS port - HostedSslConnectorFactory connectorFactory; - Collection<String> tlsCiphersOverride = deployState.getProperties().tlsCiphersOverride(); - boolean proxyProtocolMixedMode = deployState.getProperties().featureFlags().enableProxyProtocolMixedMode(); - Duration endpointConnectionTtl = deployState.getProperties().endpointConnectionTtl(); - var port = getDataplanePort(deployState); - if (deployState.endpointCertificateSecrets().isPresent()) { - boolean authorizeClient = deployState.zone().system().isPublic(); + var builder = HostedSslConnectorFactory.builder(serverName, getMtlsDataplanePort(state)) + .proxyProtocol(true, state.getProperties().featureFlags().enableProxyProtocolMixedMode()) + .tlsCiphersOverride(state.getProperties().tlsCiphersOverride()) + .endpointConnectionTtl(state.getProperties().endpointConnectionTtl()); + var endpointCert = state.endpointCertificateSecrets().orElse(null); + if (endpointCert != null) { + builder.endpointCertificate(endpointCert); + boolean isPublic = state.zone().system().isPublic(); List<X509Certificate> clientCertificates = getClientCertificates(cluster); - if (authorizeClient && clientCertificates.isEmpty()) { - throw new IllegalArgumentException("Client certificate authority security/clients.pem is missing - " + - "see: https://cloud.vespa.ai/en/security/guide#data-plane"); - } - EndpointCertificateSecrets endpointCertificateSecrets = deployState.endpointCertificateSecrets().get(); - - boolean enforceHandshakeClientAuth = cluster.getHttp().getAccessControl() - .map(accessControl -> accessControl.clientAuthentication) - .map(clientAuth -> clientAuth == AccessControl.ClientAuthentication.need) - .orElse(false); - - boolean enableTokenSupport = deployState.featureFlags().enableDataplaneProxy() - && cluster.getClients().stream().anyMatch(c -> !c.tokens().isEmpty()); - - // Set up component to generate proxy cert if token support is enabled - if (enableTokenSupport) { - cluster.addSimpleComponent(DataplaneProxyCredentials.class); - cluster.addSimpleComponent(DataplaneProxyService.class); - - var dataplaneProxy = new DataplaneProxy( - getDataplanePort(deployState), - endpointCertificateSecrets.certificate(), - endpointCertificateSecrets.key()); - cluster.addComponent(dataplaneProxy); + if (isPublic) { + if (clientCertificates.isEmpty()) + throw new IllegalArgumentException("Client certificate authority security/clients.pem is missing - " + + "see: https://cloud.vespa.ai/en/security/guide#data-plane"); + builder.tlsCaCertificatesPem(X509CertificateUtils.toPem(clientCertificates)) + .clientAuth(SslClientAuth.WANT_WITH_ENFORCER); + } else { + builder.tlsCaCertificatesPath("/opt/yahoo/share/ssl/certs/athenz_certificate_bundle.pem"); + var needAuth = cluster.getHttp().getAccessControl() + .map(accessControl -> accessControl.clientAuthentication) + .map(clientAuth -> clientAuth == AccessControl.ClientAuthentication.need) + .orElse(false); + builder.clientAuth(needAuth ? SslClientAuth.NEED : SslClientAuth.WANT); } - - connectorFactory = authorizeClient - ? HostedSslConnectorFactory.withProvidedCertificateAndTruststore( - serverName, endpointCertificateSecrets, X509CertificateUtils.toPem(clientCertificates), - tlsCiphersOverride, proxyProtocolMixedMode, port, endpointConnectionTtl, enableTokenSupport) - : HostedSslConnectorFactory.withProvidedCertificate( - serverName, endpointCertificateSecrets, enforceHandshakeClientAuth, tlsCiphersOverride, - proxyProtocolMixedMode, port, endpointConnectionTtl, enableTokenSupport); } else { - connectorFactory = HostedSslConnectorFactory.withDefaultCertificateAndTruststore( - serverName, tlsCiphersOverride, proxyProtocolMixedMode, port, - endpointConnectionTtl); + builder.clientAuth(SslClientAuth.WANT_WITH_ENFORCER); } + var connectorFactory = builder.build(); cluster.getHttp().getAccessControl().ifPresent(accessControl -> accessControl.configureHostedConnector(connectorFactory)); server.addConnector(connectorFactory); } + private void addCloudTokenSupport(DeployState state, ApplicationContainerCluster cluster) { + var server = cluster.getHttp().getHttpServer().get(); + boolean enableTokenSupport = state.isHosted() && state.zone().system().isPublic() + && state.featureFlags().enableDataplaneProxy() + && cluster.getClients().stream().anyMatch(c -> !c.tokens().isEmpty()); + if (!enableTokenSupport) return; + var endpointCert = state.endpointCertificateSecrets().orElseThrow(); + int tokenPort = getTokenDataplanePort(state).orElseThrow(); + + // Set up component to generate proxy cert if token support is enabled + cluster.addSimpleComponent(DataplaneProxyCredentials.class); + cluster.addSimpleComponent(DataplaneProxyService.class); + var dataplaneProxy = new DataplaneProxy( + getMtlsDataplanePort(state), + tokenPort, + endpointCert.certificate(), + endpointCert.key()); + cluster.addComponent(dataplaneProxy); + + // Setup dedicated connector + var connector = HostedSslConnectorFactory.builder(server.getComponentId().getName()+"-token", tokenPort) + .tokenEndpoint(true) + .proxyProtocol(false, false) + .endpointCertificate(endpointCert) + .remoteAddressHeader("X-Forwarded-For") + .remotePortHeader("X-Forwarded-Port") + .clientAuth(SslClientAuth.NEED) + .build(); + server.addConnector(connector); + + // Setup token filter chain + var tokenChain = new HttpFilterChain("cloud-token-data-plane-secure", HttpFilterChain.Type.SYSTEM); + tokenChain.addInnerComponent(new CloudTokenDataPlaneFilter(cluster, state)); + cluster.getHttp().getFilterChains().add(tokenChain); + + // Set as default filter for token port + cluster.getHttp().getHttpServer().orElseThrow().getConnectorFactories().stream() + .filter(c -> c.getListenPort() == tokenPort).findAny().orElseThrow() + .setDefaultRequestFilterChain(tokenChain.getComponentId()); + } + // Returns the client certificates of the clients defined for an application cluster private List<X509Certificate> getClientCertificates(ApplicationContainerCluster cluster) { return cluster.getClients() @@ -814,7 +833,7 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } private void addUserHandlers(DeployState deployState, ApplicationContainerCluster cluster, Element spec, ConfigModelContext context) { - OptionalInt portBindingOverride = isHostedTenantApplication(context) ? OptionalInt.of(getDataplanePort(deployState)) : OptionalInt.empty(); + var portBindingOverride = isHostedTenantApplication(context) ? getDataplanePorts(deployState) : Set.<Integer>of(); for (Element component: XML.getChildren(spec, "handler")) { cluster.addComponent( new DomHandlerBuilder(cluster, portBindingOverride).build(deployState, cluster, component)); @@ -1103,12 +1122,12 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } private void addSearchHandler(DeployState deployState, ApplicationContainerCluster cluster, Element searchElement, ConfigModelContext context) { - BindingPattern bindingPattern = SearchHandler.DEFAULT_BINDING; + var bindingPatterns = List.<BindingPattern>of(SearchHandler.DEFAULT_BINDING); if (isHostedTenantApplication(context) && deployState.featureFlags().useRestrictedDataPlaneBindings()) { - bindingPattern = SearchHandler.bindingPattern(Optional.of(Integer.toString(getDataplanePort(deployState)))); + bindingPatterns = SearchHandler.bindingPattern(getDataplanePorts(deployState)); } SearchHandler searchHandler = new SearchHandler(cluster, - serverBindings(deployState, context, searchElement, bindingPattern), + serverBindings(deployState, context, searchElement, bindingPatterns), ContainerThreadpool.UserOptions.fromXml(searchElement).orElse(null)); cluster.addComponent(searchHandler); @@ -1116,41 +1135,43 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { searchHandler.addComponent(Component.fromClassAndBundle(SearchHandler.EXECUTION_FACTORY, PlatformBundles.SEARCH_AND_DOCPROC_BUNDLE)); } - private List<BindingPattern> serverBindings(DeployState deployState, ConfigModelContext context, Element searchElement, BindingPattern... defaultBindings) { + private List<BindingPattern> serverBindings(DeployState deployState, ConfigModelContext context, Element searchElement, Collection<BindingPattern> defaultBindings) { List<Element> bindings = XML.getChildren(searchElement, "binding"); if (bindings.isEmpty()) - return List.of(defaultBindings); + return List.copyOf(defaultBindings); return toBindingList(deployState, context, bindings); } private List<BindingPattern> toBindingList(DeployState deployState, ConfigModelContext context, List<Element> bindingElements) { List<BindingPattern> result = new ArrayList<>(); - OptionalInt portOverride = isHostedTenantApplication(context) && deployState.featureFlags().useRestrictedDataPlaneBindings() ? OptionalInt.of(getDataplanePort(deployState)) : OptionalInt.empty(); + var portOverride = isHostedTenantApplication(context) && deployState.featureFlags().useRestrictedDataPlaneBindings() ? getDataplanePorts(deployState) : Set.<Integer>of(); for (Element element: bindingElements) { String text = element.getTextContent().trim(); if (!text.isEmpty()) - result.add(userBindingPattern(text, portOverride)); + result.addAll(userBindingPattern(text, portOverride)); } return result; } - private static UserBindingPattern userBindingPattern(String path, OptionalInt portOverride) { + private static Collection<UserBindingPattern> userBindingPattern(String path, Set<Integer> portBindingOverride) { UserBindingPattern bindingPattern = UserBindingPattern.fromPattern(path); - return portOverride.isPresent() - ? bindingPattern.withPort(portOverride.getAsInt()) - : bindingPattern; + if (portBindingOverride.isEmpty()) return Set.of(bindingPattern); + return portBindingOverride.stream() + .map(bindingPattern::withPort) + .toList(); } + private ContainerDocumentApi buildDocumentApi(DeployState deployState, ApplicationContainerCluster cluster, Element spec, ConfigModelContext context) { Element documentApiElement = XML.getChild(spec, "document-api"); if (documentApiElement == null) return null; ContainerDocumentApi.HandlerOptions documentApiOptions = DocumentApiOptionsBuilder.build(documentApiElement); Element ignoreUndefinedFields = XML.getChild(documentApiElement, "ignore-undefined-fields"); - OptionalInt portBindingOverride = deployState.featureFlags().useRestrictedDataPlaneBindings() && isHostedTenantApplication(context) - ? OptionalInt.of(getDataplanePort(deployState)) - : OptionalInt.empty(); + var portBindingOverride = deployState.featureFlags().useRestrictedDataPlaneBindings() && isHostedTenantApplication(context) + ? getDataplanePorts(deployState) + : Set.<Integer>of(); return new ContainerDocumentApi(cluster, documentApiOptions, "true".equals(XML.getValue(ignoreUndefinedFields)), portBindingOverride); } @@ -1410,8 +1431,18 @@ public class ContainerModelBuilder extends ConfigModelBuilder<ContainerModel> { } - private static int getDataplanePort(DeployState deployState) { - return deployState.featureFlags().enableDataplaneProxy() ? 8443 : HOSTED_VESPA_DATAPLANE_PORT; + private static Set<Integer> getDataplanePorts(DeployState ds) { + var tokenPort = getTokenDataplanePort(ds); + var mtlsPort = getMtlsDataplanePort(ds); + return tokenPort.isPresent() ? Set.of(mtlsPort, tokenPort.getAsInt()) : Set.of(mtlsPort); + } + + private static int getMtlsDataplanePort(DeployState ds) { + return ds.featureFlags().enableDataplaneProxy() ? 8443 : 4443; + } + + private static OptionalInt getTokenDataplanePort(DeployState ds) { + return ds.featureFlags().enableDataplaneProxy() ? OptionalInt.of(8444) : OptionalInt.empty(); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index 96f653bf793..f3e02adff6b 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -38,9 +38,18 @@ public class ModelIdResolver { models.put("multilingual-e5-base", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/model.onnx"); models.put("multilingual-e5-base-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json"); + models.put("multilingual-e5-small", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/model.onnx"); + models.put("multilingual-e5-small-vocab", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small/tokenizer.json"); + + models.put("multilingual-e5-small-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/model.onnx"); + models.put("multilingual-e5-small-vocab-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/multilingual-e5-small-cpu-friendly/tokenizer.json"); + models.put("e5-small-v2", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/model.onnx"); models.put("e5-small-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2/tokenizer.json"); + models.put("e5-small-v2-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/model.onnx"); + models.put("e5-small-v2-vocab-cpu-friendly", "https://data.vespa.oath.cloud/onnx_models/e5-small-v2-cpu-friendly/tokenizer.json"); + models.put("e5-base-v2", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx"); models.put("e5-base-v2-vocab", "https://data.vespa.oath.cloud/onnx_models/e5-base-v2/tokenizer.json"); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/SearchHandler.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/SearchHandler.java index ebb22b2b73b..6cfef153fee 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/SearchHandler.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/SearchHandler.java @@ -10,8 +10,8 @@ import com.yahoo.vespa.model.container.component.SystemBindingPattern; import com.yahoo.vespa.model.container.component.chain.ProcessingHandler; import com.yahoo.vespa.model.container.search.searchchain.SearchChains; +import java.util.Collection; import java.util.List; -import java.util.Optional; import static com.yahoo.container.bundle.BundleInstantiationSpecification.fromSearchAndDocproc; @@ -28,7 +28,7 @@ class SearchHandler extends ProcessingHandler<SearchChains> { static final String EXECUTION_FACTORY_CLASSNAME = EXECUTION_FACTORY.getName(); static final BundleInstantiationSpecification HANDLER_SPEC = fromSearchAndDocproc(HANDLER_CLASSNAME); - static final BindingPattern DEFAULT_BINDING = bindingPattern(Optional.empty()); + static final BindingPattern DEFAULT_BINDING = SystemBindingPattern.fromHttpPath("/search/*"); SearchHandler(ApplicationContainerCluster cluster, List<BindingPattern> bindings, @@ -37,12 +37,11 @@ class SearchHandler extends ProcessingHandler<SearchChains> { bindings.forEach(this::addServerBindings); } - static BindingPattern bindingPattern(Optional<String> port) { - String path = "/search/*"; - return port - .filter(s -> !s.isBlank()) - .map(s -> SystemBindingPattern.fromHttpPortAndPath(s, path)) - .orElseGet(() -> SystemBindingPattern.fromHttpPath(path)); + static List<BindingPattern> bindingPattern(Collection<Integer> ports) { + if (ports.isEmpty()) return List.of(DEFAULT_BINDING); + return ports.stream() + .map(s -> (BindingPattern)SystemBindingPattern.fromHttpPortAndPath(s, DEFAULT_BINDING.path())) + .toList(); } private static class Threadpool extends ContainerThreadpool { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java index bbc73e848d3..697d2d422e8 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/AccessControlTest.java @@ -16,7 +16,6 @@ import com.yahoo.jdisc.http.ConnectorConfig; import com.yahoo.path.Path; import com.yahoo.security.X509CertificateUtils; import com.yahoo.security.tls.TlsContext; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.model.container.ApplicationContainer; import com.yahoo.vespa.model.container.http.AccessControl; import com.yahoo.vespa.model.container.http.ConnectorFactory; @@ -37,10 +36,15 @@ import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; +import static com.yahoo.jdisc.http.ConnectorConfig.Ssl.ClientAuth.Enum.WANT_AUTH; import static com.yahoo.vespa.defaults.Defaults.getDefaults; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; /** * @author gjoranv @@ -280,7 +284,8 @@ public class AccessControlTest extends ContainerModelBuilderTestBase { new TestProperties() .setAthenzDomain(tenantDomain) .setHostedVespa(true) - .allowDisableMtls(true)) + .allowDisableMtls(true) + .setEndpointCertificateSecrets(Optional.of(new EndpointCertificateSecrets("CERT", "KEY")))) .build(); Http http = createModelAndGetHttp(state, " <http>", @@ -290,6 +295,13 @@ public class AccessControlTest extends ContainerModelBuilderTestBase { " </http>"); assertTrue(http.getAccessControl().isPresent()); assertEquals(AccessControl.ClientAuthentication.want, http.getAccessControl().get().clientAuthentication); + var tlsPort = http.getHttpServer().get().getConnectorFactories().stream() + .filter(connectorFactory -> connectorFactory.getListenPort() == 4443).findFirst().orElseThrow(); + var builder = new ConnectorConfig.Builder(); + tlsPort.getConfig(builder); + var connectorConfig = new ConnectorConfig(builder); + assertFalse(connectorConfig.tlsClientAuthEnforcer().enable()); + assertEquals(WANT_AUTH, connectorConfig.ssl().clientAuth()); } @Test @@ -497,7 +509,7 @@ public class AccessControlTest extends ContainerModelBuilderTestBase { ConnectorConfig connectorConfig = new ConnectorConfig(builder); assertTrue(connectorConfig.ssl().enabled()); - assertEquals(ConnectorConfig.Ssl.ClientAuth.Enum.WANT_AUTH, connectorConfig.ssl().clientAuth()); + assertEquals(WANT_AUTH, connectorConfig.ssl().clientAuth()); assertEquals("CERT", connectorConfig.ssl().certificate()); assertEquals("KEY", connectorConfig.ssl().privateKey()); assertEquals(4443, connectorConfig.listenPort()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java index 02ff7b8a03f..94d92b355f9 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudDataPlaneFilterTest.java @@ -6,7 +6,6 @@ import com.yahoo.config.model.builder.xml.test.DomBuilderTest; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; -import com.yahoo.config.provision.DataplaneToken; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; @@ -35,17 +34,14 @@ import java.nio.file.Files; import java.nio.file.Path; import java.security.KeyPair; import java.security.cert.X509Certificate; -import java.time.Duration; import java.time.Instant; import java.time.temporal.ChronoUnit; -import java.util.Collection; import java.util.List; import java.util.Optional; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertIterableEquals; -import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -89,7 +85,6 @@ public class CloudDataPlaneFilterTest extends ContainerModelBuilderTestBase { CloudDataPlaneFilterConfig.Clients client = clients.get(0); assertEquals("foo", client.id()); assertIterableEquals(List.of("read", "write"), client.permissions()); - assertTrue(client.tokens().isEmpty()); assertIterableEquals(List.of(X509CertificateUtils.toPem(certificate)), client.certificates()); ConnectorConfig connectorConfig = connectorConfig(); @@ -123,43 +118,6 @@ public class CloudDataPlaneFilterTest extends ContainerModelBuilderTestBase { } @Test - void generates_correct_config_for_tokens() throws IOException { - var certFile = securityFolder.resolve("foo.pem"); - var clusterElem = DomBuilderTest.parse( - """ - <container version='1.0'> - <clients> - <client id="foo" permissions="read,write"> - <certificate file="%s"/> - </client> - <client id="bar" permissions="read"> - <token id="my-token"/> - </client> - </clients> - </container> - """ - .formatted(applicationFolder.toPath().relativize(certFile).toString())); - createCertificate(certFile); - buildModel(clusterElem); - - var cfg = root.getConfig(CloudDataPlaneFilterConfig.class, cloudDataPlaneFilterConfigId); - var tokenClient = cfg.clients().stream().filter(c -> c.id().equals("bar")).findAny().orElse(null); - assertNotNull(tokenClient); - assertEquals(List.of("read"), tokenClient.permissions()); - assertTrue(tokenClient.certificates().isEmpty()); - var expectedTokenCfg = tokenConfig( - "my-token", List.of("myfingerprint1", "myfingerprint2"), List.of("myaccesshash1", "myaccesshash2"), - List.of("<none>", "2243-10-17T00:00:00Z")); - assertEquals(List.of(expectedTokenCfg), tokenClient.tokens()); - } - - private static CloudDataPlaneFilterConfig.Clients.Tokens tokenConfig( - String id, Collection<String> fingerprints, Collection<String> accessCheckHashes, Collection<String> expirations) { - return new CloudDataPlaneFilterConfig.Clients.Tokens.Builder() - .id(id).fingerprints(fingerprints).checkAccessHashes(accessCheckHashes).expirations(expirations).build(); - } - - @Test public void it_rejects_files_without_certificates() throws IOException { Path certFile = securityFolder.resolve("foo.pem"); Element clusterElem = DomBuilderTest.parse( @@ -231,9 +189,6 @@ public class CloudDataPlaneFilterTest extends ContainerModelBuilderTestBase { .properties( new TestProperties() .setEndpointCertificateSecrets(Optional.of(new EndpointCertificateSecrets("CERT", "KEY"))) - .setDataplaneTokens(List.of(new DataplaneToken("my-token", List.of( - new DataplaneToken.Version("myfingerprint1", "myaccesshash1", Optional.empty()), - new DataplaneToken.Version("myfingerprint2", "myaccesshash2", Optional.of(Instant.EPOCH.plus(Duration.ofDays(100000)))))))) .setHostedVespa(true)) .zone(new Zone(SystemName.PublicCd, Environment.dev, RegionName.defaultName())) .build(); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java new file mode 100644 index 00000000000..15e1d61c951 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/CloudTokenDataPlaneFilterTest.java @@ -0,0 +1,105 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.model.container.xml; + +import com.yahoo.config.model.api.EndpointCertificateSecrets; +import com.yahoo.config.model.builder.xml.test.DomBuilderTest; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.config.model.test.MockApplicationPackage; +import com.yahoo.config.provision.DataplaneToken; +import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.RegionName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.Zone; +import com.yahoo.jdisc.http.filter.security.cloud.config.CloudTokenDataPlaneFilterConfig; +import com.yahoo.vespa.model.container.ContainerModel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.w3c.dom.Element; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.time.Instant; +import java.util.Collection; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.vespa.model.container.xml.CloudDataPlaneFilterTest.createCertificate; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class CloudTokenDataPlaneFilterTest extends ContainerModelBuilderTestBase { + + @TempDir + public File applicationFolder; + + Path securityFolder; + private static final String filterConfigId = "container/filters/chain/cloud-token-data-plane-secure/component/" + + "com.yahoo.jdisc.http.filter.security.cloud.CloudTokenDataPlaneFilter"; + + @BeforeEach + public void setup() throws IOException { + securityFolder = applicationFolder.toPath().resolve("security"); + Files.createDirectories(securityFolder); + } + + @Test + void generates_correct_config_for_tokens() throws IOException { + var certFile = securityFolder.resolve("foo.pem"); + var clusterElem = DomBuilderTest.parse( + """ + <container version='1.0'> + <clients> + <client id="foo" permissions="read,write"> + <certificate file="%s"/> + </client> + <client id="bar" permissions="read"> + <token id="my-token"/> + </client> + </clients> + </container> + """ + .formatted(applicationFolder.toPath().relativize(certFile).toString())); + createCertificate(certFile); + buildModel(clusterElem); + + var cfg = root.getConfig(CloudTokenDataPlaneFilterConfig.class, filterConfigId); + var tokenClient = cfg.clients().stream().filter(c -> c.id().equals("bar")).findAny().orElse(null); + assertNotNull(tokenClient); + assertEquals(List.of("read"), tokenClient.permissions()); + var expectedTokenCfg = tokenConfig( + "my-token", List.of("myfingerprint1", "myfingerprint2"), List.of("myaccesshash1", "myaccesshash2"), + List.of("<none>", "2243-10-17T00:00:00Z")); + assertEquals(List.of(expectedTokenCfg), tokenClient.tokens()); + } + + private static CloudTokenDataPlaneFilterConfig.Clients.Tokens tokenConfig( + String id, Collection<String> fingerprints, Collection<String> accessCheckHashes, Collection<String> expirations) { + return new CloudTokenDataPlaneFilterConfig.Clients.Tokens.Builder() + .id(id).fingerprints(fingerprints).checkAccessHashes(accessCheckHashes).expirations(expirations).build(); + } + + public List<ContainerModel> buildModel(Element... clusterElem) { + var applicationPackage = new MockApplicationPackage.Builder() + .withRoot(applicationFolder) + .build(); + + DeployState state = new DeployState.Builder() + .applicationPackage(applicationPackage) + .properties( + new TestProperties() + .setEnableDataplaneProxy(true) + .setEndpointCertificateSecrets(Optional.of(new EndpointCertificateSecrets("CERT", "KEY"))) + .setDataplaneTokens(List.of(new DataplaneToken("my-token", List.of( + new DataplaneToken.Version("myfingerprint1", "myaccesshash1", Optional.empty()), + new DataplaneToken.Version("myfingerprint2", "myaccesshash2", Optional.of(Instant.EPOCH.plus(Duration.ofDays(100000)))))))) + .setHostedVespa(true)) + .zone(new Zone(SystemName.PublicCd, Environment.dev, RegionName.defaultName())) + .build(); + return createModel(root, state, null, clusterElem); + } +} diff --git a/configdefinitions/src/vespa/CMakeLists.txt b/configdefinitions/src/vespa/CMakeLists.txt index 85fc1158afe..29ed0f53421 100644 --- a/configdefinitions/src/vespa/CMakeLists.txt +++ b/configdefinitions/src/vespa/CMakeLists.txt @@ -89,3 +89,4 @@ install_config_definition(hugging-face-embedder.def embedding.huggingface.huggin install_config_definition(hugging-face-tokenizer.def language.huggingface.config.hugging-face-tokenizer.def) install_config_definition(bert-base-embedder.def embedding.bert-base-embedder.def) install_config_definition(cloud-data-plane-filter.def jdisc.http.filter.security.cloud.config.cloud-data-plane-filter.def) +install_config_definition(cloud-token-data-plane-filter.def jdisc.http.filter.security.cloud.config.cloud-token-data-plane-filter.def) diff --git a/configdefinitions/src/vespa/cloud-data-plane-filter.def b/configdefinitions/src/vespa/cloud-data-plane-filter.def index d73c5a49c81..47478a28039 100644 --- a/configdefinitions/src/vespa/cloud-data-plane-filter.def +++ b/configdefinitions/src/vespa/cloud-data-plane-filter.def @@ -2,11 +2,6 @@ namespace=jdisc.http.filter.security.cloud.config legacyMode bool default=false -tokenContext string default="" clients[].id string clients[].permissions[] string clients[].certificates[] string -clients[].tokens[].id string -clients[].tokens[].fingerprints[] string -clients[].tokens[].checkAccessHashes[] string -clients[].tokens[].expirations[] string diff --git a/configdefinitions/src/vespa/cloud-token-data-plane-filter.def b/configdefinitions/src/vespa/cloud-token-data-plane-filter.def new file mode 100644 index 00000000000..3219ae4fa48 --- /dev/null +++ b/configdefinitions/src/vespa/cloud-token-data-plane-filter.def @@ -0,0 +1,10 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +namespace=jdisc.http.filter.security.cloud.config + +tokenContext string default="" +clients[].id string +clients[].permissions[] string +clients[].tokens[].id string +clients[].tokens[].fingerprints[] string +clients[].tokens[].checkAccessHashes[] string +clients[].tokens[].expirations[] string diff --git a/configdefinitions/src/vespa/dataplane-proxy.def b/configdefinitions/src/vespa/dataplane-proxy.def index 9ce3e4b4b7b..dd1d734a91c 100644 --- a/configdefinitions/src/vespa/dataplane-proxy.def +++ b/configdefinitions/src/vespa/dataplane-proxy.def @@ -2,7 +2,8 @@ namespace=cloud.config # The port Jdisc will be listening on -port int +tokenPort int +mtlsPort int # Server certificate and key to be used when creating server socket serverCertificate string diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index dac881cf5ee..e35520d5381 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -188,7 +188,6 @@ public class ModelContextImpl implements ModelContext { private final boolean useV8GeoPositions; private final int maxCompactBuffers; private final List<String> ignoredHttpUserAgents; - private final boolean useQrserverServiceName; private final boolean avoidRenamingSummaryFeatures; private final Architecture adminClusterArchitecture; private final boolean enableProxyProtocolMixedMode; @@ -238,7 +237,6 @@ public class ModelContextImpl implements ModelContext { this.useV8GeoPositions = flagValue(source, appId, version, Flags.USE_V8_GEO_POSITIONS); this.maxCompactBuffers = flagValue(source, appId, version, Flags.MAX_COMPACT_BUFFERS); this.ignoredHttpUserAgents = flagValue(source, appId, version, PermanentFlags.IGNORED_HTTP_USER_AGENTS); - this.useQrserverServiceName = flagValue(source, appId, version, Flags.USE_QRSERVER_SERVICE_NAME); this.avoidRenamingSummaryFeatures = flagValue(source, appId, version, Flags.AVOID_RENAMING_SUMMARY_FEATURES); this.adminClusterArchitecture = Architecture.valueOf(flagValue(source, appId, version, PermanentFlags.ADMIN_CLUSTER_NODE_ARCHITECTURE)); this.enableProxyProtocolMixedMode = flagValue(source, appId, version, Flags.ENABLE_PROXY_PROTOCOL_MIXED_MODE); @@ -295,7 +293,6 @@ public class ModelContextImpl implements ModelContext { @Override public boolean useV8GeoPositions() { return useV8GeoPositions; } @Override public int maxCompactBuffers() { return maxCompactBuffers; } @Override public List<String> ignoredHttpUserAgents() { return ignoredHttpUserAgents; } - @Override public boolean useQrserverServiceName() { return useQrserverServiceName; } @Override public boolean avoidRenamingSummaryFeatures() { return avoidRenamingSummaryFeatures; } @Override public Architecture adminClusterArchitecture() { return adminClusterArchitecture; } @Override public boolean enableProxyProtocolMixedMode() { return enableProxyProtocolMixedMode; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java index b2762b2a3d4..b33b21691af 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationApiHandler.java @@ -98,8 +98,8 @@ public class ApplicationApiHandler extends SessionHandler { "Unable to parse multipart in deploy from tenant '" + tenantName.value() + "': " + Exceptions.toMessageString(e)); var message = "Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage(); - log.log(FINE, message + ", parts: " + parts, e); - throw new BadRequestException("Deploy request from '" + tenantName.value() + "' contains invalid data: " + e.getMessage()); + log.log(INFO, message + ", parts: " + parts, e); + throw new BadRequestException(message); } } else { prepareParams = PrepareParams.fromHttpRequest(request, tenantName, zookeeperBarrierTimeout); diff --git a/container-core/abi-spec.json b/container-core/abi-spec.json index 757afeb64e2..6d7e3c86351 100644 --- a/container-core/abi-spec.json +++ b/container-core/abi-spec.json @@ -1027,6 +1027,45 @@ ], "fields" : [ ] }, + "com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder" : { + "superClass" : "java.lang.Object", + "interfaces" : [ + "com.yahoo.config.ConfigBuilder" + ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void <init>()", + "public void <init>(com.yahoo.jdisc.http.ConnectorConfig$AccessLog)", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder remoteAddressHeaders(java.lang.String)", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder remoteAddressHeaders(java.util.Collection)", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder remotePortHeaders(java.lang.String)", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder remotePortHeaders(java.util.Collection)", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog build()" + ], + "fields" : [ + "public java.util.List remoteAddressHeaders", + "public java.util.List remotePortHeaders" + ] + }, + "com.yahoo.jdisc.http.ConnectorConfig$AccessLog" : { + "superClass" : "com.yahoo.config.InnerNode", + "interfaces" : [ ], + "attributes" : [ + "public", + "final" + ], + "methods" : [ + "public void <init>(com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder)", + "public java.util.List remoteAddressHeaders()", + "public java.lang.String remoteAddressHeaders(int)", + "public java.util.List remotePortHeaders()", + "public java.lang.String remotePortHeaders(int)" + ], + "fields" : [ ] + }, "com.yahoo.jdisc.http.ConnectorConfig$Builder" : { "superClass" : "java.lang.Object", "interfaces" : [ @@ -1069,6 +1108,8 @@ "public com.yahoo.jdisc.http.ConnectorConfig$Builder http2(java.util.function.Consumer)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder serverName(com.yahoo.jdisc.http.ConnectorConfig$ServerName$Builder)", "public com.yahoo.jdisc.http.ConnectorConfig$Builder serverName(java.util.function.Consumer)", + "public com.yahoo.jdisc.http.ConnectorConfig$Builder accessLog(com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder)", + "public com.yahoo.jdisc.http.ConnectorConfig$Builder accessLog(java.util.function.Consumer)", "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", "public final java.lang.String getDefMd5()", "public final java.lang.String getDefName()", @@ -1084,7 +1125,8 @@ "public com.yahoo.jdisc.http.ConnectorConfig$HealthCheckProxy$Builder healthCheckProxy", "public com.yahoo.jdisc.http.ConnectorConfig$ProxyProtocol$Builder proxyProtocol", "public com.yahoo.jdisc.http.ConnectorConfig$Http2$Builder http2", - "public com.yahoo.jdisc.http.ConnectorConfig$ServerName$Builder serverName" + "public com.yahoo.jdisc.http.ConnectorConfig$ServerName$Builder serverName", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog$Builder accessLog" ] }, "com.yahoo.jdisc.http.ConnectorConfig$HealthCheckProxy$Builder" : { @@ -1438,7 +1480,8 @@ "public double maxConnectionLife()", "public boolean http2Enabled()", "public com.yahoo.jdisc.http.ConnectorConfig$Http2 http2()", - "public com.yahoo.jdisc.http.ConnectorConfig$ServerName serverName()" + "public com.yahoo.jdisc.http.ConnectorConfig$ServerName serverName()", + "public com.yahoo.jdisc.http.ConnectorConfig$AccessLog accessLog()" ], "fields" : [ "public static final java.lang.String CONFIG_DEF_MD5", @@ -1771,45 +1814,6 @@ ], "fields" : [ ] }, - "com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder" : { - "superClass" : "java.lang.Object", - "interfaces" : [ - "com.yahoo.config.ConfigBuilder" - ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public void <init>()", - "public void <init>(com.yahoo.jdisc.http.ServerConfig$AccessLog)", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder remoteAddressHeaders(java.lang.String)", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder remoteAddressHeaders(java.util.Collection)", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder remotePortHeaders(java.lang.String)", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder remotePortHeaders(java.util.Collection)", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog build()" - ], - "fields" : [ - "public java.util.List remoteAddressHeaders", - "public java.util.List remotePortHeaders" - ] - }, - "com.yahoo.jdisc.http.ServerConfig$AccessLog" : { - "superClass" : "com.yahoo.config.InnerNode", - "interfaces" : [ ], - "attributes" : [ - "public", - "final" - ], - "methods" : [ - "public void <init>(com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder)", - "public java.util.List remoteAddressHeaders()", - "public java.lang.String remoteAddressHeaders(int)", - "public java.util.List remotePortHeaders()", - "public java.lang.String remotePortHeaders(int)" - ], - "fields" : [ ] - }, "com.yahoo.jdisc.http.ServerConfig$Builder" : { "superClass" : "java.lang.Object", "interfaces" : [ @@ -1839,8 +1843,6 @@ "public com.yahoo.jdisc.http.ServerConfig$Builder jmx(java.util.function.Consumer)", "public com.yahoo.jdisc.http.ServerConfig$Builder metric(com.yahoo.jdisc.http.ServerConfig$Metric$Builder)", "public com.yahoo.jdisc.http.ServerConfig$Builder metric(java.util.function.Consumer)", - "public com.yahoo.jdisc.http.ServerConfig$Builder accessLog(com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder)", - "public com.yahoo.jdisc.http.ServerConfig$Builder accessLog(java.util.function.Consumer)", "public com.yahoo.jdisc.http.ServerConfig$Builder connectionLog(com.yahoo.jdisc.http.ServerConfig$ConnectionLog$Builder)", "public com.yahoo.jdisc.http.ServerConfig$Builder connectionLog(java.util.function.Consumer)", "public final boolean dispatchGetConfig(com.yahoo.config.ConfigInstance$Producer)", @@ -1856,7 +1858,6 @@ "public java.util.List defaultFilters", "public com.yahoo.jdisc.http.ServerConfig$Jmx$Builder jmx", "public com.yahoo.jdisc.http.ServerConfig$Metric$Builder metric", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog$Builder accessLog", "public com.yahoo.jdisc.http.ServerConfig$ConnectionLog$Builder connectionLog" ] }, @@ -2070,7 +2071,6 @@ "public double stopTimeout()", "public com.yahoo.jdisc.http.ServerConfig$Jmx jmx()", "public com.yahoo.jdisc.http.ServerConfig$Metric metric()", - "public com.yahoo.jdisc.http.ServerConfig$AccessLog accessLog()", "public com.yahoo.jdisc.http.ServerConfig$ConnectionLog connectionLog()" ], "fields" : [ diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLog.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLog.java index 5b51eeee7d6..7a305c23ba3 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLog.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLog.java @@ -7,8 +7,6 @@ import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.container.logging.RequestLog; import com.yahoo.container.logging.RequestLogEntry; import com.yahoo.jdisc.http.HttpRequest; -import com.yahoo.jdisc.http.ServerConfig; -import jakarta.servlet.http.HttpServletRequest; import org.eclipse.jetty.http2.HTTP2Stream; import org.eclipse.jetty.http2.server.HttpTransportOverHTTP2; import org.eclipse.jetty.server.HttpChannel; @@ -27,6 +25,7 @@ import java.util.function.BiConsumer; import java.util.logging.Level; import java.util.logging.Logger; +import static com.yahoo.jdisc.http.server.jetty.RequestUtils.getConnector; import static com.yahoo.jdisc.http.server.jetty.RequestUtils.getConnectorLocalPort; /** @@ -44,13 +43,9 @@ class AccessLogRequestLog extends AbstractLifeCycle implements org.eclipse.jetty private static final List<String> LOGGED_REQUEST_HEADERS = List.of("Vespa-Client-Version"); private final RequestLog requestLog; - private final List<String> remoteAddressHeaders; - private final List<String> remotePortHeaders; - AccessLogRequestLog(RequestLog requestLog, ServerConfig.AccessLog config) { + AccessLogRequestLog(RequestLog requestLog) { this.requestLog = requestLog; - this.remoteAddressHeaders = config.remoteAddressHeaders(); - this.remotePortHeaders = config.remotePortHeaders(); } @Override @@ -144,16 +139,16 @@ class AccessLogRequestLog extends AbstractLifeCycle implements org.eclipse.jetty } } - private String getRemoteAddress(HttpServletRequest request) { - for (String header : remoteAddressHeaders) { + private String getRemoteAddress(Request request) { + for (String header : getConnector(request).connectorConfig().accessLog().remoteAddressHeaders()) { String value = request.getHeader(header); if (value != null) return value; } return request.getRemoteAddr(); } - private int getRemotePort(HttpServletRequest request) { - for (String header : remotePortHeaders) { + private int getRemotePort(Request request) { + for (String header : getConnector(request).connectorConfig().accessLog().remotePortHeaders()) { String value = request.getHeader(header); if (value != null) { OptionalInt maybePort = parsePort(value); diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java index 3ebb65e7979..7d84ee6f8a3 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyHttpServer.java @@ -68,7 +68,7 @@ public class JettyHttpServer extends AbstractServerProvider { server = new Server(); server.setStopTimeout((long)(serverConfig.stopTimeout() * 1000.0)); - server.setRequestLog(new AccessLogRequestLog(requestLog, serverConfig.accessLog())); + server.setRequestLog(new AccessLogRequestLog(requestLog)); setupJmx(server, serverConfig); configureJettyThreadpool(server, serverConfig); JettyConnectionLogger connectionLogger = new JettyConnectionLogger(serverConfig.connectionLog(), connectionLog); diff --git a/container-core/src/main/java/com/yahoo/jdisc/http/ssl/impl/CloudSslContextProvider.java b/container-core/src/main/java/com/yahoo/jdisc/http/ssl/impl/CloudTokenSslContextProvider.java index cdfd4aa938e..fe71d1b24c6 100644 --- a/container-core/src/main/java/com/yahoo/jdisc/http/ssl/impl/CloudSslContextProvider.java +++ b/container-core/src/main/java/com/yahoo/jdisc/http/ssl/impl/CloudTokenSslContextProvider.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.http.ssl.impl; +import com.yahoo.component.annotation.Inject; import com.yahoo.jdisc.http.ConnectorConfig; import com.yahoo.jdisc.http.server.jetty.DataplaneProxyCredentials; @@ -14,29 +15,23 @@ import java.util.Optional; * * @author mortent */ -public class CloudSslContextProvider extends ConfiguredSslContextFactoryProvider { +public class CloudTokenSslContextProvider extends ConfiguredSslContextFactoryProvider { private final DataplaneProxyCredentials dataplaneProxyCredentials; - public CloudSslContextProvider(ConnectorConfig connectorConfig, DataplaneProxyCredentials dataplaneProxyCredentials) { + @Inject + public CloudTokenSslContextProvider(ConnectorConfig connectorConfig, + DataplaneProxyCredentials dataplaneProxyCredentials) { super(connectorConfig); this.dataplaneProxyCredentials = dataplaneProxyCredentials; } @Override Optional<String> getCaCertificates(ConnectorConfig.Ssl sslConfig) { - String proxyCert; try { - proxyCert = Files.readString(dataplaneProxyCredentials.certificateFile(), StandardCharsets.UTF_8); + return Optional.of(Files.readString(dataplaneProxyCredentials.certificateFile(), StandardCharsets.UTF_8)); } catch (IOException e) { throw new IllegalArgumentException("Dataplane proxy certificate not available", e); } - if (!sslConfig.caCertificate().isBlank()) { - return Optional.of(sslConfig.caCertificate() + "\n" + proxyCert); - } else if (!sslConfig.caCertificateFile().isBlank()) { - return Optional.of(readToString(sslConfig.caCertificateFile()) + "\n" + proxyCert); - } else { - return Optional.of(proxyCert); - } } } diff --git a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def index 3c01012fd9e..5a2bad63682 100644 --- a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def +++ b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.connector.def @@ -138,3 +138,9 @@ serverName.fallback string default="" # The list of accepted server names. Empty list to accept any. Elements follows format of 'serverName.default'. serverName.allowed[] string + +# HTTP request headers that contain remote address +accessLog.remoteAddressHeaders[] string + +# HTTP request headers that contain remote port +accessLog.remotePortHeaders[] string diff --git a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def index c15cb6b2cc4..a85641f61e9 100644 --- a/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def +++ b/container-core/src/main/resources/configdefinitions/jdisc.http.jdisc.http.server.def @@ -52,11 +52,5 @@ metric.searchHandlerPaths[] string # User-agent names to ignore wrt statistics (crawlers etc) metric.ignoredUserAgents[] string -# HTTP request headers that contain remote address -accessLog.remoteAddressHeaders[] string - -# HTTP request headers that contain remote port -accessLog.remotePortHeaders[] string - # Whether to enable jdisc connection log connectionLog.enabled bool default=false diff --git a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLogTest.java b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLogTest.java index 766c7918882..122db0f765d 100644 --- a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLogTest.java +++ b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/AccessLogRequestLogTest.java @@ -4,7 +4,6 @@ package com.yahoo.jdisc.http.server.jetty; import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.container.logging.RequestLog; import com.yahoo.container.logging.RequestLogEntry; -import com.yahoo.jdisc.http.ServerConfig; import org.eclipse.jetty.server.Request; import org.eclipse.jetty.server.Response; import org.junit.jupiter.api.Test; @@ -117,11 +116,7 @@ public class AccessLogRequestLogTest { } private void doAccessLoggingOfRequest(RequestLog requestLog, Request jettyRequest) { - ServerConfig.AccessLog config = new ServerConfig.AccessLog( - new ServerConfig.AccessLog.Builder() - .remoteAddressHeaders(List.of("x-forwarded-for", "y-ra")) - .remotePortHeaders(List.of("X-Forwarded-Port", "y-rp"))); - new AccessLogRequestLog(requestLog, config).log(jettyRequest, createResponseMock()); + new AccessLogRequestLog(requestLog).log(jettyRequest, createResponseMock()); } private static JettyMockRequestBuilder createRequestBuilder() { diff --git a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/JettyMockRequestBuilder.java b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/JettyMockRequestBuilder.java index e62825fc2a8..8b13f30bcd7 100644 --- a/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/JettyMockRequestBuilder.java +++ b/container-core/src/test/java/com/yahoo/jdisc/http/server/jetty/JettyMockRequestBuilder.java @@ -85,7 +85,11 @@ public class JettyMockRequestBuilder { HttpChannel channel = mock(HttpChannel.class); HttpConnection connection = mock(HttpConnection.class); JDiscServerConnector connector = mock(JDiscServerConnector.class); - when(connector.connectorConfig()).thenReturn(new ConnectorConfig(new ConnectorConfig.Builder().listenPort(localPort))); + when(connector.connectorConfig()).thenReturn(new ConnectorConfig( + new ConnectorConfig.Builder().listenPort(localPort) + .accessLog(new ConnectorConfig.AccessLog.Builder() + .remoteAddressHeaders(List.of("x-forwarded-for", "y-ra")) + .remotePortHeaders(List.of("X-Forwarded-Port", "y-rp"))))); when(connector.getLocalPort()).thenReturn(localPort); when(connection.getCreatedTimeStamp()).thenReturn(System.currentTimeMillis()); when(connection.getConnector()).thenReturn(connector); diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java index 47050168b80..74e6954e1e1 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/DataplaneProxyService.java @@ -103,7 +103,8 @@ public class DataplaneProxyService extends AbstractComponent { proxyCredentialsKey, serverCertificateFile, serverKeyFile, - config.port(), + config.mtlsPort(), + config.tokenPort(), root )); if (configChanged && state == NginxState.RUNNING) { @@ -191,7 +192,8 @@ public class DataplaneProxyService extends AbstractComponent { Path clientKey, Path serverCert, Path serverKey, - int vespaPort, + int vespaMtlsPort, + int vespaTokenPort, Path root) { try { @@ -200,7 +202,8 @@ public class DataplaneProxyService extends AbstractComponent { nginxTemplate = replace(nginxTemplate, "client_key", clientKey.toString()); nginxTemplate = replace(nginxTemplate, "server_cert", serverCert.toString()); nginxTemplate = replace(nginxTemplate, "server_key", serverKey.toString()); - nginxTemplate = replace(nginxTemplate, "vespa_port", Integer.toString(vespaPort)); + nginxTemplate = replace(nginxTemplate, "vespa_mtls_port", Integer.toString(vespaMtlsPort)); + nginxTemplate = replace(nginxTemplate, "vespa_token_port", Integer.toString(vespaTokenPort)); nginxTemplate = replace(nginxTemplate, "prefix", root.toString()); // TODO: verify that all template vars have been expanded diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java index 351890e2a3a..893a527e631 100644 --- a/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/DataplaneProxyServiceTest.java @@ -168,7 +168,8 @@ public class DataplaneProxyServiceTest { private DataplaneProxyConfig proxyConfig() { X509CertificateWithKey selfSigned = X509CertificateUtils.createSelfSigned("cn=test", Duration.ofMinutes(10)); return new DataplaneProxyConfig.Builder() - .port(1234) + .mtlsPort(1234) + .tokenPort(1235) .serverCertificate(X509CertificateUtils.toPem(selfSigned.certificate())) .serverKey(KeyUtils.toPem(selfSigned.privateKey())) .build(); diff --git a/document/src/tests/serialization/vespadocumentserializer_test.cpp b/document/src/tests/serialization/vespadocumentserializer_test.cpp index e91e38e0fe4..1839005d720 100644 --- a/document/src/tests/serialization/vespadocumentserializer_test.cpp +++ b/document/src/tests/serialization/vespadocumentserializer_test.cpp @@ -46,6 +46,7 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/document/base/exceptions.h> #include <vespa/vespalib/util/compressionconfig.h> +#include <filesystem> using vespalib::File; using vespalib::Slime; @@ -706,8 +707,8 @@ void checkDeserialization(const string &name, std::unique_ptr<Slime> slime) { PredicateFieldValue value(std::move(slime)); serializeToFile(value, data_dir + name + "__cpp.new"); - vespalib::rename(data_dir + name + "__cpp.new", - data_dir + name + "__cpp"); + std::filesystem::rename(std::filesystem::path(data_dir + name + "__cpp.new"), + std::filesystem::path(data_dir + name + "__cpp")); deserializeAndCheck(data_dir + name + "__cpp", value); deserializeAndCheck(data_dir + name + "__java", value); @@ -836,8 +837,8 @@ void checkDeserialization(const string &name, std::unique_ptr<vespalib::eval::Va value = std::move(tensor); } serializeToFile(value, data_dir + name + "__cpp.new"); - vespalib::rename(data_dir + name + "__cpp.new", - data_dir + name + "__cpp"); + std::filesystem::rename(std::filesystem::path(data_dir + name + "__cpp.new"), + std::filesystem::path(data_dir + name + "__cpp")); deserializeAndCheck(data_dir + name + "__cpp", value); deserializeAndCheck(data_dir + name + "__java", value); @@ -965,8 +966,8 @@ struct RefFixture { const string field_name = "ref_field"; serializeToFile(value, data_dir + file_base_name + "__cpp.new", fixed_repo.getDocumentTypeRepo(), ref_doc_type, field_name); - vespalib::rename(data_dir + file_base_name + "__cpp.new", - data_dir + file_base_name + "__cpp"); + std::filesystem::rename(std::filesystem::path(data_dir + file_base_name + "__cpp.new"), + std::filesystem::path(data_dir + file_base_name + "__cpp")); deserializeAndCheck(data_dir + file_base_name + "__cpp", value, fixed_repo, field_name); diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/ClientPrincipal.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/ClientPrincipal.java new file mode 100644 index 00000000000..bfb9bb920db --- /dev/null +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/ClientPrincipal.java @@ -0,0 +1,30 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.jdisc.http.filter.security.cloud; + +import com.yahoo.jdisc.http.filter.DiscFilterRequest; + +import java.security.Principal; +import java.util.Set; +import java.util.logging.Logger; + +/** + * @author bjorncs + */ +record ClientPrincipal(Set<String> ids, Set<Permission> permissions) implements Principal { + + private static final Logger log = Logger.getLogger(ClientPrincipal.class.getName()); + + ClientPrincipal { ids = Set.copyOf(ids); permissions = Set.copyOf(permissions); } + @Override public String getName() { + return "ids=%s,permissions=%s".formatted(ids, permissions.stream().map(Permission::asString).toList()); + } + + static void attachToRequest(DiscFilterRequest req, Set<String> ids, Set<Permission> permissions) { + var p = new ClientPrincipal(ids, permissions); + req.setUserPrincipal(p); + log.fine(() -> "Client with ids=%s, permissions=%s" + .formatted(ids, permissions.stream().map(Permission::asString).toList())); + } +} + diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java index 2dc80fc9d2b..379973cd8cf 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilter.java @@ -2,41 +2,25 @@ package com.yahoo.jdisc.http.filter.security.cloud; import com.yahoo.component.annotation.Inject; -import com.yahoo.component.provider.ComponentRegistry; -import com.yahoo.container.jdisc.AclMapping; -import com.yahoo.container.jdisc.RequestHandlerSpec; -import com.yahoo.container.jdisc.RequestView; -import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.jdisc.Response; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.security.base.JsonSecurityRequestFilterBase; import com.yahoo.jdisc.http.filter.security.cloud.config.CloudDataPlaneFilterConfig; -import com.yahoo.jdisc.http.server.jetty.DataplaneProxyCredentials; import com.yahoo.security.X509CertificateUtils; -import com.yahoo.security.token.Token; -import com.yahoo.security.token.TokenCheckHash; -import com.yahoo.security.token.TokenDomain; -import com.yahoo.security.token.TokenFingerprint; -import java.security.Principal; import java.security.cert.X509Certificate; -import java.time.Clock; -import java.time.Instant; import java.util.ArrayList; import java.util.EnumSet; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.TreeSet; import java.util.logging.Logger; -import java.util.stream.Collectors; -import static com.yahoo.jdisc.http.filter.security.cloud.CloudDataPlaneFilter.Permission.READ; -import static com.yahoo.jdisc.http.filter.security.cloud.CloudDataPlaneFilter.Permission.WRITE; -import static com.yahoo.jdisc.http.server.jetty.AccessLoggingRequestHandler.CONTEXT_KEY_ACCESS_LOG_ENTRY; +import static com.yahoo.jdisc.http.filter.security.cloud.Permission.READ; +import static com.yahoo.jdisc.http.filter.security.cloud.Permission.WRITE; + /** * Data plane filter for Cloud @@ -50,91 +34,49 @@ import static com.yahoo.jdisc.http.server.jetty.AccessLoggingRequestHandler.CONT public class CloudDataPlaneFilter extends JsonSecurityRequestFilterBase { private static final Logger log = Logger.getLogger(CloudDataPlaneFilter.class.getName()); - static final int CHECK_HASH_BYTES = 32; private final boolean legacyMode; private final List<Client> allowedClients; - private final TokenDomain tokenDomain; - private final Clock clock; @Inject - public CloudDataPlaneFilter(CloudDataPlaneFilterConfig cfg, - ComponentRegistry<DataplaneProxyCredentials> optionalReverseProxy) { - this(cfg, reverseProxyCert(optionalReverseProxy).orElse(null), Clock.systemUTC()); - } - - CloudDataPlaneFilter(CloudDataPlaneFilterConfig cfg, X509Certificate reverseProxyCert, Clock clock) { + public CloudDataPlaneFilter(CloudDataPlaneFilterConfig cfg) { this.legacyMode = cfg.legacyMode(); - this.tokenDomain = TokenDomain.of(cfg.tokenContext()); - this.clock = clock; if (legacyMode) { allowedClients = List.of(); log.fine(() -> "Legacy mode enabled"); } else { - allowedClients = parseClients(cfg, reverseProxyCert, clock); + allowedClients = parseClients(cfg); } } - private static Optional<X509Certificate> reverseProxyCert( - ComponentRegistry<DataplaneProxyCredentials> optionalReverseProxy) { - return optionalReverseProxy.allComponents().stream().findAny().map(DataplaneProxyCredentials::certificate); - } - - private static List<Client> parseClients(CloudDataPlaneFilterConfig cfg, X509Certificate reverseProxyCert, Clock clock) { - var now = clock.instant(); + private static List<Client> parseClients(CloudDataPlaneFilterConfig cfg) { Set<String> ids = new HashSet<>(); List<Client> clients = new ArrayList<>(cfg.clients().size()); - boolean hasClientRequiringCertificate = false; if (cfg.clients().isEmpty()) throw new IllegalArgumentException("Empty clients configuration"); for (var c : cfg.clients()) { if (ids.contains(c.id())) throw new IllegalArgumentException("Clients definition has duplicate id '%s'".formatted(c.id())); - if (!c.certificates().isEmpty() && !c.tokens().isEmpty()) - throw new IllegalArgumentException("Client '%s' has both certificate and token configured".formatted(c.id())); - if (c.certificates().isEmpty() && c.tokens().isEmpty()) - throw new IllegalArgumentException("Client '%s' has neither certificate nor token configured".formatted(c.id())); - if (!c.tokens().isEmpty() && reverseProxyCert == null) - throw new IllegalArgumentException( - "Client '%s' has token configured but reverse proxy certificate is missing".formatted(c.id())); + if (c.certificates().isEmpty()) + throw new IllegalArgumentException("Client '%s' has no certificate configured".formatted(c.id())); ids.add(c.id()); - EnumSet<Permission> permissions = c.permissions().stream().map(Permission::of) - .collect(Collectors.toCollection(() -> EnumSet.noneOf(Permission.class))); - if (!c.certificates().isEmpty()) { - List<X509Certificate> certs; - try { - certs = c.certificates().stream() - .flatMap(pem -> X509CertificateUtils.certificateListFromPem(pem).stream()).toList(); - } catch (Exception e) { - throw new IllegalArgumentException( - "Client '%s' contains invalid X.509 certificate PEM: %s".formatted(c.id(), e.toString()), e); - } - if (certs.isEmpty()) throw new IllegalArgumentException( - "Client '%s' certificate PEM contains no valid X.509 entries".formatted(c.id())); - clients.add(new Client(c.id(), permissions, certs, Map.of())); - hasClientRequiringCertificate = true; - } else { - var tokens = new HashMap<TokenCheckHash, TokenVersion>(); - for (var token : c.tokens()) { - for (int version = 0; version < token.checkAccessHashes().size(); version++) { - var tokenVersion = TokenVersion.of( - token.id(), token.fingerprints().get(version), token.checkAccessHashes().get(version), - token.expirations().get(version)); - tokens.put(tokenVersion.accessHash(), tokenVersion); - } - } - // Add reverse proxy certificate as required certificate for client definition - clients.add(new Client(c.id(), permissions, List.of(reverseProxyCert), tokens)); + List<X509Certificate> certs; + try { + certs = c.certificates().stream() + .flatMap(pem -> X509CertificateUtils.certificateListFromPem(pem).stream()).toList(); + } catch (Exception e) { + throw new IllegalArgumentException( + "Client '%s' contains invalid X.509 certificate PEM: %s".formatted(c.id(), e.toString()), e); } + if (certs.isEmpty()) throw new IllegalArgumentException( + "Client '%s' certificate PEM contains no valid X.509 entries".formatted(c.id())); + clients.add(new Client(c.id(), Permission.setOf(c.permissions()), certs)); } - if (!hasClientRequiringCertificate) - throw new IllegalArgumentException("At least one client must require a certificate"); log.fine(() -> "Configured clients with ids %s".formatted(ids)); return clients; } @Override protected Optional<ErrorResponse> filter(DiscFilterRequest req) { - var now = clock.instant(); var certs = req.getClientCertificateChain(); log.fine(() -> "Certificate chain contains %d elements".formatted(certs.size())); if (certs.isEmpty()) { @@ -143,109 +85,28 @@ public class CloudDataPlaneFilter extends JsonSecurityRequestFilterBase { } if (legacyMode) { log.fine("Legacy mode validation complete"); - req.setUserPrincipal(new ClientPrincipal(Set.of(), Set.of(READ, WRITE))); + ClientPrincipal.attachToRequest(req, Set.of(), Set.of(READ, WRITE)); return Optional.empty(); } - RequestView view = req.asRequestView(); - var permission = Optional.ofNullable((RequestHandlerSpec) req.getAttribute(RequestHandlerSpec.ATTRIBUTE_NAME)) - .or(() -> Optional.of(RequestHandlerSpec.DEFAULT_INSTANCE)) - .flatMap(spec -> { - var action = spec.aclMapping().get(view); - var maybePermission = Permission.of(action); - if (maybePermission.isEmpty()) log.fine(() -> "Unknown action '%s'".formatted(action)); - return maybePermission; - }).orElse(null); - if (permission == null) { - log.fine(() -> "No valid permission mapping defined for %s @ '%s'".formatted(view.method(), view.uri())); - return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); - } + var permission = Permission.getRequiredPermission(req).orElse(null); + if (permission == null) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); var clientCert = certs.get(0); - var requestTokenHash = requestTokenHash(req).orElse(null); var clientIds = new TreeSet<String>(); var permissions = new TreeSet<Permission>(); - var matchedTokens = new HashSet<TokenVersion>(); for (Client c : allowedClients) { if (!c.permissions().contains(permission)) continue; if (!c.certificates().contains(clientCert)) continue; - if (!c.tokens().isEmpty()) { - if (requestTokenHash == null) continue; - var matchedToken = c.tokens().get(requestTokenHash); - if (matchedToken == null) continue; - var expiration = matchedToken.expiration().orElse(null); - if (expiration != null && now.isAfter(expiration)) continue; - matchedTokens.add(matchedToken); - } clientIds.add(c.id()); permissions.addAll(c.permissions()); } - if (matchedTokens.size() > 1) { - log.warning("Multiple tokens matched for request %s" - .formatted(matchedTokens.stream().map(TokenVersion::id).toList())); - return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); - } - var matchedToken = matchedTokens.stream().findAny().orElse(null); - if (matchedToken != null) { - addAccessLogEntry(req, "token.id", matchedToken.id()); - addAccessLogEntry(req, "token.hash", matchedToken.fingerprint().toDelimitedHexString()); - addAccessLogEntry(req, "token.exp", matchedToken.expiration().map(Instant::toString).orElse("<none>")); - } - log.fine(() -> "Client with ids=%s, permissions=%s" - .formatted(clientIds, permissions.stream().map(Permission::asString).toList())); if (clientIds.isEmpty()) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); - req.setUserPrincipal(new ClientPrincipal(clientIds, permissions)); + ClientPrincipal.attachToRequest(req, clientIds, permissions); return Optional.empty(); } - private Optional<TokenCheckHash> requestTokenHash(DiscFilterRequest req) { - return Optional.ofNullable(req.getHeader("Authorization")) - .filter(h -> h.startsWith("Bearer ")) - .map(t -> t.substring("Bearer ".length()).trim()) - .map(t -> TokenCheckHash.of(Token.of(tokenDomain, t), CHECK_HASH_BYTES)); - } - - private static void addAccessLogEntry(DiscFilterRequest req, String key, String value) { - ((AccessLogEntry) req.getAttribute(CONTEXT_KEY_ACCESS_LOG_ENTRY)).addKeyValue(key, value); - } - - public record ClientPrincipal(Set<String> ids, Set<Permission> permissions) implements Principal { - public ClientPrincipal { ids = Set.copyOf(ids); permissions = Set.copyOf(permissions); } - @Override public String getName() { - return "ids=%s,permissions=%s".formatted(ids, permissions.stream().map(Permission::asString).toList()); - } - } - - enum Permission { READ, WRITE; - String asString() { - return switch (this) { - case READ -> "read"; - case WRITE -> "write"; - }; - } - static Permission of(String v) { - return switch (v) { - case "read" -> READ; - case "write" -> WRITE; - default -> throw new IllegalArgumentException("Invalid permission '%s'".formatted(v)); - }; - } - static Optional<Permission> of(AclMapping.Action a) { - if (a.equals(AclMapping.Action.READ)) return Optional.of(READ); - if (a.equals(AclMapping.Action.WRITE)) return Optional.of(WRITE); - return Optional.empty(); - } - } - - private record TokenVersion(String id, TokenFingerprint fingerprint, TokenCheckHash accessHash, Optional<Instant> expiration) { - static TokenVersion of(String id, String fingerprint, String accessHash, String expiration) { - return new TokenVersion(id, TokenFingerprint.ofHex(fingerprint), TokenCheckHash.ofHex(accessHash), - expiration.equals("<none>") ? Optional.empty() : Optional.of(Instant.parse(expiration))); - } - } - - private record Client(String id, EnumSet<Permission> permissions, List<X509Certificate> certificates, - Map<TokenCheckHash, TokenVersion> tokens) { + private record Client(String id, EnumSet<Permission> permissions, List<X509Certificate> certificates) { Client { - permissions = EnumSet.copyOf(permissions); certificates = List.copyOf(certificates); tokens = Map.copyOf(tokens); + permissions = EnumSet.copyOf(permissions); certificates = List.copyOf(certificates); } } } diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java new file mode 100644 index 00000000000..6597f10198d --- /dev/null +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java @@ -0,0 +1,146 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter.security.cloud; + +import com.yahoo.component.annotation.Inject; +import com.yahoo.container.logging.AccessLogEntry; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.http.filter.DiscFilterRequest; +import com.yahoo.jdisc.http.filter.security.base.JsonSecurityRequestFilterBase; +import com.yahoo.jdisc.http.filter.security.cloud.config.CloudTokenDataPlaneFilterConfig; +import com.yahoo.security.token.Token; +import com.yahoo.security.token.TokenCheckHash; +import com.yahoo.security.token.TokenDomain; +import com.yahoo.security.token.TokenFingerprint; + +import java.time.Clock; +import java.time.Instant; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; +import java.util.logging.Logger; + +import static com.yahoo.jdisc.http.server.jetty.AccessLoggingRequestHandler.CONTEXT_KEY_ACCESS_LOG_ENTRY; + +/** + * Token data plane filter for Cloud + * + * @author bjorncs + */ +public class CloudTokenDataPlaneFilter extends JsonSecurityRequestFilterBase { + + private static final Logger log = Logger.getLogger(CloudTokenDataPlaneFilter.class.getName()); + static final int CHECK_HASH_BYTES = 32; + + private final List<Client> allowedClients; + private final TokenDomain tokenDomain; + private final Clock clock; + + @Inject + public CloudTokenDataPlaneFilter(CloudTokenDataPlaneFilterConfig cfg) { + this(cfg, Clock.systemUTC()); + } + + CloudTokenDataPlaneFilter(CloudTokenDataPlaneFilterConfig cfg, Clock clock) { + this.tokenDomain = TokenDomain.of(cfg.tokenContext()); + this.clock = clock; + this.allowedClients = parseClients(cfg); + } + + private static List<Client> parseClients(CloudTokenDataPlaneFilterConfig cfg) { + Set<String> ids = new HashSet<>(); + List<Client> clients = new ArrayList<>(cfg.clients().size()); + if (cfg.clients().isEmpty()) throw new IllegalArgumentException("Empty clients configuration"); + for (var c : cfg.clients()) { + if (ids.contains(c.id())) + throw new IllegalArgumentException("Clients definition has duplicate id '%s'".formatted(c.id())); + if (c.tokens().isEmpty()) + throw new IllegalArgumentException("Client '%s' has no tokens configured".formatted(c.id())); + ids.add(c.id()); + var tokens = new HashMap<TokenCheckHash, TokenVersion>(); + for (var token : c.tokens()) { + for (int version = 0; version < token.checkAccessHashes().size(); version++) { + var tokenVersion = TokenVersion.of( + token.id(), token.fingerprints().get(version), token.checkAccessHashes().get(version), + token.expirations().get(version)); + tokens.put(tokenVersion.accessHash(), tokenVersion); + } + } + clients.add(new Client(c.id(), Permission.setOf(c.permissions()), tokens)); + } + log.fine(() -> "Configured clients with ids %s".formatted(ids)); + return List.copyOf(clients); + } + + @Override + protected Optional<ErrorResponse> filter(DiscFilterRequest req) { + var now = clock.instant(); + var bearerToken = requestBearerToken(req).orElse(null); + if (bearerToken == null) { + log.fine("Missing bearer token"); + return Optional.of(new ErrorResponse(Response.Status.UNAUTHORIZED, "Unauthorized")); + } + var permission = Permission.getRequiredPermission(req).orElse(null); + if (permission == null) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); + var requestTokenHash = requestTokenHash(bearerToken); + var clientIds = new TreeSet<String>(); + var permissions = EnumSet.noneOf(Permission.class); + var matchedTokens = new HashSet<TokenVersion>(); + for (Client c : allowedClients) { + if (!c.permissions().contains(permission)) continue; + var matchedToken = c.tokens().get(requestTokenHash); + if (matchedToken == null) continue; + var expiration = matchedToken.expiration().orElse(null); + if (expiration != null && now.isAfter(expiration)) continue; + matchedTokens.add(matchedToken); + clientIds.add(c.id()); + permissions.addAll(c.permissions()); + } + if (clientIds.isEmpty()) return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); + if (matchedTokens.size() > 1) { + log.warning("Multiple tokens matched for request %s" + .formatted(matchedTokens.stream().map(TokenVersion::id).toList())); + return Optional.of(new ErrorResponse(Response.Status.FORBIDDEN, "Forbidden")); + } + var matchedToken = matchedTokens.stream().findAny().get(); + addAccessLogEntry(req, "token.id", matchedToken.id()); + addAccessLogEntry(req, "token.hash", matchedToken.fingerprint().toDelimitedHexString()); + addAccessLogEntry(req, "token.exp", matchedToken.expiration().map(Instant::toString).orElse("<none>")); + ClientPrincipal.attachToRequest(req, clientIds, permissions); + return Optional.empty(); + } + + private TokenCheckHash requestTokenHash(String bearerToken) { + return TokenCheckHash.of(Token.of(tokenDomain, bearerToken), CHECK_HASH_BYTES); + } + + private static Optional<String> requestBearerToken(DiscFilterRequest req) { + return Optional.ofNullable(req.getHeader("Authorization")) + .filter(h -> h.startsWith("Bearer ")) + .map(t -> t.substring("Bearer ".length()).trim()) + .filter(t -> !t.isBlank()); + + } + + private static void addAccessLogEntry(DiscFilterRequest req, String key, String value) { + ((AccessLogEntry) req.getAttribute(CONTEXT_KEY_ACCESS_LOG_ENTRY)).addKeyValue(key, value); + } + + private record TokenVersion(String id, TokenFingerprint fingerprint, TokenCheckHash accessHash, Optional<Instant> expiration) { + static TokenVersion of(String id, String fingerprint, String accessHash, String expiration) { + return new TokenVersion(id, TokenFingerprint.ofHex(fingerprint), TokenCheckHash.ofHex(accessHash), + expiration.equals("<none>") ? Optional.empty() : Optional.of(Instant.parse(expiration))); + } + } + + private record Client(String id, EnumSet<Permission> permissions, Map<TokenCheckHash, TokenVersion> tokens) { + Client { + permissions = EnumSet.copyOf(permissions); tokens = Map.copyOf(tokens); + } + } +} diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/Permission.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/Permission.java new file mode 100644 index 00000000000..4bab83f8576 --- /dev/null +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/Permission.java @@ -0,0 +1,63 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.jdisc.http.filter.security.cloud; + +import com.yahoo.container.jdisc.AclMapping; +import com.yahoo.container.jdisc.RequestHandlerSpec; +import com.yahoo.container.jdisc.RequestView; +import com.yahoo.jdisc.http.filter.DiscFilterRequest; + +import java.util.Collection; +import java.util.EnumSet; +import java.util.Optional; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +/** + * @author bjorncs + */ +enum Permission { + READ, WRITE; + + private static final Logger log = Logger.getLogger(Permission.class.getName()); + + String asString() { + return switch (this) { + case READ -> "read"; + case WRITE -> "write"; + }; + } + + static Permission of(String v) { + return switch (v) { + case "read" -> READ; + case "write" -> WRITE; + default -> throw new IllegalArgumentException("Invalid permission '%s'".formatted(v)); + }; + } + + static EnumSet<Permission> setOf(Collection<String> v) { + return v.stream().map(Permission::of).collect(Collectors.toCollection(() -> EnumSet.noneOf(Permission.class))); + } + + static Optional<Permission> getRequiredPermission(DiscFilterRequest req) { + RequestView view = req.asRequestView(); + var result = Optional.ofNullable((RequestHandlerSpec) req.getAttribute(RequestHandlerSpec.ATTRIBUTE_NAME)) + .or(() -> Optional.of(RequestHandlerSpec.DEFAULT_INSTANCE)) + .flatMap(spec -> { + var action = spec.aclMapping().get(view); + var maybePermission = Permission.of(action); + if (maybePermission.isEmpty()) log.fine(() -> "Unknown action '%s'".formatted(action)); + return maybePermission; + }); + if (result.isEmpty()) + log.fine(() -> "No valid permission mapping defined for %s @ '%s'".formatted(view.method(), view.uri())); + return result; + } + + static Optional<Permission> of(AclMapping.Action a) { + if (a.equals(AclMapping.Action.READ)) return Optional.of(READ); + if (a.equals(AclMapping.Action.WRITE)) return Optional.of(WRITE); + return Optional.empty(); + } +} diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilterTest.java index d9daf8b6f46..8d2fd1f569e 100644 --- a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilterTest.java +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudDataPlaneFilterTest.java @@ -5,35 +5,24 @@ import com.yahoo.container.jdisc.AclMapping.Action; import com.yahoo.container.jdisc.HttpMethodAclMapping; import com.yahoo.container.jdisc.RequestHandlerSpec; import com.yahoo.container.jdisc.RequestHandlerTestDriver.MockResponseHandler; -import com.yahoo.container.logging.AccessLogEntry; import com.yahoo.jdisc.http.HttpRequest.Method; -import com.yahoo.jdisc.http.filter.security.cloud.CloudDataPlaneFilter.ClientPrincipal; import com.yahoo.jdisc.http.filter.security.cloud.config.CloudDataPlaneFilterConfig; import com.yahoo.jdisc.http.filter.util.FilterTestUtils; import com.yahoo.security.KeyUtils; import com.yahoo.security.X509CertificateBuilder; import com.yahoo.security.X509CertificateUtils; -import com.yahoo.security.token.Token; -import com.yahoo.security.token.TokenCheckHash; -import com.yahoo.security.token.TokenDomain; -import com.yahoo.security.token.TokenGenerator; -import com.yahoo.test.ManualClock; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import javax.security.auth.x500.X500Principal; import java.math.BigInteger; import java.security.cert.X509Certificate; -import java.time.Duration; -import java.time.Instant; import java.util.List; import java.util.Set; import static com.yahoo.jdisc.Response.Status.FORBIDDEN; import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; -import static com.yahoo.jdisc.http.filter.security.cloud.CloudDataPlaneFilter.CHECK_HASH_BYTES; -import static com.yahoo.jdisc.http.filter.security.cloud.CloudDataPlaneFilter.Permission.READ; -import static com.yahoo.jdisc.http.filter.security.cloud.CloudDataPlaneFilter.Permission.WRITE; +import static com.yahoo.jdisc.http.filter.security.cloud.Permission.READ; +import static com.yahoo.jdisc.http.filter.security.cloud.Permission.WRITE; import static com.yahoo.security.KeyAlgorithm.EC; import static com.yahoo.security.SignatureAlgorithm.SHA256_WITH_ECDSA; import static java.time.Instant.EPOCH; @@ -50,21 +39,8 @@ class CloudDataPlaneFilterTest { private static final X509Certificate FEED_CERT = certificate("my-feed-client"); private static final X509Certificate SEARCH_CERT = certificate("my-search-client"); private static final X509Certificate LEGACY_CLIENT_CERT = certificate("my-legacy-client"); - private static final X509Certificate REVERSE_PROXY_CERT = certificate("nginx"); private static final String FEED_CLIENT_ID = "feed-client"; private static final String MTLS_SEARCH_CLIENT_ID = "mtls-search-client"; - private static final String TOKEN_SEARCH_CLIENT = "token-search-client"; - private static final String TOKEN_CONTEXT = "my-token-context"; - private static final String TOKEN_ID = "my-token-id"; - private static final Instant TOKEN_EXPIRATION = EPOCH.plus(Duration.ofDays(1)); - private static final Token VALID_TOKEN = - TokenGenerator.generateToken(TokenDomain.of(TOKEN_CONTEXT), "vespa_token_", CHECK_HASH_BYTES); - private static final Token UNKNOWN_TOKEN = - TokenGenerator.generateToken(TokenDomain.of(TOKEN_CONTEXT), "vespa_token_", CHECK_HASH_BYTES); - - private ManualClock clock; - - @BeforeEach void resetClock() { clock = new ManualClock(EPOCH); } @Test void accepts_any_trusted_client_certificate_in_legacy_mode() { @@ -144,137 +120,13 @@ class CloudDataPlaneFilterTest { assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); } - @Test - void accepts_reverse_proxy_with_token() { - var entry = new AccessLogEntry(); - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.GET) - .withAccessLogEntry(entry) - .withClientCertificate(REVERSE_PROXY_CERT) - .withHeader("Authorization", "Bearer " + VALID_TOKEN.secretTokenString()) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNull(responseHandler.getResponse()); - assertEquals(new ClientPrincipal(Set.of(TOKEN_SEARCH_CLIENT), Set.of(READ)), req.getUserPrincipal()); - assertEquals(TOKEN_ID, entry.getKeyValues().get("token.id").get(0)); - assertEquals(VALID_TOKEN.fingerprint().toDelimitedHexString(), entry.getKeyValues().get("token.hash").get(0)); - assertEquals(TOKEN_EXPIRATION.toString(), entry.getKeyValues().get("token.exp").get(0)); - } - - @Test - void fails_for_reverse_proxy_with_token_wrong_permission() { - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.POST) - .withClientCertificate(REVERSE_PROXY_CERT) - .withHeader("Authorization", "Bearer " + VALID_TOKEN.secretTokenString()) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNotNull(responseHandler.getResponse()); - assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); - } - - @Test - void fails_for_reverse_proxy_without_token() { - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.GET) - .withClientCertificate(REVERSE_PROXY_CERT) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNotNull(responseHandler.getResponse()); - assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); - } - - @Test - void fails_for_reverse_proxy_with_unknown_token() { - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.GET) - .withClientCertificate(REVERSE_PROXY_CERT) - .withHeader("Authorization", "Bearer " + UNKNOWN_TOKEN.secretTokenString()) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNotNull(responseHandler.getResponse()); - assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); - } - - @Test - void fails_for_missing_certificate_with_token() { - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.GET) - .withHeader("Authorization", "Bearer " + VALID_TOKEN.secretTokenString()) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNotNull(responseHandler.getResponse()); - assertEquals(UNAUTHORIZED, responseHandler.getResponse().getStatus()); - } - - @Test - void fails_for_unknown_certificate_with_token() { - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.GET) - .withClientCertificate(LEGACY_CLIENT_CERT) - .withHeader("Authorization", "Bearer " + VALID_TOKEN.secretTokenString()) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNotNull(responseHandler.getResponse()); - assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); - } - - @Test - void certificate_has_precedence_over_token() { - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.POST) - .withClientCertificate(FEED_CERT) - .withHeader("Authorization", "Bearer " + VALID_TOKEN.secretTokenString()) - .build(); - var responseHandler = new MockResponseHandler(); - newFilterWithClientsConfig().filter(req, responseHandler); - assertNull(responseHandler.getResponse()); - assertEquals(new ClientPrincipal(Set.of(FEED_CLIENT_ID), Set.of(WRITE)), req.getUserPrincipal()); - } - - @Test - void fails_for_expired_token() { - var entry = new AccessLogEntry(); - var req = FilterTestUtils.newRequestBuilder() - .withMethod(Method.GET) - .withAccessLogEntry(entry) - .withClientCertificate(REVERSE_PROXY_CERT) - .withHeader("Authorization", "Bearer " + VALID_TOKEN.secretTokenString()) - .build(); - var filter = newFilterWithClientsConfig(); - - var responseHandler = new MockResponseHandler(); - filter.filter(req, responseHandler); - assertNull(responseHandler.getResponse()); - - clock.advance(Duration.ofDays(1)); - responseHandler = new MockResponseHandler(); - filter.filter(req, responseHandler); - assertNull(responseHandler.getResponse()); - - clock.advance(Duration.ofMillis(1)); - responseHandler = new MockResponseHandler(); - filter.filter(req, responseHandler); - assertNotNull(responseHandler.getResponse()); - assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); - } - private CloudDataPlaneFilter newFilterWithLegacyMode() { - return new CloudDataPlaneFilter( - new CloudDataPlaneFilterConfig.Builder() - .legacyMode(true).build(), (X509Certificate) null, clock); + return new CloudDataPlaneFilter(new CloudDataPlaneFilterConfig.Builder().legacyMode(true).build()); } private CloudDataPlaneFilter newFilterWithClientsConfig() { return new CloudDataPlaneFilter( new CloudDataPlaneFilterConfig.Builder() - .tokenContext(TOKEN_CONTEXT) .clients(List.of( new CloudDataPlaneFilterConfig.Clients.Builder() .certificates(X509CertificateUtils.toPem(FEED_CERT)) @@ -283,18 +135,8 @@ class CloudDataPlaneFilterTest { new CloudDataPlaneFilterConfig.Clients.Builder() .certificates(X509CertificateUtils.toPem(SEARCH_CERT)) .permissions(READ.asString()) - .id(MTLS_SEARCH_CLIENT_ID), - new CloudDataPlaneFilterConfig.Clients.Builder() - .tokens(new CloudDataPlaneFilterConfig.Clients.Tokens.Builder() - .id(TOKEN_ID) - .checkAccessHashes(TokenCheckHash.of(VALID_TOKEN, 32).toHexString()) - .fingerprints(VALID_TOKEN.fingerprint().toDelimitedHexString()) - .expirations(TOKEN_EXPIRATION.toString())) - .permissions(READ.asString()) - .id(TOKEN_SEARCH_CLIENT))) - .build(), - REVERSE_PROXY_CERT, - clock); + .id(MTLS_SEARCH_CLIENT_ID))) + .build()); } private static X509Certificate certificate(String name) { diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java new file mode 100644 index 00000000000..a34d2eb67c3 --- /dev/null +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java @@ -0,0 +1,194 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter.security.cloud; + +import com.yahoo.container.jdisc.AclMapping.Action; +import com.yahoo.container.jdisc.HttpMethodAclMapping; +import com.yahoo.container.jdisc.RequestHandlerSpec; +import com.yahoo.container.jdisc.RequestHandlerTestDriver.MockResponseHandler; +import com.yahoo.container.logging.AccessLogEntry; +import com.yahoo.jdisc.http.HttpRequest.Method; +import com.yahoo.jdisc.http.filter.security.cloud.config.CloudTokenDataPlaneFilterConfig; +import com.yahoo.jdisc.http.filter.util.FilterTestUtils; +import com.yahoo.security.token.Token; +import com.yahoo.security.token.TokenCheckHash; +import com.yahoo.security.token.TokenDomain; +import com.yahoo.security.token.TokenGenerator; +import com.yahoo.test.ManualClock; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Set; + +import static com.yahoo.jdisc.Response.Status.FORBIDDEN; +import static com.yahoo.jdisc.Response.Status.UNAUTHORIZED; +import static com.yahoo.jdisc.http.filter.security.cloud.CloudTokenDataPlaneFilter.CHECK_HASH_BYTES; +import static com.yahoo.jdisc.http.filter.security.cloud.Permission.READ; +import static com.yahoo.jdisc.http.filter.security.cloud.Permission.WRITE; +import static java.time.Instant.EPOCH; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +/** + * @author bjorncs + */ +class CloudTokenDataPlaneFilterTest { + + private static final String TOKEN_SEARCH_CLIENT = "token-search-client"; + private static final String TOKEN_FEED_CLIENT = "token-feed-client"; + private static final String TOKEN_CONTEXT = "my-token-context"; + private static final String READ_TOKEN_ID = "my-read-token-id"; + private static final String WRITE_TOKEN_ID = "my-write-token-id"; + private static final Instant TOKEN_EXPIRATION = EPOCH.plus(Duration.ofDays(1)); + private static final Token READ_TOKEN = + TokenGenerator.generateToken(TokenDomain.of(TOKEN_CONTEXT), "vespa_token_", CHECK_HASH_BYTES); + private static final Token WRITE_TOKEN = + TokenGenerator.generateToken(TokenDomain.of(TOKEN_CONTEXT), "vespa_token_", CHECK_HASH_BYTES); + private static final Token UNKNOWN_TOKEN = + TokenGenerator.generateToken(TokenDomain.of(TOKEN_CONTEXT), "vespa_token_", CHECK_HASH_BYTES); + private ManualClock clock; + + @BeforeEach void resetClock() { clock = new ManualClock(EPOCH); } + + @Test + void supports_handler_with_custom_request_spec() { + // Spec that maps POST as action 'read' + var spec = RequestHandlerSpec.builder() + .withAclMapping(HttpMethodAclMapping.standard() + .override(Method.POST, Action.READ).build()) + .build(); + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.POST) + .withHeader("Authorization", "Bearer " + READ_TOKEN.secretTokenString()) + .withAttribute(RequestHandlerSpec.ATTRIBUTE_NAME, spec) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithClientsConfig().filter(req, responseHandler); + assertNull(responseHandler.getResponse()); + assertEquals(new ClientPrincipal(Set.of(TOKEN_SEARCH_CLIENT), Set.of(READ)), req.getUserPrincipal()); + } + + @Test + void fails_on_handler_with_custom_request_spec_with_invalid_action() { + var spec = RequestHandlerSpec.builder() + .withAclMapping(HttpMethodAclMapping.standard() + .override(Method.GET, Action.custom("custom")).build()) + .build(); + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .withHeader("Authorization", "Bearer " + READ_TOKEN.secretTokenString()) + .withAttribute(RequestHandlerSpec.ATTRIBUTE_NAME, spec) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithClientsConfig().filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); + } + + @Test + void accepts_valid_token() { + var entry = new AccessLogEntry(); + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .withAccessLogEntry(entry) + .withHeader("Authorization", "Bearer " + READ_TOKEN.secretTokenString()) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithClientsConfig().filter(req, responseHandler); + assertNull(responseHandler.getResponse()); + assertEquals(new ClientPrincipal(Set.of(TOKEN_SEARCH_CLIENT), Set.of(READ)), req.getUserPrincipal()); + assertEquals(READ_TOKEN_ID, entry.getKeyValues().get("token.id").get(0)); + assertEquals(READ_TOKEN.fingerprint().toDelimitedHexString(), entry.getKeyValues().get("token.hash").get(0)); + assertEquals(TOKEN_EXPIRATION.toString(), entry.getKeyValues().get("token.exp").get(0)); + } + + @Test + void fails_for_token_with_invalid_permission() { + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .withHeader("Authorization", "Bearer " + WRITE_TOKEN.secretTokenString()) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithClientsConfig().filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); + } + + @Test + void fails_for_missing_token() { + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithClientsConfig().filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(UNAUTHORIZED, responseHandler.getResponse().getStatus()); + } + + @Test + void fails_for_unknown_token() { + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .withHeader("Authorization", "Bearer " + UNKNOWN_TOKEN.secretTokenString()) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithClientsConfig().filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); + } + + @Test + void fails_for_expired_token() { + var entry = new AccessLogEntry(); + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .withAccessLogEntry(entry) + .withHeader("Authorization", "Bearer " + READ_TOKEN.secretTokenString()) + .build(); + var filter = newFilterWithClientsConfig(); + + var responseHandler = new MockResponseHandler(); + filter.filter(req, responseHandler); + assertNull(responseHandler.getResponse()); + + clock.advance(Duration.ofDays(1)); + responseHandler = new MockResponseHandler(); + filter.filter(req, responseHandler); + assertNull(responseHandler.getResponse()); + + clock.advance(Duration.ofMillis(1)); + responseHandler = new MockResponseHandler(); + filter.filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); + } + + private CloudTokenDataPlaneFilter newFilterWithClientsConfig() { + return new CloudTokenDataPlaneFilter( + new CloudTokenDataPlaneFilterConfig.Builder() + .tokenContext(TOKEN_CONTEXT) + .clients(List.of( + new CloudTokenDataPlaneFilterConfig.Clients.Builder() + .tokens(new CloudTokenDataPlaneFilterConfig.Clients.Tokens.Builder() + .id(READ_TOKEN_ID) + .checkAccessHashes(TokenCheckHash.of(READ_TOKEN, 32).toHexString()) + .fingerprints(READ_TOKEN.fingerprint().toDelimitedHexString()) + .expirations(TOKEN_EXPIRATION.toString())) + .permissions(READ.asString()) + .id(TOKEN_SEARCH_CLIENT), + new CloudTokenDataPlaneFilterConfig.Clients.Builder() + .tokens(new CloudTokenDataPlaneFilterConfig.Clients.Tokens.Builder() + .id(WRITE_TOKEN_ID) + .checkAccessHashes(TokenCheckHash.of(WRITE_TOKEN, 32).toHexString()) + .fingerprints(WRITE_TOKEN.fingerprint().toDelimitedHexString()) + .expirations(TOKEN_EXPIRATION.toString())) + .permissions(WRITE.asString()) + .id(TOKEN_FEED_CLIENT))) + .build(), + clock); + } + +} diff --git a/screwdriver.yaml b/screwdriver.yaml index 19374a436d5..849f4a01328 100644 --- a/screwdriver.yaml +++ b/screwdriver.yaml @@ -426,10 +426,8 @@ jobs: fi publish-cli-release: + requires: [publish-release] image: homebrew/brew:latest - annotations: - # Run once an hour, in the hours 7-15 UTC, Monday-Thursday - screwdriver.cd/buildPeriodically: H 7-15 * * 1-4 secrets: - HOMEBREW_GITHUB_API_TOKEN - GH_TOKEN diff --git a/searchcore/src/tests/proton/documentdb/documentdb_test.cpp b/searchcore/src/tests/proton/documentdb/documentdb_test.cpp index 47cbde152ef..85e610c092a 100644 --- a/searchcore/src/tests/proton/documentdb/documentdb_test.cpp +++ b/searchcore/src/tests/proton/documentdb/documentdb_test.cpp @@ -33,7 +33,6 @@ #include <vespa/searchlib/index/dummyfileheadercontext.h> #include <vespa/searchlib/transactionlog/translogserver.h> #include <vespa/vespalib/data/slime/slime.h> -#include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/util/size_literals.h> @@ -333,11 +332,7 @@ TEST("require that resume after interrupted save config works") std::cout << "Best config serial is " << best_config_snapshot.syncToken << std::endl; auto old_config_subdir = config_subdir(best_config_snapshot.syncToken); auto new_config_subdir = config_subdir(serialNum + 1); - std::filesystem::create_directories(std::filesystem::path(new_config_subdir)); - auto config_files = vespalib::listDirectory(old_config_subdir); - for (auto &config_file : config_files) { - vespalib::copy(old_config_subdir + "/" + config_file, new_config_subdir + "/" + config_file, false, false); - } + std::filesystem::copy(std::filesystem::path(old_config_subdir), std::filesystem::path(new_config_subdir)); info.addSnapshot({true, serialNum + 1, new_config_subdir.substr(new_config_subdir.rfind('/') + 1)}); info.save(); } diff --git a/searchcore/src/vespa/searchcore/proton/server/proton_disk_layout.cpp b/searchcore/src/vespa/searchcore/proton/server/proton_disk_layout.cpp index 57a3d21652b..aecb1eec262 100644 --- a/searchcore/src/vespa/searchcore/proton/server/proton_disk_layout.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/proton_disk_layout.cpp @@ -83,7 +83,9 @@ ProtonDiskLayout::remove(const DocTypeName &docTypeName) vespalib::string name(docTypeName.toString()); vespalib::string normalDir(documentsDir + "/" + name); vespalib::string removedDir(documentsDir + "/" + getRemovedName(name)); - vespalib::rename(normalDir, removedDir, false, false); + if (std::filesystem::exists(std::filesystem::path(normalDir))) { + std::filesystem::rename(std::filesystem::path(normalDir), std::filesystem::path(removedDir)); + } vespalib::File::sync(documentsDir); TransLogClient tlc(_transport, _tlsSpec); if (!tlc.remove(name)) { diff --git a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp index 8f2f8f2e96b..e468560f4ec 100644 --- a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp +++ b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp @@ -67,7 +67,7 @@ struct LeafProxy : SimpleLeafBlueprint { } LeafProxy(std::unique_ptr<Blueprint> child_in) : SimpleLeafBlueprint(), child(std::move(child_in)) { init(); } - LeafProxy(const FieldSpec &field, std::unique_ptr<Blueprint> child_in) + LeafProxy(FieldSpecBase field, std::unique_ptr<Blueprint> child_in) : SimpleLeafBlueprint(field), child(std::move(child_in)) { init(); } SearchIteratorUP createLeafSearch(const TermFieldMatchDataArray &, bool) const override { abort(); } SearchIteratorUP createFilterSearch(bool strict, Constraint constraint) const override { diff --git a/searchlib/src/tests/queryeval/weighted_set_term/weighted_set_term_test.cpp b/searchlib/src/tests/queryeval/weighted_set_term/weighted_set_term_test.cpp index f93aa537625..9d0dd05e3e3 100644 --- a/searchlib/src/tests/queryeval/weighted_set_term/weighted_set_term_test.cpp +++ b/searchlib/src/tests/queryeval/weighted_set_term/weighted_set_term_test.cpp @@ -3,7 +3,6 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/searchlib/queryeval/weighted_set_term_search.h> -#include <vespa/searchlib/fef/fef.h> #include <vespa/searchlib/query/tree/simplequery.h> #include <vespa/searchlib/queryeval/field_spec.h> #include <vespa/searchlib/queryeval/blueprint.h> @@ -282,7 +281,7 @@ TEST("verify search iterator conformance with document weight iterator children" struct VerifyMatchData { struct MyBlueprint : search::queryeval::SimpleLeafBlueprint { VerifyMatchData &vmd; - MyBlueprint(VerifyMatchData &vmd_in, const FieldSpec & spec_in) + MyBlueprint(VerifyMatchData &vmd_in, FieldSpecBase spec_in) : SimpleLeafBlueprint(spec_in), vmd(vmd_in) {} [[nodiscard]] SearchIterator::UP createLeafSearch(const fef::TermFieldMatchDataArray &tfmda, bool) const override { EXPECT_EQUAL(tfmda.size(), 1u); @@ -301,7 +300,7 @@ struct VerifyMatchData { }; size_t child_cnt = 0; TermFieldMatchData *child_tfmd = nullptr; - search::queryeval::Blueprint::UP create(const FieldSpec &spec) { + search::queryeval::Blueprint::UP create(FieldSpecBase spec) { return std::make_unique<MyBlueprint>(*this, spec); } }; diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index ba791444dea..6cb5dbf7889 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -74,6 +74,7 @@ using search::queryeval::ComplexLeafBlueprint; using search::queryeval::CreateBlueprintVisitorHelper; using search::queryeval::DotProductBlueprint; using search::queryeval::FieldSpec; +using search::queryeval::FieldSpecBase; using search::queryeval::FieldSpecBaseList; using search::queryeval::FilterWrapper; using search::queryeval::IRequestContext; @@ -129,27 +130,11 @@ private: Type _type; public: - AttributeFieldBlueprint(const FieldSpec &field, const IAttributeVector &attribute, - const string &query_stack, const SearchContextParams ¶ms) - : AttributeFieldBlueprint(field, attribute, QueryTermDecoder::decodeTerm(query_stack), params) - { } - AttributeFieldBlueprint(const FieldSpec &field, const IAttributeVector &attribute, - QueryTermSimple::UP term, const SearchContextParams ¶ms) - : SimpleLeafBlueprint(field), - _attr(attribute), - _query_term(term->getTermString()), - _search_context(attribute.createSearchContext(std::move(term), params)), - _type(OTHER) - { - uint32_t estHits = _search_context->approximateHits(); - HitEstimate estimate(estHits, estHits == 0); - setEstimate(estimate); - if (attribute.isFloatingPointType()) { - _type = FLOAT; - } else if (attribute.isIntegerType()) { - _type = INT; - } - } + AttributeFieldBlueprint(FieldSpecBase field, const IAttributeVector &attribute, + const string &query_stack, const SearchContextParams ¶ms); + AttributeFieldBlueprint(FieldSpecBase field, const IAttributeVector &attribute, + QueryTermSimple::UP term, const SearchContextParams ¶ms); + ~AttributeFieldBlueprint() override; SearchIteratorUP createLeafSearch(const TermFieldMatchDataArray &tfmda, bool strict) const override { assert(tfmda.size() == 1); @@ -181,6 +166,31 @@ public: bool getRange(vespalib::string &from, vespalib::string &to) const override; }; +AttributeFieldBlueprint::~AttributeFieldBlueprint() = default; + +AttributeFieldBlueprint::AttributeFieldBlueprint(FieldSpecBase field, const IAttributeVector &attribute, + const string &query_stack, const SearchContextParams ¶ms) + : AttributeFieldBlueprint(field, attribute, QueryTermDecoder::decodeTerm(query_stack), params) +{ } + +AttributeFieldBlueprint::AttributeFieldBlueprint(FieldSpecBase field, const IAttributeVector &attribute, + QueryTermSimple::UP term, const SearchContextParams ¶ms) + : SimpleLeafBlueprint(field), + _attr(attribute), + _query_term(term->getTermString()), + _search_context(attribute.createSearchContext(std::move(term), params)), + _type(OTHER) +{ + uint32_t estHits = _search_context->approximateHits(); + HitEstimate estimate(estHits, estHits == 0); + setEstimate(estimate); + if (attribute.isFloatingPointType()) { + _type = FLOAT; + } else if (attribute.isIntegerType()) { + _type = INT; + } +} + vespalib::string get_type(const IAttributeVector& attr) { @@ -866,7 +876,7 @@ CreateBlueprintVisitor::createShallowWeightedSet(WS *bp, MultiTerm &n, const Fie bp->reserve(n.getNumTerms()); Blueprint::HitEstimate estimate; for (uint32_t i(0); i < n.getNumTerms(); i++) { - FieldSpec childfs = bp->getNextChildField(fs); + FieldSpecBase childfs = bp->getNextChildField(fs); auto term = n.getAsString(i); bp->addTerm(std::make_unique<AttributeFieldBlueprint>(childfs, _attr, extractTerm(term.first, isInteger), scParams.useBitVector(childfs.isFilter())), term.second.percent(), estimate); } diff --git a/searchlib/src/vespa/searchlib/common/bitword.h b/searchlib/src/vespa/searchlib/common/bitword.h index cba6cf7723f..06071e4cae8 100644 --- a/searchlib/src/vespa/searchlib/common/bitword.h +++ b/searchlib/src/vespa/searchlib/common/bitword.h @@ -12,16 +12,16 @@ class BitWord { public: using Word = uint64_t; using Index = uint32_t; - static Word checkTab(Index index) { return _checkTab[bitNum(index)]; } - static Word startBits(Index index) { return (std::numeric_limits<Word>::max() >> 1) >> (WordLen - 1 - bitNum(index)); } + static Word checkTab(Index index) noexcept { return _checkTab[bitNum(index)]; } + static constexpr Word startBits(Index index) noexcept { return (std::numeric_limits<Word>::max() >> 1) >> (WordLen - 1 - bitNum(index)); } static constexpr size_t WordLen = sizeof(Word)*8; - static uint8_t bitNum(Index idx) { return (idx % WordLen); } - static Word endBits(Index index) { return (std::numeric_limits<Word>::max() - 1) << bitNum(index); } - static Word allBits() { return std::numeric_limits<Word>::max(); } - static Index wordNum(Index idx) { return idx >> numWordBits(); } - static Word mask(Index idx) { return Word(1) << bitNum(idx); } - static constexpr uint8_t size_bits(uint8_t n) { return (n > 1) ? (1 + size_bits(n >> 1)) : 0; } - static uint8_t numWordBits() { return size_bits(WordLen); } + static constexpr uint8_t bitNum(Index idx) noexcept { return (idx % WordLen); } + static constexpr Word endBits(Index index) noexcept { return (std::numeric_limits<Word>::max() - 1) << bitNum(index); } + static constexpr Word allBits() noexcept { return std::numeric_limits<Word>::max(); } + static constexpr Index wordNum(Index idx) noexcept { return idx >> numWordBits(); } + static constexpr Word mask(Index idx) noexcept { return Word(1) << bitNum(idx); } + static constexpr uint8_t size_bits(uint8_t n) noexcept { return (n > 1) ? (1 + size_bits(n >> 1)) : 0; } + static constexpr uint8_t numWordBits() noexcept { return size_bits(WordLen); } private: static Word _checkTab[WordLen]; diff --git a/searchlib/src/vespa/searchlib/common/matching_elements_fields.h b/searchlib/src/vespa/searchlib/common/matching_elements_fields.h index cae28276eef..e4d61f8dedc 100644 --- a/searchlib/src/vespa/searchlib/common/matching_elements_fields.h +++ b/searchlib/src/vespa/searchlib/common/matching_elements_fields.h @@ -28,7 +28,7 @@ public: _fields.insert(field_name); } void add_mapping(const vespalib::string &field_name, - const vespalib::string &struct_field_name) { + const vespalib::string &struct_field_name) { _fields.insert(field_name); _struct_fields[struct_field_name] = field_name; } diff --git a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp index bb44eaa0f3d..0719e4511be 100644 --- a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.cpp @@ -77,11 +77,12 @@ void CreateBlueprintVisitorHelper::createWeightedSet(std::unique_ptr<WS> bp, NODE &n) { bp->reserve(n.getNumTerms()); Blueprint::HitEstimate estimate; + FieldSpec childField(_field); for (size_t i = 0; i < n.getNumTerms(); ++i) { auto term = n.getAsString(i); query::SimpleStringTerm node(term.first, n.getView(), 0, term.second); // TODO Temporary - FieldSpec field = bp->getNextChildField(_field); - bp->addTerm(_searchable.createBlueprint(_requestContext, field, node), term.second.percent(), estimate); + childField.setBase(bp->getNextChildField(_field)); + bp->addTerm(_searchable.createBlueprint(_requestContext, childField, node), term.second.percent(), estimate); } bp->complete(estimate); setResult(std::move(bp)); diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp index 3e85ae4d00a..795f5f1424a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.cpp @@ -16,12 +16,6 @@ DotProductBlueprint::DotProductBlueprint(const FieldSpec &field) DotProductBlueprint::~DotProductBlueprint() = default; -FieldSpec -DotProductBlueprint::getNextChildField(const FieldSpec &outer) -{ - return FieldSpec(outer.getName(), outer.getFieldId(), _layout.allocTermField(outer.getFieldId()), false); -} - void DotProductBlueprint::reserve(size_t num_children) { _weights.reserve(num_children); diff --git a/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.h index 18770691350..2704d76d3db 100644 --- a/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/dot_product_blueprint.h @@ -22,7 +22,9 @@ public: ~DotProductBlueprint() override; // used by create visitor - FieldSpec getNextChildField(const FieldSpec &outer); + FieldSpecBase getNextChildField(FieldSpecBase parent) { + return {parent.getFieldId(), _layout.allocTermField(parent.getFieldId()), false}; + } // used by create visitor void reserve(size_t num_children); diff --git a/searchlib/src/vespa/searchlib/queryeval/field_spec.h b/searchlib/src/vespa/searchlib/queryeval/field_spec.h index fd925fdf4ff..3fe43597602 100644 --- a/searchlib/src/vespa/searchlib/queryeval/field_spec.h +++ b/searchlib/src/vespa/searchlib/queryeval/field_spec.h @@ -44,9 +44,16 @@ public: : FieldSpecBase(fieldId, handle, isFilter_), _name(name) {} + FieldSpec(const vespalib::string & name, FieldSpecBase base) + : FieldSpecBase(base), + _name(name) + {} ~FieldSpec(); - const vespalib::string & getName() const { return _name; } + void setBase(FieldSpecBase base) { + static_cast<FieldSpecBase &>(*this) = base; + } + const vespalib::string & getName() const noexcept { return _name; } private: vespalib::string _name; // field name }; diff --git a/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.cpp b/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.cpp index 4203a6361b2..8b8db0f293a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.cpp @@ -29,12 +29,12 @@ public: memset(_lastWords, 0, sizeof(_lastWords)); } protected: - void updateLastValue(uint32_t docId); - void strictSeek(uint32_t docId); + void updateLastValue(uint32_t docId) noexcept; + void strictSeek(uint32_t docId) noexcept; private: void doSeek(uint32_t docId) override; Trinary is_strict() const override { return Trinary::False; } - bool acceptExtraFilter() const override { return Update::isAnd(); } + bool acceptExtraFilter() const noexcept final { return Update::isAnd(); } Update _update; const IAccelrated & _accel; alignas(64) Word _lastWords[8]; @@ -42,7 +42,7 @@ private: }; template<typename Update> -class MultiBitVectorIteratorStrict : public MultiBitVectorIterator<Update> +class MultiBitVectorIteratorStrict final : public MultiBitVectorIterator<Update> { public: explicit MultiBitVectorIteratorStrict(MultiSearch::Children children) @@ -55,36 +55,36 @@ private: struct And { using Word = BitWord::Word; - void operator () (const IAccelrated & accel, size_t offset, const std::vector<std::pair<const void *, bool>> & src, void *dest) { + void operator () (const IAccelrated & accel, size_t offset, const std::vector<std::pair<const void *, bool>> & src, void *dest) noexcept { accel.and64(offset, src, dest); } - static bool isAnd() { return true; } + static bool isAnd() noexcept { return true; } }; struct Or { using Word = BitWord::Word; - void operator () (const IAccelrated & accel, size_t offset, const std::vector<std::pair<const void *, bool>> & src, void *dest) { + void operator () (const IAccelrated & accel, size_t offset, const std::vector<std::pair<const void *, bool>> & src, void *dest) noexcept { accel.or64(offset, src, dest); } - static bool isAnd() { return false; } + static bool isAnd() noexcept { return false; } }; template<typename Update> -void MultiBitVectorIterator<Update>::updateLastValue(uint32_t docId) +void MultiBitVectorIterator<Update>::updateLastValue(uint32_t docId) noexcept { if (docId >= _lastMaxDocIdLimit) { if (__builtin_expect(docId >= _numDocs, false)) { setAtEnd(); return; } - const uint32_t index(wordNum(docId)); + const uint32_t index(BitWord::wordNum(docId)); if (docId >= _lastMaxDocIdLimitRequireFetch) { uint32_t baseIndex = index & ~(NumWordsInBatch - 1); _update(_accel, baseIndex*sizeof(Word), _bvs, _lastWords); - _lastMaxDocIdLimitRequireFetch = (baseIndex + NumWordsInBatch) * WordLen; + _lastMaxDocIdLimitRequireFetch = (baseIndex + NumWordsInBatch) * BitWord::WordLen; } _lastValue = _lastWords[index % NumWordsInBatch]; - _lastMaxDocIdLimit = (index + 1) * WordLen; + _lastMaxDocIdLimit = (index + 1) * BitWord::WordLen; } } @@ -94,7 +94,7 @@ MultiBitVectorIterator<Update>::doSeek(uint32_t docId) { updateLastValue(docId); if (__builtin_expect( ! isAtEnd(), true)) { - if (_lastValue & mask(docId)) { + if (_lastValue & BitWord::mask(docId)) { setDocId(docId); } } @@ -102,13 +102,13 @@ MultiBitVectorIterator<Update>::doSeek(uint32_t docId) template<typename Update> void -MultiBitVectorIterator<Update>::strictSeek(uint32_t docId) +MultiBitVectorIterator<Update>::strictSeek(uint32_t docId) noexcept { - for (updateLastValue(docId), _lastValue = _lastValue & checkTab(docId); + for (updateLastValue(docId), _lastValue = _lastValue & BitWord::checkTab(docId); (_lastValue == 0) && __builtin_expect(! isAtEnd(), true); updateLastValue(_lastMaxDocIdLimit)); if (__builtin_expect(!isAtEnd(), true)) { - docId = _lastMaxDocIdLimit - WordLen + vespalib::Optimized::lsbIdx(_lastValue); + docId = _lastMaxDocIdLimit - BitWord::WordLen + vespalib::Optimized::lsbIdx(_lastValue); if (__builtin_expect(docId >= _numDocs, false)) { setAtEnd(); } else { diff --git a/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.h b/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.h index d75a9ddd357..d99e439af0b 100644 --- a/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/multibitvectoriterator.h @@ -8,7 +8,7 @@ namespace search::queryeval { -class MultiBitVectorIteratorBase : public MultiSearch, protected BitWord +class MultiBitVectorIteratorBase : public MultiSearch { public: ~MultiBitVectorIteratorBase() override; @@ -20,7 +20,8 @@ public: */ static SearchIterator::UP optimize(SearchIterator::UP parent); protected: - MultiBitVectorIteratorBase(Children hildren); + using Word = BitWord::Word; + MultiBitVectorIteratorBase(Children children); using MetaWord = std::pair<const void *, bool>; uint32_t _numDocs; @@ -29,7 +30,7 @@ protected: Word _lastValue; // Last value computed std::vector<MetaWord> _bvs; private: - virtual bool acceptExtraFilter() const = 0; + virtual bool acceptExtraFilter() const noexcept = 0; UP andWith(UP filter, uint32_t estimate) override; void doUnpack(uint32_t docid) override; static SearchIterator::UP optimizeMultiSearch(SearchIterator::UP parent); diff --git a/searchlib/src/vespa/searchlib/queryeval/searchiterator.h b/searchlib/src/vespa/searchlib/queryeval/searchiterator.h index b58d7ceed43..4ab066727af 100644 --- a/searchlib/src/vespa/searchlib/queryeval/searchiterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/searchiterator.h @@ -56,7 +56,7 @@ protected: * * @param id docid for hit **/ - void setDocId(uint32_t id) { _docid = id; } + void setDocId(uint32_t id) noexcept { _docid = id; } /** * Used to adjust the end of the legal docid range. @@ -64,13 +64,13 @@ protected: * * @param end_id the first docid outside the legal iterator range */ - void setEndId(uint32_t end_id) { _endid = end_id; } + void setEndId(uint32_t end_id) noexcept { _endid = end_id; } /** * Will terminate the iterator by setting it past the end. * Further calls to isAtEnd() will then return true. */ - void setAtEnd() { _docid = search::endDocId; } + void setAtEnd() noexcept { _docid = search::endDocId; } public: using Trinary=vespalib::Trinary; @@ -180,7 +180,7 @@ public: /** * The constructor sets the current document id to @ref beginId. **/ - SearchIterator() : _docid(0), _endid(0) { } + SearchIterator() noexcept : _docid(0), _endid(0) { } SearchIterator(const SearchIterator &) = delete; SearchIterator &operator=(const SearchIterator &) = delete; @@ -192,15 +192,15 @@ public: * * @return constant **/ - static uint32_t beginId() { return beginDocId; } + static uint32_t beginId() noexcept { return beginDocId; } /** * Tell if the iterator has reached the end. * * @return true if the iterator has reached its end. **/ - bool isAtEnd() const { return isAtEnd(_docid); } - bool isAtEnd(uint32_t docid) const { + bool isAtEnd() const noexcept { return isAtEnd(_docid); } + bool isAtEnd(uint32_t docid) const noexcept { if (__builtin_expect(docid >= _endid, false)) { return true; } @@ -214,9 +214,9 @@ public: * * @return current document id **/ - uint32_t getDocId() const { return _docid; } + uint32_t getDocId() const noexcept { return _docid; } - uint32_t getEndId() const { return _endid; } + uint32_t getEndId() const noexcept { return _endid; } /** * Check if the given document id is a hit. If it is a hit, the diff --git a/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp index e303e0b16d9..48a09f099a6 100644 --- a/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.cpp @@ -12,12 +12,11 @@ namespace search::queryeval { -ParallelWeakAndBlueprint::ParallelWeakAndBlueprint(const FieldSpec &field, +ParallelWeakAndBlueprint::ParallelWeakAndBlueprint(FieldSpecBase field, uint32_t scoresToTrack, score_t scoreThreshold, double thresholdBoostFactor) : ComplexLeafBlueprint(field), - _field(field), _scores(scoresToTrack), _scoreThreshold(scoreThreshold), _thresholdBoostFactor(thresholdBoostFactor), @@ -28,13 +27,12 @@ ParallelWeakAndBlueprint::ParallelWeakAndBlueprint(const FieldSpec &field, { } -ParallelWeakAndBlueprint::ParallelWeakAndBlueprint(const FieldSpec &field, +ParallelWeakAndBlueprint::ParallelWeakAndBlueprint(FieldSpecBase field, uint32_t scoresToTrack, score_t scoreThreshold, double thresholdBoostFactor, uint32_t scoresAdjustFrequency) : ComplexLeafBlueprint(field), - _field(field), _scores(scoresToTrack), _scoreThreshold(scoreThreshold), _thresholdBoostFactor(thresholdBoostFactor), @@ -47,12 +45,6 @@ ParallelWeakAndBlueprint::ParallelWeakAndBlueprint(const FieldSpec &field, ParallelWeakAndBlueprint::~ParallelWeakAndBlueprint() = default; -FieldSpec -ParallelWeakAndBlueprint::getNextChildField(const FieldSpec &outer) -{ - return FieldSpec(outer.getName(), outer.getFieldId(), _layout.allocTermField(outer.getFieldId()), false); -} - void ParallelWeakAndBlueprint::reserve(size_t num_children) { _weights.reserve(num_children); diff --git a/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.h index cb4d44f4497..a8d066ee689 100644 --- a/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/wand/parallel_weak_and_blueprint.h @@ -21,7 +21,6 @@ class ParallelWeakAndBlueprint : public ComplexLeafBlueprint private: using score_t = wand::score_t; - const FieldSpec _field; mutable SharedWeakAndPriorityQueue _scores; const wand::score_t _scoreThreshold; double _thresholdBoostFactor; @@ -30,15 +29,14 @@ private: std::vector<int32_t> _weights; std::vector<Blueprint::UP> _terms; - ParallelWeakAndBlueprint(const ParallelWeakAndBlueprint &); - ParallelWeakAndBlueprint &operator=(const ParallelWeakAndBlueprint &); - public: - ParallelWeakAndBlueprint(const FieldSpec &field, + ParallelWeakAndBlueprint(const ParallelWeakAndBlueprint &) = delete; + ParallelWeakAndBlueprint &operator=(const ParallelWeakAndBlueprint &) = delete; + ParallelWeakAndBlueprint(FieldSpecBase field, uint32_t scoresToTrack, score_t scoreThreshold, double thresholdBoostFactor); - ParallelWeakAndBlueprint(const FieldSpec &field, + ParallelWeakAndBlueprint(FieldSpecBase field, uint32_t scoresToTrack, score_t scoreThreshold, double thresholdBoostFactor, @@ -52,7 +50,9 @@ public: double getThresholdBoostFactor() const { return _thresholdBoostFactor; } // Used by create visitor - FieldSpec getNextChildField(const FieldSpec &outer); + FieldSpecBase getNextChildField(FieldSpecBase parent) { + return {parent.getFieldId(), _layout.allocTermField(parent.getFieldId()), false}; + } // Used by create visitor void reserve(size_t num_children); diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h index b40ab421890..0e3c82444d7 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h @@ -26,7 +26,7 @@ public: // used by create visitor // matches signature in dot product blueprint for common blueprint // building code. Hands out the same field spec to all children. - FieldSpec getNextChildField(const FieldSpec &) { return _children_field; } + FieldSpecBase getNextChildField(FieldSpecBase) { return _children_field; } // used by create visitor void reserve(size_t num_children); @@ -39,9 +39,6 @@ public: SearchIteratorUP createFilterSearch(bool strict, FilterConstraint constraint) const override; std::unique_ptr<MatchingElementsSearch> create_matching_elements_search(const MatchingElementsFields &fields) const override; void visitMembers(vespalib::ObjectVisitor &visitor) const override; - const vespalib::string &field_name() const { return _children_field.getName(); } - const std::vector<Blueprint::UP> &get_terms() const { return _terms; } - private: void fetchPostings(const ExecuteInfo &execInfo) override; }; diff --git a/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java b/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java index bd085f6f624..3a46891085f 100644 --- a/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java +++ b/security-utils/src/main/java/com/yahoo/security/SideChannelSafe.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.security; +import org.bouncycastle.util.Arrays; + /** * Utility functions for comparing the contents of arrays without leaking information about the * data contained within them via timing side-channels. This is done by avoiding any branches @@ -13,18 +15,11 @@ package com.yahoo.security; public class SideChannelSafe { /** - * @return true iff all bytes in the array are zero. An empty array always returns false - * since it technically can't contain any zeros at all. + * @return true iff all bytes in the array are zero. An empty array always returns true + * to be in line with BouncyCastle semantics. */ public static boolean allZeros(byte[] buf) { - if (buf.length == 0) { - return false; - } - byte accu = 0; - for (byte b : buf) { - accu |= b; - } - return (accu == 0); + return Arrays.areAllZeroes(buf, 0, buf.length); } /** @@ -32,23 +27,14 @@ public class SideChannelSafe { * about the contents of either of the arrays. * * <strong>Important:</strong> the <em>length</em> of the arrays is not considered secret, and - * will be leaked if arrays of differing sizes are given. + * <em>may</em> be leaked if arrays of differing sizes are given. * * @param lhs first array of bytes to compare * @param rhs second array of bytes to compare * @return true iff both arrays have the same size and are element-wise identical */ public static boolean arraysEqual(byte[] lhs, byte[] rhs) { - if (lhs.length != rhs.length) { - return false; - } - // Only use constant time bitwise ops. `accu` will be non-zero if at least one bit - // differed in any byte compared between the two arrays. - byte accu = 0; - for (int i = 0; i < lhs.length; ++i) { - accu |= (byte)(lhs[i] ^ rhs[i]); - } - return (accu == 0); + return Arrays.constantTimeAreEqual(lhs, rhs); } } diff --git a/security-utils/src/main/java/com/yahoo/security/token/TokenCheckHash.java b/security-utils/src/main/java/com/yahoo/security/token/TokenCheckHash.java index 2ff47081784..b67b120ba7b 100644 --- a/security-utils/src/main/java/com/yahoo/security/token/TokenCheckHash.java +++ b/security-utils/src/main/java/com/yahoo/security/token/TokenCheckHash.java @@ -1,6 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.security.token; +import com.yahoo.security.SideChannelSafe; + import java.util.Arrays; import static com.yahoo.security.ArrayUtils.hex; @@ -18,8 +20,9 @@ public record TokenCheckHash(byte[] hashBytes) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; TokenCheckHash tokenCheckHash = (TokenCheckHash) o; - // We don't consider token hashes secret data, so no harm in data-dependent equals() - return Arrays.equals(hashBytes, tokenCheckHash.hashBytes); + // Although not considered secret information, avoid leaking the contents of + // the check-hashes themselves via timing channels. + return SideChannelSafe.arraysEqual(hashBytes, tokenCheckHash.hashBytes); } @Override diff --git a/security-utils/src/test/java/com/yahoo/security/SideChannelSafeTest.java b/security-utils/src/test/java/com/yahoo/security/SideChannelSafeTest.java index 7a66ed6eb7f..9731bbffc38 100644 --- a/security-utils/src/test/java/com/yahoo/security/SideChannelSafeTest.java +++ b/security-utils/src/test/java/com/yahoo/security/SideChannelSafeTest.java @@ -14,7 +14,7 @@ public class SideChannelSafeTest { @Test void all_zeros_checks_length_and_array_contents() { - assertFalse(SideChannelSafe.allZeros(new byte[0])); + assertTrue(SideChannelSafe.allZeros(new byte[0])); assertFalse(SideChannelSafe.allZeros(new byte[]{ 1 })); assertTrue(SideChannelSafe.allZeros(new byte[]{ 0 })); assertFalse(SideChannelSafe.allZeros(new byte[]{ 0, 0, 127, 0 })); diff --git a/storage/src/vespa/storage/storageserver/statemanager.cpp b/storage/src/vespa/storage/storageserver/statemanager.cpp index 654fe0e1f5d..c228229e4ef 100644 --- a/storage/src/vespa/storage/storageserver/statemanager.cpp +++ b/storage/src/vespa/storage/storageserver/statemanager.cpp @@ -17,6 +17,7 @@ #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/string_escape.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/time.h> #include <fstream> #include <vespa/log/log.h> @@ -68,6 +69,10 @@ StateManager::StateManager(StorageComponentRegister& compReg, _threadLock(), _systemStateHistory(), _systemStateHistorySize(50), + _start_time(vespalib::steady_clock::now()), + _health_ping_time(), + _health_ping_warn_interval(5min), + _health_ping_warn_time(_start_time + _health_ping_warn_interval), _hostInfo(std::move(hostInfo)), _controllers_observed_explicit_node_state(), _noThreadTestMode(testMode), @@ -391,6 +396,8 @@ StateManager::onGetNodeState(const api::GetNodeStateCommand::SP& cmd) std::shared_ptr<api::GetNodeStateReply> reply; { std::unique_lock guard(_stateLock); + _health_ping_time = vespalib::steady_clock::now(); + _health_ping_warn_time = _health_ping_time.value() + _health_ping_warn_interval; const bool is_up_to_date = (_controllers_observed_explicit_node_state.find(cmd->getSourceIndex()) != _controllers_observed_explicit_node_state.end()); if ((cmd->getExpectedState() != nullptr) @@ -479,6 +486,28 @@ StateManager::run(framework::ThreadHandle& thread) } void +StateManager::warn_on_missing_health_ping() +{ + vespalib::steady_time now(vespalib::steady_clock::now()); + std::optional<vespalib::steady_time> health_ping_time; + { + std::lock_guard lock(_stateLock); + if (now <= _health_ping_warn_time) { + return; + } + health_ping_time = _health_ping_time; + _health_ping_warn_time = now + _health_ping_warn_interval; + } + if (health_ping_time.has_value()) { + vespalib::duration duration = now - health_ping_time.value(); + LOG(warning, "Last health ping from cluster controller was %1.1f seconds ago", vespalib::to_s(duration)); + } else { + vespalib::duration duration = now - _start_time; + LOG(warning, "No health pings from cluster controller since startup %1.1f seconds ago", vespalib::to_s(duration)); + } +} + +void StateManager::tick() { bool almost_immediate_replies = _requested_almost_immediate_node_state_replies.load(std::memory_order_relaxed); if (almost_immediate_replies) { @@ -487,6 +516,7 @@ StateManager::tick() { } else { sendGetNodeStateReplies(_component.getClock().getMonotonicTime()); } + warn_on_missing_health_ping(); } bool diff --git a/storage/src/vespa/storage/storageserver/statemanager.h b/storage/src/vespa/storage/storageserver/statemanager.h index 0b9a47c2515..3b1291b1c3f 100644 --- a/storage/src/vespa/storage/storageserver/statemanager.h +++ b/storage/src/vespa/storage/storageserver/statemanager.h @@ -65,6 +65,10 @@ class StateManager : public NodeStateUpdater, std::condition_variable _threadCond; std::deque<TimeSysStatePair> _systemStateHistory; uint32_t _systemStateHistorySize; + const vespalib::steady_time _start_time; + std::optional<vespalib::steady_time> _health_ping_time; + vespalib::duration _health_ping_warn_interval; + vespalib::steady_time _health_ping_warn_time; std::unique_ptr<HostInfo> _hostInfo; std::unique_ptr<framework::Thread> _thread; // Controllers that have observed a GetNodeState response sent _after_ @@ -84,6 +88,7 @@ public: void onClose() override; void tick(); + void warn_on_missing_health_ping(); void print(std::ostream& out, bool verbose, const std::string& indent) const override; void reportHtmlStatus(std::ostream&, const framework::HttpUrlPath&) const override; diff --git a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp index 4d425d9dedd..f408910d8c2 100644 --- a/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp +++ b/streamingvisitors/src/tests/rank_processor/rank_processor_test.cpp @@ -18,6 +18,7 @@ using search::query::SimpleQueryNodeTypes; using search::query::StackDumpCreator; using search::streaming::NearestNeighborQueryNode; using search::streaming::Query; +using search::streaming::QueryTerm; using streaming::RankProcessor; using streaming::QueryTermData; using streaming::QueryTermDataFactory; @@ -34,6 +35,7 @@ protected: ~RankProcessorTest() override; void build_query(QueryBuilder<SimpleQueryNodeTypes> &builder); + void test_unpack_match_data_for_term_node(bool interleaved_features); }; RankProcessorTest::RankProcessorTest() @@ -55,6 +57,63 @@ RankProcessorTest::build_query(QueryBuilder<SimpleQueryNodeTypes> &builder) _query_wrapper = std::make_unique<QueryWrapper>(*_query); } +void +RankProcessorTest::test_unpack_match_data_for_term_node(bool interleaved_features) +{ + QueryBuilder<SimpleQueryNodeTypes> builder; + constexpr int32_t id = 42; + constexpr int32_t weight = 1; + builder.addStringTerm("term", "field", id, Weight(weight)); + build_query(builder); + auto& term_list = _query_wrapper->getTermList(); + EXPECT_EQ(1u, term_list.size()); + auto node = dynamic_cast<QueryTerm*>(term_list.front().getTerm()); + EXPECT_NE(nullptr, node); + auto& qtd = static_cast<QueryTermData &>(node->getQueryItem()); + auto& td = qtd.getTermData(); + constexpr TermFieldHandle handle = 27; + constexpr uint32_t field_id = 12; + constexpr uint32_t mock_num_occs = 2; + constexpr uint32_t mock_field_length = 101; + td.addField(field_id).setHandle(handle); + node->resizeFieldId(field_id); + auto md = MatchData::makeTestInstance(handle + 1, handle + 1); + auto tfmd = md->resolveTermField(handle); + tfmd->setNeedInterleavedFeatures(interleaved_features); + auto invalid_id = TermFieldMatchData::invalidId(); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + RankProcessor::unpack_match_data(1, *md, *_query_wrapper); + EXPECT_EQ(invalid_id, tfmd->getDocId()); + node->add(0, field_id, 0, 1); + auto& field_info = node->getFieldInfo(field_id); + field_info.setHitCount(mock_num_occs); + field_info.setFieldLength(mock_field_length); + RankProcessor::unpack_match_data(2, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); + if (interleaved_features) { + EXPECT_EQ(mock_num_occs, tfmd->getNumOccs()); + EXPECT_EQ(mock_field_length, tfmd->getFieldLength()); + } else { + EXPECT_EQ(0, tfmd->getNumOccs()); + EXPECT_EQ(0, tfmd->getFieldLength()); + } + EXPECT_EQ(1, tfmd->size()); + node->reset(); + RankProcessor::unpack_match_data(3, *md, *_query_wrapper); + EXPECT_EQ(2, tfmd->getDocId()); +} + + +TEST_F(RankProcessorTest, unpack_normal_match_data_for_term_node) +{ + test_unpack_match_data_for_term_node(false); +} + +TEST_F(RankProcessorTest, unpack_interleaved_match_data_for_term_node) +{ + test_unpack_match_data_for_term_node(true); +} + class MockRawScoreCalculator : public search::streaming::NearestNeighborQueryNode::RawScoreCalculator { public: double to_raw_score(double distance) override { return distance * 2; } diff --git a/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h b/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h index c5dc442e424..8084d776efe 100644 --- a/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h +++ b/streamingvisitors/src/vespa/searchvisitor/queryenvironment.h @@ -55,7 +55,7 @@ public: // inherit documentation virtual const search::attribute::IAttributeContext & getAttributeContext() const override { return *_attrCtx; } - double get_average_field_length(const vespalib::string &) const override { return 1.0; } + double get_average_field_length(const vespalib::string &) const override { return 100.0; } // inherit documentation virtual const search::fef::IIndexEnvironment & getIndexEnvironment() const override { return _indexEnv; } diff --git a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp index 78d72102fe9..17056d9d4b7 100644 --- a/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/rankprocessor.cpp @@ -6,6 +6,7 @@ #include <vespa/searchlib/fef/simpletermfielddata.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/vsm/vsm/fieldsearchspec.h> +#include <algorithm> #include <cmath> #include <vespa/log/log.h> LOG_SETUP(".searchvisitor.rankprocessor"); @@ -50,6 +51,11 @@ getFeature(const RankProgram &rankProgram) { return resolver.resolve(0); } +uint16_t +cap_16_bits(uint32_t value) { + return std::min(value, static_cast<uint32_t>(std::numeric_limits<uint16_t>::max())); +} + } void @@ -284,6 +290,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap uint32_t lastFieldId = -1; TermFieldMatchData *tmd = nullptr; uint32_t fieldLen = search::fef::FieldPositionsIterator::UNKNOWN_LENGTH; + uint32_t num_occs = 0; // optimize for hitlist giving all hits for a single field in one chunk for (const Hit & hit : hitList) { @@ -292,6 +299,7 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // reset to notfound/unknown values tmd = nullptr; fieldLen = search::fef::FieldPositionsIterator::UNKNOWN_LENGTH; + num_occs = 0; // setup for new field that had a hit const ITermFieldData *tfd = td.lookupField(fieldId); @@ -306,11 +314,15 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap // find fieldLen for new field if (isPhrase) { if (fieldId < term.getParent()->getFieldInfoSize()) { - fieldLen = term.getParent()->getFieldInfo(fieldId).getFieldLength(); + auto& field_info = term.getParent()->getFieldInfo(fieldId); + fieldLen = field_info.getFieldLength(); + num_occs = field_info.getHitCount(); } } else { if (fieldId < term.getTerm()->getFieldInfoSize()) { - fieldLen = term.getTerm()->getFieldInfo(fieldId).getFieldLength(); + auto& field_info = term.getTerm()->getFieldInfo(fieldId); + fieldLen = field_info.getFieldLength(); + num_occs = field_info.getHitCount(); } } lastFieldId = fieldId; @@ -322,6 +334,10 @@ RankProcessor::unpack_match_data(uint32_t docid, MatchData &matchData, QueryWrap tmd->appendPosition(pos); LOG(debug, "Append elemId(%u),position(%u), weight(%d), tfmd.weight(%d)", pos.getElementId(), pos.getPosition(), pos.getElementWeight(), tmd->getWeight()); + if (tmd->needs_interleaved_features()) { + tmd->setFieldLength(cap_16_bits(fieldLen)); + tmd->setNumOccs(cap_16_bits(num_occs)); + } } } } diff --git a/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/JettyCluster.java b/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/JettyCluster.java index cd7a4e6222e..1d9fd9d1805 100644 --- a/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/JettyCluster.java +++ b/vespa-feed-client/src/main/java/ai/vespa/feed/client/impl/JettyCluster.java @@ -14,6 +14,7 @@ import org.eclipse.jetty.client.api.Request; import org.eclipse.jetty.client.api.Response; import org.eclipse.jetty.client.api.Result; import org.eclipse.jetty.client.dynamic.HttpClientTransportDynamic; +import org.eclipse.jetty.client.http.HttpClientConnectionFactory; import org.eclipse.jetty.client.util.BufferingResponseListener; import org.eclipse.jetty.client.util.BytesRequestContent; import org.eclipse.jetty.http.HttpField; @@ -145,8 +146,10 @@ class JettyCluster implements Cluster { int initialWindow = Integer.MAX_VALUE; h2Client.setInitialSessionRecvWindow(initialWindow); h2Client.setInitialStreamRecvWindow(initialWindow); + // Need HTTP/1.1 for tunnel using CONNECT method + ClientConnectionFactory.Info h1 = HttpClientConnectionFactory.HTTP11; ClientConnectionFactory.Info http2 = new ClientConnectionFactoryOverHTTP2.HTTP2(h2Client); - HttpClientTransportDynamic transport = new HttpClientTransportDynamic(connector, http2); + HttpClientTransportDynamic transport = new HttpClientTransportDynamic(connector, http2, h1); int connectionsPerEndpoint = b.connectionsPerEndpoint; transport.setConnectionPoolFactory(dest -> { MultiplexConnectionPool pool = new MultiplexConnectionPool( @@ -171,6 +174,7 @@ class JettyCluster implements Cluster { private static void addProxyConfiguration(FeedClientBuilderImpl b, HttpClient httpClient) throws IOException { Origin.Address address = new Origin.Address(b.proxy.getHost(), b.proxy.getPort()); + Map<String, Supplier<String>> proxyHeadersCopy = new TreeMap<>(b.proxyRequestHeaders); if (b.proxy.getScheme().equals("https")) { SslContextFactory.Client proxySslCtxFactory = new SslContextFactory.Client(); if (b.proxyHostnameVerifier != null) { @@ -182,19 +186,25 @@ class JettyCluster implements Cluster { try { proxySslCtxFactory.start(); } catch (Exception e) { throw new IOException(e); } httpClient.getProxyConfiguration().addProxy( new HttpProxy(address, proxySslCtxFactory, new Origin.Protocol(Collections.singletonList("h2"), false))); - } else { - httpClient.getProxyConfiguration().addProxy( - new HttpProxy(address, false, new Origin.Protocol(Collections.singletonList("h2c"), false))); - } - Map<String, Supplier<String>> proxyHeadersCopy = new TreeMap<>(b.proxyRequestHeaders); - URI proxyUri = URI.create(endpointUri(b.proxy)); - if (!proxyHeadersCopy.isEmpty()) { + URI proxyUri = URI.create(endpointUri(b.proxy)); httpClient.getAuthenticationStore().addAuthenticationResult(new Authentication.Result() { @Override public URI getURI() { return proxyUri; } @Override public void apply(Request r) { r.headers(hs -> proxyHeadersCopy.forEach((k, v) -> hs.add(k, v.get()))); } }); + } else { + // Assume insecure proxy uses HTTP/1.1 + httpClient.getProxyConfiguration().addProxy( + new HttpProxy(address, false, new Origin.Protocol(Collections.singletonList("http/1.1"), false))); + // Bug in Jetty cause authentication result to be ignored for HTTP/1.1 CONNECT requests + httpClient.getRequestListeners().add(new Request.Listener() { + @Override + public void onHeaders(Request r) { + if (HttpMethod.CONNECT.is(r.getMethod())) + r.headers(hs -> proxyHeadersCopy.forEach((k, v) -> hs.add(k, v.get()))); + } + }); } } diff --git a/vespalib/src/tests/io/fileutil/fileutiltest.cpp b/vespalib/src/tests/io/fileutil/fileutiltest.cpp index 4eb700fd4ed..337c9052a66 100644 --- a/vespalib/src/tests/io/fileutil/fileutiltest.cpp +++ b/vespalib/src/tests/io/fileutil/fileutiltest.cpp @@ -34,7 +34,7 @@ TEST("require that vespalib::File::open works") { // Opening non-existing file for reading should fail. try{ - unlink("myfile"); // Just in case + std::filesystem::remove(std::filesystem::path("myfile")); // Just in case File f("myfile"); f.open(File::READONLY); TEST_FATAL("Opening non-existing file for reading should fail."); @@ -155,7 +155,7 @@ TEST("require that vespalib::File::isOpen works") TEST("require that vespalib::File::stat works") { - unlink("myfile"); + std::filesystem::remove(std::filesystem::path("myfile")); std::filesystem::remove_all(std::filesystem::path("mydir")); EXPECT_EQUAL(false, fileExists("myfile")); EXPECT_EQUAL(false, fileExists("mydir")); @@ -188,7 +188,7 @@ TEST("require that vespalib::File::stat works") TEST("require that vespalib::File::resize works") { - unlink("myfile"); + std::filesystem::remove(std::filesystem::path("myfile")); File f("myfile"); f.open(File::CREATE, false); f.write("foobar", 6, 0); @@ -206,162 +206,6 @@ TEST("require that vespalib::File::resize works") EXPECT_EQUAL(std::string("foo"), std::string(&vec[0], 3)); } -TEST("require that vespalib::unlink works") -{ - // Fails on directory - try{ - std::filesystem::create_directory(std::filesystem::path("mydir")); - unlink("mydir"); - TEST_FATAL("Should work on directories."); - } catch (IoException& e) { - //std::cerr << e.what() << "\n"; -#ifdef __APPLE__ - EXPECT_EQUAL(IoException::NO_PERMISSION, e.getType()); -#else - EXPECT_EQUAL(IoException::ILLEGAL_PATH, e.getType()); -#endif - } - // Works for file - { - { - File f("myfile"); - f.open(File::CREATE); - f.write("foo", 3, 0); - } - ASSERT_TRUE(fileExists("myfile")); - ASSERT_TRUE(unlink("myfile")); - ASSERT_TRUE(!fileExists("myfile")); - ASSERT_TRUE(!unlink("myfile")); - } -} - -TEST("require that vespalib::rename works") -{ - std::filesystem::remove_all(std::filesystem::path("mydir")); - File f("myfile"); - f.open(File::CREATE | File::TRUNC); - f.write("Hello World!\n", 13, 0); - f.close(); - // Renaming to non-existing dir doesn't work - try{ - rename("myfile", "mydir/otherfile"); - TEST_FATAL("This shouldn't work when mydir doesn't exist"); - } catch (IoException& e) { - //std::cerr << e.what() << "\n"; - EXPECT_EQUAL(IoException::NOT_FOUND, e.getType()); - } - // Renaming to non-existing dir works if autocreating dirs - { - ASSERT_TRUE(rename("myfile", "mydir/otherfile", true, true)); - ASSERT_TRUE(!fileExists("myfile")); - ASSERT_TRUE(fileExists("mydir/otherfile")); - - File f2("mydir/otherfile"); - f2.open(File::READONLY); - std::vector<char> vec(20, ' '); - size_t read = f2.read(&vec[0], 20, 0); - EXPECT_EQUAL(13u, read); - EXPECT_EQUAL(std::string("Hello World!\n"), std::string(&vec[0], 13)); - } - // Renaming non-existing returns false - ASSERT_TRUE(!rename("myfile", "mydir/otherfile", true)); - // Rename to overwrite works - { - f.open(File::CREATE | File::TRUNC); - f.write("Bah\n", 4, 0); - f.close(); - ASSERT_TRUE(rename("myfile", "mydir/otherfile", true, true)); - - File f2("mydir/otherfile"); - f2.open(File::READONLY); - std::vector<char> vec(20, ' '); - size_t read = f2.read(&vec[0], 20, 0); - EXPECT_EQUAL(4u, read); - EXPECT_EQUAL(std::string("Bah\n"), std::string(&vec[0], 4)); - } - // Overwriting directory fails (does not put inside dir) - try{ - std::filesystem::create_directory(std::filesystem::path("mydir")); - f.open(File::CREATE | File::TRUNC); - f.write("Bah\n", 4, 0); - f.close(); - ASSERT_TRUE(rename("myfile", "mydir")); - } catch (IoException& e) { - //std::cerr << e.what() << "\n"; - EXPECT_EQUAL(IoException::ILLEGAL_PATH, e.getType()); - } - // Moving directory works - { - ASSERT_TRUE(isDirectory("mydir")); - std::filesystem::remove_all(std::filesystem::path("myotherdir")); - ASSERT_TRUE(rename("mydir", "myotherdir")); - ASSERT_TRUE(isDirectory("myotherdir")); - ASSERT_TRUE(!isDirectory("mydir")); - ASSERT_TRUE(!rename("mydir", "myotherdir")); - } - // Overwriting directory fails - try{ - File f2("mydir/yetanotherfile"); - f2.open(File::CREATE, true); - f2.write("foo", 3, 0); - f2.open(File::READONLY); - f2.close(); - rename("mydir", "myotherdir"); - TEST_FATAL("Should fail trying to overwrite directory"); - } catch (IoException& e) { - //std::cerr << e.what() << "\n"; - EXPECT_TRUE((IoException::DIRECTORY_HAVE_CONTENT == e.getType()) || - (IoException::ALREADY_EXISTS == e.getType())); - } -} - -TEST("require that vespalib::copy works") -{ - std::filesystem::remove_all(std::filesystem::path("mydir")); - File f("myfile"); - f.open(File::CREATE | File::TRUNC); - - MallocAutoPtr buffer = getAlignedBuffer(5000); - memset(buffer.get(), 0, 5000); - strncpy(static_cast<char*>(buffer.get()), "Hello World!\n", 14); - f.write(buffer.get(), 4_Ki, 0); - f.close(); - std::cerr << "Simple copy\n"; - // Simple copy works (4096b dividable file) - copy("myfile", "targetfile"); - ASSERT_TRUE(system("diff myfile targetfile") == 0); - std::cerr << "Overwriting\n"; - // Overwriting works (may not be able to use direct IO writing on all - // systems, so will always use cached IO) - { - f.open(File::CREATE | File::TRUNC); - f.write("Bah\n", 4, 0); - f.close(); - - ASSERT_TRUE(system("diff myfile targetfile > /dev/null") != 0); - copy("myfile", "targetfile"); - ASSERT_TRUE(system("diff myfile targetfile > /dev/null") == 0); - } - // Fails if target is directory - try{ - std::filesystem::create_directory(std::filesystem::path("mydir")); - copy("myfile", "mydir"); - TEST_FATAL("Should fail trying to overwrite directory"); - } catch (IoException& e) { - //std::cerr << e.what() << "\n"; - EXPECT_EQUAL(IoException::ILLEGAL_PATH, e.getType()); - } - // Fails if source is directory - try{ - std::filesystem::create_directory(std::filesystem::path("mydir")); - copy("mydir", "myfile"); - TEST_FATAL("Should fail trying to copy directory"); - } catch (IoException& e) { - //std::cerr << e.what() << "\n"; - EXPECT_EQUAL(IoException::ILLEGAL_PATH, e.getType()); - } -} - TEST("require that copy constructor and assignment for vespalib::File works") { // Copy file not opened. @@ -403,72 +247,10 @@ TEST("require that copy constructor and assignment for vespalib::File works") } } -TEST("require that vespalib::symlink works") -{ - // Target exists - { - std::filesystem::remove_all(std::filesystem::path("mydir")); - std::filesystem::create_directory(std::filesystem::path("mydir")); - - File f("mydir/myfile"); - f.open(File::CREATE | File::TRUNC); - f.write("Hello World!\n", 13, 0); - f.close(); - - symlink("myfile", "mydir/linkyfile"); - EXPECT_TRUE(fileExists("mydir/linkyfile")); - - File f2("mydir/linkyfile"); - f2.open(File::READONLY); - std::vector<char> vec(20, ' '); - size_t read = f2.read(&vec[0], 20, 0); - EXPECT_EQUAL(13u, read); - EXPECT_EQUAL(std::string("Hello World!\n"), std::string(&vec[0], 13)); - } - - // POSIX symlink() fails - { - std::filesystem::remove_all(std::filesystem::path("mydir")); - std::filesystem::create_directories(std::filesystem::path("mydir/a")); - std::filesystem::create_directory(std::filesystem::path("mydir/b")); - try { - // Link already exists - symlink("a", "mydir/b"); - TEST_FATAL("Exception not thrown on already existing link"); - } catch (IoException& e) { - EXPECT_EQUAL(IoException::ALREADY_EXISTS, e.getType()); - } - } - - { - std::filesystem::remove_all(std::filesystem::path("mydir")); - std::filesystem::create_directory(std::filesystem::path("mydir")); - - File f("mydir/myfile"); - f.open(File::CREATE | File::TRUNC); - f.write("Hello World!\n", 13, 0); - f.close(); - } - - // readLink success - { - symlink("myfile", "mydir/linkyfile"); - EXPECT_EQUAL("myfile", readLink("mydir/linkyfile")); - } - // readLink failure - { - try { - readLink("no/such/link"); - } catch (IoException& e) { - EXPECT_EQUAL(IoException::NOT_FOUND, e.getType()); - } - } -} - TEST("require that we can read all data written to file") { // Write text into a file. - unlink("myfile"); + std::filesystem::remove(std::filesystem::path("myfile")); File fileForWriting("myfile"); fileForWriting.open(File::CREATE); vespalib::string text = "This is some text. "; diff --git a/vespalib/src/tests/net/tls/auto_reloading_tls_crypto_engine/auto_reloading_tls_crypto_engine_test.cpp b/vespalib/src/tests/net/tls/auto_reloading_tls_crypto_engine/auto_reloading_tls_crypto_engine_test.cpp index 6662b2a4e41..62614f5d811 100644 --- a/vespalib/src/tests/net/tls/auto_reloading_tls_crypto_engine/auto_reloading_tls_crypto_engine_test.cpp +++ b/vespalib/src/tests/net/tls/auto_reloading_tls_crypto_engine/auto_reloading_tls_crypto_engine_test.cpp @@ -9,6 +9,7 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/testkit/time_bomb.h> #include <openssl/ssl.h> +#include <filesystem> using namespace vespalib; using namespace vespalib::net::tls; @@ -118,7 +119,7 @@ TEST_FF("Config reloading transitively loads updated files", Fixture(50ms), Time ASSERT_EQUAL(cert1_pem, current_certs); write_file("test_cert.pem.tmp", cert2_pem); - rename("test_cert.pem.tmp", "test_cert.pem", false, false); // We expect this to be an atomic rename under the hood + std::filesystem::rename(std::filesystem::path("test_cert.pem.tmp"), std::filesystem::path("test_cert.pem")); // We expect this to be an atomic rename under the hood current_certs = f1.current_cert_chain(); while (current_certs != cert2_pem) { @@ -140,7 +141,7 @@ TEST_FF("Config reload failure increments failure statistic", Fixture(50ms), Tim auto before = ConfigStatistics::get().snapshot(); write_file("test_cert.pem.tmp", "Broken file oh no :("); - rename("test_cert.pem.tmp", "test_cert.pem", false, false); + std::filesystem::rename(std::filesystem::path("test_cert.pem.tmp"), std::filesystem::path("test_cert.pem")); while (ConfigStatistics::get().snapshot().subtract(before).failed_config_reloads == 0) { std::this_thread::sleep_for(10ms); diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp index 590223ed13a..1f3b2d29744 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.cpp @@ -6,32 +6,32 @@ namespace vespalib::hwaccelrated { size_t -Avx2Accelrator::populationCount(const uint64_t *a, size_t sz) const { +Avx2Accelrator::populationCount(const uint64_t *a, size_t sz) const noexcept { return helper::populationCount(a, sz); } double -Avx2Accelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const { +Avx2Accelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept { return helper::squaredEuclideanDistance(a, b, sz); } double -Avx2Accelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const { +Avx2Accelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept { return avx::euclideanDistanceSelectAlignment<float, 32>(a, b, sz); } double -Avx2Accelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const { +Avx2Accelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept { return avx::euclideanDistanceSelectAlignment<double, 32>(a, b, sz); } void -Avx2Accelrator::and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const { +Avx2Accelrator::and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept { helper::andChunks<32u, 2u>(offset, src, dest); } void -Avx2Accelrator::or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const { +Avx2Accelrator::or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept { helper::orChunks<32u, 2u>(offset, src, dest); } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h index 2949e81fd36..1166282fbe8 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx2.h @@ -12,12 +12,12 @@ namespace vespalib::hwaccelrated { class Avx2Accelrator : public GenericAccelrator { public: - size_t populationCount(const uint64_t *a, size_t sz) const override; - double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const override; - double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const override; - double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const override; - void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; - void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; + size_t populationCount(const uint64_t *a, size_t sz) const noexcept override; + double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; + double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; + double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; + void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; + void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp index 5878165bb6d..1648f80c9e9 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.cpp @@ -6,44 +6,42 @@ namespace vespalib:: hwaccelrated { float -Avx512Accelrator::dotProduct(const float * af, const float * bf, size_t sz) const -{ +Avx512Accelrator::dotProduct(const float * af, const float * bf, size_t sz) const noexcept { return avx::dotProductSelectAlignment<float, 64>(af, bf, sz); } double -Avx512Accelrator::dotProduct(const double * af, const double * bf, size_t sz) const -{ +Avx512Accelrator::dotProduct(const double * af, const double * bf, size_t sz) const noexcept { return avx::dotProductSelectAlignment<double, 64>(af, bf, sz); } size_t -Avx512Accelrator::populationCount(const uint64_t *a, size_t sz) const { +Avx512Accelrator::populationCount(const uint64_t *a, size_t sz) const noexcept { return helper::populationCount(a, sz); } double -Avx512Accelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const { +Avx512Accelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept { return helper::squaredEuclideanDistance(a, b, sz); } double -Avx512Accelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const { +Avx512Accelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept { return avx::euclideanDistanceSelectAlignment<float, 64>(a, b, sz); } double -Avx512Accelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const { +Avx512Accelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept { return avx::euclideanDistanceSelectAlignment<double, 64>(a, b, sz); } void -Avx512Accelrator::and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const { +Avx512Accelrator::and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept { helper::andChunks<64, 1>(offset, src, dest); } void -Avx512Accelrator::or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const { +Avx512Accelrator::or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept { helper::orChunks<64, 1>(offset, src, dest); } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h index 4989f72e698..3dc207e8ade 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avx512.h @@ -12,14 +12,14 @@ namespace vespalib::hwaccelrated { class Avx512Accelrator : public Avx2Accelrator { public: - float dotProduct(const float * a, const float * b, size_t sz) const override; - double dotProduct(const double * a, const double * b, size_t sz) const override; - size_t populationCount(const uint64_t *a, size_t sz) const override; - double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const override; - double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const override; - double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const override; - void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; - void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; + float dotProduct(const float * a, const float * b, size_t sz) const noexcept override; + double dotProduct(const double * a, const double * b, size_t sz) const noexcept override; + size_t populationCount(const uint64_t *a, size_t sz) const noexcept override; + double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; + double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; + double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; + void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; + void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/avxprivate.hpp b/vespalib/src/vespa/vespalib/hwaccelrated/avxprivate.hpp index 3bdbb7a81ff..e1cea280d0c 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/avxprivate.hpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/avxprivate.hpp @@ -8,12 +8,12 @@ namespace vespalib::hwaccelrated::avx { namespace { -inline bool validAlignment(const void * p, const size_t align) { +inline bool validAlignment(const void * p, const size_t align) noexcept { return (reinterpret_cast<uint64_t>(p) & (align-1)) == 0; } template <typename T, typename V> -T sumT(const V & v) { +T sumT(const V & v) noexcept { T sum(0); for (size_t i(0); i < (sizeof(V)/sizeof(T)); i++) { sum += v[i]; @@ -22,7 +22,7 @@ T sumT(const V & v) { } template <typename T, size_t C> -T sumR(const T * v) { +T sumR(const T * v) noexcept { if (C == 1) { return v[0]; } else if (C == 2) { @@ -33,10 +33,10 @@ T sumR(const T * v) { } template <typename T, size_t VLEN, unsigned AlignA, unsigned AlignB, size_t VectorsPerChunk> -static T computeDotProduct(const T * af, const T * bf, size_t sz) __attribute__((noinline)); +static T computeDotProduct(const T * af, const T * bf, size_t sz) noexcept __attribute__((noinline)); template <typename T, size_t VLEN, unsigned AlignA, unsigned AlignB, size_t VectorsPerChunk> -T computeDotProduct(const T * af, const T * bf, size_t sz) +T computeDotProduct(const T * af, const T * bf, size_t sz) noexcept { constexpr const size_t ChunkSize = VLEN*VectorsPerChunk/sizeof(T); typedef T V __attribute__ ((vector_size (VLEN))); @@ -65,10 +65,10 @@ T computeDotProduct(const T * af, const T * bf, size_t sz) } template <typename T, size_t VLEN, size_t VectorsPerChunk=4> -VESPA_DLL_LOCAL T dotProductSelectAlignment(const T * af, const T * bf, size_t sz); +VESPA_DLL_LOCAL T dotProductSelectAlignment(const T * af, const T * bf, size_t sz) noexcept; template <typename T, size_t VLEN, size_t VectorsPerChunk> -T dotProductSelectAlignment(const T * af, const T * bf, size_t sz) +T dotProductSelectAlignment(const T * af, const T * bf, size_t sz) noexcept { if (validAlignment(af, VLEN)) { if (validAlignment(bf, VLEN)) { @@ -87,7 +87,7 @@ T dotProductSelectAlignment(const T * af, const T * bf, size_t sz) template <typename T, unsigned VLEN, unsigned AlignA, unsigned AlignB> double -euclideanDistanceT(const T * af, const T * bf, size_t sz) +euclideanDistanceT(const T * af, const T * bf, size_t sz) noexcept { constexpr unsigned VectorsPerChunk = 4; constexpr unsigned ChunkSize = VLEN*VectorsPerChunk/sizeof(T); @@ -115,7 +115,7 @@ euclideanDistanceT(const T * af, const T * bf, size_t sz) } template <typename T, unsigned VLEN> -double euclideanDistanceSelectAlignment(const T * af, const T * bf, size_t sz) +double euclideanDistanceSelectAlignment(const T * af, const T * bf, size_t sz) noexcept { constexpr unsigned ALIGN = 32; if (validAlignment(af, ALIGN)) { diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp b/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp index 13946fa3398..dcc1189e35d 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp +++ b/vespalib/src/vespa/vespalib/hwaccelrated/generic.cpp @@ -10,7 +10,7 @@ namespace { template <typename ACCUM, typename T, size_t UNROLL> ACCUM -multiplyAdd(const T * a, const T * b, size_t sz) +multiplyAdd(const T * a, const T * b, size_t sz) noexcept { ACCUM partial[UNROLL]; for (size_t i(0); i < UNROLL; i++) { @@ -34,7 +34,7 @@ multiplyAdd(const T * a, const T * b, size_t sz) template <typename T, size_t UNROLL> double -squaredEuclideanDistanceT(const T * a, const T * b, size_t sz) +squaredEuclideanDistanceT(const T * a, const T * b, size_t sz) noexcept { T partial[UNROLL]; for (size_t i(0); i < UNROLL; i++) { @@ -60,7 +60,7 @@ squaredEuclideanDistanceT(const T * a, const T * b, size_t sz) template<size_t UNROLL, typename Operation> void -bitOperation(Operation operation, void * aOrg, const void * bOrg, size_t bytes) { +bitOperation(Operation operation, void * aOrg, const void * bOrg, size_t bytes) noexcept { const size_t sz(bytes/sizeof(uint64_t)); { @@ -87,59 +87,59 @@ bitOperation(Operation operation, void * aOrg, const void * bOrg, size_t bytes) } float -GenericAccelrator::dotProduct(const float * a, const float * b, size_t sz) const +GenericAccelrator::dotProduct(const float * a, const float * b, size_t sz) const noexcept { return cblas_sdot(sz, a, 1, b, 1); } double -GenericAccelrator::dotProduct(const double * a, const double * b, size_t sz) const +GenericAccelrator::dotProduct(const double * a, const double * b, size_t sz) const noexcept { return cblas_ddot(sz, a, 1, b, 1); } int64_t -GenericAccelrator::dotProduct(const int8_t * a, const int8_t * b, size_t sz) const +GenericAccelrator::dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept { return multiplyAdd<int64_t, int8_t, 8>(a, b, sz); } int64_t -GenericAccelrator::dotProduct(const int16_t * a, const int16_t * b, size_t sz) const +GenericAccelrator::dotProduct(const int16_t * a, const int16_t * b, size_t sz) const noexcept { return multiplyAdd<int64_t, int16_t, 8>(a, b, sz); } int64_t -GenericAccelrator::dotProduct(const int32_t * a, const int32_t * b, size_t sz) const +GenericAccelrator::dotProduct(const int32_t * a, const int32_t * b, size_t sz) const noexcept { return multiplyAdd<int64_t, int32_t, 8>(a, b, sz); } long long -GenericAccelrator::dotProduct(const int64_t * a, const int64_t * b, size_t sz) const +GenericAccelrator::dotProduct(const int64_t * a, const int64_t * b, size_t sz) const noexcept { return multiplyAdd<long long, int64_t, 8>(a, b, sz); } void -GenericAccelrator::orBit(void * aOrg, const void * bOrg, size_t bytes) const +GenericAccelrator::orBit(void * aOrg, const void * bOrg, size_t bytes) const noexcept { bitOperation<8>([](uint64_t a, uint64_t b) { return a | b; }, aOrg, bOrg, bytes); } void -GenericAccelrator::andBit(void * aOrg, const void * bOrg, size_t bytes) const +GenericAccelrator::andBit(void * aOrg, const void * bOrg, size_t bytes) const noexcept { bitOperation<8>([](uint64_t a, uint64_t b) { return a & b; }, aOrg, bOrg, bytes); } void -GenericAccelrator::andNotBit(void * aOrg, const void * bOrg, size_t bytes) const +GenericAccelrator::andNotBit(void * aOrg, const void * bOrg, size_t bytes) const noexcept { bitOperation<8>([](uint64_t a, uint64_t b) { return a & ~b; }, aOrg, bOrg, bytes); } void -GenericAccelrator::notBit(void * aOrg, size_t bytes) const +GenericAccelrator::notBit(void * aOrg, size_t bytes) const noexcept { auto a(static_cast<uint64_t *>(aOrg)); const size_t sz(bytes/sizeof(uint64_t)); @@ -153,32 +153,32 @@ GenericAccelrator::notBit(void * aOrg, size_t bytes) const } size_t -GenericAccelrator::populationCount(const uint64_t *a, size_t sz) const { +GenericAccelrator::populationCount(const uint64_t *a, size_t sz) const noexcept { return helper::populationCount(a, sz); } double -GenericAccelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const { +GenericAccelrator::squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept { return helper::squaredEuclideanDistance(a, b, sz); } double -GenericAccelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const { +GenericAccelrator::squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept { return squaredEuclideanDistanceT<float, 2>(a, b, sz); } double -GenericAccelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const { +GenericAccelrator::squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept { return squaredEuclideanDistanceT<double, 2>(a, b, sz); } void -GenericAccelrator::and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const { +GenericAccelrator::and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept { helper::andChunks<16, 4>(offset, src, dest); } void -GenericAccelrator::or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const { +GenericAccelrator::or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept { helper::orChunks<16,4>(offset, src, dest); } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/generic.h b/vespalib/src/vespa/vespalib/hwaccelrated/generic.h index 315e807da07..13c347df80f 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/generic.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/generic.h @@ -12,22 +12,22 @@ namespace vespalib::hwaccelrated { class GenericAccelrator : public IAccelrated { public: - float dotProduct(const float * a, const float * b, size_t sz) const override; - double dotProduct(const double * a, const double * b, size_t sz) const override; - int64_t dotProduct(const int8_t * a, const int8_t * b, size_t sz) const override; - int64_t dotProduct(const int16_t * a, const int16_t * b, size_t sz) const override; - int64_t dotProduct(const int32_t * a, const int32_t * b, size_t sz) const override; - long long dotProduct(const int64_t * a, const int64_t * b, size_t sz) const override; - void orBit(void * a, const void * b, size_t bytes) const override; - void andBit(void * a, const void * b, size_t bytes) const override; - void andNotBit(void * a, const void * b, size_t bytes) const override; - void notBit(void * a, size_t bytes) const override; - size_t populationCount(const uint64_t *a, size_t sz) const override; - double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const override; - double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const override; - double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const override; - void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; - void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const override; + float dotProduct(const float * a, const float * b, size_t sz) const noexcept override; + double dotProduct(const double * a, const double * b, size_t sz) const noexcept override; + int64_t dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; + int64_t dotProduct(const int16_t * a, const int16_t * b, size_t sz) const noexcept override; + int64_t dotProduct(const int32_t * a, const int32_t * b, size_t sz) const noexcept override; + long long dotProduct(const int64_t * a, const int64_t * b, size_t sz) const noexcept override; + void orBit(void * a, const void * b, size_t bytes) const noexcept override; + void andBit(void * a, const void * b, size_t bytes) const noexcept override; + void andNotBit(void * a, const void * b, size_t bytes) const noexcept override; + void notBit(void * a, size_t bytes) const noexcept override; + size_t populationCount(const uint64_t *a, size_t sz) const noexcept override; + double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept override; + double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept override; + double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept override; + void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; + void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept override; }; } diff --git a/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h b/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h index 73740858a41..c9b1f7cd45c 100644 --- a/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h +++ b/vespalib/src/vespa/vespalib/hwaccelrated/iaccelrated.h @@ -17,24 +17,24 @@ class IAccelrated public: virtual ~IAccelrated() = default; using UP = std::unique_ptr<IAccelrated>; - virtual float dotProduct(const float * a, const float * b, size_t sz) const = 0; - virtual double dotProduct(const double * a, const double * b, size_t sz) const = 0; - virtual int64_t dotProduct(const int8_t * a, const int8_t * b, size_t sz) const = 0; - virtual int64_t dotProduct(const int16_t * a, const int16_t * b, size_t sz) const = 0; - virtual int64_t dotProduct(const int32_t * a, const int32_t * b, size_t sz) const = 0; - virtual long long dotProduct(const int64_t * a, const int64_t * b, size_t sz) const = 0; - virtual void orBit(void * a, const void * b, size_t bytes) const = 0; - virtual void andBit(void * a, const void * b, size_t bytes) const = 0; - virtual void andNotBit(void * a, const void * b, size_t bytes) const = 0; - virtual void notBit(void * a, size_t bytes) const = 0; - virtual size_t populationCount(const uint64_t *a, size_t sz) const = 0; - virtual double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const = 0; - virtual double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const = 0; - virtual double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const = 0; + virtual float dotProduct(const float * a, const float * b, size_t sz) const noexcept = 0; + virtual double dotProduct(const double * a, const double * b, size_t sz) const noexcept = 0; + virtual int64_t dotProduct(const int8_t * a, const int8_t * b, size_t sz) const noexcept = 0; + virtual int64_t dotProduct(const int16_t * a, const int16_t * b, size_t sz) const noexcept = 0; + virtual int64_t dotProduct(const int32_t * a, const int32_t * b, size_t sz) const noexcept = 0; + virtual long long dotProduct(const int64_t * a, const int64_t * b, size_t sz) const noexcept = 0; + virtual void orBit(void * a, const void * b, size_t bytes) const noexcept = 0; + virtual void andBit(void * a, const void * b, size_t bytes) const noexcept = 0; + virtual void andNotBit(void * a, const void * b, size_t bytes) const noexcept = 0; + virtual void notBit(void * a, size_t bytes) const noexcept = 0; + virtual size_t populationCount(const uint64_t *a, size_t sz) const noexcept = 0; + virtual double squaredEuclideanDistance(const int8_t * a, const int8_t * b, size_t sz) const noexcept = 0; + virtual double squaredEuclideanDistance(const float * a, const float * b, size_t sz) const noexcept = 0; + virtual double squaredEuclideanDistance(const double * a, const double * b, size_t sz) const noexcept = 0; // AND 64 bytes from multiple, optionally inverted sources - virtual void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const = 0; + virtual void and64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept = 0; // OR 64 bytes from multiple, optionally inverted sources - virtual void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const = 0; + virtual void or64(size_t offset, const std::vector<std::pair<const void *, bool>> &src, void *dest) const noexcept = 0; static const IAccelrated & getAccelerator() __attribute__((noinline)); }; diff --git a/vespalib/src/vespa/vespalib/io/fileutil.cpp b/vespalib/src/vespa/vespalib/io/fileutil.cpp index ff39e56f000..1ff2d3434f7 100644 --- a/vespalib/src/vespa/vespalib/io/fileutil.cpp +++ b/vespalib/src/vespa/vespalib/io/fileutil.cpp @@ -433,7 +433,7 @@ bool File::unlink() { close(); - return vespalib::unlink(_filename); + return std::filesystem::remove(std::filesystem::path(_filename)); } string @@ -449,34 +449,6 @@ getCurrentDirectory() } void -symlink(const string & oldPath, const string & newPath) -{ - if (::symlink(oldPath.c_str(), newPath.c_str())) { - asciistream ss; - const int err = errno; - ss << "symlink(" << oldPath << ", " << newPath - << "): Failed, errno(" << err << "): " - << safeStrerror(err); - throw IoException(ss.str(), IoException::getErrorType(err), VESPA_STRLOC); - } -} - -string -readLink(const string & path) -{ - char buf[256]; - ssize_t bytes(::readlink(path.c_str(), buf, sizeof(buf))); - if (bytes < 0) { - asciistream ss; - const int err = errno; - ss << "readlink(" << path << "): Failed, errno(" << err << "): " - << safeStrerror(err); - throw IoException(ss.str(), IoException::getErrorType(err), VESPA_STRLOC); - } - return string(buf, bytes); -} - -void chdir(const string & directory) { if (::chdir(directory.c_str()) != 0) { @@ -507,119 +479,12 @@ fileExists(const string & path) { return (stat(path).get() != 0); } -bool -unlink(const string & filename) -{ - if (::unlink(filename.c_str()) != 0) { - if (errno == ENOENT) { - return false; - } - asciistream ost; - ost << "unlink(" << filename << "): Failed, errno(" << errno << "): " - << safeStrerror(errno); - throw IoException(ost.str(), IoException::getErrorType(errno), VESPA_STRLOC); - } - LOG(debug, "unlink(%s): File deleted.", filename.c_str()); - return true; -} - -bool -rename(const string & frompath, const string & topath, - bool copyDeleteBetweenFilesystems, bool createTargetDirectoryIfMissing) -{ - LOG(spam, "rename(%s, %s): Renaming file%s.", - frompath.c_str(), topath.c_str(), - createTargetDirectoryIfMissing - ? " recursively creating target directory if missing" : ""); - if (::rename(frompath.c_str(), topath.c_str()) != 0) { - if (errno == ENOENT) { - if (!fileExists(frompath)) return false; - if (createTargetDirectoryIfMissing) { - string::size_type pos = topath.rfind('/'); - if (pos != string::npos) { - string path(topath.substr(0, pos)); - std::filesystem::create_directories(std::filesystem::path(path)); - LOG(debug, "rename(%s, %s): Created target directory. Calling recursively.", - frompath.c_str(), topath.c_str()); - return rename(frompath, topath, copyDeleteBetweenFilesystems, false); - } - } else { - asciistream ost; - ost << "rename(" << frompath << ", " << topath - << (copyDeleteBetweenFilesystems ? ", revert to copy" : "") - << (createTargetDirectoryIfMissing - ? ", create missing target" : "") - << "): Failed, target path does not exist."; - throw IoException(ost.str(), IoException::NOT_FOUND, - VESPA_STRLOC); - } - } else if (errno == EXDEV && copyDeleteBetweenFilesystems) { - if (!fileExists(frompath)) { - LOG(debug, "rename(%s, %s): Renaming non-existing file across " - "filesystems returned EXDEV rather than ENOENT.", - frompath.c_str(), topath.c_str()); - return false; - } - LOG(debug, "rename(%s, %s): Cannot rename across filesystems. " - "Copying and deleting instead.", - frompath.c_str(), topath.c_str()); - copy(frompath, topath, createTargetDirectoryIfMissing); - unlink(frompath); - return true; - } - asciistream ost; - ost << "rename(" << frompath << ", " << topath - << (copyDeleteBetweenFilesystems ? ", revert to copy" : "") - << (createTargetDirectoryIfMissing ? ", create missing target" : "") - << "): Failed, errno(" << errno << "): " << safeStrerror(errno); - throw IoException(ost.str(), IoException::getErrorType(errno), - VESPA_STRLOC); - } - LOG(debug, "rename(%s, %s): Renamed.", frompath.c_str(), topath.c_str()); - return true; -} - namespace { - uint32_t bufferSize = 1_Mi; uint32_t diskAlignmentSize = 4_Ki; } -void -copy(const string & frompath, const string & topath, - bool createTargetDirectoryIfMissing, bool useDirectIO) -{ - // Get aligned buffer, so it works with direct IO - LOG(spam, "copy(%s, %s): Copying file%s.", - frompath.c_str(), topath.c_str(), - createTargetDirectoryIfMissing - ? " recursively creating target directory if missing" : ""); - MallocAutoPtr buffer(getAlignedBuffer(bufferSize)); - - File source(frompath); - File target(topath); - source.open(File::READONLY | (useDirectIO ? File::DIRECTIO : 0)); - size_t sourceSize = source.getFileSize(); - if (useDirectIO && sourceSize % diskAlignmentSize != 0) { - LOG(warning, "copy(%s, %s): Cannot use direct IO to write new file, " - "as source file has size %zu, which is not " - "dividable by the disk alignment size of %u.", - frompath.c_str(), topath.c_str(), sourceSize, diskAlignmentSize); - useDirectIO = false; - } - target.open(File::CREATE | File::TRUNC | (useDirectIO ? File::DIRECTIO : 0), - createTargetDirectoryIfMissing); - off_t offset = 0; - for (;;) { - size_t bytesRead = source.read(buffer.get(), bufferSize, offset); - target.write(buffer.get(), bytesRead, offset); - if (bytesRead < bufferSize) break; - offset += bytesRead; - } - LOG(debug, "copy(%s, %s): Completed.", frompath.c_str(), topath.c_str()); -} - DirectoryList listDirectory(const string & path) { diff --git a/vespalib/src/vespa/vespalib/io/fileutil.h b/vespalib/src/vespa/vespalib/io/fileutil.h index 7d9e51532d0..6214bf3e60d 100644 --- a/vespalib/src/vespa/vespalib/io/fileutil.h +++ b/vespalib/src/vespa/vespalib/io/fileutil.h @@ -300,70 +300,6 @@ extern inline bool isSymLink(const vespalib::string & path) { } /** - * Creates a symbolic link named newPath which contains the string oldPath. - * - * IMPORTANT: from the spec: - * "Symbolic links are interpreted at run time as if the contents of the link had - * been substituted into the path being followed to find a file or directory." - * - * This means oldPath is _relative_ to the directory in which newPath resides! - * - * @param oldPath Target of symbolic link. - * @param newPath Relative link to be created. See above note for semantics. - * @throw IoException if we fail to create the symlink. - */ -extern void symlink(const vespalib::string & oldPath, - const vespalib::string & newPath); - -/** - * Read and return the contents of symbolic link at the given path. - * - * @param path Path to symbolic link. - * @return Contents of symbolic link. - * @throw IoException if we cannot read the link. - */ -extern vespalib::string readLink(const vespalib::string & path); - -/** - * Remove the given file. - * - * @param filename name of file. - * @return True if file was removed, false if it did not exist. - * @throw IoException If we failed to unlink the file. - */ -extern bool unlink(const vespalib::string & filename); - -/** - * Rename the file at frompath to topath. - * - * @param frompath old name of file. - * @param topath new name of file. - * - * @param copyDeleteBetweenFilesystems whether a copy-and-delete - * operation should be performed if rename crosses a file system - * boundary, or not. - * - * @param createTargetDirectoryIfMissing whether the target directory - * should be created if it's missing, or not. - * - * @throw IoException If we failed to rename the file. - * @throw std::filesystem::filesystem_error If we failed to create a target directory - * @return True if file was renamed, false if frompath did not exist. - */ -extern bool rename(const vespalib::string & frompath, - const vespalib::string & topath, - bool copyDeleteBetweenFilesystems = true, - bool createTargetDirectoryIfMissing = false); - -/** - * Copies a file to a destination using Direct IO. - */ -extern void copy(const vespalib::string & frompath, - const vespalib::string & topath, - bool createTargetDirectoryIfMissing = false, - bool useDirectIO = true); - -/** * List the contents of the given directory. */ using DirectoryList = std::vector<vespalib::string>; |